
Using Jai's Unique and Powerful Compiler for Typesafe Units
May 24th, 2024
I've always wanted a nice library for writing typesafe math. For example if you're solving for distance in
Meters
but somehow wind up with MetersPerSecond
you know you did something wrong. This
can prevent issues such as when NASA infamously lost a Mars Orbiter due to mixing imperial and
metric units!
Many units libraries exist in various languages. None of them have sparked my joy. They're complex, have bad error messages, and make trade-offs I don't like.
Awhile back I learned Jai via Advent of Code. Since then I haven't done much with Jai. But over the past year I've found myself missing several of its features.
Feeling a spark of inspiration I wrote a small, proof-of-life units library in Jai. In doing so I learned about some pretty crazy compiler features that kinda blew me away. It makes C++ templates and Rust traits feel disgustingly archaic, imho.
I'm honestly not sure if this post is for Jai programmers, the Jai curious, or C++/Rust programmers. It's a mix.
What are Typesafe Units?
What do I mean when I say "typesafe units". This is best served with an example.
Meters<float> CalcBallisticRange( MetersPerSecond<float> speed, MetersPerSecond2<float> gravity, Meters<float> height) { float d2r = 0.01745329252f; float angle = 45.0f * d2r; float cos = std::cos(angle); float sin = std::sin(angle); auto range = (speed*cos / gravity) * (speed*sin + std::sqrt( speed*speed*sin*sin + 2.0f*gravity*initial_height)); return range; }
This is a simple function to compute the maximum range of a ballistic trajecy. The
computation is non-trivial and involves a variety of intermediate units. However all the terms cancel out and
the final result should be Meters
. If it's not Meters
then we have a bug and should
get a compile error.
I have a few requirements. Units logic must be strictly compile-time with zero run-time cost in either memory
or cpu. The storage type should be user specified - float
, int64
, etc. Values should
not be normalized to base units (meters, grams, seconds). Types must be compiler generated and must not be
manually declared. And the code should be readable by humans - no inscrutable templates or macros.
A Rust Implementation
Way back in 2019 I wrote my own units library in Rust. There are existing crates like uom but none of them meet all my requirements.
My Rust implementation is built entirely in the type system. It relies heavily on the popular typenum crate. Which exists because Rust const generics are still woefully incomplete.
// Declare a quantity struct the stores a value and units pub struct QuantityT<T, U> where T: Amount { amount: T, _u: PhantomData<U>, } // Declare SI types pub trait Ratio { fn numerator() -> i64; fn denominator() -> i64; } pub trait SIRatios { type Length: Ratio; type Mass: Ratio; type Time: Ratio; } pub trait SIExponents { type Length: Integer; type Mass: Integer; type Time: Integer; } pub trait SIUnits { type Ratios: SIRatios; type Exponents: SIExponents; } pub struct SIUnitsT<RATIOS, EXPONENTS> where RATIOS: SIRatios, EXPONENTS: SIExponents, { _r: PhantomData<RATIOS>, _e: PhantomData<EXPONENTS>, }
These types and traits let us declare types such as:
// via helpers, nice type Kilometers = Length<Kilo>; // expanded type, blech type Kilometers = QuantityT<f32, SIUnitsT< SIRatiosT<RatioT<Exp<P10, P3>,P1>, RatioT<Z0,Z0>, RatioT<Z0,Z0>>, SIExponentsT<P1, Z0, Z0>>>;
The end result is pretty delightful at least.
fn calc_ballistic_range( speed: MetersPerSecond<f32>, gravity: MetersPerSecond2<f32>, initial_height: Meters<f32>, ) -> Meters<f32> { let d2r = 0.01745329252; let angle: f32 = 45.0 * d2r; let cos = angle.cos(); let sin = angle.sin(); let range = (speed * cos / gravity) * (speed*sin + (speed*speed*sin*sin + 2.0*gravity*initial_height) .sqrt()); range }
Unfortunately the implementation is trait hell. For example:
/// Multiplying a pair of SIUnits is complicated /// Ratios: must be XOR-able /// meters * meters is fine. /// meters * kilometers is not. The output type is ambiguous. /// Exponents: can be anything. will be added together /// meters * meters = meters squared /// However exponents which become zero must zero out their corresponding ratio /// 10 kilometers per hour * 2 hours = 20 kilometers. The time Ratio was hours, but needs to be zero'd out. impl<R0, E0, R1, E1> std::ops::Mul<SIUnitsT<R1, E1>> for SIUnitsT<R0, E0> where R0: SIRatios + Xor<R1>, E0: SIExponents + std::ops::Add<E1>, R1: SIRatios, E1: SIExponents, XorOutput<R0, R1>: SIRatios + ReduceWith<AddOutput<E0, E1>> + Default, AddOutput<E0, E1>: SIExponents + Default, ReduceWithOutput<XorOutput<R0, R1>, AddOutput<E0, E1>>: SIRatios + Default, ReduceOutput<SIUnitsT<XorOutput<R0, R1>, AddOutput<E0, E1>>>: Default, { type Output = ReduceOutput<SIUnitsT<XorOutput<R0, R1>, AddOutput<E0, E1>>>; fn mul(self, _other: SIUnitsT<R1, E1>) -> Self::Output { Default::default() } }
The nice thing is we're using the type system. The `SIUnits` will compile away to nothing.
The gross thing is that type system logic is hideously complex. Worse, it's effectively a different language.
All that type magic doesn't look or feel like Rust code. It's alien. Worse still the logic is all spread across
six different impls
just for std::ops::mul
.
The error messages for typenum
are also genuinely awful. For example:
expected struct `QuantityT<_, SIUnitsT<_, SIExponentsT<typenum::int::PInt<typenum::uint::UInt<typenum::uint::UTerm, typenum::bit::B1>>, _, typenum::int::NInt<typenum::uint::UInt<typenum::uint::UTerm, typenum::bit::B1>>>>>` found struct `QuantityT<f32, SIUnitsT<SIRatiosT<RatioT<typenum::int::PInt<typenum::uint::UInt<typenum::uint::UTerm, typenum::bit::B1>>, typenum::int::PInt<typenum::uint::UInt<typenum::uint::UTerm, typenum::bit::B1>>>, RatioZero, RatioT<typenum::int::PInt<typenum::uint::UInt<typenum::uint::UTerm, typenum::bit::B1>>, typenum::int::PInt<typenum::uint::UInt<typenum::uint::UTerm, typenum::bit::B1>>>>, SIExponentsT<typenum::int::PInt<typenum::uint::UInt<typenum::uint::UInt<typenum::uint::UTerm, typenum::bit::B1>, typenum::bit::B0>>, typenum::int::Z0, typenum::int::NInt<typenum::uint::UInt<typenum::uint::UInt<typenum::uint::UTerm, typenum::bit::B1>, typenum::bit::B0>>>>>`
Here's a Gist of my implementation: Rust units Gist.
It's
about 1500 lines. It contains 11 structs
, 21 traits
, and 74 impls
. I
never relased it because I keep waiting on const generics. Maybe some day.
I can't find a decent C++ units system. If one existed it would involve similarly grotesque templates full of
std::enable_if
and assorted wizardry.
Jai - A Breath of Fresh Air
Now let's implement a type safe units module in Jai.
I want to be clear, I barely know what I'm doing with Jai. I had to ask the community for help a half dozen times. But in the process I learned a bunch of tricks!
My first attempt was to replicate what I did with Rust with nested structs. That didn't work out so I settled on a flattened structure.
SIQuantity :: struct($DataType: Type, $LengthNum: int, $LengthDenom: int, $LengthExp: int, $MassNum: int, $MassDenom: int, $MassExp: int, $TimeNum: int, $TimeDenom: int, $TimeExp: int) { amount : DataType; }
Here we have a flattened structure the defines nine ints
to define the SI Units for
Length
Mass
, and Time
Other units omitted because games don't typically
care about them.
Helpers make it easy to declare various types.
Meters :: #bake_arguments SIQuantity( LengthNum=1, LengthDenom=1, LengthExp=1, MassNum=0, MassDenom=0, MassExp=0, TimeNum=0, TimeDenom=0, TimeExp=0); KilometersPerSecond2 :: #bake_arguments SIQuantity( LengthNum=1000, LengthDenom=1, LengthExp=1, MassNum=0, MassDenom=0, MassExp=0, TimeNum=1, TimeDenom=1, TimeExp=-2); dist := Meters(float).{5.0}; accel := KilometersPerSecond2(float).{.0098};
Now we need to implement operators like multiply. Here's where Jai has a feature I've never seen before. I'm sure it exists in some language, but certainly not C++ or Rust!
It's called #modify
and it lets you write code to compute function argument types, including the
return type. Let's do a basic example first.
mul_nums :: (a: $T, b:T) -> $R #modify { if T == int { R = s64; return true; } else if T == float { R = float64; return true; } msg := tprint("add_nums only supports int and float but was given [%]", T); return false, msg; } { result : R = cast(R)a * cast(R)b; return result; }
What's going on here? We have what Jai calls a polymorphic procedure. $T
is basically the same as
a template or generic type. This function takes two arguments of type T
and returns... something!
The return type R
is determined progammatically. In this case I've said that
int * int = s64
and float * float = float64
. If T
is of any other type
it's rejected with a nice error message.
Let's try something a slightly more more complicated.
// Super basic struct that stores a fixed-size array SimpleArray :: struct($DataType: Type, $Len: int) { data : [Len]DataType; } // Function to combine two SimpleArrays of different length into something new combine_arrays :: (a: $T/SimpleArray, b:$U/SimpleArray) -> $R #modify { // Get Type_Info (reflection) typeinfo_a := cast(*Type_Info_Struct)(T); typeinfo_b := cast(*Type_Info_Struct)(U); // Make sure storage types are same datatype_a := find_param(typeinfo_a, "DataType", Type); datatype_b := find_param(typeinfo_b, "DataType", Type); if datatype_a != datatype_b { msg := tprint("\nFATAL ERROR: Arrays must have same DataType.\n A: [%]\n B: [%]", datatype_a, datatype_b); return false, msg; } // Get lengths len_a := find_param(typeinfo_a, "Len", int); len_b := find_param(typeinfo_b, "Len", int); // Calculate new length new_len := len_a*2 + factorial(len_b) + 7; // Set output type R = #dynamic_specialize SimpleArray(DataType=datatype_a, Len=new_len); return true; } { result : R; // copy arrays into result or something return result; } main :: () { a : SimpleArray(int, 3); b : SimpleArray(int, 4); c := combine_arrays(a, b); // prints: SimpleArray(DataType=s64, Len=37) print("%\n", type_of(c)); }
Here I've defined SimpleArray
which stores a fixed size array. The function
combine_arrays
takes two SimpleArrays
of different fixed length and returns a new
array with Len = len_a*2 + factorial(len_b) + 7
.
This is a very silly operation, of course. The point here is that #modify
allows for totally
arbitrary logic. We can even call functions like factorial
!
Now let's use #modify
to determine the output units of operator*
.
operator * :: (a:$T/SIQuantity, b:$U/SIQuantity) -> $RET #modify { // Extract polymorphic constants from TypeInfo ti1 := cast(*Type_Info_Struct)(T); datatype_a := find_param(ti1, "DataType", Type); ln1 := find_param(ti1, "LengthNum", int); ld1 := find_param(ti1, "LengthDenom", int); le1 := find_param(ti1, "LengthExp", int); mn1 := find_param(ti1, "MassNum", int); md1 := find_param(ti1, "MassDenom", int); me1 := find_param(ti1, "MassExp", int); tn1 := find_param(ti1, "TimeNum", int); td1 := find_param(ti1, "TimeDenom", int); te1 := find_param(ti1, "TimeExp", int); ti2 := cast(*Type_Info_Struct)(U); datatype_b := find_param(ti2, "DataType", Type); ln2 := find_param(ti2, "LengthNum", int); ld2 := find_param(ti2, "LengthDenom", int); le2 := find_param(ti2, "LengthExp", int); mn2 := find_param(ti2, "MassNum", int); md2 := find_param(ti2, "MassDenom", int); me2 := find_param(ti2, "MassExp", int); tn2 := find_param(ti2, "TimeNum", int); td2 := find_param(ti2, "TimeDenom", int); te2 := find_param(ti2, "TimeExp", int); // Ensure we aren't mixing storage types if datatype_a != datatype_b { msg := tprint("\nFATAL ERROR: Quantities must have same DataType.\n T: [%]\n U: [%]", T, U); return false, msg; } // Ratios must be same OR one must be zero ln_ok := ln1 == ln2 || ln1 == 0 || ln2 == 0; ld_ok := ld1 == ld2 || ld1 == 0 || ld2 == 0; mn_ok := mn1 == mn2 || mn1 == 0 || mn2 == 0; md_ok := md1 == md2 || md1 == 0 || md2 == 0; tn_ok := tn1 == tn2 || tn1 == 0 || tn2 == 0; td_ok := td1 == td2 || td1 == 0 || td2 == 0; if !ln_ok || !ld_ok || !mn_ok || !md_ok || !tn_ok || !td_ok { msg := tprint("\nFATAL ERROR: Incompatible ratios. \n [%] \n [%]\n", T, U); return false, msg; } // Compute new exponents // If exponent is 0 then reduce ratio to 0 le := le1 + le2; ln := ifx le == 0 then 0 else ln1|ln2; ld := ifx le == 0 then 0 else ld1|ld2; me := me1 + me2; mn := ifx me == 0 then 0 else mn1|mn2; md := ifx me == 0 then 0 else md1|md2; te := te1 + te2; tn := ifx te == 0 then 0 else tn1|tn2; td := ifx te == 0 then 0 else td1|td2; // Compute final return type RET = #dynamic_specialize SIQuantity( DataType=datatype_a, LengthNum=ln, LengthDenom=ld, LengthExp=le, MassNum=mn, MassDenom=md, MassExp=me, TimeNum=tn, TimeDenom=td, TimeExp=te); return true; } { result : RET; result.amount = a.amount * b.amount; return result; }
Let's break this down.
First, we extract the constant arguments. I wish these could be accessed directly. Hopefully a future update to the Jai compiler allows this. The language is a beta afterall!
Second, we make sure the data and ratios are compatible. Meters * Kilometers
isn't allowed. I
chose not to auto-convert units. A nice error message is given if invalid types are provided.
Third, we compute the new exponent by adding exponents - Meters * Meters = Meters²
. We also reduce
ratios. If we compute KilometersPerHour * Hours = Kilometers
we want the numerator and denominator
for Time
to be zero.
Finally, we declare the final RET
type.
All of this happens at compile time. There's zero runtime over head. No extra bytes of data are stored. No runtime logic to check and convert units.
Here's what our original calc_ballistic_range
function looks like in Jai:
calc_ballistic_range :: ( speed: MetersPerSecond(float), gravity: MetersPerSecond2(float), initial_height: Meters(float) ) -> Meters(float) { d2r : float: cast(float)0.01745329252; angle : float: 45.0 * d2r; ang_cos := cos(angle); ang_sin := sin(angle); range : Meters(float) = (speed * ang_cos / gravity) * (speed * ang_sin + sqrt(speed * speed * ang_sin * ang_sin + 2.0 * gravity * initial_height)); return range; }
However we can actually go one step further. We can be generic across units.
calc_ballistic_range2 :: ( speed: $T/SIQuantity, gravity: $U/SIQuantity, initial_height: $V/SIQuantity ) -> $R #modify { // Assume return type uses the numerator/denominator of speed // If gravity/height have different num/denom there will be a compile error ti := cast(*Type_Info_Struct)(T); dt := find_param(ti, "DataType", Type); ln := find_param(ti, "LengthNum", int); ld := find_param(ti, "LengthDenom", int); R = #dynamic_specialize SIQuantity( DataType=dt, LengthNum=ln, LengthDenom=ld, LengthExp=1, MassNum=0, MassDenom=0, MassExp=0, TimeNum=0, TimeDenom=0, TimeExp=0); return true; } { // Make sure units are correct #assert T.LengthExp == 1 && T.MassExp == 0 && T.TimeExp == -1; #assert U.LengthExp == 1 && U.MassExp == 0 && U.TimeExp == -2; #assert V.LengthExp == 1 && V.MassExp == 0 && V.TimeExp == 0; d2r : speed.DataType: xx 0.01745329252; angle :: 45.0 * d2r; ang_cos := cos(angle); ang_sin := sin(angle); range := (speed * ang_cos / gravity) * (speed * ang_sin + sqrt(speed * speed * ang_sin * ang_sin + 2.0 * gravity * initial_height)); return range; } main :: () { s := MetersPerSecond(float).{5}; g := MetersPerSecond2(float).{9.8}; h := Meters(float).{1}; range := calc_ballistic_range2(s, g, h); // 3.31951261 s2 := KilometersPerSecond(float64).{.005}; g2 := KilometersPerSecond2(float64).{.0098}; h2 := Kilometers(float64).{0.001}; range2 := calc_ballistic_range2(s2, g2, h2); // 0.0033195128187733592 }
This is pretty neat! If incompatible types are provided a compile-time #assert
or
operator*
error message will be given.
I admit this level of generality is of questionable value. But I tried to do this in Rust and it is utterly intractable due to the way generics works. The fact that it works so easily in Jai sparked my joy.
Full source is here: GitHub Gist
It's about 350 lines total. This is incomplete and doesn't have all the features you'd want. But it's a solid proof of life, imho.
Feature Requests
If I had two feature requests to the Jai language team it would be this:
One, allow Type
constants to be accessed in #modify
without having to go through
reflection Type_Info
. The Type
of T
isn't constant, but it is known. It's
an ugly hoop to jump through.
Two, I wish polymorphic types could be constrained to #bake_arguments
types. It'd make error
checking cleaner and the API more obvious.
More Jai Super Powers
I'm going to assume most readers here aren't familiar with Jai. I suggest my previous post Learning Jai via Advent of Code for a decent intro.
Since this post is about Jai compile-time magic I want to quickly call out a few more Jai superpowers.
#run
The #run
directive causes code to be executed at compile-time.
// :: means compile-time constant CONSTANT_X :: #run factorial(7);
#insert
The #insert
directive allows code to be dynamically inserted. It can insert both strings and type
safe Code
. This can be combined with #run
.
get_code_string :: () -> string { return "x *= 3;"; } get_code :: () -> Code #expand { return #code { `x *= 3; }; } main :: () { x : int = 3; #insert #run get_code_string(); #insert #run get_code(); assert(x == 27); }
Abstract Syntax Tree
// function we're going to #run at compile-time comptime_modify :: (code: Code) -> Code { // covert Code to AST nodes root, expressions := compiler_get_nodes(code); // walk AST // multiply number literals by their factorial // 3 -> 3*factorial(3) -> 3*6 -> 18 for expr : expressions { if expr.kind == .LITERAL { literal := cast(*Code_Literal) expr; if literal.value_type == .NUMBER { // Compute factorial fac := factorial(literal._s64); // Modify node in place literal._s64 *= fac; } } } // convert modified nodes back to Code modified : Code = compiler_get_code(root); return modified; }
Reflection
Jai has robust run-time reflection. Since compile-time code is just plain ol regular Jai code this means you can also use reflection at compile-time.
dump_struct :: (v: *$T) { ti := type_info(T); for member : ti.members { print("Name: [%] Type: [%] Offset: [%] Size: [%]\n", member.name, member.type.type, member.offset_in_bytes, member.type.runtime_size); } } main :: () { s : Stuff; dump_struct(*s); } // Output: // Name: [num] Type: [INTEGER] Offset: [0] Size: [8] // Name: [nums] Type: [ARRAY] Offset: [8] Size: [20] // Name: [v] Type: [STRUCT] Offset: [28] Size: [12] // Name: [s] Type: [STRING] Offset: [40] Size: [16]
It's easy to imagine how this can be used to do cool things like generate serialization code.
This was a super, super brief look at a set of compile-time capabilities. I suggest reading my previous post if you'd like more details on them and the rest of Jai.
Closing Thoughts
I've been a professional programmer for 17 years now. The vast majority of my career has been in game dev writing mostly C++ with a dash of C# and an assortment of scripting languages. I adore Rust and have made heroic pushes to add support for it in my day job.
Jai's compile-time capabilities are genuinely next-level, imho. They put C++ and Rust to absolute shame. C++ template metaprogramming and Rust macros are, ahem, not good in comparison. I say this not because I hate those languages! But because I use them daily and wish they were better.
Jai was created to solve "hard problems" in programming. Articulating what exactly constitutes a hard problem is, well, hard. I think Jai's assortment of compile-time capabilities are a "big feature for hard problems" and I've regularly found myself wishing I had them.
Jai strives to keep things in Jai. All the compile-time code is just regular Jai code. It's not crazy type system algebra that requires a PhD to understand. It's just simple code. Similarly the build system is just regular Jai. It's a not a completely separate and insane hierarchy of scripting languages like Make and CMake.
Is a typesafe units library a thing worth having? Well, NASA could have saved a $327 million mission if they had one! But more importantly I learned some cool new tricks. If you've read this far hopefully you've learned a thing or two as well.
Thanks for reading.
Bonus Section
find_param
You may be wondering about the function find_param
. Here's what it looks like:
find_param :: (ti: *Type_Info_Struct, name: string, $T: Type) -> T { placeholder: T; for param: ti.specified_parameters { if param.name == name { if param.type != type_info(T) { print("\nFATAL ERROR: find_param(%, %) expected type [%] but found [%]\n", ti.name, name, T, param.type.type); assert(false); return placeholder; } value := (cast(*T) *ti.constant_storage[param.offset_into_constant_storage]).*; return value; } } print("FATAL ERROR: Failed to find parameter [%]\n", name); assert(false); return placeholder; }
Rust generic ballistic trajectory
Here is the hideous, god awful Rust code that tries to solve a ballistic trajectory generic across units.
fn calc_ballistic_range2<T, V, A, H>( speed: V, gravity: A, initial_height: H, ) -> impl Length<AmountType = T> where T: Amount + num_traits::Num + num_traits::real::Real + Copy + PartialOrd + From<u8> + From<f32>, V: Copy + Velocity<AmountType = T> + Mul<Pure<T>> + Mul<V>, A: Copy + Acceleration<AmountType = T>, H: Copy + Length<AmountType = T>, Pure<T>: Mul<A>, MulOutput<V, Pure<T>>: Div<A>, MulOutput<V, V>: Mul<Pure<T>>, MulOutput3<V, V, Pure<T>>: Mul<Pure<T>>, MulOutput<Pure<T>, A>: Mul<H>, MulOutput4<V, V, Pure<T>, Pure<T>>: Add<MulOutput3<Pure<T>, A, H>>, AddOutput<MulOutput4<V, V, Pure<T>, Pure<T>>, MulOutput3<Pure<T>, A, H>>: Sqrt, SqrtOutput<AddOutput<MulOutput4<V, V, Pure<T>, Pure<T>>, MulOutput3<Pure<T>, A, H>>>: Add<MulOutput<V, Pure<T>>>, DivOutput<MulOutput<V, Pure<T>>, A>: Mul< AddOutput< SqrtOutput< AddOutput<MulOutput4<V, V, Pure<T>, Pure<T>>, MulOutput3<Pure<T>, A, H>>, >, MulOutput<V, Pure<T>>, >, >, MulOutput< DivOutput<MulOutput<V, Pure<T>>, A>, AddOutput< SqrtOutput< AddOutput<MulOutput4<V, V, Pure<T>, Pure<T>>, MulOutput3<Pure<T>, A, H>>, >, MulOutput<V, Pure<T>>, >, >: Length<AmountType = T>, { let d2r = 0.01745329252; let angle: T = (45.0 * d2r).into(); let cos = Pure::<T>::new(angle.cos()); let sin = Pure::<T>::new(angle.sin()); let range = (speed * cos / gravity) * ((speed * speed * sin * sin + Pure::<T>::new(2.into()) * gravity * initial_height) .sqrt() + speed * sin); range }
Goodness gracious. This is totally intractable and doesn't scale.
This is a case where I really wish Rust simply supported C++ style templates. I'm told that supporting both templates and generics/constraints is problematic. But yikes.