JAX-inspired automatic differentiation and JIT compilation library for OCaml
Rune brings JAX-like capabilities to OCaml, enabling high-performance numerical computation with automatic differentiation, multi-device support (CPU, CUDA, Metal), and JIT compilation.
- N-dimensional tensor operations (arithmetic, linear algebra, etc.)
- Automatic differentiation:
grad,grads,value_and_grad,value_and_grads - Functional API for pure computations
- Multi-device backends: CPU, CUDA, Metal
- Random tensor initialization:
rand - JIT compilation to accelerate operations on GPU backends
- Seamless interop with Nx for data loading and visualization
open Rune
(* Define a simple function: sum of squares *)
let f x = sum (mul x x)
(* Create input tensor *)
let x = create Float32 [|3;3|] (Array.init 9 float_of_int)
(* Compute gradient of f at x *)
let grad_x = grad f x
(* Print gradient *)
print grad_xSee the examples/ directory for:
01-mlp: training a simple MLP withvalue_and_gradsxx-higher-derivative: computing higher-order derivatives
See the Raven monorepo README for guidelines.
ISC License. See LICENSE for details.