package levmath import "base:intrinsics" import "core:math" import "core:testing" // --------------------------------------------------------------------------------------------------------------------- // ----- Fast Exp (Schraudolph IEEE 754 bit trick) --------------------------------------------------------------------- // --------------------------------------------------------------------------------------------------------------------- // Approximates exp(x) by linearly mapping x into an IEEE 754 bit pattern. // Follows the standard formulation from Schraudolph (1999), adapted by Rade (2021): // // y = fma(SCALE, x, BIAS) (fused multiply-add, single rounding) // if y < SHIFT: y = 0 (flush denormals to zero) // else: y = min(y, MAX_Y) (clamp overflow to +inf) // result = transmute(float)uint(y) // // SCALE = 2^mantissa_bits / ln(2) — converts from natural-log domain to bit-pattern units. // BIAS = 2^mantissa_bits * (exponent_bias - correction) // — the correction of 0.04367744890362246 minimizes worst-case relative error // (Schraudolph 1999, Table 1). // // Values below SHIFT (= 2^mantissa_bits, the smallest normal bit pattern) are flushed // to zero — this avoids producing denormals, which are slow on many CPUs and outside // the approximation's valid range. Values above MAX_Y (= the +inf bit pattern read as // a float) are clamped to MAX_Y, producing exact +inf. // // Worst-case relative error: < 2.983% (f32/f64), < 3.705% (f16). // Zero memory access — pure ALU. // // References: // N. Schraudolph, "A Fast, Compact Approximation of the Exponential Function", // Neural Computation 11, 853–862 (1999). // J. Rade, FastExp.h (2021), https://gist.github.com/jrade/293a73f89dfef51da6522428c857802d fast_exp :: #force_inline proc "contextless" (x: $FLOAT) -> FLOAT where intrinsics.type_is_float(FLOAT) { CORRECTION :: 0.04367744890362246 when FLOAT == f16 { MANTISSA_BITS :: 10 EXPONENT_BIAS :: 15 SHIFT :: f16(1 << MANTISSA_BITS) SCALE :: SHIFT / math.LN2 BIAS :: SHIFT * f16(EXPONENT_BIAS - CORRECTION) MAX_Y :: SHIFT * (2 * EXPONENT_BIAS + 1) y := intrinsics.fused_mul_add(SCALE, x, BIAS) y = y < SHIFT ? 0 : min(y, MAX_Y) return transmute(f16)u16(y) } else when FLOAT == f32 { MANTISSA_BITS :: 23 EXPONENT_BIAS :: 127 SHIFT :: f32(1 << MANTISSA_BITS) SCALE :: SHIFT / math.LN2 BIAS :: SHIFT * f32(EXPONENT_BIAS - CORRECTION) MAX_Y :: SHIFT * (2 * EXPONENT_BIAS + 1) y := intrinsics.fused_mul_add(SCALE, x, BIAS) y = y < SHIFT ? 0 : min(y, MAX_Y) return transmute(f32)u32(y) } else when FLOAT == f64 { MANTISSA_BITS :: 52 EXPONENT_BIAS :: 1023 SHIFT :: f64(1 << MANTISSA_BITS) SCALE :: SHIFT / math.LN2 BIAS :: SHIFT * f64(EXPONENT_BIAS - CORRECTION) MAX_Y :: SHIFT * (2 * EXPONENT_BIAS + 1) y := intrinsics.fused_mul_add(SCALE, x, BIAS) y = y < SHIFT ? 0 : min(y, MAX_Y) return transmute(f64)u64(y) } else { #panic("fast_exp only supports f16, f32, and f64") } } // --------------------------------------------------------------------------------------------------------------------- // ----- Testing ------------------------------------------------------------------------------------------------------- // --------------------------------------------------------------------------------------------------------------------- @(test) test_fast_exp_identity :: proc(t: ^testing.T) { // exp(0) ≈ 1.0 — validates bias correction for every width. r16 := fast_exp(f16(0)) r32 := fast_exp(f32(0)) r64 := fast_exp(f64(0)) testing.expectf(t, abs(f32(r16) - 1) < 0.05, "f16: exp(0) = %v, want ≈ 1", r16) testing.expectf(t, abs(r32 - 1) < 0.04, "f32: exp(0) = %v, want ≈ 1", r32) testing.expectf(t, abs(r64 - 1) < 0.04, "f64: exp(0) = %v, want ≈ 1", r64) } @(test) test_fast_exp_accuracy :: proc(t: ^testing.T) { // Sweep the full representable domain and find the actual worst-case relative error. MAX_REL :: 0.02983 // measured worst-case f32/f64: 2.983% MAX_REL_16 :: 0.03705 // measured worst-case f16: 3.705% (coarser constants in 10-bit mantissa) // ---- f16 sweep ([-9, 11]: non-denormal output range for f16) ---- worst_16: f64 worst_x_16: f16 STEPS_16 :: 100_000 LO_16 :: f16(-9.0) HI_16 :: f16(11.0) STEP_16 :: (HI_16 - LO_16) / STEPS_16 for i in 0 ..< STEPS_16 { x := LO_16 + f16(i) * STEP_16 approx := fast_exp(x) exact := math.exp(f32(x)) if exact < 1e-10 do continue rel := f64(abs(f32(approx) - exact)) / f64(exact) if rel > worst_16 { worst_16 = rel worst_x_16 = x } } testing.expectf(t, worst_16 <= MAX_REL_16, "f16: worst relative error = %.8f at x = %v", worst_16, worst_x_16) // ---- f32 sweep ---- worst_32: f64 worst_x_32: f32 STEPS_32 :: 10_000_000 LO_32 :: f32(-87.0) HI_32 :: f32(88.0) STEP_32 :: (HI_32 - LO_32) / STEPS_32 for i in 0 ..< STEPS_32 { x := LO_32 + f32(i) * STEP_32 approx := fast_exp(x) exact := math.exp(x) if exact < 1e-30 do continue rel := f64(abs(approx - exact)) / f64(exact) if rel > worst_32 { worst_32 = rel worst_x_32 = x } } testing.expectf(t, worst_32 <= MAX_REL, "f32: worst relative error = %.8f at x = %v", worst_32, worst_x_32) // ---- f64 sweep ---- worst_64: f64 worst_x_64: f64 STEPS_64 :: 10_000_000 LO_64 :: f64(-700.0) HI_64 :: f64(709.0) STEP_64 :: (HI_64 - LO_64) / STEPS_64 for i in 0 ..< STEPS_64 { x := LO_64 + f64(i) * STEP_64 approx := fast_exp(x) exact := math.exp(x) if exact < 1e-300 do continue rel := abs((approx - exact) / exact) if rel > worst_64 { worst_64 = rel worst_x_64 = x } } testing.expectf(t, worst_64 <= MAX_REL, "f64: worst relative error = %.8f at x = %v", worst_64, worst_x_64) } @(test) test_fast_exp_saturation :: proc(t: ^testing.T) { // Underflow: very negative input → exact 0 testing.expectf(t, fast_exp(f32(-1000)) == 0, "f32 underflow: got %v, want 0", fast_exp(f32(-1000))) testing.expectf(t, fast_exp(f64(-1e5)) == 0, "f64 underflow: got %v, want 0", fast_exp(f64(-1e5))) // Overflow: very positive input → exact +inf ov32 := fast_exp(f32(1000)) ov64 := fast_exp(f64(1e5)) testing.expectf(t, math.is_inf_f32(ov32), "f32 overflow: got %v, want +inf", ov32) testing.expectf(t, math.is_inf_f64(ov64), "f64 overflow: got %v, want +inf", ov64) }