Softmax
The softmax operator, defined by
$$ \sigma: \mathbb{R}^K \rightarrow (0,1)^K, $$
where $K > 1$, takes a tuple
$$ \mathbf{z} = (z_1, \ldots, z_K) \in \mathbb{R}^K $$
and computes each component of vector $\sigma(\mathbf{z}) \in (0,1)^K$ with
$$ \sigma(\mathbf{z})i = \frac{e^{z_i}}{\sum{j=1}^{K} e^{z_j}}. $$
We refer to $N = \sum e^{z_j}$ as the normalising factor. After applying softmax, each component will be in the interval (0,1), and the components will add up to 1, so that they can be interpreted as probabilities.
To avoid numerical instability, a variant called "safe softmax" replaces the original softmax. It subtracts the maximum value of the input vector from each element before applying the exponential. This simple trick doesn’t change the output, but it significantly improves numerical stability.
$$ \sigma(\mathbf{z})i = \frac{e^{z_i - z{max}}}{\sum_{j=1}^{K} e^{z_j - z_{max}}}. $$
The base used in softmax is irrelevant. $e$ is used for convenience, but other bases such as 2 and 3 can also be used, since softmax is translation-invariant (adding the same constant to every logit does nothing). That is, for any base $b \geq 2$, $\sigma_b(\mathbf{z})i = \frac{e^{(\ln b)\,z_i}}{\sum{j} e^{(\ln b)\,z_j}} = \sigma\left( (\ln b)\,\mathbf{z} \right)i$*. To maximise the domain, we pick $b = 2$. This is what we want to prove: $\sigma(\mathbf{z})i = \frac{2^{z_i - z_{max}}}{\sum (2^{z_j - z_{max}})}.$*
The normalising factor $N = \sum e^{z_j}$ needs to be known before any individual $\sigma(\mathbf{z})_i$ can be computed. This means that at least a two-step instruction is required.
Implementation
The sequence performs 14 virtual steps:
1
VirtualConst(0)
Initialize zero tensor
2
Gte
Compute ge0 = (z >= 0)
3
Sub
Compute neg_z = -z
4
Select
Compute abs_z = select(ge0, z, -z)
5
VirtualPow2
Compute `c = 2^{
6
VirtualConst(Q)
Constant quantization scalar
7
Div
Compute d_q_over_c = Q / c
8
Mul
Compute d_q_times_c = Q * c
9
Select
Select d = (z >= 0 ? Q * c : Q / c)
10
Sum
ReduceSum(d) to get total
11
Broadcast
Broadcast the sum
12
Mul
Compute f = Q * d
13
Div
Normalize g = f / e_sum
14
VirtualMove
Write final result to output tensor
Last updated