Added fast_exp math

This commit is contained in:
Zachary Levy
2026-03-29 18:42:04 -07:00
parent 2d4494233b
commit 44232dde4f
2 changed files with 179 additions and 1 deletions

177
levmath/levmath.odin Normal file
View File

@@ -0,0 +1,177 @@
package levmath
import "base:intrinsics"
import "core:math"
import "core:testing"
// ---------------------------------------------------------------------------------------------------------------------
// ----- Fast Exp (Schraudolph IEEE 754 bit trick) ---------------------------------------------------------------------
// ---------------------------------------------------------------------------------------------------------------------
// Approximates exp(x) by linearly mapping x into an IEEE 754 bit pattern.
// Follows the standard formulation from Schraudolph (1999), adapted by Rade (2021):
//
// y = fma(SCALE, x, BIAS) (fused multiply-add, single rounding)
// if y < SHIFT: y = 0 (flush denormals to zero)
// else: y = min(y, MAX_Y) (clamp overflow to +inf)
// result = transmute(float)uint(y)
//
// SCALE = 2^mantissa_bits / ln(2) — converts from natural-log domain to bit-pattern units.
// BIAS = 2^mantissa_bits * (exponent_bias - correction)
// — the correction of 0.04367744890362246 minimizes worst-case relative error
// (Schraudolph 1999, Table 1).
//
// Values below SHIFT (= 2^mantissa_bits, the smallest normal bit pattern) are flushed
// to zero — this avoids producing denormals, which are slow on many CPUs and outside
// the approximation's valid range. Values above MAX_Y (= the +inf bit pattern read as
// a float) are clamped to MAX_Y, producing exact +inf.
//
// Worst-case relative error: < 2.983% (f32/f64), < 3.705% (f16).
// Zero memory access — pure ALU.
//
// References:
// N. Schraudolph, "A Fast, Compact Approximation of the Exponential Function",
// Neural Computation 11, 853862 (1999).
// J. Rade, FastExp.h (2021), https://gist.github.com/jrade/293a73f89dfef51da6522428c857802d
fast_exp :: #force_inline proc "contextless" (x: $FLOAT) -> FLOAT where intrinsics.type_is_float(FLOAT) {
LN_2 :: 0.6931471805599453
CORRECTION :: 0.04367744890362246
when FLOAT == f16 {
MANTISSA_BITS :: 10
EXPONENT_BIAS :: 15
SHIFT :: f16(1 << MANTISSA_BITS)
SCALE :: SHIFT / LN_2
BIAS :: SHIFT * f16(EXPONENT_BIAS - CORRECTION)
MAX_Y :: SHIFT * (2 * EXPONENT_BIAS + 1)
y := intrinsics.fused_mul_add(SCALE, x, BIAS)
y = y < SHIFT ? 0 : min(y, MAX_Y)
return transmute(f16)u16(y)
} else when FLOAT == f32 {
MANTISSA_BITS :: 23
EXPONENT_BIAS :: 127
SHIFT :: f32(1 << MANTISSA_BITS)
SCALE :: SHIFT / LN_2
BIAS :: SHIFT * f32(EXPONENT_BIAS - CORRECTION)
MAX_Y :: SHIFT * (2 * EXPONENT_BIAS + 1)
y := intrinsics.fused_mul_add(SCALE, x, BIAS)
y = y < SHIFT ? 0 : min(y, MAX_Y)
return transmute(f32)u32(y)
} else when FLOAT == f64 {
MANTISSA_BITS :: 52
EXPONENT_BIAS :: 1023
SHIFT :: f64(1 << MANTISSA_BITS)
SCALE :: SHIFT / LN_2
BIAS :: SHIFT * f64(EXPONENT_BIAS - CORRECTION)
MAX_Y :: SHIFT * (2 * EXPONENT_BIAS + 1)
y := intrinsics.fused_mul_add(SCALE, x, BIAS)
y = y < SHIFT ? 0 : min(y, MAX_Y)
return transmute(f64)u64(y)
} else {
#panic("fast_exp only supports f16, f32, and f64")
}
}
// ---------------------------------------------------------------------------------------------------------------------
// ----- Testing -------------------------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------------------------------------
@(test)
test_fast_exp_identity :: proc(t: ^testing.T) {
// exp(0) ≈ 1.0 — validates bias correction for every width.
r16 := fast_exp(f16(0))
r32 := fast_exp(f32(0))
r64 := fast_exp(f64(0))
testing.expectf(t, abs(f32(r16) - 1) < 0.05, "f16: exp(0) = %v, want ≈ 1", r16)
testing.expectf(t, abs(r32 - 1) < 0.04, "f32: exp(0) = %v, want ≈ 1", r32)
testing.expectf(t, abs(r64 - 1) < 0.04, "f64: exp(0) = %v, want ≈ 1", r64)
}
@(test)
test_fast_exp_accuracy :: proc(t: ^testing.T) {
// Sweep the full representable domain and find the actual worst-case relative error.
MAX_REL :: 0.02983 // measured worst-case f32/f64: 2.983%
MAX_REL_16 :: 0.03705 // measured worst-case f16: 3.705% (coarser constants in 10-bit mantissa)
// ---- f16 sweep ([-9, 11]: non-denormal output range for f16) ----
worst_16: f64
worst_x_16: f16
STEPS_16 :: 100_000
LO_16 :: f16(-9.0)
HI_16 :: f16(11.0)
STEP_16 :: (HI_16 - LO_16) / STEPS_16
for i in 0 ..< STEPS_16 {
x := LO_16 + f16(i) * STEP_16
approx := fast_exp(x)
exact := math.exp(f32(x))
if exact < 1e-10 do continue
rel := f64(abs(f32(approx) - exact)) / f64(exact)
if rel > worst_16 {
worst_16 = rel
worst_x_16 = x
}
}
testing.expectf(t, worst_16 <= MAX_REL_16,
"f16: worst relative error = %.8f at x = %v", worst_16, worst_x_16)
// ---- f32 sweep ----
worst_32: f64
worst_x_32: f32
STEPS_32 :: 10_000_000
LO_32 :: f32(-87.0)
HI_32 :: f32(88.0)
STEP_32 :: (HI_32 - LO_32) / STEPS_32
for i in 0 ..< STEPS_32 {
x := LO_32 + f32(i) * STEP_32
approx := fast_exp(x)
exact := math.exp(x)
if exact < 1e-30 do continue
rel := f64(abs(approx - exact)) / f64(exact)
if rel > worst_32 {
worst_32 = rel
worst_x_32 = x
}
}
testing.expectf(t, worst_32 <= MAX_REL,
"f32: worst relative error = %.8f at x = %v", worst_32, worst_x_32)
// ---- f64 sweep ----
worst_64: f64
worst_x_64: f64
STEPS_64 :: 10_000_000
LO_64 :: f64(-700.0)
HI_64 :: f64(709.0)
STEP_64 :: (HI_64 - LO_64) / STEPS_64
for i in 0 ..< STEPS_64 {
x := LO_64 + f64(i) * STEP_64
approx := fast_exp(x)
exact := math.exp(x)
if exact < 1e-300 do continue
rel := abs((approx - exact) / exact)
if rel > worst_64 {
worst_64 = rel
worst_x_64 = x
}
}
testing.expectf(t, worst_64 <= MAX_REL,
"f64: worst relative error = %.8f at x = %v", worst_64, worst_x_64)
}
@(test)
test_fast_exp_saturation :: proc(t: ^testing.T) {
// Underflow: very negative input → exact 0
testing.expectf(t, fast_exp(f32(-1000)) == 0, "f32 underflow: got %v, want 0", fast_exp(f32(-1000)))
testing.expectf(t, fast_exp(f64(-1e5)) == 0, "f64 underflow: got %v, want 0", fast_exp(f64(-1e5)))
// Overflow: very positive input → exact +inf
ov32 := fast_exp(f32(1000))
ov64 := fast_exp(f64(1e5))
testing.expectf(t, math.is_inf_f32(ov32), "f32 overflow: got %v, want +inf", ov32)
testing.expectf(t, math.is_inf_f64(ov64), "f64 overflow: got %v, want +inf", ov64)
}