SIMD Programming
Codon's simd module and Vec[T, N] type provide direct, ergonomic access to
SIMD
instructions, including:
- Explicit control over vector width and element type
- Portable syntax for arithmetic, logic, comparisons, and math intrinsics
- Reductions, masking, and overflow-aware operations
Vector types¶
simd.Vec[T, N] represents an LLVM vector <N x T>:
Tis a scalar numeric type (e.g.float32,int,u32, etc.)Nis an integeral literal
Conceptually:
Vec[float32, 8]≈ "8-wide float32 SIMD register"Vec[u8, 16]≈ "16 byte-wide lanes"
You would typically use Vec in hot loops where:
- Work is data-parallel (same operation applied to many elements)
- The data is laid out in contiguous memory (arrays, lists, strings)
- You care about predictable vectorization (not relying on the auto-vectorizer)
Creating vectors¶
Broadcasting a scalar¶
The simplest way to create a vector is to broadcast a scalar into all lanes:
from simd import Vec
# 8-lane vector of all ones
v = Vec[float, 8](1.0)
# 4-lane vector of all -3
w = Vec[int, 4](-3)
Loading from pointers and arrays¶
Vec can also load data from a pointer, such as the underlying buffer of a
NumPy array. Here is an example that implements hand-vectorized addition of
two arrays:
import numpy as np
from simd import Vec
def add_arrays(a: np.ndarray[float32, 1],
b: np.ndarray[float32, 1],
out: np.ndarray[float32, 1]):
W: Literal[int] = 8 # vector width
i = 0
while i + W <= len(a):
va = Vec[float32, W](a.data + i) # load 8 elements from a[i..i+7]
vb = Vec[float32, W](b.data + i) # load 8 elements from b[i..i+7]
vc = va + vb # SIMD add
Ptr[Vec[float32, W]](out.data + i)[0] = vc # store back
i += W
# handle remaining tail elements (scalar)
while i < len(a):
out[i] = a[i] + b[i]
i += 1
Note that:
Vec[T, N](ptr)treatsptras aPtr[Vec[T, N]]and loads one SIMD vector- You can store by casting the output pointer to
Ptr[Vec[T, N]]
You can also construct vectors from lists:
data = [1.0, 2.0, 3.0, 4.0]
v = Vec[float, 4](data) # load from data[0..3]
Loading from strings (bytes)¶
For byte-sized element types (i8, u8, byte), you can read directly from a string buffer:
# Load 16 bytes from a string
s = "ABCDEFGHIJKLMNOP"
v = Vec[u8, 16](s)
# Skip first 4 bytes
v2 = Vec[u8, 16](s, 4)
SIMD arithmetic¶
All basic arithmetic operations are lane-wise:
+,-,*on integer and float vectors/on float vectors (true division)//and%on integer vectors- and so on...
Example: fused multiply-add style accumulation for a dot product:
import numpy as np
from simd import Vec
def dot(a: Ptr[float32], b: Ptr[float32], n: int) -> float32:
W: Literal[int] = 8
i = 0
acc = Vec[float32, W](0.0f32)
while i + W <= n:
va = Vec[float32, W](a + i)
vb = Vec[float32, W](b + i)
acc = acc + va * vb # lane-wise multiply + add
i += W
# horizontal reduce SIMD accumulator
total = acc.sum()
# handle tail scalars
while i < n:
total += a[i] * b[i]
i += 1
return total
Note that:
va * vbmultiplies lanesacc.sum()adds all lanes, resulting in a scalar
Masks, comparisons and branchless code¶
Comparisons between vectors (or between a vector and a scalar) produce mask vectors:
v: Vec[float, 8] = ...
mask_negative = v < Vec[float, 8](0.0) # Vec[u1, 8]
You can then use mask to select between two vectors, without branches:
def relu_vec(v: Vec[float, 8]) -> Vec[float, 8]:
zero = Vec[float, 8](0.0)
m = v > zero # positive lanes
return v.mask(zero, m) # if m: v else zero
The general pattern is:
- Build a mask via comparisons (
<,<=,>,>=,==,!=) - Use
mask(self, other, mask)to do:mask ? self : otherlane-wise
This is useful to turn control-flow into data-flow, which is conducive to SIMD programming:
- Clamping (
min/max/conditionals) - Thresholding (e.g.
x > t ? x : 0) - Selectively updating subset of lanes
Reductions and horizontal operations¶
Vec includes a reduction over addition, both for integers and floats:
v = Vec[float32, 8](...)
s = v.sum() # returns float32
Internally this uses LLVM's vector reduction intrinsics and is typically much faster than manually scattering and summing.
You can combine this with loops to build aggregate operations. Below is an example that implements L2 norm:
def l2_norm(xs: Ptr[float32], n: int) -> float32:
W: Literal[int] = 8
i = 0
acc = Vec[float32, W](0.0f32)
while i + W <= n:
v = Vec[float32, W](xs + i)
acc = acc + v * v
i += W
total = acc.sum()
while i < n:
total += xs[i] * xs[i]
i += 1
return np.sqrt(total)
Integer-specific operations¶
Integer types support additional operations, such as:
- bitwise
&,|,^ - shifts
<<,>> min,max- overflow-aware add/sub
Saturating add example¶
Suppose you want to add two u8 images but clamp to 255 on overflow.
You can use the overflow-aware addition and a mask:
def saturating_add_u8(a: Ptr[u8], b: Ptr[u8],
out: Ptr[u8], n: int):
W: Literal[int] = 16
i = 0
max_val = Vec[u8, W](255u8)
while i + W <= n:
va = Vec[u8, W](a + i)
vb = Vec[u8, W](b + i)
(sum_vec, overflow) = va.add(vb, overflow=True)
# where overflow[i] == 1, clamp to 255
clamped = max_val.mask(sum_vec, overflow)
Ptr[Vec[u8, W]](out + i)[0] = clamped
i += W
while i < n:
s = int(a[i]) + int(b[i])
out[i] = u8(255 if s > 255 else s)
i += 1
Note:
va.add(vb, overflow=True)returns(result, overflow_mask)- Use
maskto blend between a "safe" value and the raw result
Debugging¶
Scatter to list¶
Vectors can be transformed into lists:
v = Vec[int, 4](...)
lst = v.scatter() # List[int] of length 4
print(lst) # e.g. [1, 1, 1, 1]
Printing vectors¶
Vectors can be printed directly:
print(v) # e.g. "<1,1,1,1>"