177 lines
6.2 KiB
Odin
177 lines
6.2 KiB
Odin
package levmath
|
||
|
||
import "base:intrinsics"
|
||
import "core:math"
|
||
import "core:testing"
|
||
|
||
// ---------------------------------------------------------------------------------------------------------------------
|
||
// ----- Fast Exp (Schraudolph IEEE 754 bit trick) ---------------------------------------------------------------------
|
||
// ---------------------------------------------------------------------------------------------------------------------
|
||
|
||
// Approximates exp(x) by linearly mapping x into an IEEE 754 bit pattern.
|
||
// Follows the standard formulation from Schraudolph (1999), adapted by Rade (2021):
|
||
//
|
||
// y = fma(SCALE, x, BIAS) (fused multiply-add, single rounding)
|
||
// if y < SHIFT: y = 0 (flush denormals to zero)
|
||
// else: y = min(y, MAX_Y) (clamp overflow to +inf)
|
||
// result = transmute(float)uint(y)
|
||
//
|
||
// SCALE = 2^mantissa_bits / ln(2) — converts from natural-log domain to bit-pattern units.
|
||
// BIAS = 2^mantissa_bits * (exponent_bias - correction)
|
||
// — the correction of 0.04367744890362246 minimizes worst-case relative error
|
||
// (Schraudolph 1999, Table 1).
|
||
//
|
||
// Values below SHIFT (= 2^mantissa_bits, the smallest normal bit pattern) are flushed
|
||
// to zero — this avoids producing denormals, which are slow on many CPUs and outside
|
||
// the approximation's valid range. Values above MAX_Y (= the +inf bit pattern read as
|
||
// a float) are clamped to MAX_Y, producing exact +inf.
|
||
//
|
||
// Worst-case relative error: < 2.983% (f32/f64), < 3.705% (f16).
|
||
// Zero memory access — pure ALU.
|
||
//
|
||
// References:
|
||
// N. Schraudolph, "A Fast, Compact Approximation of the Exponential Function",
|
||
// Neural Computation 11, 853–862 (1999).
|
||
// J. Rade, FastExp.h (2021), https://gist.github.com/jrade/293a73f89dfef51da6522428c857802d
|
||
fast_exp :: #force_inline proc "contextless" (x: $FLOAT) -> FLOAT where intrinsics.type_is_float(FLOAT) {
|
||
LN_2 :: 0.6931471805599453
|
||
CORRECTION :: 0.04367744890362246
|
||
|
||
when FLOAT == f16 {
|
||
MANTISSA_BITS :: 10
|
||
EXPONENT_BIAS :: 15
|
||
SHIFT :: f16(1 << MANTISSA_BITS)
|
||
SCALE :: SHIFT / LN_2
|
||
BIAS :: SHIFT * f16(EXPONENT_BIAS - CORRECTION)
|
||
MAX_Y :: SHIFT * (2 * EXPONENT_BIAS + 1)
|
||
|
||
y := intrinsics.fused_mul_add(SCALE, x, BIAS)
|
||
y = y < SHIFT ? 0 : min(y, MAX_Y)
|
||
return transmute(f16)u16(y)
|
||
} else when FLOAT == f32 {
|
||
MANTISSA_BITS :: 23
|
||
EXPONENT_BIAS :: 127
|
||
SHIFT :: f32(1 << MANTISSA_BITS)
|
||
SCALE :: SHIFT / LN_2
|
||
BIAS :: SHIFT * f32(EXPONENT_BIAS - CORRECTION)
|
||
MAX_Y :: SHIFT * (2 * EXPONENT_BIAS + 1)
|
||
|
||
y := intrinsics.fused_mul_add(SCALE, x, BIAS)
|
||
y = y < SHIFT ? 0 : min(y, MAX_Y)
|
||
return transmute(f32)u32(y)
|
||
} else when FLOAT == f64 {
|
||
MANTISSA_BITS :: 52
|
||
EXPONENT_BIAS :: 1023
|
||
SHIFT :: f64(1 << MANTISSA_BITS)
|
||
SCALE :: SHIFT / LN_2
|
||
BIAS :: SHIFT * f64(EXPONENT_BIAS - CORRECTION)
|
||
MAX_Y :: SHIFT * (2 * EXPONENT_BIAS + 1)
|
||
|
||
y := intrinsics.fused_mul_add(SCALE, x, BIAS)
|
||
y = y < SHIFT ? 0 : min(y, MAX_Y)
|
||
return transmute(f64)u64(y)
|
||
} else {
|
||
#panic("fast_exp only supports f16, f32, and f64")
|
||
}
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------------------------------------------------
|
||
// ----- Testing -------------------------------------------------------------------------------------------------------
|
||
// ---------------------------------------------------------------------------------------------------------------------
|
||
|
||
@(test)
|
||
test_fast_exp_identity :: proc(t: ^testing.T) {
|
||
// exp(0) ≈ 1.0 — validates bias correction for every width.
|
||
r16 := fast_exp(f16(0))
|
||
r32 := fast_exp(f32(0))
|
||
r64 := fast_exp(f64(0))
|
||
testing.expectf(t, abs(f32(r16) - 1) < 0.05, "f16: exp(0) = %v, want ≈ 1", r16)
|
||
testing.expectf(t, abs(r32 - 1) < 0.04, "f32: exp(0) = %v, want ≈ 1", r32)
|
||
testing.expectf(t, abs(r64 - 1) < 0.04, "f64: exp(0) = %v, want ≈ 1", r64)
|
||
}
|
||
|
||
@(test)
|
||
test_fast_exp_accuracy :: proc(t: ^testing.T) {
|
||
// Sweep the full representable domain and find the actual worst-case relative error.
|
||
MAX_REL :: 0.02983 // measured worst-case f32/f64: 2.983%
|
||
MAX_REL_16 :: 0.03705 // measured worst-case f16: 3.705% (coarser constants in 10-bit mantissa)
|
||
|
||
// ---- f16 sweep ([-9, 11]: non-denormal output range for f16) ----
|
||
worst_16: f64
|
||
worst_x_16: f16
|
||
STEPS_16 :: 100_000
|
||
LO_16 :: f16(-9.0)
|
||
HI_16 :: f16(11.0)
|
||
STEP_16 :: (HI_16 - LO_16) / STEPS_16
|
||
|
||
for i in 0 ..< STEPS_16 {
|
||
x := LO_16 + f16(i) * STEP_16
|
||
approx := fast_exp(x)
|
||
exact := math.exp(f32(x))
|
||
if exact < 1e-10 do continue
|
||
rel := f64(abs(f32(approx) - exact)) / f64(exact)
|
||
if rel > worst_16 {
|
||
worst_16 = rel
|
||
worst_x_16 = x
|
||
}
|
||
}
|
||
testing.expectf(t, worst_16 <= MAX_REL_16,
|
||
"f16: worst relative error = %.8f at x = %v", worst_16, worst_x_16)
|
||
|
||
// ---- f32 sweep ----
|
||
worst_32: f64
|
||
worst_x_32: f32
|
||
STEPS_32 :: 10_000_000
|
||
LO_32 :: f32(-87.0)
|
||
HI_32 :: f32(88.0)
|
||
STEP_32 :: (HI_32 - LO_32) / STEPS_32
|
||
|
||
for i in 0 ..< STEPS_32 {
|
||
x := LO_32 + f32(i) * STEP_32
|
||
approx := fast_exp(x)
|
||
exact := math.exp(x)
|
||
if exact < 1e-30 do continue
|
||
rel := f64(abs(approx - exact)) / f64(exact)
|
||
if rel > worst_32 {
|
||
worst_32 = rel
|
||
worst_x_32 = x
|
||
}
|
||
}
|
||
testing.expectf(t, worst_32 <= MAX_REL,
|
||
"f32: worst relative error = %.8f at x = %v", worst_32, worst_x_32)
|
||
|
||
// ---- f64 sweep ----
|
||
worst_64: f64
|
||
worst_x_64: f64
|
||
STEPS_64 :: 10_000_000
|
||
LO_64 :: f64(-700.0)
|
||
HI_64 :: f64(709.0)
|
||
STEP_64 :: (HI_64 - LO_64) / STEPS_64
|
||
|
||
for i in 0 ..< STEPS_64 {
|
||
x := LO_64 + f64(i) * STEP_64
|
||
approx := fast_exp(x)
|
||
exact := math.exp(x)
|
||
if exact < 1e-300 do continue
|
||
rel := abs((approx - exact) / exact)
|
||
if rel > worst_64 {
|
||
worst_64 = rel
|
||
worst_x_64 = x
|
||
}
|
||
}
|
||
testing.expectf(t, worst_64 <= MAX_REL,
|
||
"f64: worst relative error = %.8f at x = %v", worst_64, worst_x_64)
|
||
}
|
||
|
||
@(test)
|
||
test_fast_exp_saturation :: proc(t: ^testing.T) {
|
||
// Underflow: very negative input → exact 0
|
||
testing.expectf(t, fast_exp(f32(-1000)) == 0, "f32 underflow: got %v, want 0", fast_exp(f32(-1000)))
|
||
testing.expectf(t, fast_exp(f64(-1e5)) == 0, "f64 underflow: got %v, want 0", fast_exp(f64(-1e5)))
|
||
|
||
// Overflow: very positive input → exact +inf
|
||
ov32 := fast_exp(f32(1000))
|
||
ov64 := fast_exp(f64(1e5))
|
||
testing.expectf(t, math.is_inf_f32(ov32), "f32 overflow: got %v, want +inf", ov32)
|
||
testing.expectf(t, math.is_inf_f64(ov64), "f64 overflow: got %v, want +inf", ov64)
|
||
} |