2440 lines
63 KiB
Odin
2440 lines
63 KiB
Odin
package levsort
|
|
|
|
import "base:intrinsics"
|
|
import "core:math/bits"
|
|
|
|
// Threshold for switching to heap-based selection when k is small.
|
|
@(private = "file")
|
|
HEAP_SELECT_K_THRESHOLD :: 32
|
|
|
|
// Threshold for switching to insertion sort for small arrays.
|
|
@(private = "file")
|
|
INSERTION_THRESHOLD :: 32
|
|
|
|
// MSD select threshold - below this size, finish with insertion sort.
|
|
@(private = "file")
|
|
MSD_SMALL_THRESHOLD :: 64
|
|
|
|
// Threshold for using 11-bit radix in LSD sort (needs enough elements to amortize larger counts array)
|
|
@(private = "file")
|
|
LSD_11BIT_THRESHOLD :: 8192
|
|
|
|
// Radix-based partial sort for arrays of floats.
|
|
// Sorts the smallest k elements to the front of the slice.
|
|
// Elements after position k are in unspecified order.
|
|
//
|
|
// The `allocator` parameter specifies where to allocate the temporary buffer
|
|
// used during radix sorting.
|
|
partial_sort_float :: #force_inline proc(
|
|
data: []$FLOAT,
|
|
k: int,
|
|
allocator := context.temp_allocator,
|
|
) where intrinsics.type_is_float(FLOAT) {
|
|
partial_sort_by_fkey(data, k, proc(x: FLOAT) -> FLOAT {return x}, allocator)
|
|
}
|
|
|
|
// Radix-based partial sort using a float key extraction function.
|
|
// This is faster than comparison-based sorting for large arrays.
|
|
// The `key` procedure extracts a float value from each element for comparison.
|
|
// Sorts the smallest k elements to the front of the slice.
|
|
// Elements after position k are in unspecified order.
|
|
//
|
|
// The `allocator` parameter specifies where to allocate the temporary buffer
|
|
// used during radix sorting.
|
|
partial_sort_by_fkey :: proc(
|
|
data: []$T,
|
|
k: int,
|
|
key: proc(val: T) -> $FLOAT,
|
|
allocator := context.temp_allocator,
|
|
) where intrinsics.type_is_float(FLOAT) {
|
|
n := len(data)
|
|
if k <= 0 || n <= 1 do return
|
|
|
|
k := min(k, n)
|
|
|
|
when FLOAT == f16 {
|
|
NUM_BITS :: 16
|
|
Key_Type :: u16
|
|
} else when FLOAT == f32 {
|
|
NUM_BITS :: 32
|
|
Key_Type :: u32
|
|
} else when FLOAT == f64 {
|
|
NUM_BITS :: 64
|
|
Key_Type :: u64
|
|
} else {
|
|
#panic("partial_sort_by_fkey only supports f16, f32, and f64")
|
|
}
|
|
|
|
// Algorithm selection based on k and n (integer math avoids float division)
|
|
// Using 2*k >= n for ratio >= 0.5
|
|
if k <= HEAP_SELECT_K_THRESHOLD {
|
|
// Heap-select for small k: O(n + k log k)
|
|
// Don't allocate keys[n] - compute keys on-the-fly during heap selection
|
|
// This avoids an extra full pass over memory for the "tiny k, huge n" case
|
|
heap_keys := make([]Key_Type, k, allocator)
|
|
defer delete(heap_keys, allocator)
|
|
tmp_items := make([]T, k, allocator)
|
|
defer delete(tmp_items, allocator)
|
|
tmp_keys := make([]Key_Type, k, allocator)
|
|
defer delete(tmp_keys, allocator)
|
|
|
|
heap_select_streaming(data, heap_keys, k, key)
|
|
radix_sort_lsd_pingpong(data[:k], heap_keys, tmp_items, tmp_keys, NUM_BITS)
|
|
} else if 2 * k >= n {
|
|
// Full radix sort when k is a large fraction of n
|
|
// Precompute sortable keys once
|
|
keys := make([]Key_Type, n, allocator)
|
|
defer delete(keys, allocator)
|
|
for i := 0; i < n; i += 1 {
|
|
keys[i] = float_to_sortable_typed(key(data[i]))
|
|
}
|
|
tmp_items := make([]T, n, allocator)
|
|
defer delete(tmp_items, allocator)
|
|
tmp_keys := make([]Key_Type, n, allocator)
|
|
defer delete(tmp_keys, allocator)
|
|
|
|
radix_sort_lsd_pingpong(data, keys, tmp_items, tmp_keys, NUM_BITS)
|
|
} else {
|
|
// MSD select path - precompute all keys
|
|
keys := make([]Key_Type, n, allocator)
|
|
defer delete(keys, allocator)
|
|
for i := 0; i < n; i += 1 {
|
|
keys[i] = float_to_sortable_typed(key(data[i]))
|
|
}
|
|
tmp_items := make([]T, n, allocator)
|
|
defer delete(tmp_items, allocator)
|
|
tmp_keys := make([]Key_Type, n, allocator)
|
|
defer delete(tmp_keys, allocator)
|
|
|
|
// MSD radix select to partition k smallest elements to front
|
|
// Pass k-1 as the 0-based rank of the kth smallest element
|
|
radix_select_msd_digit(data, keys, k - 1, NUM_BITS, tmp_items, tmp_keys)
|
|
// Sort the k smallest elements using LSD radix sort
|
|
radix_sort_lsd_pingpong(data[:k], keys[:k], tmp_items[:k], tmp_keys[:k], NUM_BITS)
|
|
}
|
|
}
|
|
|
|
// Convert float bits to a sortable unsigned integer representation.
|
|
// For positive floats: flip the sign bit (makes them larger than negatives)
|
|
// For negative floats: flip all bits (reverses their order correctly)
|
|
// NaNs are mapped to max value so they sort to the end.
|
|
//
|
|
// Uses bit-pattern NaN detection for better performance in hot loops.
|
|
// Returns properly-sized key type to minimize memory bandwidth.
|
|
@(private = "file")
|
|
float_to_sortable_typed :: proc {
|
|
float_to_sortable_f16,
|
|
float_to_sortable_f32,
|
|
float_to_sortable_f64,
|
|
}
|
|
|
|
@(private = "file")
|
|
float_to_sortable_f16 :: #force_inline proc "contextless" (f: f16) -> u16 {
|
|
bits := transmute(u16)f
|
|
// NaN detection: exponent all 1s (0x7C00) and mantissa nonzero (0x03FF)
|
|
exp_mask :: u16(0x7C00)
|
|
mant_mask :: u16(0x03FF)
|
|
if (bits & exp_mask) == exp_mask && (bits & mant_mask) != 0 {
|
|
return max(u16)
|
|
}
|
|
mask := u16(i16(bits) >> 15)
|
|
return bits ~ (mask | (1 << 15))
|
|
}
|
|
|
|
@(private = "file")
|
|
float_to_sortable_f32 :: #force_inline proc "contextless" (f: f32) -> u32 {
|
|
bits := transmute(u32)f
|
|
// NaN detection: exponent all 1s (0x7F800000) and mantissa nonzero (0x007FFFFF)
|
|
exp_mask :: u32(0x7F800000)
|
|
mant_mask :: u32(0x007FFFFF)
|
|
if (bits & exp_mask) == exp_mask && (bits & mant_mask) != 0 {
|
|
return max(u32)
|
|
}
|
|
mask := u32(i32(bits) >> 31)
|
|
return bits ~ (mask | (1 << 31))
|
|
}
|
|
|
|
@(private = "file")
|
|
float_to_sortable_f64 :: #force_inline proc "contextless" (f: f64) -> u64 {
|
|
bits := transmute(u64)f
|
|
// NaN detection: exponent all 1s (0x7FF0000000000000) and mantissa nonzero
|
|
exp_mask :: u64(0x7FF0000000000000)
|
|
mant_mask :: u64(0x000FFFFFFFFFFFFF)
|
|
if (bits & exp_mask) == exp_mask && (bits & mant_mask) != 0 {
|
|
return max(u64)
|
|
}
|
|
mask := u64(i64(bits) >> 63)
|
|
return bits ~ (mask | (1 << 63))
|
|
}
|
|
|
|
// Legacy version returning uint (for backward compatibility with heap_select)
|
|
@(private = "file")
|
|
float_to_sortable :: #force_inline proc "contextless" (
|
|
f: $FLOAT,
|
|
) -> (
|
|
result: uint,
|
|
) where intrinsics.type_is_float(FLOAT) {
|
|
when FLOAT == f16 {
|
|
return uint(float_to_sortable_typed(f))
|
|
} else when FLOAT == f32 {
|
|
return uint(float_to_sortable_typed(f))
|
|
} else when FLOAT == f64 {
|
|
return uint(float_to_sortable_typed(f))
|
|
} else {
|
|
#panic("float_to_sortable only supports f16, f32, and f64")
|
|
}
|
|
}
|
|
|
|
// Wide-digit MSD radix select - partitions data so smallest k elements are in data[:k].
|
|
// Uses 8-11 bit digits instead of 1-bit-at-a-time for fewer passes and less memory traffic.
|
|
// k_rank is the 0-based rank of the target element (for "smallest k", pass k-1).
|
|
// Uses real ping-pong buffers - scatter src→dst, swap pointers, copy back only once at end.
|
|
@(private = "file")
|
|
radix_select_msd_digit :: proc(
|
|
data: []$T,
|
|
keys: []$Key,
|
|
k_rank: int,
|
|
total_bits: int,
|
|
tmp_items: []T,
|
|
tmp_keys: []Key,
|
|
) where Key == u16 ||
|
|
Key == u32 ||
|
|
Key == u64 {
|
|
n := len(data)
|
|
if n <= 1 || k_rank < 0 do return
|
|
|
|
// Working range [lo, hi) and relative rank within that range
|
|
lo := 0
|
|
hi := n
|
|
k_rel := k_rank
|
|
|
|
// Ping-pong buffer state - swap whole arrays, copy back once at end
|
|
src_items := data
|
|
src_keys := keys
|
|
dst_items := tmp_items
|
|
dst_keys := tmp_keys
|
|
in_temp := false
|
|
|
|
// Choose radix bits based on float type
|
|
// For f64: use 11 bits for large ranges, 8 bits otherwise
|
|
// For f32/f16: use 8 bits
|
|
radix_bits_large :: 11
|
|
radix_bits_small :: 8
|
|
|
|
// Start from most significant bit
|
|
bit_pos := total_bits
|
|
|
|
for bit_pos > 0 && hi - lo > MSD_SMALL_THRESHOLD {
|
|
m := hi - lo
|
|
|
|
// Choose digit width based on range size and type
|
|
radix_bits: int
|
|
if total_bits == 64 && m >= 32768 {
|
|
radix_bits = radix_bits_large
|
|
} else {
|
|
radix_bits = radix_bits_small
|
|
}
|
|
|
|
// Clamp to remaining bits
|
|
if bit_pos < radix_bits {
|
|
radix_bits = bit_pos
|
|
}
|
|
|
|
radix_size := 1 << uint(radix_bits)
|
|
radix_mask := Key(radix_size - 1)
|
|
shift := uint(bit_pos - radix_bits)
|
|
|
|
base_lo := lo
|
|
|
|
// Early check: sample first few elements to detect likely single-bucket case
|
|
first_digit: Key = (src_keys[lo] >> shift) & radix_mask
|
|
all_same_digit := true
|
|
sample_end := min(lo + 16, hi)
|
|
for i := lo + 1; i < sample_end; i += 1 {
|
|
if ((src_keys[i] >> shift) & radix_mask) != first_digit {
|
|
all_same_digit = false
|
|
break
|
|
}
|
|
}
|
|
|
|
// Full histogram with single-bucket tracking
|
|
bucket_start := 0
|
|
bucket_end := 0
|
|
|
|
if radix_bits <= 8 {
|
|
counts: [256]u32
|
|
for i := lo; i < hi; i += 1 {
|
|
digit := (src_keys[i] >> shift) & radix_mask
|
|
counts[digit] += 1
|
|
if digit != first_digit do all_same_digit = false
|
|
}
|
|
|
|
// Skip this digit level if all in one bucket
|
|
if all_same_digit {
|
|
bit_pos -= radix_bits
|
|
continue
|
|
}
|
|
|
|
// Find the bucket containing k_rel (using > for 0-based rank)
|
|
cumsum := 0
|
|
for d := 0; d < radix_size; d += 1 {
|
|
cd := int(counts[d])
|
|
if cumsum + cd > k_rel {
|
|
bucket_start = cumsum
|
|
bucket_end = cumsum + cd
|
|
break
|
|
}
|
|
cumsum += cd
|
|
}
|
|
|
|
// Convert counts to offsets for scatter
|
|
offset: u32 = 0
|
|
for i := 0; i < radix_size; i += 1 {
|
|
c := counts[i]
|
|
counts[i] = offset
|
|
offset += c
|
|
}
|
|
|
|
// Scatter [lo, hi) from src to dst by digit
|
|
for i := lo; i < hi; i += 1 {
|
|
digit := (src_keys[i] >> shift) & radix_mask
|
|
dst_idx := base_lo + int(counts[digit])
|
|
dst_items[dst_idx] = src_items[i]
|
|
dst_keys[dst_idx] = src_keys[i]
|
|
counts[digit] += 1
|
|
}
|
|
} else {
|
|
// 11-bit radix path
|
|
counts: [2048]u32
|
|
for i := lo; i < hi; i += 1 {
|
|
digit := (src_keys[i] >> shift) & radix_mask
|
|
counts[digit] += 1
|
|
if digit != first_digit do all_same_digit = false
|
|
}
|
|
|
|
// Skip this digit level if all in one bucket
|
|
if all_same_digit {
|
|
bit_pos -= radix_bits
|
|
continue
|
|
}
|
|
|
|
// Find the bucket containing k_rel
|
|
cumsum := 0
|
|
for d := 0; d < radix_size; d += 1 {
|
|
cd := int(counts[d])
|
|
if cumsum + cd > k_rel {
|
|
bucket_start = cumsum
|
|
bucket_end = cumsum + cd
|
|
break
|
|
}
|
|
cumsum += cd
|
|
}
|
|
|
|
// Convert counts to offsets for scatter
|
|
offset: u32 = 0
|
|
for i := 0; i < radix_size; i += 1 {
|
|
c := counts[i]
|
|
counts[i] = offset
|
|
offset += c
|
|
}
|
|
|
|
// Scatter [lo, hi) from src to dst by digit
|
|
for i := lo; i < hi; i += 1 {
|
|
digit := (src_keys[i] >> shift) & radix_mask
|
|
dst_idx := base_lo + int(counts[digit])
|
|
dst_items[dst_idx] = src_items[i]
|
|
dst_keys[dst_idx] = src_keys[i]
|
|
counts[digit] += 1
|
|
}
|
|
}
|
|
|
|
// Preserve untouched segments so dst becomes a full valid array
|
|
for i := 0; i < lo; i += 1 {
|
|
dst_items[i] = src_items[i]
|
|
dst_keys[i] = src_keys[i]
|
|
}
|
|
for i := hi; i < n; i += 1 {
|
|
dst_items[i] = src_items[i]
|
|
dst_keys[i] = src_keys[i]
|
|
}
|
|
|
|
// Swap whole buffers (ping-pong)
|
|
src_items, dst_items = dst_items, src_items
|
|
src_keys, dst_keys = dst_keys, src_keys
|
|
in_temp = !in_temp
|
|
|
|
// Narrow to bucket containing k_rel
|
|
lo = base_lo + bucket_start
|
|
hi = base_lo + bucket_end
|
|
k_rel = k_rel - bucket_start
|
|
|
|
bit_pos -= radix_bits
|
|
}
|
|
|
|
// Finish with insertion sort on small range if needed
|
|
if hi - lo > 1 && hi - lo <= MSD_SMALL_THRESHOLD {
|
|
insertion_sort_with_keys(src_items[lo:hi], src_keys[lo:hi])
|
|
}
|
|
|
|
// Copy back once if we ended in temp buffers
|
|
if in_temp {
|
|
for i := 0; i < n; i += 1 {
|
|
data[i] = src_items[i]
|
|
keys[i] = src_keys[i]
|
|
}
|
|
}
|
|
}
|
|
|
|
// LSD radix sort with ping-pong buffers - no per-pass copy overhead.
|
|
// Operates on parallel (items, keys) arrays.
|
|
// Uses ctz to skip identical low digits.
|
|
@(private = "file")
|
|
radix_sort_lsd_pingpong :: proc(
|
|
data: []$T,
|
|
keys: []$Key,
|
|
tmp_items: []T,
|
|
tmp_keys: []Key,
|
|
total_bits: int,
|
|
) where Key == u16 ||
|
|
Key == u32 ||
|
|
Key == u64 {
|
|
n := len(data)
|
|
if n <= 1 do return
|
|
|
|
// For small arrays, use insertion sort
|
|
if n <= INSERTION_THRESHOLD {
|
|
insertion_sort_with_keys(data, keys)
|
|
return
|
|
}
|
|
|
|
// Choose radix bits: 11-bit for f64 with large n, 8-bit otherwise
|
|
radix_bits: int
|
|
if total_bits == 64 && n >= LSD_11BIT_THRESHOLD {
|
|
radix_bits = 11
|
|
} else {
|
|
radix_bits = 8
|
|
}
|
|
|
|
radix_mask := Key((1 << uint(radix_bits)) - 1)
|
|
|
|
// Compute true variability mask: var_bits has a 1 wherever ANY element differs from base
|
|
// This is safe for both low and high pass skipping (unlike min^max which misses intermediate values)
|
|
base := keys[0]
|
|
var_bits: Key = 0
|
|
for i := 1; i < n; i += 1 {
|
|
var_bits |= (keys[i] ~ base)
|
|
}
|
|
if var_bits == 0 do return // All keys identical
|
|
|
|
// Safe skip of identical low digits + safe limit of high digits
|
|
start_pass: int
|
|
end_pass: int
|
|
when Key == u16 {
|
|
low_bit := int(bits.count_trailing_zeros(u16(var_bits)))
|
|
high_bit := 15 - int(bits.count_leading_zeros(u16(var_bits)))
|
|
start_pass = low_bit / radix_bits
|
|
end_pass = (high_bit + radix_bits) / radix_bits
|
|
} else when Key == u32 {
|
|
low_bit := int(bits.count_trailing_zeros(u32(var_bits)))
|
|
high_bit := 31 - int(bits.count_leading_zeros(u32(var_bits)))
|
|
start_pass = low_bit / radix_bits
|
|
end_pass = (high_bit + radix_bits) / radix_bits
|
|
} else {
|
|
low_bit := int(bits.count_trailing_zeros(u64(var_bits)))
|
|
high_bit := 63 - int(bits.count_leading_zeros(u64(var_bits)))
|
|
start_pass = low_bit / radix_bits
|
|
end_pass = (high_bit + radix_bits) / radix_bits
|
|
}
|
|
|
|
// Ping-pong buffer state
|
|
src_items := data
|
|
src_keys := keys
|
|
dst_items := tmp_items
|
|
dst_keys := tmp_keys
|
|
in_temp := false
|
|
|
|
// Process passes from start_pass to end_pass (skipping identical low digits)
|
|
if radix_bits == 8 {
|
|
for pass := start_pass; pass < end_pass; pass += 1 {
|
|
shift := uint(pass * radix_bits)
|
|
|
|
// Count occurrences, tracking if all same digit
|
|
counts: [256]u32
|
|
first_digit := (src_keys[0] >> shift) & radix_mask
|
|
all_same := true
|
|
|
|
for i := 0; i < n; i += 1 {
|
|
digit := (src_keys[i] >> shift) & radix_mask
|
|
counts[digit] += 1
|
|
if digit != first_digit do all_same = false
|
|
}
|
|
|
|
// Skip pass if all elements in one bucket
|
|
if all_same do continue
|
|
|
|
// Convert counts to offsets (prefix sum)
|
|
offset: u32 = 0
|
|
for i := 0; i < 256; i += 1 {
|
|
c := counts[i]
|
|
counts[i] = offset
|
|
offset += c
|
|
}
|
|
|
|
// Scatter from src to dst
|
|
for i := 0; i < n; i += 1 {
|
|
digit := (src_keys[i] >> shift) & radix_mask
|
|
idx := int(counts[digit])
|
|
dst_items[idx] = src_items[i]
|
|
dst_keys[idx] = src_keys[i]
|
|
counts[digit] += 1
|
|
}
|
|
|
|
// Swap src and dst (ping-pong)
|
|
src_items, dst_items = dst_items, src_items
|
|
src_keys, dst_keys = dst_keys, src_keys
|
|
in_temp = !in_temp
|
|
}
|
|
} else {
|
|
// 11-bit radix path
|
|
for pass := start_pass; pass < end_pass; pass += 1 {
|
|
shift := uint(pass * radix_bits)
|
|
|
|
// Count occurrences, tracking if all same digit
|
|
counts: [2048]u32
|
|
first_digit := (src_keys[0] >> shift) & radix_mask
|
|
all_same := true
|
|
|
|
for i := 0; i < n; i += 1 {
|
|
digit := (src_keys[i] >> shift) & radix_mask
|
|
counts[digit] += 1
|
|
if digit != first_digit do all_same = false
|
|
}
|
|
|
|
// Skip pass if all elements in one bucket
|
|
if all_same do continue
|
|
|
|
// Convert counts to offsets (prefix sum)
|
|
offset: u32 = 0
|
|
for i := 0; i < 2048; i += 1 {
|
|
c := counts[i]
|
|
counts[i] = offset
|
|
offset += c
|
|
}
|
|
|
|
// Scatter from src to dst
|
|
for i := 0; i < n; i += 1 {
|
|
digit := (src_keys[i] >> shift) & radix_mask
|
|
idx := int(counts[digit])
|
|
dst_items[idx] = src_items[i]
|
|
dst_keys[idx] = src_keys[i]
|
|
counts[digit] += 1
|
|
}
|
|
|
|
// Swap src and dst (ping-pong)
|
|
src_items, dst_items = dst_items, src_items
|
|
src_keys, dst_keys = dst_keys, src_keys
|
|
in_temp = !in_temp
|
|
}
|
|
}
|
|
|
|
// If result ended up in temp buffer, copy back to original
|
|
if in_temp {
|
|
for i := 0; i < n; i += 1 {
|
|
data[i] = src_items[i]
|
|
keys[i] = src_keys[i]
|
|
}
|
|
}
|
|
}
|
|
|
|
// Streaming heap-based selection - computes keys on-the-fly.
|
|
// Avoids allocating keys[n] for the "tiny k, huge n" case.
|
|
// heap_keys is pre-allocated to size k.
|
|
@(private = "file")
|
|
heap_select_streaming :: proc(
|
|
data: []$T,
|
|
heap_keys: []$Key,
|
|
k: int,
|
|
key_fn: proc(_: T) -> $FLOAT,
|
|
) where Key == u16 ||
|
|
Key == u32 ||
|
|
Key == u64,
|
|
intrinsics.type_is_float(FLOAT) {
|
|
n := len(data)
|
|
if k <= 0 || n <= 1 do return
|
|
|
|
k := min(k, n)
|
|
|
|
// Initialize heap with first k elements, computing keys on the fly
|
|
for i := 0; i < k; i += 1 {
|
|
heap_keys[i] = float_to_sortable_typed(key_fn(data[i]))
|
|
}
|
|
|
|
// Build max-heap from first k elements
|
|
for i := k / 2 - 1; i >= 0; i -= 1 {
|
|
sift_down_max_with_keys(data[:k], heap_keys[:k], i, k)
|
|
}
|
|
|
|
// Scan remaining elements, computing key for each candidate
|
|
for i := k; i < n; i += 1 {
|
|
candidate_key := float_to_sortable_typed(key_fn(data[i]))
|
|
if candidate_key < heap_keys[0] {
|
|
// Swap item into heap root
|
|
data[0], data[i] = data[i], data[0]
|
|
heap_keys[0] = candidate_key
|
|
sift_down_max_with_keys(data[:k], heap_keys[:k], 0, k)
|
|
}
|
|
}
|
|
|
|
// The k smallest are now in data[:k] with their keys in heap_keys[:k]
|
|
// Sorting will be done by the caller with radix sort
|
|
}
|
|
|
|
// Heap-based selection using precomputed keys (for non-heap paths that already have keys).
|
|
@(private = "file")
|
|
heap_select_with_keys :: proc(
|
|
data: []$T,
|
|
keys: []$Key,
|
|
k: int,
|
|
) where Key == u16 ||
|
|
Key == u32 ||
|
|
Key == u64 {
|
|
n := len(data)
|
|
if k <= 0 || n <= 1 do return
|
|
|
|
k := min(k, n)
|
|
|
|
// Build max-heap from first k elements
|
|
for i := k / 2 - 1; i >= 0; i -= 1 {
|
|
sift_down_max_with_keys(data[:k], keys[:k], i, k)
|
|
}
|
|
|
|
// Scan remaining elements, replace max if smaller found
|
|
for i := k; i < n; i += 1 {
|
|
if keys[i] < keys[0] {
|
|
data[0], data[i] = data[i], data[0]
|
|
keys[0], keys[i] = keys[i], keys[0]
|
|
sift_down_max_with_keys(data[:k], keys[:k], 0, k)
|
|
}
|
|
}
|
|
|
|
// The k smallest are now in data[:k] but not sorted
|
|
}
|
|
|
|
@(private = "file")
|
|
sift_down_max_with_keys :: proc(
|
|
data: []$T,
|
|
keys: []$Key,
|
|
root: int,
|
|
n: int,
|
|
) where Key == u16 ||
|
|
Key == u32 ||
|
|
Key == u64 {
|
|
i := root
|
|
for {
|
|
largest := i
|
|
left := 2 * i + 1
|
|
right := 2 * i + 2
|
|
|
|
if left < n && keys[left] > keys[largest] {
|
|
largest = left
|
|
}
|
|
if right < n && keys[right] > keys[largest] {
|
|
largest = right
|
|
}
|
|
|
|
if largest == i do break
|
|
|
|
data[i], data[largest] = data[largest], data[i]
|
|
keys[i], keys[largest] = keys[largest], keys[i]
|
|
i = largest
|
|
}
|
|
}
|
|
|
|
// Insertion sort with parallel keys array.
|
|
@(private = "file")
|
|
insertion_sort_with_keys :: proc(data: []$T, keys: []$Key) where Key == u16 || Key == u32 || Key == u64 {
|
|
n := len(data)
|
|
for i := 1; i < n; i += 1 {
|
|
temp_item := data[i]
|
|
temp_key := keys[i]
|
|
j := i - 1
|
|
for j >= 0 && keys[j] > temp_key {
|
|
data[j + 1] = data[j]
|
|
keys[j + 1] = keys[j]
|
|
j -= 1
|
|
}
|
|
data[j + 1] = temp_item
|
|
keys[j + 1] = temp_key
|
|
}
|
|
}
|
|
|
|
// Legacy heap select (kept for potential direct usage)
|
|
@(private = "file")
|
|
heap_select :: proc(data: []$T, k: int, key: proc(_: T) -> $FLOAT) where intrinsics.type_is_float(FLOAT) {
|
|
n := len(data)
|
|
if k <= 0 || n <= 1 do return
|
|
|
|
k := min(k, n)
|
|
|
|
// Build max-heap from first k elements
|
|
heap := data[:k]
|
|
for i := k / 2 - 1; i >= 0; i -= 1 {
|
|
sift_down_max_by_key(heap, i, k, key)
|
|
}
|
|
|
|
// Scan remaining elements, replace max if smaller found
|
|
for i := k; i < n; i += 1 {
|
|
if float_to_sortable(key(data[i])) < float_to_sortable(key(heap[0])) {
|
|
heap[0], data[i] = data[i], heap[0]
|
|
sift_down_max_by_key(heap, 0, k, key)
|
|
}
|
|
}
|
|
|
|
// Extract elements from heap in sorted order (heapsort the prefix)
|
|
for i := k - 1; i > 0; i -= 1 {
|
|
heap[0], heap[i] = heap[i], heap[0]
|
|
sift_down_max_by_key(heap[:i], 0, i, key)
|
|
}
|
|
}
|
|
|
|
@(private = "file")
|
|
sift_down_max_by_key :: proc(
|
|
heap: []$T,
|
|
root: int,
|
|
n: int,
|
|
key: proc(_: T) -> $FLOAT,
|
|
) where intrinsics.type_is_float(FLOAT) {
|
|
i := root
|
|
for {
|
|
largest := i
|
|
left := 2 * i + 1
|
|
right := 2 * i + 2
|
|
|
|
if left < n && float_to_sortable(key(heap[left])) > float_to_sortable(key(heap[largest])) {
|
|
largest = left
|
|
}
|
|
if right < n && float_to_sortable(key(heap[right])) > float_to_sortable(key(heap[largest])) {
|
|
largest = right
|
|
}
|
|
|
|
if largest == i do break
|
|
heap[i], heap[largest] = heap[largest], heap[i]
|
|
i = largest
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------------------------------------------------
|
|
// ----- Testing -------------------------------------------------------------------------------------------------------
|
|
// ---------------------------------------------------------------------------------------------------------------------
|
|
import "core:math"
|
|
import "core:math/rand"
|
|
import "core:slice"
|
|
import "core:sort"
|
|
import "core:testing"
|
|
|
|
//----- Helper Procedures ----------------------------------
|
|
|
|
@(private = "file")
|
|
f64_key :: proc(x: f64) -> f64 {return x}
|
|
|
|
@(private = "file")
|
|
f32_key :: proc(x: f32) -> f32 {return x}
|
|
|
|
@(private = "file")
|
|
f16_key :: proc(x: f16) -> f16 {return x}
|
|
|
|
@(private = "file")
|
|
is_prefix_sorted :: proc(data: []$T, k: int) -> bool {
|
|
if k <= 1 do return true
|
|
for i := 0; i < k - 1; i += 1 {
|
|
if data[i] > data[i + 1] do return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
@(private = "file")
|
|
partition_property_holds :: proc(data: []$T, k: int) -> bool {
|
|
if k <= 0 || k >= len(data) do return true
|
|
|
|
max_in_prefix := data[0]
|
|
for i := 1; i < k; i += 1 {
|
|
if data[i] > max_in_prefix do max_in_prefix = data[i]
|
|
}
|
|
|
|
for i := k; i < len(data); i += 1 {
|
|
if data[i] < max_in_prefix do return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
@(private = "file")
|
|
has_correct_elements :: proc(original: []$T, result: []T, k: int) -> bool {
|
|
if k <= 0 do return true
|
|
|
|
truth := make([]T, len(original))
|
|
defer delete(truth)
|
|
copy(truth, original)
|
|
sort.quick_sort(truth)
|
|
|
|
result_prefix := make([]T, k)
|
|
defer delete(result_prefix)
|
|
copy(result_prefix, result[:k])
|
|
sort.quick_sort(result_prefix)
|
|
|
|
for i := 0; i < k; i += 1 {
|
|
if result_prefix[i] != truth[i] do return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
@(private = "file")
|
|
elements_preserved :: proc(original: []$T, result: []T) -> bool {
|
|
if len(original) != len(result) do return false
|
|
|
|
orig_sorted := make([]T, len(original))
|
|
defer delete(orig_sorted)
|
|
copy(orig_sorted, original)
|
|
sort.quick_sort(orig_sorted)
|
|
|
|
res_sorted := make([]T, len(result))
|
|
defer delete(res_sorted)
|
|
copy(res_sorted, result)
|
|
sort.quick_sort(res_sorted)
|
|
|
|
for i := 0; i < len(orig_sorted); i += 1 {
|
|
if orig_sorted[i] != res_sorted[i] do return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
@(private = "file")
|
|
validate_partial_sort :: proc(original: []$T, result: []T, k: int) -> (ok: bool, reason: string) {
|
|
k := min(k, len(result))
|
|
|
|
if !elements_preserved(original, result) {
|
|
return false, "Elements were lost or corrupted"
|
|
}
|
|
if !is_prefix_sorted(result, k) {
|
|
return false, "Prefix is not sorted"
|
|
}
|
|
if !partition_property_holds(result, k) {
|
|
return false, "Partition property violated"
|
|
}
|
|
if !has_correct_elements(original, result, k) {
|
|
return false, "Wrong elements in prefix"
|
|
}
|
|
return true, ""
|
|
}
|
|
|
|
//----- partial_sort_float wrapper test ----------------------------------
|
|
|
|
@(test)
|
|
test_partial_sort_float_wrapper :: proc(t: ^testing.T) {
|
|
size := 100
|
|
original := make([]f64, size)
|
|
defer delete(original)
|
|
for i := 0; i < size; i += 1 {
|
|
original[i] = f64(rand.int31() % 1000)
|
|
}
|
|
|
|
data1 := make([]f64, size)
|
|
defer delete(data1)
|
|
copy(data1, original)
|
|
|
|
data2 := make([]f64, size)
|
|
defer delete(data2)
|
|
copy(data2, original)
|
|
|
|
k := 25
|
|
|
|
partial_sort_float(data1, k)
|
|
partial_sort_by_fkey(data2, k, f64_key)
|
|
|
|
for i := 0; i < k; i += 1 {
|
|
testing.expectf(
|
|
t,
|
|
data1[i] == data2[i],
|
|
"Mismatch at position %d: partial_sort_float=%v, partial_sort_by_fkey=%v",
|
|
i,
|
|
data1[i],
|
|
data2[i],
|
|
)
|
|
}
|
|
}
|
|
|
|
//----- Edge case tests ----------------------------------
|
|
|
|
@(test)
|
|
test_empty_array :: proc(t: ^testing.T) {
|
|
data: []f64
|
|
partial_sort_float(data, 5)
|
|
testing.expect(t, len(data) == 0, "Empty array should remain empty")
|
|
}
|
|
|
|
@(test)
|
|
test_k_zero :: proc(t: ^testing.T) {
|
|
original := []f64{5, 3, 8, 1, 9}
|
|
data := make([]f64, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
|
|
partial_sort_float(data, 0)
|
|
|
|
for i := 0; i < len(data); i += 1 {
|
|
testing.expectf(t, data[i] == original[i], "k=0 should not modify array, index %d changed", i)
|
|
}
|
|
}
|
|
|
|
@(test)
|
|
test_k_negative :: proc(t: ^testing.T) {
|
|
original := []f64{5, 3, 8, 1, 9}
|
|
data := make([]f64, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
|
|
partial_sort_float(data, -5)
|
|
|
|
for i := 0; i < len(data); i += 1 {
|
|
testing.expectf(t, data[i] == original[i], "Negative k should not modify array, index %d changed", i)
|
|
}
|
|
}
|
|
|
|
@(test)
|
|
test_single_element :: proc(t: ^testing.T) {
|
|
data := []f64{42}
|
|
partial_sort_float(data, 1)
|
|
testing.expect(t, data[0] == 42, "Single element should be unchanged")
|
|
}
|
|
|
|
@(test)
|
|
test_k_one :: proc(t: ^testing.T) {
|
|
original := []f64{9, 5, 2, 8, 1, 7, 3}
|
|
data := make([]f64, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
|
|
partial_sort_float(data, 1)
|
|
|
|
testing.expect(t, data[0] == 1, "First element should be minimum when k=1")
|
|
testing.expect(t, elements_preserved(original, data), "Elements should be preserved")
|
|
}
|
|
|
|
@(test)
|
|
test_k_equals_length :: proc(t: ^testing.T) {
|
|
original := []f64{5, 2, 8, 1, 9, 3, 7, 4, 6}
|
|
data := make([]f64, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
|
|
partial_sort_float(data, len(data))
|
|
|
|
testing.expect(t, slice.is_sorted(data), "k=n should fully sort the array")
|
|
testing.expect(t, elements_preserved(original, data), "Elements should be preserved")
|
|
}
|
|
|
|
@(test)
|
|
test_k_exceeds_length :: proc(t: ^testing.T) {
|
|
original := []f64{5, 2, 8, 1, 9}
|
|
data := make([]f64, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
|
|
partial_sort_float(data, 100)
|
|
|
|
testing.expect(t, slice.is_sorted(data), "k>n should fully sort the array")
|
|
testing.expect(t, elements_preserved(original, data), "Elements should be preserved")
|
|
}
|
|
|
|
@(test)
|
|
test_two_elements_sorted :: proc(t: ^testing.T) {
|
|
data := []f64{1, 2}
|
|
partial_sort_float(data, 1)
|
|
testing.expect(t, data[0] == 1, "Minimum should be first")
|
|
}
|
|
|
|
@(test)
|
|
test_two_elements_reversed :: proc(t: ^testing.T) {
|
|
data := []f64{2, 1}
|
|
partial_sort_float(data, 1)
|
|
testing.expect(t, data[0] == 1, "Minimum should be first after sort")
|
|
}
|
|
|
|
@(test)
|
|
test_two_elements_k_two :: proc(t: ^testing.T) {
|
|
data := []f64{2, 1}
|
|
partial_sort_float(data, 2)
|
|
testing.expect(t, data[0] == 1 && data[1] == 2, "Two elements should be fully sorted")
|
|
}
|
|
|
|
//----- Input pattern tests ----------------------------------
|
|
|
|
@(test)
|
|
test_already_sorted :: proc(t: ^testing.T) {
|
|
original := []f64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
|
|
data := make([]f64, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 5
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "Already sorted input failed: %s", reason)
|
|
|
|
for i := 0; i < k; i += 1 {
|
|
testing.expectf(t, data[i] == f64(i + 1), "Expected %v at position %d, got %v", f64(i + 1), i, data[i])
|
|
}
|
|
}
|
|
|
|
@(test)
|
|
test_reverse_sorted :: proc(t: ^testing.T) {
|
|
original := []f64{10, 9, 8, 7, 6, 5, 4, 3, 2, 1}
|
|
data := make([]f64, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 5
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "Reverse sorted input failed: %s", reason)
|
|
}
|
|
|
|
@(test)
|
|
test_all_equal_elements :: proc(t: ^testing.T) {
|
|
original := []f64{7, 7, 7, 7, 7, 7, 7, 7}
|
|
data := make([]f64, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 4
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "All equal elements failed: %s", reason)
|
|
|
|
for i := 0; i < len(data); i += 1 {
|
|
testing.expect(t, data[i] == 7, "Element value changed unexpectedly")
|
|
}
|
|
}
|
|
|
|
@(test)
|
|
test_two_distinct_values :: proc(t: ^testing.T) {
|
|
original := []f64{1, 0, 1, 0, 1, 0, 1, 0}
|
|
data := make([]f64, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 4
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "Two distinct values failed: %s", reason)
|
|
|
|
for i := 0; i < k; i += 1 {
|
|
testing.expectf(t, data[i] == 0, "Expected 0 at position %d, got %v", i, data[i])
|
|
}
|
|
}
|
|
|
|
@(test)
|
|
test_many_duplicates :: proc(t: ^testing.T) {
|
|
original := []f64{3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 9}
|
|
data := make([]f64, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 7
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "Many duplicates failed: %s", reason)
|
|
}
|
|
|
|
@(test)
|
|
test_negative_numbers :: proc(t: ^testing.T) {
|
|
original := []f64{-5, 3, -8, 0, 2, -1, 7, -3}
|
|
data := make([]f64, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 4
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "Negative numbers failed: %s", reason)
|
|
|
|
expected := []f64{-8, -5, -3, -1}
|
|
for i := 0; i < k; i += 1 {
|
|
testing.expectf(t, data[i] == expected[i], "Expected %v at position %d, got %v", expected[i], i, data[i])
|
|
}
|
|
}
|
|
|
|
@(test)
|
|
test_mixed_positive_negative_zero :: proc(t: ^testing.T) {
|
|
original := []f64{0, -1, 1, 0, -2, 2, 0, -3, 3}
|
|
data := make([]f64, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 5
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "Mixed positive/negative/zero failed: %s", reason)
|
|
}
|
|
|
|
//----- Worst-case pattern tests ----------------------------------
|
|
|
|
@(test)
|
|
test_pipe_organ_pattern :: proc(t: ^testing.T) {
|
|
original := []f64{1, 2, 3, 4, 5, 5, 4, 3, 2, 1}
|
|
data := make([]f64, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 5
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "Pipe organ pattern failed: %s", reason)
|
|
}
|
|
|
|
@(test)
|
|
test_sawtooth_pattern :: proc(t: ^testing.T) {
|
|
original := []f64{10, 1, 9, 2, 8, 3, 7, 4, 6, 5}
|
|
data := make([]f64, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 5
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "Sawtooth pattern failed: %s", reason)
|
|
}
|
|
|
|
@(test)
|
|
test_median_of_three_killer :: proc(t: ^testing.T) {
|
|
original := make([]f64, 16)
|
|
defer delete(original)
|
|
|
|
for i := 0; i < 8; i += 1 {
|
|
original[i] = f64(i + 1)
|
|
original[15 - i] = f64(i + 1)
|
|
}
|
|
|
|
data := make([]f64, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 8
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "Median-of-three killer failed: %s", reason)
|
|
}
|
|
|
|
@(test)
|
|
test_all_same_except_one_small :: proc(t: ^testing.T) {
|
|
original := []f64{9, 9, 9, 9, 1, 9, 9, 9, 9, 9}
|
|
data := make([]f64, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 3
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
testing.expect(t, data[0] == 1, "Single minimum should be first")
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "All same except one failed: %s", reason)
|
|
}
|
|
|
|
@(test)
|
|
test_all_same_except_one_large :: proc(t: ^testing.T) {
|
|
original := []f64{1, 1, 1, 1, 9, 1, 1, 1, 1, 1}
|
|
data := make([]f64, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 3
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "All same except one large failed: %s", reason)
|
|
|
|
for i := 0; i < k; i += 1 {
|
|
testing.expect(t, data[i] == 1, "Prefix should contain only 1s")
|
|
}
|
|
}
|
|
|
|
//----- Floating point specific tests ----------------------------------
|
|
|
|
@(test)
|
|
test_float_basic :: proc(t: ^testing.T) {
|
|
original := []f64{3.14, 2.71, 1.41, 1.73, 2.23, 0.57}
|
|
data := make([]f64, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 3
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "Float basic failed: %s", reason)
|
|
}
|
|
|
|
@(test)
|
|
test_float_with_negative :: proc(t: ^testing.T) {
|
|
original := []f64{-1.5, 2.5, -3.5, 0.0, 1.0, -0.5}
|
|
data := make([]f64, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 3
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "Float with negative failed: %s", reason)
|
|
}
|
|
|
|
@(test)
|
|
test_float_very_close_values :: proc(t: ^testing.T) {
|
|
original := []f64{1.0000001, 1.0000002, 1.0000000, 1.0000003, 1.0000004}
|
|
data := make([]f64, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 3
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
testing.expect(t, is_prefix_sorted(data, k), "Prefix should be sorted")
|
|
testing.expect(t, elements_preserved(original, data), "Elements should be preserved")
|
|
}
|
|
|
|
@(test)
|
|
test_float_subnormal :: proc(t: ^testing.T) {
|
|
// Test with subnormal (denormalized) floats
|
|
original := []f64{1e-310, 1e-320, 1e-300, 1e-315, 0.0}
|
|
data := make([]f64, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 3
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
testing.expect(t, is_prefix_sorted(data, k), "Prefix should be sorted with subnormals")
|
|
testing.expect(t, elements_preserved(original, data), "Elements should be preserved")
|
|
}
|
|
|
|
@(test)
|
|
test_float_infinity :: proc(t: ^testing.T) {
|
|
inf := math.inf_f64(1)
|
|
neg_inf := math.inf_f64(-1)
|
|
original := []f64{inf, 0, neg_inf, 1, -1, inf}
|
|
data := make([]f64, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 3
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
testing.expect(t, data[0] == neg_inf, "Negative infinity should be first")
|
|
testing.expect(t, is_prefix_sorted(data, k), "Prefix should be sorted")
|
|
}
|
|
|
|
@(test)
|
|
test_float_negative_zero :: proc(t: ^testing.T) {
|
|
neg_zero := transmute(f64)u64(1 << 63)
|
|
pos_zero := f64(0.0)
|
|
original := []f64{1.0, neg_zero, -1.0, pos_zero, 0.5}
|
|
data := make([]f64, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 3
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
testing.expect(t, is_prefix_sorted(data, k), "Prefix should be sorted")
|
|
testing.expect(t, elements_preserved(original, data), "Elements should be preserved")
|
|
}
|
|
|
|
@(test)
|
|
test_float_nan_positive :: proc(t: ^testing.T) {
|
|
nan := math.nan_f64()
|
|
original := []f64{3.0, nan, 1.0, 5.0, 2.0, nan}
|
|
data := make([]f64, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 3
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
// NaNs should be at the end, not in the prefix
|
|
for i := 0; i < k; i += 1 {
|
|
testing.expectf(t, !math.is_nan(data[i]), "NaN should not be in prefix at position %d", i)
|
|
}
|
|
// The k smallest non-NaN values should be in the prefix
|
|
testing.expect(t, data[0] == 1.0, "First should be 1.0")
|
|
testing.expect(t, data[1] == 2.0, "Second should be 2.0")
|
|
testing.expect(t, data[2] == 3.0, "Third should be 3.0")
|
|
}
|
|
|
|
@(test)
|
|
test_float_nan_negative :: proc(t: ^testing.T) {
|
|
// Create a negative NaN (sign bit set)
|
|
neg_nan := transmute(f64)(u64(0xFFF8_0000_0000_0000))
|
|
pos_nan := math.nan_f64()
|
|
original := []f64{3.0, neg_nan, 1.0, pos_nan, -5.0, 2.0}
|
|
data := make([]f64, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 3
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
// Neither positive nor negative NaNs should be in the prefix
|
|
for i := 0; i < k; i += 1 {
|
|
testing.expectf(t, !math.is_nan(data[i]), "NaN should not be in prefix at position %d", i)
|
|
}
|
|
// The k smallest non-NaN values should be in the prefix: -5, 1, 2
|
|
testing.expect(t, data[0] == -5.0, "First should be -5.0")
|
|
testing.expect(t, data[1] == 1.0, "Second should be 1.0")
|
|
testing.expect(t, data[2] == 2.0, "Third should be 2.0")
|
|
}
|
|
|
|
@(test)
|
|
test_float_all_nan :: proc(t: ^testing.T) {
|
|
nan := math.nan_f64()
|
|
neg_nan := transmute(f64)(u64(0xFFF8_0000_0000_0000))
|
|
original := []f64{nan, neg_nan, nan, neg_nan}
|
|
data := make([]f64, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 2
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
// All elements are NaN, so prefix will contain NaNs
|
|
// Just verify no crash and elements preserved
|
|
for i := 0; i < len(data); i += 1 {
|
|
testing.expect(t, math.is_nan(data[i]), "All elements should still be NaN")
|
|
}
|
|
}
|
|
|
|
@(test)
|
|
test_float_nan_with_infinity :: proc(t: ^testing.T) {
|
|
nan := math.nan_f64()
|
|
inf := math.inf_f64(1)
|
|
neg_inf := math.inf_f64(-1)
|
|
original := []f64{inf, nan, neg_inf, 0.0, nan, 1.0}
|
|
data := make([]f64, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 3
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
// NaNs should sort after +infinity, so prefix should have: neg_inf, 0, 1
|
|
for i := 0; i < k; i += 1 {
|
|
testing.expectf(t, !math.is_nan(data[i]), "NaN should not be in prefix at position %d", i)
|
|
}
|
|
testing.expect(t, data[0] == neg_inf, "First should be -inf")
|
|
testing.expect(t, data[1] == 0.0, "Second should be 0.0")
|
|
testing.expect(t, data[2] == 1.0, "Third should be 1.0")
|
|
}
|
|
|
|
//----- f32 tests ----------------------------------
|
|
|
|
@(test)
|
|
test_f32_basic :: proc(t: ^testing.T) {
|
|
original := []f32{3.14, 2.71, 1.41, 1.73, 2.23, 0.57}
|
|
data := make([]f32, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 3
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "f32 basic failed: %s", reason)
|
|
}
|
|
|
|
@(test)
|
|
test_f32_with_negative :: proc(t: ^testing.T) {
|
|
original := []f32{-1.5, 2.5, -3.5, 0.0, 1.0, -0.5}
|
|
data := make([]f32, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 3
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "f32 with negative failed: %s", reason)
|
|
}
|
|
|
|
@(test)
|
|
test_f32_large_array :: proc(t: ^testing.T) {
|
|
size := 1000
|
|
original := make([]f32, size)
|
|
defer delete(original)
|
|
for i := 0; i < size; i += 1 {
|
|
original[i] = f32(rand.int31() % 10000) - 5000
|
|
}
|
|
|
|
data := make([]f32, size)
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 50
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "f32 large array failed: %s", reason)
|
|
}
|
|
|
|
//----- f16 tests ----------------------------------
|
|
|
|
@(test)
|
|
test_f16_basic :: proc(t: ^testing.T) {
|
|
original := []f16{3.14, 2.71, 1.41, 1.73, 2.23, 0.57}
|
|
data := make([]f16, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 3
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "f16 basic failed: %s", reason)
|
|
}
|
|
|
|
@(test)
|
|
test_f16_with_negative :: proc(t: ^testing.T) {
|
|
original := []f16{-1.5, 2.5, -3.5, 0.0, 1.0, -0.5}
|
|
data := make([]f16, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 3
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "f16 with negative failed: %s", reason)
|
|
}
|
|
|
|
@(test)
|
|
test_f32_nan :: proc(t: ^testing.T) {
|
|
nan := math.nan_f32()
|
|
neg_nan := transmute(f32)(u32(0xFFC0_0000))
|
|
original := []f32{3.0, nan, 1.0, neg_nan, -5.0, 2.0}
|
|
data := make([]f32, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 3
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
// Neither positive nor negative NaNs should be in the prefix
|
|
for i := 0; i < k; i += 1 {
|
|
testing.expectf(t, !math.is_nan(data[i]), "NaN should not be in prefix at position %d", i)
|
|
}
|
|
// The k smallest non-NaN values should be in the prefix: -5, 1, 2
|
|
testing.expect(t, data[0] == -5.0, "First should be -5.0")
|
|
testing.expect(t, data[1] == 1.0, "Second should be 1.0")
|
|
testing.expect(t, data[2] == 2.0, "Third should be 2.0")
|
|
}
|
|
|
|
//----- Early termination / all-equal tests ----------------------------------
|
|
|
|
@(test)
|
|
test_all_equal_early_termination :: proc(t: ^testing.T) {
|
|
// This tests the early termination optimization when all elements are equal
|
|
size := 1000
|
|
original := make([]f64, size)
|
|
defer delete(original)
|
|
for i := 0; i < size; i += 1 {
|
|
original[i] = 42.0
|
|
}
|
|
|
|
data := make([]f64, size)
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 100
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
// All elements should still be 42.0
|
|
for i := 0; i < size; i += 1 {
|
|
testing.expectf(t, data[i] == 42.0, "Element at %d changed from 42.0 to %v", i, data[i])
|
|
}
|
|
}
|
|
|
|
@(test)
|
|
test_mostly_equal_with_outliers :: proc(t: ^testing.T) {
|
|
// Tests partition behavior when most elements are equal
|
|
size := 100
|
|
original := make([]f64, size)
|
|
defer delete(original)
|
|
for i := 0; i < size; i += 1 {
|
|
original[i] = 5.0
|
|
}
|
|
// Add a few outliers
|
|
original[10] = 1.0
|
|
original[50] = 2.0
|
|
original[90] = 3.0
|
|
|
|
data := make([]f64, size)
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 5
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "Mostly equal with outliers failed: %s", reason)
|
|
|
|
// First 3 should be the outliers
|
|
testing.expect(t, data[0] == 1.0, "First should be 1.0")
|
|
testing.expect(t, data[1] == 2.0, "Second should be 2.0")
|
|
testing.expect(t, data[2] == 3.0, "Third should be 3.0")
|
|
}
|
|
|
|
//----- Boundary K value tests ----------------------------------
|
|
|
|
@(test)
|
|
test_k_equals_length_minus_one :: proc(t: ^testing.T) {
|
|
original := []f64{5, 1, 9, 3, 7, 2, 8, 4, 6}
|
|
data := make([]f64, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := len(data) - 1
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "k=n-1 failed: %s", reason)
|
|
}
|
|
|
|
@(test)
|
|
test_k_half_length :: proc(t: ^testing.T) {
|
|
original := []f64{10, 2, 8, 4, 6, 5, 7, 3, 9, 1}
|
|
data := make([]f64, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := len(data) / 2
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "k=n/2 failed: %s", reason)
|
|
}
|
|
|
|
//----- Randomized stress tests ----------------------------------
|
|
|
|
@(test)
|
|
test_random_small_arrays :: proc(t: ^testing.T) {
|
|
for size := 2; size <= 20; size += 1 {
|
|
for k := 1; k <= size; k += 1 {
|
|
original := make([]f64, size)
|
|
defer delete(original)
|
|
for i := 0; i < size; i += 1 {
|
|
original[i] = f64(rand.int31() % 100)
|
|
}
|
|
|
|
data := make([]f64, size)
|
|
defer delete(data)
|
|
copy(data, original)
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "Random small (size=%d, k=%d) failed: %s", size, k, reason)
|
|
}
|
|
}
|
|
}
|
|
|
|
@(test)
|
|
test_random_medium_array :: proc(t: ^testing.T) {
|
|
size := 1000
|
|
|
|
for trial := 0; trial < 10; trial += 1 {
|
|
original := make([]f64, size)
|
|
defer delete(original)
|
|
for i := 0; i < size; i += 1 {
|
|
original[i] = f64(rand.int31())
|
|
}
|
|
|
|
for _, k in ([]int{1, 10, 50, 100, 500, 999, 1000}) {
|
|
data := make([]f64, size)
|
|
defer delete(data)
|
|
copy(data, original)
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "Random medium (trial=%d, k=%d) failed: %s", trial, k, reason)
|
|
}
|
|
}
|
|
}
|
|
|
|
@(test)
|
|
test_random_large_array :: proc(t: ^testing.T) {
|
|
size := 10000
|
|
|
|
original := make([]f64, size)
|
|
defer delete(original)
|
|
for i := 0; i < size; i += 1 {
|
|
original[i] = f64(rand.int31())
|
|
}
|
|
|
|
for _, k in ([]int{1, 10, 100, 1000, 5000, 9999, 10000}) {
|
|
data := make([]f64, size)
|
|
defer delete(data)
|
|
copy(data, original)
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "Random large (k=%d) failed: %s", k, reason)
|
|
}
|
|
}
|
|
|
|
//----- Regression / specific bug pattern tests ----------------------------------
|
|
|
|
@(test)
|
|
test_three_elements_all_permutations :: proc(t: ^testing.T) {
|
|
perms := [][3]f64{{1, 2, 3}, {1, 3, 2}, {2, 1, 3}, {2, 3, 1}, {3, 1, 2}, {3, 2, 1}}
|
|
|
|
for perm in perms {
|
|
for k := 1; k <= 3; k += 1 {
|
|
data := []f64{perm[0], perm[1], perm[2]}
|
|
original := []f64{perm[0], perm[1], perm[2]}
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "Permutation %v with k=%d failed: %s", perm, k, reason)
|
|
}
|
|
}
|
|
}
|
|
|
|
@(test)
|
|
test_duplicate_at_kth_position :: proc(t: ^testing.T) {
|
|
original := []f64{1, 3, 3, 3, 5, 6, 7}
|
|
data := make([]f64, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 4
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "Duplicate at kth position failed: %s", reason)
|
|
|
|
testing.expect(t, data[0] == 1, "First element should be 1")
|
|
count_threes := 0
|
|
for i := 1; i < k; i += 1 {
|
|
if data[i] == 3 do count_threes += 1
|
|
}
|
|
testing.expect(t, count_threes == 3, "Should have exactly 3 threes in prefix")
|
|
}
|
|
|
|
//----- Special sequence tests ----------------------------------
|
|
|
|
@(test)
|
|
test_fibonacci_sequence :: proc(t: ^testing.T) {
|
|
original := []f64{1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144}
|
|
data := make([]f64, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 6
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "Fibonacci sequence failed: %s", reason)
|
|
}
|
|
|
|
@(test)
|
|
test_powers_of_two_reversed :: proc(t: ^testing.T) {
|
|
original := []f64{256, 128, 64, 32, 16, 8, 4, 2, 1}
|
|
data := make([]f64, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 5
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "Powers of two reversed failed: %s", reason)
|
|
|
|
expected := []f64{1, 2, 4, 8, 16}
|
|
for i := 0; i < k; i += 1 {
|
|
testing.expectf(t, data[i] == expected[i], "Expected %v at position %d, got %v", expected[i], i, data[i])
|
|
}
|
|
}
|
|
|
|
@(test)
|
|
test_arithmetic_sequence :: proc(t: ^testing.T) {
|
|
original := make([]f64, 20)
|
|
defer delete(original)
|
|
for i := 0; i < 20; i += 1 {
|
|
original[i] = f64(100 - i * 5)
|
|
}
|
|
|
|
data := make([]f64, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 10
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "Arithmetic sequence failed: %s", reason)
|
|
}
|
|
|
|
//----- Float boundary tests ----------------------------------
|
|
|
|
@(test)
|
|
test_max_min_float :: proc(t: ^testing.T) {
|
|
original := []f64{max(f64), -max(f64), 0, max(f64), -max(f64), 1, -1}
|
|
data := make([]f64, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 4
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "Max/min float failed: %s", reason)
|
|
}
|
|
|
|
//----- Repeated partial sort tests ----------------------------------
|
|
|
|
@(test)
|
|
test_idempotent :: proc(t: ^testing.T) {
|
|
original := []f64{9, 1, 8, 2, 7, 3, 6, 4, 5}
|
|
data := make([]f64, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 5
|
|
|
|
partial_sort_float(data, k)
|
|
first_result := make([]f64, len(data))
|
|
defer delete(first_result)
|
|
copy(first_result, data)
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
for i := 0; i < k; i += 1 {
|
|
testing.expectf(t, data[i] == first_result[i], "Partial sort not idempotent at position %d", i)
|
|
}
|
|
}
|
|
|
|
@(test)
|
|
test_increasing_k :: proc(t: ^testing.T) {
|
|
original := []f64{10, 2, 8, 4, 6, 1, 9, 3, 7, 5}
|
|
|
|
for k := 1; k <= len(original); k += 1 {
|
|
data := make([]f64, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "Increasing k=%d failed: %s", k, reason)
|
|
}
|
|
}
|
|
|
|
//----- Key extraction tests (partial_sort_by_fkey) ----------------------------------
|
|
|
|
@(test)
|
|
test_struct_field_f64_key :: proc(t: ^testing.T) {
|
|
Item :: struct {
|
|
priority: f64,
|
|
name: string,
|
|
}
|
|
|
|
original := []Item{{5.0, "e"}, {2.0, "b"}, {8.0, "h"}, {1.0, "a"}, {9.0, "i"}, {3.0, "c"}}
|
|
data := make([]Item, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 3
|
|
|
|
partial_sort_by_fkey(data, k, proc(x: Item) -> f64 {
|
|
return x.priority
|
|
})
|
|
|
|
testing.expect(t, data[0].priority == 1, "First should have priority 1")
|
|
testing.expect(t, data[1].priority == 2, "Second should have priority 2")
|
|
testing.expect(t, data[2].priority == 3, "Third should have priority 3")
|
|
|
|
for i := 0; i < k - 1; i += 1 {
|
|
testing.expect(t, data[i].priority <= data[i + 1].priority, "Prefix should be sorted by priority")
|
|
}
|
|
}
|
|
|
|
@(test)
|
|
test_struct_field_f32_key :: proc(t: ^testing.T) {
|
|
Item :: struct {
|
|
score: f32,
|
|
id: int,
|
|
}
|
|
|
|
original := []Item{{5.5, 0}, {2.2, 1}, {8.8, 2}, {1.1, 3}, {9.9, 4}, {3.3, 5}}
|
|
data := make([]Item, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 3
|
|
|
|
partial_sort_by_fkey(data, k, proc(x: Item) -> f32 {
|
|
return x.score
|
|
})
|
|
|
|
testing.expect(t, data[0].score == 1.1, "First should have score 1.1")
|
|
testing.expect(t, data[1].score == 2.2, "Second should have score 2.2")
|
|
testing.expect(t, data[2].score == 3.3, "Third should have score 3.3")
|
|
}
|
|
|
|
@(test)
|
|
test_negative_key :: proc(t: ^testing.T) {
|
|
// Sort by negative value (effectively descending order)
|
|
original := []f64{1, 5, 2, 8, 3, 9, 4}
|
|
data := make([]f64, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 3
|
|
|
|
partial_sort_by_fkey(data, k, proc(x: f64) -> f64 {
|
|
return -x
|
|
})
|
|
|
|
// Should get largest 3 values in descending order
|
|
testing.expect(t, data[0] == 9, "First should be 9")
|
|
testing.expect(t, data[1] == 8, "Second should be 8")
|
|
testing.expect(t, data[2] == 5, "Third should be 5")
|
|
}
|
|
|
|
@(test)
|
|
test_absolute_value_key :: proc(t: ^testing.T) {
|
|
original := []f64{-5, 3, -1, 4, -2, 0, 2, -3}
|
|
data := make([]f64, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 4
|
|
|
|
partial_sort_by_fkey(data, k, proc(x: f64) -> f64 {
|
|
return x if x >= 0 else -x
|
|
})
|
|
|
|
abs_values := make([]f64, k)
|
|
defer delete(abs_values)
|
|
for i := 0; i < k; i += 1 {
|
|
abs_values[i] = data[i] if data[i] >= 0 else -data[i]
|
|
}
|
|
|
|
testing.expect(t, abs_values[0] == 0, "First should be 0")
|
|
testing.expect(t, abs_values[1] == 1, "Second should have abs value 1")
|
|
testing.expect(t, abs_values[2] == 2, "Third should have abs value 2")
|
|
testing.expect(t, abs_values[3] == 2, "Fourth should have abs value 2")
|
|
|
|
for i := 0; i < k - 1; i += 1 {
|
|
testing.expect(t, abs_values[i] <= abs_values[i + 1], "Prefix should be sorted by absolute value")
|
|
}
|
|
}
|
|
|
|
@(test)
|
|
test_squared_key :: proc(t: ^testing.T) {
|
|
// Sort by squared value
|
|
original := []f64{-3, 1, -2, 0, 2, -1, 3}
|
|
data := make([]f64, len(original))
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 4
|
|
|
|
partial_sort_by_fkey(data, k, proc(x: f64) -> f64 {
|
|
return x * x
|
|
})
|
|
|
|
// The 4 smallest squares are 0, 1, 1, 4 (from 0, 1, -1, 2 or -2)
|
|
squared := make([]f64, k)
|
|
defer delete(squared)
|
|
for i := 0; i < k; i += 1 {
|
|
squared[i] = data[i] * data[i]
|
|
}
|
|
|
|
testing.expect(t, squared[0] == 0, "First squared should be 0")
|
|
testing.expect(t, squared[1] == 1, "Second squared should be 1")
|
|
testing.expect(t, squared[2] == 1, "Third squared should be 1")
|
|
testing.expect(t, squared[3] == 4, "Fourth squared should be 4")
|
|
}
|
|
|
|
//----- Pre-computed sort path tests (large k) ----------------------------------
|
|
|
|
@(test)
|
|
test_precompute_sort_path :: proc(t: ^testing.T) {
|
|
// This test specifically exercises the pre-computed keys sort path
|
|
// k >= PRECOMPUTE_SORT_THRESHOLD (256)
|
|
size := 1000
|
|
original := make([]f64, size)
|
|
defer delete(original)
|
|
for i := 0; i < size; i += 1 {
|
|
original[i] = f64(rand.int31() % 10000) - 5000
|
|
}
|
|
|
|
data := make([]f64, size)
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 300 // >= 256, triggers SIMD path
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "Precompute sort path failed: %s", reason)
|
|
}
|
|
|
|
@(test)
|
|
test_precompute_sort_f32 :: proc(t: ^testing.T) {
|
|
size := 500
|
|
original := make([]f32, size)
|
|
defer delete(original)
|
|
for i := 0; i < size; i += 1 {
|
|
original[i] = f32(rand.int31() % 10000) - 5000
|
|
}
|
|
|
|
data := make([]f32, size)
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 300
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "Precompute sort f32 failed: %s", reason)
|
|
}
|
|
|
|
@(test)
|
|
test_precompute_sort_with_nan :: proc(t: ^testing.T) {
|
|
size := 500
|
|
nan := math.nan_f64()
|
|
neg_nan := transmute(f64)(u64(0xFFF8_0000_0000_0000))
|
|
|
|
original := make([]f64, size)
|
|
defer delete(original)
|
|
for i := 0; i < size; i += 1 {
|
|
original[i] = f64(rand.int31() % 10000) - 5000
|
|
}
|
|
// Sprinkle in some NaNs
|
|
original[10] = nan
|
|
original[100] = neg_nan
|
|
original[200] = nan
|
|
original[300] = neg_nan
|
|
|
|
data := make([]f64, size)
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 300
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
// Verify no NaNs in prefix (they should sort to end)
|
|
nan_in_prefix := false
|
|
for i := 0; i < k; i += 1 {
|
|
if math.is_nan(data[i]) do nan_in_prefix = true
|
|
}
|
|
testing.expect(t, !nan_in_prefix, "NaNs should not be in prefix with precompute sort")
|
|
}
|
|
|
|
@(test)
|
|
test_precompute_sort_all_negative :: proc(t: ^testing.T) {
|
|
size := 500
|
|
original := make([]f64, size)
|
|
defer delete(original)
|
|
for i := 0; i < size; i += 1 {
|
|
original[i] = -f64(rand.int31() % 10000) - 1
|
|
}
|
|
|
|
data := make([]f64, size)
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 300
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "Precompute sort all negative failed: %s", reason)
|
|
}
|
|
|
|
//----- Heap selection path tests (small k, large n) ----------------------------------
|
|
|
|
@(test)
|
|
test_heap_select_path :: proc(t: ^testing.T) {
|
|
// This test specifically exercises the heap selection path
|
|
// k < HEAP_SELECT_K_THRESHOLD (64) and n > HEAP_SELECT_N_THRESHOLD (131072)
|
|
size := 150000 // > 131072
|
|
original := make([]f64, size)
|
|
defer delete(original)
|
|
for i := 0; i < size; i += 1 {
|
|
original[i] = f64(rand.int31())
|
|
}
|
|
|
|
data := make([]f64, size)
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 10 // < 64, triggers heap path
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "Heap select path failed: %s", reason)
|
|
}
|
|
|
|
@(test)
|
|
test_heap_select_with_key :: proc(t: ^testing.T) {
|
|
Item :: struct {
|
|
value: f64,
|
|
id: int,
|
|
}
|
|
|
|
size := 150000
|
|
original := make([]Item, size)
|
|
defer delete(original)
|
|
for i := 0; i < size; i += 1 {
|
|
original[i] = Item{f64(rand.int31()), i}
|
|
}
|
|
|
|
data := make([]Item, size)
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 20
|
|
|
|
partial_sort_by_fkey(data, k, proc(x: Item) -> f64 {
|
|
return x.value
|
|
})
|
|
|
|
// Verify prefix is sorted by value
|
|
for i := 0; i < k - 1; i += 1 {
|
|
testing.expectf(t, data[i].value <= data[i + 1].value, "Heap select prefix not sorted at %d", i)
|
|
}
|
|
}
|
|
|
|
@(test)
|
|
test_heap_select_negative_values :: proc(t: ^testing.T) {
|
|
size := 150000
|
|
original := make([]f64, size)
|
|
defer delete(original)
|
|
for i := 0; i < size; i += 1 {
|
|
original[i] = f64(rand.int31()) - f64(max(i32) / 2)
|
|
}
|
|
|
|
data := make([]f64, size)
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 32
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "Heap select with negatives failed: %s", reason)
|
|
}
|
|
|
|
//----- Skip empty passes tests ----------------------------------
|
|
|
|
@(test)
|
|
test_skip_empty_passes_uniform :: proc(t: ^testing.T) {
|
|
// All elements have same lower bits, should skip some radix passes
|
|
size := 100
|
|
original := make([]f64, size)
|
|
defer delete(original)
|
|
for i := 0; i < size; i += 1 {
|
|
// Values that differ only in higher bits
|
|
original[i] = f64(i) * 1000.0
|
|
}
|
|
|
|
data := make([]f64, size)
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 20
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "Skip empty passes uniform failed: %s", reason)
|
|
}
|
|
|
|
@(test)
|
|
test_skip_empty_passes_integers :: proc(t: ^testing.T) {
|
|
// Integer values as floats - lower mantissa bits are zero
|
|
size := 100
|
|
original := make([]f64, size)
|
|
defer delete(original)
|
|
for i := 0; i < size; i += 1 {
|
|
original[i] = f64(size - i)
|
|
}
|
|
|
|
data := make([]f64, size)
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 30
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "Skip empty passes integers failed: %s", reason)
|
|
|
|
// Verify exact order
|
|
for i := 0; i < k; i += 1 {
|
|
testing.expectf(t, data[i] == f64(i + 1), "Expected %v at %d, got %v", f64(i + 1), i, data[i])
|
|
}
|
|
}
|
|
|
|
//----- Key extraction tests (partial_sort_by_fkey) ----------------------------------
|
|
|
|
@(test)
|
|
test_large_struct_array_by_key :: proc(t: ^testing.T) {
|
|
Record :: struct {
|
|
data: [10]int,
|
|
score: f64,
|
|
}
|
|
|
|
size := 500
|
|
original := make([]Record, size)
|
|
defer delete(original)
|
|
for i := 0; i < size; i += 1 {
|
|
original[i].score = f64(rand.int31() % 10000)
|
|
}
|
|
|
|
data := make([]Record, size)
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 20
|
|
|
|
partial_sort_by_fkey(data, k, proc(x: Record) -> f64 {
|
|
return x.score
|
|
})
|
|
|
|
// Verify prefix is sorted by score
|
|
for i := 0; i < k - 1; i += 1 {
|
|
testing.expectf(
|
|
t,
|
|
data[i].score <= data[i + 1].score,
|
|
"Prefix not sorted at position %d: %v > %v",
|
|
i,
|
|
data[i].score,
|
|
data[i + 1].score,
|
|
)
|
|
}
|
|
|
|
// Verify partition property
|
|
if k < size {
|
|
max_in_prefix := data[0].score
|
|
for i := 1; i < k; i += 1 {
|
|
if data[i].score > max_in_prefix do max_in_prefix = data[i].score
|
|
}
|
|
for i := k; i < size; i += 1 {
|
|
testing.expectf(t, data[i].score >= max_in_prefix, "Partition property violated at %d", i)
|
|
}
|
|
}
|
|
}
|
|
|
|
//----- var_bits ctz skip tests ----------------------------------
|
|
|
|
@(test)
|
|
test_var_bits_low_bits_identical :: proc(t: ^testing.T) {
|
|
// Keys that differ only in high bits - should skip low passes via ctz(var_bits)
|
|
size := 200
|
|
original := make([]f64, size)
|
|
defer delete(original)
|
|
// Values 256, 512, 768, ... differ only in bits 8+ (low byte is 0)
|
|
for i := 0; i < size; i += 1 {
|
|
original[i] = f64((i + 1) * 256)
|
|
}
|
|
|
|
data := make([]f64, size)
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 50
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "var_bits low bits identical failed: %s", reason)
|
|
}
|
|
|
|
@(test)
|
|
test_var_bits_high_bits_identical :: proc(t: ^testing.T) {
|
|
// Keys that differ only in low bits - should limit high passes via clz(var_bits)
|
|
size := 200
|
|
original := make([]f64, size)
|
|
defer delete(original)
|
|
// Small positive values that only differ in low bits
|
|
for i := 0; i < size; i += 1 {
|
|
original[i] = f64(i % 256) + 1000.0
|
|
}
|
|
|
|
data := make([]f64, size)
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 50
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "var_bits high bits identical failed: %s", reason)
|
|
}
|
|
|
|
@(test)
|
|
test_var_bits_single_differing_bit :: proc(t: ^testing.T) {
|
|
// Only one bit differs across all elements
|
|
size := 100
|
|
original := make([]f64, size)
|
|
defer delete(original)
|
|
for i := 0; i < size; i += 1 {
|
|
if i % 2 == 0 {
|
|
original[i] = 1000.0
|
|
} else {
|
|
original[i] = 1001.0
|
|
}
|
|
}
|
|
|
|
data := make([]f64, size)
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 50
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "var_bits single differing bit failed: %s", reason)
|
|
|
|
// First 50 should all be 1000.0
|
|
for i := 0; i < k; i += 1 {
|
|
testing.expectf(t, data[i] == 1000.0, "Expected 1000.0 at %d, got %v", i, data[i])
|
|
}
|
|
}
|
|
|
|
//----- MSD ping-pong tests ----------------------------------
|
|
|
|
@(test)
|
|
test_msd_multiple_levels :: proc(t: ^testing.T) {
|
|
// Force multiple MSD levels by having widely distributed values
|
|
size := 1000
|
|
original := make([]f64, size)
|
|
defer delete(original)
|
|
for i := 0; i < size; i += 1 {
|
|
original[i] = f64(rand.int31())
|
|
}
|
|
|
|
data := make([]f64, size)
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 100 // > 32 (heap threshold), 2*100 < 1000 (triggers MSD path)
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "MSD multiple levels failed: %s", reason)
|
|
}
|
|
|
|
@(test)
|
|
test_msd_narrow_range :: proc(t: ^testing.T) {
|
|
// Values in narrow range - MSD should terminate quickly
|
|
size := 500
|
|
original := make([]f64, size)
|
|
defer delete(original)
|
|
for i := 0; i < size; i += 1 {
|
|
original[i] = f64(rand.int31() % 1000) + 50000.0
|
|
}
|
|
|
|
data := make([]f64, size)
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 100
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "MSD narrow range failed: %s", reason)
|
|
}
|
|
|
|
@(test)
|
|
test_msd_to_insertion_sort :: proc(t: ^testing.T) {
|
|
// Force MSD to fall through to insertion sort on small range
|
|
size := 200
|
|
original := make([]f64, size)
|
|
defer delete(original)
|
|
for i := 0; i < size; i += 1 {
|
|
original[i] = f64(rand.int31() % 100)
|
|
}
|
|
|
|
data := make([]f64, size)
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 50
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "MSD to insertion sort failed: %s", reason)
|
|
}
|
|
|
|
//----- 11-bit radix threshold tests ----------------------------------
|
|
|
|
@(test)
|
|
test_11bit_threshold_below :: proc(t: ^testing.T) {
|
|
// n = 8000 < 8192, should use 8-bit radix
|
|
size := 8000
|
|
original := make([]f64, size)
|
|
defer delete(original)
|
|
for i := 0; i < size; i += 1 {
|
|
original[i] = f64(rand.int31())
|
|
}
|
|
|
|
data := make([]f64, size)
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := size // Full sort to exercise LSD path
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "11-bit threshold below failed: %s", reason)
|
|
}
|
|
|
|
@(test)
|
|
test_11bit_threshold_at :: proc(t: ^testing.T) {
|
|
// n = 8192, should use 11-bit radix for f64
|
|
size := 8192
|
|
original := make([]f64, size)
|
|
defer delete(original)
|
|
for i := 0; i < size; i += 1 {
|
|
original[i] = f64(rand.int31())
|
|
}
|
|
|
|
data := make([]f64, size)
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := size
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "11-bit threshold at failed: %s", reason)
|
|
}
|
|
|
|
@(test)
|
|
test_11bit_threshold_above :: proc(t: ^testing.T) {
|
|
// n = 10000 > 8192, should use 11-bit radix for f64
|
|
size := 10000
|
|
original := make([]f64, size)
|
|
defer delete(original)
|
|
for i := 0; i < size; i += 1 {
|
|
original[i] = f64(rand.int31())
|
|
}
|
|
|
|
data := make([]f64, size)
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := size
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "11-bit threshold above failed: %s", reason)
|
|
}
|
|
|
|
//----- Streaming heap edge cases ----------------------------------
|
|
|
|
@(test)
|
|
test_streaming_heap_k_at_threshold :: proc(t: ^testing.T) {
|
|
// k = 32 exactly at HEAP_SELECT_K_THRESHOLD
|
|
size := 10000
|
|
original := make([]f64, size)
|
|
defer delete(original)
|
|
for i := 0; i < size; i += 1 {
|
|
original[i] = f64(rand.int31())
|
|
}
|
|
|
|
data := make([]f64, size)
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 32 // Exactly at threshold
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "Streaming heap k at threshold failed: %s", reason)
|
|
}
|
|
|
|
@(test)
|
|
test_streaming_heap_k_above_threshold :: proc(t: ^testing.T) {
|
|
// k = 33 just above HEAP_SELECT_K_THRESHOLD, should use MSD path
|
|
size := 10000
|
|
original := make([]f64, size)
|
|
defer delete(original)
|
|
for i := 0; i < size; i += 1 {
|
|
original[i] = f64(rand.int31())
|
|
}
|
|
|
|
data := make([]f64, size)
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 33 // Just above threshold
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "Streaming heap k above threshold failed: %s", reason)
|
|
}
|
|
|
|
@(test)
|
|
test_streaming_heap_all_same_then_one_smaller :: proc(t: ^testing.T) {
|
|
// Test heap replacement logic: all same values then one smaller
|
|
size := 10000
|
|
original := make([]f64, size)
|
|
defer delete(original)
|
|
for i := 0; i < size; i += 1 {
|
|
original[i] = 100.0
|
|
}
|
|
original[size - 1] = 1.0 // One smaller at the end
|
|
|
|
data := make([]f64, size)
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 10
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
testing.expect(t, data[0] == 1.0, "Smallest element should be first")
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "Streaming heap one smaller failed: %s", reason)
|
|
}
|
|
|
|
@(test)
|
|
test_streaming_heap_descending :: proc(t: ^testing.T) {
|
|
// Worst case for heap: descending order means every element replaces heap root
|
|
size := 10000
|
|
original := make([]f64, size)
|
|
defer delete(original)
|
|
for i := 0; i < size; i += 1 {
|
|
original[i] = f64(size - i)
|
|
}
|
|
|
|
data := make([]f64, size)
|
|
defer delete(data)
|
|
copy(data, original)
|
|
k := 20
|
|
|
|
partial_sort_float(data, k)
|
|
|
|
ok, reason := validate_partial_sort(original, data, k)
|
|
testing.expectf(t, ok, "Streaming heap descending failed: %s", reason)
|
|
|
|
// Verify exact values
|
|
for i := 0; i < k; i += 1 {
|
|
testing.expectf(t, data[i] == f64(i + 1), "Expected %d at position %d, got %v", i + 1, i, data[i])
|
|
}
|
|
}
|