Added fast_exp math
This commit is contained in:
177
levmath/levmath.odin
Normal file
177
levmath/levmath.odin
Normal file
@@ -0,0 +1,177 @@
|
||||
package levmath
|
||||
|
||||
import "base:intrinsics"
|
||||
import "core:math"
|
||||
import "core:testing"
|
||||
|
||||
// ---------------------------------------------------------------------------------------------------------------------
|
||||
// ----- Fast Exp (Schraudolph IEEE 754 bit trick) ---------------------------------------------------------------------
|
||||
// ---------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
// Approximates exp(x) by linearly mapping x into an IEEE 754 bit pattern.
|
||||
// Follows the standard formulation from Schraudolph (1999), adapted by Rade (2021):
|
||||
//
|
||||
// y = fma(SCALE, x, BIAS) (fused multiply-add, single rounding)
|
||||
// if y < SHIFT: y = 0 (flush denormals to zero)
|
||||
// else: y = min(y, MAX_Y) (clamp overflow to +inf)
|
||||
// result = transmute(float)uint(y)
|
||||
//
|
||||
// SCALE = 2^mantissa_bits / ln(2) — converts from natural-log domain to bit-pattern units.
|
||||
// BIAS = 2^mantissa_bits * (exponent_bias - correction)
|
||||
// — the correction of 0.04367744890362246 minimizes worst-case relative error
|
||||
// (Schraudolph 1999, Table 1).
|
||||
//
|
||||
// Values below SHIFT (= 2^mantissa_bits, the smallest normal bit pattern) are flushed
|
||||
// to zero — this avoids producing denormals, which are slow on many CPUs and outside
|
||||
// the approximation's valid range. Values above MAX_Y (= the +inf bit pattern read as
|
||||
// a float) are clamped to MAX_Y, producing exact +inf.
|
||||
//
|
||||
// Worst-case relative error: < 2.983% (f32/f64), < 3.705% (f16).
|
||||
// Zero memory access — pure ALU.
|
||||
//
|
||||
// References:
|
||||
// N. Schraudolph, "A Fast, Compact Approximation of the Exponential Function",
|
||||
// Neural Computation 11, 853–862 (1999).
|
||||
// J. Rade, FastExp.h (2021), https://gist.github.com/jrade/293a73f89dfef51da6522428c857802d
|
||||
fast_exp :: #force_inline proc "contextless" (x: $FLOAT) -> FLOAT where intrinsics.type_is_float(FLOAT) {
|
||||
LN_2 :: 0.6931471805599453
|
||||
CORRECTION :: 0.04367744890362246
|
||||
|
||||
when FLOAT == f16 {
|
||||
MANTISSA_BITS :: 10
|
||||
EXPONENT_BIAS :: 15
|
||||
SHIFT :: f16(1 << MANTISSA_BITS)
|
||||
SCALE :: SHIFT / LN_2
|
||||
BIAS :: SHIFT * f16(EXPONENT_BIAS - CORRECTION)
|
||||
MAX_Y :: SHIFT * (2 * EXPONENT_BIAS + 1)
|
||||
|
||||
y := intrinsics.fused_mul_add(SCALE, x, BIAS)
|
||||
y = y < SHIFT ? 0 : min(y, MAX_Y)
|
||||
return transmute(f16)u16(y)
|
||||
} else when FLOAT == f32 {
|
||||
MANTISSA_BITS :: 23
|
||||
EXPONENT_BIAS :: 127
|
||||
SHIFT :: f32(1 << MANTISSA_BITS)
|
||||
SCALE :: SHIFT / LN_2
|
||||
BIAS :: SHIFT * f32(EXPONENT_BIAS - CORRECTION)
|
||||
MAX_Y :: SHIFT * (2 * EXPONENT_BIAS + 1)
|
||||
|
||||
y := intrinsics.fused_mul_add(SCALE, x, BIAS)
|
||||
y = y < SHIFT ? 0 : min(y, MAX_Y)
|
||||
return transmute(f32)u32(y)
|
||||
} else when FLOAT == f64 {
|
||||
MANTISSA_BITS :: 52
|
||||
EXPONENT_BIAS :: 1023
|
||||
SHIFT :: f64(1 << MANTISSA_BITS)
|
||||
SCALE :: SHIFT / LN_2
|
||||
BIAS :: SHIFT * f64(EXPONENT_BIAS - CORRECTION)
|
||||
MAX_Y :: SHIFT * (2 * EXPONENT_BIAS + 1)
|
||||
|
||||
y := intrinsics.fused_mul_add(SCALE, x, BIAS)
|
||||
y = y < SHIFT ? 0 : min(y, MAX_Y)
|
||||
return transmute(f64)u64(y)
|
||||
} else {
|
||||
#panic("fast_exp only supports f16, f32, and f64")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------------------------------------------------
|
||||
// ----- Testing -------------------------------------------------------------------------------------------------------
|
||||
// ---------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
@(test)
|
||||
test_fast_exp_identity :: proc(t: ^testing.T) {
|
||||
// exp(0) ≈ 1.0 — validates bias correction for every width.
|
||||
r16 := fast_exp(f16(0))
|
||||
r32 := fast_exp(f32(0))
|
||||
r64 := fast_exp(f64(0))
|
||||
testing.expectf(t, abs(f32(r16) - 1) < 0.05, "f16: exp(0) = %v, want ≈ 1", r16)
|
||||
testing.expectf(t, abs(r32 - 1) < 0.04, "f32: exp(0) = %v, want ≈ 1", r32)
|
||||
testing.expectf(t, abs(r64 - 1) < 0.04, "f64: exp(0) = %v, want ≈ 1", r64)
|
||||
}
|
||||
|
||||
@(test)
|
||||
test_fast_exp_accuracy :: proc(t: ^testing.T) {
|
||||
// Sweep the full representable domain and find the actual worst-case relative error.
|
||||
MAX_REL :: 0.02983 // measured worst-case f32/f64: 2.983%
|
||||
MAX_REL_16 :: 0.03705 // measured worst-case f16: 3.705% (coarser constants in 10-bit mantissa)
|
||||
|
||||
// ---- f16 sweep ([-9, 11]: non-denormal output range for f16) ----
|
||||
worst_16: f64
|
||||
worst_x_16: f16
|
||||
STEPS_16 :: 100_000
|
||||
LO_16 :: f16(-9.0)
|
||||
HI_16 :: f16(11.0)
|
||||
STEP_16 :: (HI_16 - LO_16) / STEPS_16
|
||||
|
||||
for i in 0 ..< STEPS_16 {
|
||||
x := LO_16 + f16(i) * STEP_16
|
||||
approx := fast_exp(x)
|
||||
exact := math.exp(f32(x))
|
||||
if exact < 1e-10 do continue
|
||||
rel := f64(abs(f32(approx) - exact)) / f64(exact)
|
||||
if rel > worst_16 {
|
||||
worst_16 = rel
|
||||
worst_x_16 = x
|
||||
}
|
||||
}
|
||||
testing.expectf(t, worst_16 <= MAX_REL_16,
|
||||
"f16: worst relative error = %.8f at x = %v", worst_16, worst_x_16)
|
||||
|
||||
// ---- f32 sweep ----
|
||||
worst_32: f64
|
||||
worst_x_32: f32
|
||||
STEPS_32 :: 10_000_000
|
||||
LO_32 :: f32(-87.0)
|
||||
HI_32 :: f32(88.0)
|
||||
STEP_32 :: (HI_32 - LO_32) / STEPS_32
|
||||
|
||||
for i in 0 ..< STEPS_32 {
|
||||
x := LO_32 + f32(i) * STEP_32
|
||||
approx := fast_exp(x)
|
||||
exact := math.exp(x)
|
||||
if exact < 1e-30 do continue
|
||||
rel := f64(abs(approx - exact)) / f64(exact)
|
||||
if rel > worst_32 {
|
||||
worst_32 = rel
|
||||
worst_x_32 = x
|
||||
}
|
||||
}
|
||||
testing.expectf(t, worst_32 <= MAX_REL,
|
||||
"f32: worst relative error = %.8f at x = %v", worst_32, worst_x_32)
|
||||
|
||||
// ---- f64 sweep ----
|
||||
worst_64: f64
|
||||
worst_x_64: f64
|
||||
STEPS_64 :: 10_000_000
|
||||
LO_64 :: f64(-700.0)
|
||||
HI_64 :: f64(709.0)
|
||||
STEP_64 :: (HI_64 - LO_64) / STEPS_64
|
||||
|
||||
for i in 0 ..< STEPS_64 {
|
||||
x := LO_64 + f64(i) * STEP_64
|
||||
approx := fast_exp(x)
|
||||
exact := math.exp(x)
|
||||
if exact < 1e-300 do continue
|
||||
rel := abs((approx - exact) / exact)
|
||||
if rel > worst_64 {
|
||||
worst_64 = rel
|
||||
worst_x_64 = x
|
||||
}
|
||||
}
|
||||
testing.expectf(t, worst_64 <= MAX_REL,
|
||||
"f64: worst relative error = %.8f at x = %v", worst_64, worst_x_64)
|
||||
}
|
||||
|
||||
@(test)
|
||||
test_fast_exp_saturation :: proc(t: ^testing.T) {
|
||||
// Underflow: very negative input → exact 0
|
||||
testing.expectf(t, fast_exp(f32(-1000)) == 0, "f32 underflow: got %v, want 0", fast_exp(f32(-1000)))
|
||||
testing.expectf(t, fast_exp(f64(-1e5)) == 0, "f64 underflow: got %v, want 0", fast_exp(f64(-1e5)))
|
||||
|
||||
// Overflow: very positive input → exact +inf
|
||||
ov32 := fast_exp(f32(1000))
|
||||
ov64 := fast_exp(f64(1e5))
|
||||
testing.expectf(t, math.is_inf_f32(ov32), "f32 overflow: got %v, want +inf", ov32)
|
||||
testing.expectf(t, math.is_inf_f64(ov64), "f64 overflow: got %v, want +inf", ov64)
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
{
|
||||
"$schema": "https://raw.githubusercontent.com/DanielGavin/ols/master/misc/odinfmt.schema.json",
|
||||
"character_width": 110,
|
||||
"tabs_width": 1
|
||||
"tabs_width": 1,
|
||||
"inline_single_stmt_case": true
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user