Faux Attention

One might contend that attention, as the term is used in transformers, is already misleading and an abuse of the term; calling this faux attention is even more egregious and bordering on meaningless. It makes for a succinct title, however.

Approximation the first

A quadratic approximation of exp ( x ) for nonnegative x near zero is

f + ( x ) = 1 + x + x 2 / 2 .

Noting that exp ( x ) = 1 / exp ( x ) gives us

f ( x ) = [ 1 / f + ( x ) when x < 0 otherwise f + ( x ) ] = [ 1 / ( 1 x + x 2 / 2 ) when x < 0 otherwise 1 + x + x 2 / 2 ] .

Noting that e x = ( e x / a ) a gives us

g ( x ) = f ( x / 4 ) 4 = ( f ( x / 4 ) 2 ) 2 .

This can be seen here.

A straightforward implementation parallelizes well:

pub fn vexpps(x : &[f32; 16], y : &mut [f32; 16])
{
  for i in 0..16 {
    let rr = x[i] * 0.25;
    let sq = rr * rr * 0.5;
    let pos = (1.0 + rr) + sq;
    let neg = (1.0 - rr) + sq;
    let fst = if x[i] < 0.0 { neg.recip() } else { pos };
    let snd = fst * fst;
    y[i] = snd * snd;
  }
}
; rustc --edition=2024 --codegen opt-level=3 --codegen target-cpu=znver5
.CONST0: .long 0x3e800000
.CONST1: .long 0x3f000000
.CONST2: .long 0x3f800000
vexpps:
  vmovups zmm0, zmmword ptr [rdi]
  vbroadcastss  zmm2, dword ptr [rip + .CONST2]
  vxorps  xmm5, xmm5, xmm5
  vmulps  zmm1, zmm0, dword ptr [rip + .CONST0]{1to16}
  vcmpltps  k1, zmm0, zmm5
  vmulps  zmm3, zmm1, zmm1
  vmulps  zmm3, zmm3, dword ptr [rip + .CONST1]{1to16}
  vaddps  zmm4, zmm1, zmm2
  vsubps  zmm0, zmm2, zmm1
  vaddps  zmm4, zmm4, zmm3
  vaddps  zmm0, zmm0, zmm3
  vdivps  zmm4  {k1}, zmm2, zmm0
  vmulps  zmm0, zmm4, zmm4
  vmulps  zmm0, zmm0, zmm0
  vmovups zmmword ptr [rsi], zmm0
  vzeroupper
  ret

For comparison, I believe the implementation of exp used in Rust can be seen here, which is not parallelized by the compiler.

The approximation has middling accuracy:

 input        approx       actual      abs error    rel error
−7.000000    +0.002977    +0.000912    +0.002065    +2.264217
−6.000000    +0.005791    +0.002479    +0.003312    +1.336334
−5.000000    +0.011844    +0.006738    +0.005106    +0.757864
−4.000000    +0.025600    +0.018316    +0.007284    +0.397713
−3.000000    +0.058742    +0.049787    +0.008955    +0.179859
−2.000000    +0.143412    +0.135335    +0.008077    +0.059682
−1.000000    +0.371077    +0.367879    +0.003198    +0.008693
 0.000000    +1.000000    +1.000000     0.000000     0.000000
+1.000000    +2.694856    +2.718282    −0.023426    −0.008618
+2.000000    +6.972900    +7.389056    −0.416156    −0.056321
+3.000000   +17.023682   +20.085537    −3.061855    −0.152441
+4.000000   +39.062500   +54.598148   −15.535648    −0.284545
+5.000000   +84.428101  +148.413162   −63.985062    −0.431128
+6.000000  +172.676025  +403.428802  −230.752777    −0.571979
+7.000000  +335.955963 +1096.633179  −760.677246    −0.693648

Approximation the second

Adding a third-order term and reducing the range further, seen here,

f + ( x ) = 1 + x + x 2 / 2 + x 3 / 6 g ( x ) = f ( x / 8 ) 8 = ( ( f ( x / 8 ) 2 ) 2 ) 2

improves this at the cost of three additional multiplications:

 input        approx       actual      abs error    rel error
−7.000000    +0.001006    +0.000912    +0.000095    +0.103716
−6.000000    +0.002628    +0.002479    +0.000149    +0.060299
−5.000000    +0.006951    +0.006738    +0.000213    +0.031565
−4.000000    +0.018574    +0.018316    +0.000259    +0.014124
−3.000000    +0.050031    +0.049787    +0.000244    +0.004906
−2.000000    +0.135480    +0.135335    +0.000145    +0.001068
−1.000000    +0.367906    +0.367879    +0.000027    +0.000073
 0.000000    +1.000000    +1.000000     0.000000     0.000000
+1.000000    +2.718082    +2.718282    −0.000200    −0.000074
+2.000000    +7.381175    +7.389056    −0.007882    −0.001067
+3.000000   +19.987476   +20.085537    −0.098061    −0.004882
+4.000000   +53.837746   +54.598148    −0.760403    −0.013927
+5.000000  +143.871826  +148.413162    −4.541336    −0.030599
+6.000000  +380.485840  +403.428802   −22.942963    −0.056870
+7.000000  +993.582458 +1096.633179  −103.050720    −0.093970

Approximation the third

However, we can do better if we are willing to abuse the format of ieee 754 single-precision floating-point numbers. The following is from A Fast, Compact Approximation of the Exponential Function by Nicol Schraudolph, which can be read here. My thanks to dear Cosmo, who found this and implemented it.

pub fn vexpps(x : &[f32; 16], y : &mut [f32; 16])
{
  const M : u32 = f32::MANTISSA_DIGITS - 1;
  const A : f32 = (1 << M) as f32 / std::f32::consts::LN_2;
  const B : f32 = (127u32 << M) as f32;

  for i in 0..16 {
    let aff = x[i].mul_add(A, B);
    let rnd = unsafe { aff.to_int_unchecked::<u32>() };
    y[i] = f32::from_bits(rnd);
  }
}
; rustc --edition=2024 --codegen opt-level=3 --codegen target-cpu=znver5
.CONST0: .long 0x4b38aa3b
.CONST1: .long 0x4e7e0000
vexpps:
  vmovups      zmm1, zmmword ptr [rdi]
  vbroadcastss zmm0, dword ptr [rip + .CONST0]
  vfmadd213ps  zmm0, zmm1, dword ptr [rip + .CONST1]{1to16}
  vcvttps2udq  zmm0, zmm0
  vmovups      zmmword ptr [rsi], zmm0
  vzeroupper
  ret
 input        approx       actual      abs error    rel error
−7.000000    +0.000928    +0.000912    +0.000016    +0.017994
−6.000000    +0.002625    +0.002479    +0.000146    +0.058864
−5.000000    +0.006979    +0.006738    +0.000241    +0.035716
−4.000000    +0.019207    +0.018316    +0.000891    +0.048641
−3.000000    +0.052247    +0.049787    +0.002460    +0.049415
−2.000000    +0.139326    +0.135335    +0.003991    +0.029488
−1.000000    +0.389326    +0.367879    +0.021447    +0.058298
 0.000000    +1.000000    +1.000000     0.000000     0.000000
+1.000000    +2.885376    +2.718282    +0.167094    +0.061471
+2.000000    +7.541565    +7.389056    +0.152509    +0.020640
+3.000000   +21.249268   +20.085537    +1.163731    +0.057939
+4.000000   +56.665039   +54.598148    +2.066891    +0.037856
+5.000000  +155.324219  +148.413162    +6.911057    +0.046566
+6.000000  +423.980469  +403.428802   +20.551666    +0.050942
+7.000000 +1125.234375 +1096.633179   +28.601196    +0.026081
››› to-do ‹‹‹
write an intuitive explanation for why it is a piecewise linear function

It is effectively a piecewise linear function, as can be seen here, and therefore its derivative is a piecewise step function,

1 log 2 × exp ( floor ( x / log 2 ) × log 2 ) ,

as can be seen here.

Because this approximation is a piecewise linear function, the error is not monotonic on either side of zero, and there are regions where even the quadratic approximation has smaller error. The first two approximations are usually better than this approximation near zero, but this behaves well over the entire domain.

Approximation the fourth

We can improve the accuracy of this approximation by again making use of the fact that exp ( x ) = 1 / exp ( x ) . If f is the piecewise linear function, then

g ( x ) = ( f ( x ) + 1 / f ( x ) ) / 2 .

This can be optimized to the following. My thanks to Jonathan Hallström, who made this observation and implemented it. (The use of intrinsics is unfortunately necessary to coax the compiler to emit vrcp14ps.)

pub fn vexpps(x : &[f32; 16], y : &mut [f32; 16])
{
  unsafe {
    use std::arch::x86_64::{                 // I think you mean “std::arch::amd64”
      __m512              as f32x16,
      _mm512_set1_ps      as broadcast,
      _mm512_fmadd_ps     as mul_add,
      _mm512_fnmadd_ps    as neg_mul_add,
      _mm512_cvtps_epi32  as convert,
      _mm512_castsi512_ps as transmute,
      _mm512_rcp14_ps     as recip,
      _mm512_add_ps       as add
    };
    let A  = broadcast(0x00800000 as f32 / std::f32::consts::LN_2); // 2²³ ÷ log 2
    let B0 = broadcast(0x3f000000 as f32);                      // 2²³ × (127 − 1)
    let B1 = broadcast(0x40000000 as f32);                      // 2²³ × (127 + 1)

    let xs = *(x.as_ptr() as *const f32x16);

    let aff0 = mul_add(xs, A, B0);
    let rnd0 = transmute(convert(aff0));

    let aff1 = neg_mul_add(xs, A, B1);
    let rnd1 = transmute(convert(aff1));
    let inv1 = recip(rnd1);

    let avg = add(rnd0, inv1);
    *y = *(&raw const avg as *const [f32; 16]);
  }
}
; rustc --edition=2024 --codegen opt-level=3 --codegen target-cpu=znver5
.CONST0: .long 0x4b38aa3b
.CONST1: .long 0x4e7c0000
.CONST2: .long 0x4e800000
vexpps:
  vmovaps      zmm0, zmmword ptr [rdi]
  vbroadcastss zmm1, dword ptr [rip + .CONST0]
  vbroadcastss zmm2, dword ptr [rip + .CONST1]
  vfmadd231ps  zmm2, zmm0, zmm1
  vfnmadd213ps zmm1, zmm0, dword ptr [rip + .CONST2]{1to16}
  vcvtps2dq    zmm2, zmm2
  vcvtps2dq    zmm0, zmm1
  vrcp14ps     zmm0, zmm0
  vaddps       zmm0, zmm0, zmm2
  vmovups      zmmword ptr [rsi], zmm0
  vzeroupper
  ret
 input        approx       actual      abs error    rel error
−7.000000    +0.000909    +0.000912    −0.000003    −0.003707
−6.000000    +0.002492    +0.002479    +0.000013    +0.005186
−5.000000    +0.006708    +0.006738    −0.000030    −0.004385
−4.000000    +0.018427    +0.018316    +0.000111    +0.006081
−3.000000    +0.049653    +0.049787    −0.000134    −0.002685
−2.000000    +0.135962    +0.135335    +0.000627    +0.004634
−1.000000    +0.367951    +0.367879    +0.000072    +0.000196
 0.000000    +1.000000    +1.000000     0.000000     0.000000
+1.000000    +2.726982    +2.718282    +0.008700    +0.003201
+2.000000    +7.359528    +7.389056    −0.029529    −0.003996
+3.000000   +20.194458   +20.085537    +0.108921    +0.005423
+4.000000   +54.365723   +54.598148    −0.232426    −0.004257
+5.000000  +149.310547  +148.413162    +0.897385    +0.006047
+6.000000  +402.488281  +403.428802    −0.940521    −0.002331
+7.000000 +1101.234375 +1096.633179    +4.601196    +0.004196
››› to-do ‹‹‹
write a proof of C ¹ smoothness

Adjusting the constants can reduce the error a skosh further, but zero then no longer maps to one.

The Proposition

Won’t you propagate our gradients with me,
my love?

I was interested in finding an approximation of the exponential function so that I could perform interpolated arg max (“softmax”) without unreasonable overhead. I had in mind the following architecture.

The typical nnue begins with two vectors, i w and i b , of ones and zeroes which encode the positions of white’s and black’s pieces. There are then two sets of weights, the first-layer matrixes W 1 s and W 1 p , that lead to the vectors σ 1 w = W 1 s i w + W 1 p i b and σ 1 b = W 1 p i w + W 1 s i b . The s and p stand for same side and opposite side.

These are relabelled as σ 1 m = σ 1 w and σ 1 r = σ 1 b if white is the side to move and black is the side in repose or vice versa, as σ 1 m = σ 1 b and σ 1 r = σ 1 w , if black is the side to move and white is the side in repose.

Then an activation function f 1 is applied elementwise to yield the vectors a 1 m = f 1 ( σ 1 m ) and a 1 r = f 1 ( σ 1 r ) . These are then reduced by pairwise multi­pli­cation: suppose the number of elements of a 1 m or a 1 r is n 1 ; then x 1 m = a 1 m [ 1 , . . . , n 1 / 2 ] a 1 m [ n 1 / 2 + 1 , . . . , n 1 ] and x 1 r = a 1 r [ 1 , . . . , n 1 / 2 ] a 1 r [ n 1 / 2 + 1 , . . . , n 1 ] . We then concatenate these to form x 1 = x 1 m x 1 r .

Now the idea: there are two sets of weights, the second-layer matrixes W 2 v and W 2 k , that lead to the vectors σ 2 v = W 2 v x 1 and σ 2 k = W 2 k x 1 of equal dimension.

Then exponentiation is performed elementwise on σ 2 k to yield a 2 k = exp ( σ 2 k ) and this is then 𝓁 1 -normalized so that we obtain k 2 = a 2 k / a 2 k 1 or k 2 = s a 2 k where

s = ( i a 2 k [ i ] ) 1

since each element of a 2 k is positive. (This is the interpolated arg max.)

We then perform pairwise multiplication and obtain x 2 = σ 2 v k 2 .

Finally the elements of x 2 are summed (note that this can indeed simply be a sum rather than a weighted sum or dot product) and the output of the network is then

σ 3 = i x 2 [ i ] .

Here the interpolated arg max and pairwise multiplication function as “soft output heads”. This was Cosmo’s observation, although she used the common term “output buckets”, which I simply refuse to adopt. But one might instead insert additional layers after x 1 x 2 before the output.

››› to-do ‹‹‹
add a diagram of the architecture described

One might also imagine applying elementwise an activation function to σ 2 v before pairwise multiplication: a 2 v = f 2 ( σ 2 k ) and then x 2 = a 2 v k 2 . I would be curious to first see the result without f 2 , however.