Softmax

The softmax operator, defined by

$$ \sigma: \mathbb{R}^K \rightarrow (0,1)^K, $$

where $K > 1$, takes a tuple

$$ \mathbf{z} = (z_1, \ldots, z_K) \in \mathbb{R}^K $$

and computes each component of vector $\sigma(\mathbf{z}) \in (0,1)^K$ with

$$ \sigma(\mathbf{z})i = \frac{e^{z_i}}{\sum{j=1}^{K} e^{z_j}}. $$

We refer to $N = \sum e^{z_j}$ as the normalising factor. After applying softmax, each component will be in the interval (0,1), and the components will add up to 1, so that they can be interpreted as probabilities.

To avoid numerical instability, a variant called "safe softmax" replaces the original softmax. It subtracts the maximum value of the input vector from each element before applying the exponential. This simple trick doesn’t change the output, but it significantly improves numerical stability.

$$ \sigma(\mathbf{z})i = \frac{e^{z_i - z{max}}}{\sum_{j=1}^{K} e^{z_j - z_{max}}}. $$

The base used in softmax is irrelevant. $e$ is used for convenience, but other bases such as 2 and 3 can also be used, since softmax is translation-invariant (adding the same constant to every logit does nothing). That is, for any base $b \geq 2$, $\sigma_b(\mathbf{z})i = \frac{e^{(\ln b)\,z_i}}{\sum{j} e^{(\ln b)\,z_j}} = \sigma\left( (\ln b)\,\mathbf{z} \right)i$*. To maximise the domain, we pick $b = 2$. This is what we want to prove: $\sigma(\mathbf{z})i = \frac{2^{z_i} - 2^{z{max}}}{\sum (2^{z_j} - 2^{z{max}})}.$*

Main constraints

  • Both domain and range of this operator are vectors. Unless we choose to restrict vectors to a small length, a single lookup won't suffice even for quantised inputs.

  • The normalising factor $N = \sum e^{z_j}$ needs to be known before any individual $\sigma(\mathbf{z})_i$ can be computed. This means that at least a two-step instruction is required.

  • The sum $H = \sum_{j\in[n]} (e^{z_j} - e^{z_{max}})$ hints the possibility of a sum-check. However, to prove that $H = \sum_{x\in\{0,1\}^{\log n}} \widetilde{g}(x)$ for $\widetilde{g}(X) = \sum_{j\in\{0,1\}^{\log n}} (e^{z[j]} - e^{z_{max}}) \cdot \tilde{eq}(X, z[j])$, we need to show that $\vec{z}$ are related to $\widetilde{g}$ and this requires introducing a MLE $\widetilde{z}$ of $z$. Since the sum-check involves an evaluation of $\widetilde{z}$ on an arbitrary field element, the exponential $e^{\widetilde{z}(r)}$ can't be computed.

  • We are constrained by the types of the lookup operations, in particular u32 and u64. If $b=2$, any element $z_i$ bigger than $64$ will overflow (i.e., $2^{z_i}$ must be smaller than u64::MAX).

  • $2^{z_i} / N$ will be $\sim 0$, for any $z_i$ such that $z_{max} - z_i > k$, where $k$ is a very small number (like $8$).

Implementation

I suggest a virtual instruction consisting on the following steps:

  1. Take the maximum element from $[z_1, ..., z_n]$. Call this element $z_{max}$. We can recursively apply the max instruction as $z_{max} = max(...(max(max(z_0, z_1), z_2),..., z_n)$.

  2. Multiply each element by $63$ and divide it by $z_{max}$, that is, $z'i = z_i * 63 / z{max}$ so that the maximum element is now to $63$ and thus no element $2^{z'_i}$ overflows. The sum $\sum 2^{z'_i}$ must not overflow either. This is a more restrictive form of quantisation.

  3. Compute the "power-of-two" lookup table for each $z'_i$ (i.e., $2^{z'_i}$). This will return a vector $\vec{a} = [2^{z'_1},..., 2^{z'_n}]$.

  4. Run the sum-check to prove that $\sum a_i = N$, where $N$ is the normalisation factor.

  5. Run a division lookup $a_i / N$ for each element in $\vec{a}$. Since the value is in (0,1), we multiply by $2^8$ for quantization.

  6. Concatenate the results.

Last updated