From 44232dde4ff878fa97e47eea2fc5346814e98c7c Mon Sep 17 00:00:00 2001 From: Zachary Levy Date: Sun, 29 Mar 2026 18:42:04 -0700 Subject: [PATCH] Added fast_exp math --- levmath/levmath.odin | 177 +++++++++++++++++++++++++++++++++++++++++++ odinfmt.json | 3 +- 2 files changed, 179 insertions(+), 1 deletion(-) create mode 100644 levmath/levmath.odin diff --git a/levmath/levmath.odin b/levmath/levmath.odin new file mode 100644 index 0000000..cbd6de3 --- /dev/null +++ b/levmath/levmath.odin @@ -0,0 +1,177 @@ +package levmath + +import "base:intrinsics" +import "core:math" +import "core:testing" + +// --------------------------------------------------------------------------------------------------------------------- +// ----- Fast Exp (Schraudolph IEEE 754 bit trick) --------------------------------------------------------------------- +// --------------------------------------------------------------------------------------------------------------------- + +// Approximates exp(x) by linearly mapping x into an IEEE 754 bit pattern. +// Follows the standard formulation from Schraudolph (1999), adapted by Rade (2021): +// +// y = fma(SCALE, x, BIAS) (fused multiply-add, single rounding) +// if y < SHIFT: y = 0 (flush denormals to zero) +// else: y = min(y, MAX_Y) (clamp overflow to +inf) +// result = transmute(float)uint(y) +// +// SCALE = 2^mantissa_bits / ln(2) — converts from natural-log domain to bit-pattern units. +// BIAS = 2^mantissa_bits * (exponent_bias - correction) +// — the correction of 0.04367744890362246 minimizes worst-case relative error +// (Schraudolph 1999, Table 1). +// +// Values below SHIFT (= 2^mantissa_bits, the smallest normal bit pattern) are flushed +// to zero — this avoids producing denormals, which are slow on many CPUs and outside +// the approximation's valid range. Values above MAX_Y (= the +inf bit pattern read as +// a float) are clamped to MAX_Y, producing exact +inf. +// +// Worst-case relative error: < 2.983% (f32/f64), < 3.705% (f16). +// Zero memory access — pure ALU. +// +// References: +// N. Schraudolph, "A Fast, Compact Approximation of the Exponential Function", +// Neural Computation 11, 853–862 (1999). +// J. Rade, FastExp.h (2021), https://gist.github.com/jrade/293a73f89dfef51da6522428c857802d +fast_exp :: #force_inline proc "contextless" (x: $FLOAT) -> FLOAT where intrinsics.type_is_float(FLOAT) { + LN_2 :: 0.6931471805599453 + CORRECTION :: 0.04367744890362246 + + when FLOAT == f16 { + MANTISSA_BITS :: 10 + EXPONENT_BIAS :: 15 + SHIFT :: f16(1 << MANTISSA_BITS) + SCALE :: SHIFT / LN_2 + BIAS :: SHIFT * f16(EXPONENT_BIAS - CORRECTION) + MAX_Y :: SHIFT * (2 * EXPONENT_BIAS + 1) + + y := intrinsics.fused_mul_add(SCALE, x, BIAS) + y = y < SHIFT ? 0 : min(y, MAX_Y) + return transmute(f16)u16(y) + } else when FLOAT == f32 { + MANTISSA_BITS :: 23 + EXPONENT_BIAS :: 127 + SHIFT :: f32(1 << MANTISSA_BITS) + SCALE :: SHIFT / LN_2 + BIAS :: SHIFT * f32(EXPONENT_BIAS - CORRECTION) + MAX_Y :: SHIFT * (2 * EXPONENT_BIAS + 1) + + y := intrinsics.fused_mul_add(SCALE, x, BIAS) + y = y < SHIFT ? 0 : min(y, MAX_Y) + return transmute(f32)u32(y) + } else when FLOAT == f64 { + MANTISSA_BITS :: 52 + EXPONENT_BIAS :: 1023 + SHIFT :: f64(1 << MANTISSA_BITS) + SCALE :: SHIFT / LN_2 + BIAS :: SHIFT * f64(EXPONENT_BIAS - CORRECTION) + MAX_Y :: SHIFT * (2 * EXPONENT_BIAS + 1) + + y := intrinsics.fused_mul_add(SCALE, x, BIAS) + y = y < SHIFT ? 0 : min(y, MAX_Y) + return transmute(f64)u64(y) + } else { + #panic("fast_exp only supports f16, f32, and f64") + } +} + +// --------------------------------------------------------------------------------------------------------------------- +// ----- Testing ------------------------------------------------------------------------------------------------------- +// --------------------------------------------------------------------------------------------------------------------- + +@(test) +test_fast_exp_identity :: proc(t: ^testing.T) { + // exp(0) ≈ 1.0 — validates bias correction for every width. + r16 := fast_exp(f16(0)) + r32 := fast_exp(f32(0)) + r64 := fast_exp(f64(0)) + testing.expectf(t, abs(f32(r16) - 1) < 0.05, "f16: exp(0) = %v, want ≈ 1", r16) + testing.expectf(t, abs(r32 - 1) < 0.04, "f32: exp(0) = %v, want ≈ 1", r32) + testing.expectf(t, abs(r64 - 1) < 0.04, "f64: exp(0) = %v, want ≈ 1", r64) +} + +@(test) +test_fast_exp_accuracy :: proc(t: ^testing.T) { + // Sweep the full representable domain and find the actual worst-case relative error. + MAX_REL :: 0.02983 // measured worst-case f32/f64: 2.983% + MAX_REL_16 :: 0.03705 // measured worst-case f16: 3.705% (coarser constants in 10-bit mantissa) + + // ---- f16 sweep ([-9, 11]: non-denormal output range for f16) ---- + worst_16: f64 + worst_x_16: f16 + STEPS_16 :: 100_000 + LO_16 :: f16(-9.0) + HI_16 :: f16(11.0) + STEP_16 :: (HI_16 - LO_16) / STEPS_16 + + for i in 0 ..< STEPS_16 { + x := LO_16 + f16(i) * STEP_16 + approx := fast_exp(x) + exact := math.exp(f32(x)) + if exact < 1e-10 do continue + rel := f64(abs(f32(approx) - exact)) / f64(exact) + if rel > worst_16 { + worst_16 = rel + worst_x_16 = x + } + } + testing.expectf(t, worst_16 <= MAX_REL_16, + "f16: worst relative error = %.8f at x = %v", worst_16, worst_x_16) + + // ---- f32 sweep ---- + worst_32: f64 + worst_x_32: f32 + STEPS_32 :: 10_000_000 + LO_32 :: f32(-87.0) + HI_32 :: f32(88.0) + STEP_32 :: (HI_32 - LO_32) / STEPS_32 + + for i in 0 ..< STEPS_32 { + x := LO_32 + f32(i) * STEP_32 + approx := fast_exp(x) + exact := math.exp(x) + if exact < 1e-30 do continue + rel := f64(abs(approx - exact)) / f64(exact) + if rel > worst_32 { + worst_32 = rel + worst_x_32 = x + } + } + testing.expectf(t, worst_32 <= MAX_REL, + "f32: worst relative error = %.8f at x = %v", worst_32, worst_x_32) + + // ---- f64 sweep ---- + worst_64: f64 + worst_x_64: f64 + STEPS_64 :: 10_000_000 + LO_64 :: f64(-700.0) + HI_64 :: f64(709.0) + STEP_64 :: (HI_64 - LO_64) / STEPS_64 + + for i in 0 ..< STEPS_64 { + x := LO_64 + f64(i) * STEP_64 + approx := fast_exp(x) + exact := math.exp(x) + if exact < 1e-300 do continue + rel := abs((approx - exact) / exact) + if rel > worst_64 { + worst_64 = rel + worst_x_64 = x + } + } + testing.expectf(t, worst_64 <= MAX_REL, + "f64: worst relative error = %.8f at x = %v", worst_64, worst_x_64) +} + +@(test) +test_fast_exp_saturation :: proc(t: ^testing.T) { + // Underflow: very negative input → exact 0 + testing.expectf(t, fast_exp(f32(-1000)) == 0, "f32 underflow: got %v, want 0", fast_exp(f32(-1000))) + testing.expectf(t, fast_exp(f64(-1e5)) == 0, "f64 underflow: got %v, want 0", fast_exp(f64(-1e5))) + + // Overflow: very positive input → exact +inf + ov32 := fast_exp(f32(1000)) + ov64 := fast_exp(f64(1e5)) + testing.expectf(t, math.is_inf_f32(ov32), "f32 overflow: got %v, want +inf", ov32) + testing.expectf(t, math.is_inf_f64(ov64), "f64 overflow: got %v, want +inf", ov64) +} \ No newline at end of file diff --git a/odinfmt.json b/odinfmt.json index 601712b..f225fc4 100644 --- a/odinfmt.json +++ b/odinfmt.json @@ -1,5 +1,6 @@ { "$schema": "https://raw.githubusercontent.com/DanielGavin/ols/master/misc/odinfmt.schema.json", "character_width": 110, - "tabs_width": 1 + "tabs_width": 1, + "inline_single_stmt_case": true } -- 2.43.0