Softmax

The softmax operator, defined by

$$ \sigma: \mathbb{R}^K \rightarrow (0,1)^K, $$

where $K > 1$, takes a tuple

$$ \mathbf{z} = (z_1, \ldots, z_K) \in \mathbb{R}^K $$

and computes each component of vector $\sigma(\mathbf{z}) \in (0,1)^K$ with

$$ \sigma(\mathbf{z})i = \frac{e^{z_i}}{\sum{j=1}^{K} e^{z_j}}. $$

We refer to $N = \sum e^{z_j}$ as the normalising factor. After applying softmax, each component will be in the interval (0,1), and the components will add up to 1, so that they can be interpreted as probabilities.

To avoid numerical instability, a variant called "safe softmax" replaces the original softmax. It subtracts the maximum value of the input vector from each element before applying the exponential. This simple trick doesn’t change the output, but it significantly improves numerical stability.

$$ \sigma(\mathbf{z})i = \frac{e^{z_i - z{max}}}{\sum_{j=1}^{K} e^{z_j - z_{max}}}. $$

The base used in softmax is irrelevant. $e$ is used for convenience, but other bases such as 2 and 3 can also be used, since softmax is translation-invariant (adding the same constant to every logit does nothing). That is, for any base $b \geq 2$, $\sigma_b(\mathbf{z})i = \frac{e^{(\ln b)\,z_i}}{\sum{j} e^{(\ln b)\,z_j}} = \sigma\left( (\ln b)\,\mathbf{z} \right)i$*. To maximise the domain, we pick $b = 2$. This is what we want to prove: $\sigma(\mathbf{z})i = \frac{2^{z_i - z_{max}}}{\sum (2^{z_j - z_{max}})}.$*

The normalising factor $N = \sum e^{z_j}$ needs to be known before any individual $\sigma(\mathbf{z})_i$ can be computed. This means that at least a two-step instruction is required.

Implementation

The sequence performs 14 virtual steps:

Step
Operation
Description

1

VirtualConst(0)

Initialize zero tensor

2

Gte

Compute ge0 = (z >= 0)

3

Sub

Compute neg_z = -z

4

Select

Compute abs_z = select(ge0, z, -z)

5

VirtualPow2

Compute `c = 2^{

6

VirtualConst(Q)

Constant quantization scalar

7

Div

Compute d_q_over_c = Q / c

8

Mul

Compute d_q_times_c = Q * c

9

Select

Select d = (z >= 0 ? Q * c : Q / c)

10

Sum

ReduceSum(d) to get total

11

Broadcast

Broadcast the sum

12

Mul

Compute f = Q * d

13

Div

Normalize g = f / e_sum

14

VirtualMove

Write final result to output tensor

Last updated