Files
levlib/levmath/levmath.odin
Zachary Levy 59c600d630 phased-executor (#4)
Co-authored-by: Zachary Levy <zachary@sunforge.is>
Reviewed-on: #4
2026-04-03 01:53:23 +00:00

177 lines
6.2 KiB
Odin
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package levmath
import "base:intrinsics"
import "core:math"
// ---------------------------------------------------------------------------------------------------------------------
// ----- Fast Exp (Schraudolph IEEE 754 bit trick) ---------------------------------------------------------------------
// ---------------------------------------------------------------------------------------------------------------------
// Approximates exp(x) by linearly mapping x into an IEEE 754 bit pattern.
// Follows the standard formulation from Schraudolph (1999), adapted by Rade (2021):
//
// y = fma(SCALE, x, BIAS) (fused multiply-add, single rounding)
// if y < SHIFT: y = 0 (flush denormals to zero)
// else: y = min(y, MAX_Y) (clamp overflow to +inf)
// result = transmute(float)uint(y)
//
// SCALE = 2^mantissa_bits / ln(2) — converts from natural-log domain to bit-pattern units.
// BIAS = 2^mantissa_bits * (exponent_bias - correction)
// — the correction of 0.04367744890362246 minimizes worst-case relative error
// (Schraudolph 1999, Table 1).
//
// Values below SHIFT (= 2^mantissa_bits, the smallest normal bit pattern) are flushed
// to zero — this avoids producing denormals, which are slow on many CPUs and outside
// the approximation's valid range. Values above MAX_Y (= the +inf bit pattern read as
// a float) are clamped to MAX_Y, producing exact +inf.
//
// Worst-case relative error: < 2.983% (f32/f64), < 3.705% (f16).
// Zero memory access — pure ALU.
//
// References:
// N. Schraudolph, "A Fast, Compact Approximation of the Exponential Function",
// Neural Computation 11, 853862 (1999).
// J. Rade, FastExp.h (2021), https://gist.github.com/jrade/293a73f89dfef51da6522428c857802d
fast_exp :: #force_inline proc "contextless" (x: $FLOAT) -> FLOAT where intrinsics.type_is_float(FLOAT) {
CORRECTION :: 0.04367744890362246
when FLOAT == f16 {
MANTISSA_BITS :: 10
EXPONENT_BIAS :: 15
SHIFT :: f16(1 << MANTISSA_BITS)
SCALE :: SHIFT / math.LN2
BIAS :: SHIFT * f16(EXPONENT_BIAS - CORRECTION)
MAX_Y :: SHIFT * (2 * EXPONENT_BIAS + 1)
y := intrinsics.fused_mul_add(SCALE, x, BIAS)
y = y < SHIFT ? 0 : min(y, MAX_Y)
return transmute(f16)u16(y)
} else when FLOAT == f32 {
MANTISSA_BITS :: 23
EXPONENT_BIAS :: 127
SHIFT :: f32(1 << MANTISSA_BITS)
SCALE :: SHIFT / math.LN2
BIAS :: SHIFT * f32(EXPONENT_BIAS - CORRECTION)
MAX_Y :: SHIFT * (2 * EXPONENT_BIAS + 1)
y := intrinsics.fused_mul_add(SCALE, x, BIAS)
y = y < SHIFT ? 0 : min(y, MAX_Y)
return transmute(f32)u32(y)
} else when FLOAT == f64 {
MANTISSA_BITS :: 52
EXPONENT_BIAS :: 1023
SHIFT :: f64(1 << MANTISSA_BITS)
SCALE :: SHIFT / math.LN2
BIAS :: SHIFT * f64(EXPONENT_BIAS - CORRECTION)
MAX_Y :: SHIFT * (2 * EXPONENT_BIAS + 1)
y := intrinsics.fused_mul_add(SCALE, x, BIAS)
y = y < SHIFT ? 0 : min(y, MAX_Y)
return transmute(f64)u64(y)
} else {
#panic("fast_exp only supports f16, f32, and f64")
}
}
// ---------------------------------------------------------------------------------------------------------------------
// ----- Testing -------------------------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------------------------------------
import "core:testing"
@(test)
test_fast_exp_identity :: proc(t: ^testing.T) {
// exp(0) ≈ 1.0 — validates bias correction for every width.
r16 := fast_exp(f16(0))
r32 := fast_exp(f32(0))
r64 := fast_exp(f64(0))
testing.expectf(t, abs(f32(r16) - 1) < 0.05, "f16: exp(0) = %v, want ≈ 1", r16)
testing.expectf(t, abs(r32 - 1) < 0.04, "f32: exp(0) = %v, want ≈ 1", r32)
testing.expectf(t, abs(r64 - 1) < 0.04, "f64: exp(0) = %v, want ≈ 1", r64)
}
@(test)
test_fast_exp_accuracy :: proc(t: ^testing.T) {
// Sweep the full representable domain and find the actual worst-case relative error.
MAX_REL :: 0.02983 // measured worst-case f32/f64: 2.983%
MAX_REL_16 :: 0.03705 // measured worst-case f16: 3.705% (coarser constants in 10-bit mantissa)
// ---- f16 sweep ([-9, 11]: non-denormal output range for f16) ----
worst_16: f64
worst_x_16: f16
STEPS_16 :: 100_000
LO_16 :: f16(-9.0)
HI_16 :: f16(11.0)
STEP_16 :: (HI_16 - LO_16) / STEPS_16
for i in 0 ..< STEPS_16 {
x := LO_16 + f16(i) * STEP_16
approx := fast_exp(x)
exact := math.exp(f32(x))
if exact < 1e-10 do continue
rel := f64(abs(f32(approx) - exact)) / f64(exact)
if rel > worst_16 {
worst_16 = rel
worst_x_16 = x
}
}
testing.expectf(t, worst_16 <= MAX_REL_16,
"f16: worst relative error = %.8f at x = %v", worst_16, worst_x_16)
// ---- f32 sweep ----
worst_32: f64
worst_x_32: f32
STEPS_32 :: 10_000_000
LO_32 :: f32(-87.0)
HI_32 :: f32(88.0)
STEP_32 :: (HI_32 - LO_32) / STEPS_32
for i in 0 ..< STEPS_32 {
x := LO_32 + f32(i) * STEP_32
approx := fast_exp(x)
exact := math.exp(x)
if exact < 1e-30 do continue
rel := f64(abs(approx - exact)) / f64(exact)
if rel > worst_32 {
worst_32 = rel
worst_x_32 = x
}
}
testing.expectf(t, worst_32 <= MAX_REL,
"f32: worst relative error = %.8f at x = %v", worst_32, worst_x_32)
// ---- f64 sweep ----
worst_64: f64
worst_x_64: f64
STEPS_64 :: 10_000_000
LO_64 :: f64(-700.0)
HI_64 :: f64(709.0)
STEP_64 :: (HI_64 - LO_64) / STEPS_64
for i in 0 ..< STEPS_64 {
x := LO_64 + f64(i) * STEP_64
approx := fast_exp(x)
exact := math.exp(x)
if exact < 1e-300 do continue
rel := abs((approx - exact) / exact)
if rel > worst_64 {
worst_64 = rel
worst_x_64 = x
}
}
testing.expectf(t, worst_64 <= MAX_REL,
"f64: worst relative error = %.8f at x = %v", worst_64, worst_x_64)
}
@(test)
test_fast_exp_saturation :: proc(t: ^testing.T) {
// Underflow: very negative input → exact 0
testing.expectf(t, fast_exp(f32(-1000)) == 0, "f32 underflow: got %v, want 0", fast_exp(f32(-1000)))
testing.expectf(t, fast_exp(f64(-1e5)) == 0, "f64 underflow: got %v, want 0", fast_exp(f64(-1e5)))
// Overflow: very positive input → exact +inf
ov32 := fast_exp(f32(1000))
ov64 := fast_exp(f64(1e5))
testing.expectf(t, math.is_inf_f32(ov32), "f32 overflow: got %v, want +inf", ov32)
testing.expectf(t, math.is_inf_f64(ov64), "f64 overflow: got %v, want +inf", ov64)
}