Skip to content
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
e03f422
Draft to figure out better StructArray support
ptiede Feb 25, 2026
7c425b8
Simplify and generalize structarray type conversion
wsmoses Feb 25, 2026
bb28db8
Merge remote-tracking branch 'upstream/satc' into ptiede-structarrays
ptiede Feb 25, 2026
f3080b6
Start adding StaticArray support
ptiede Feb 25, 2026
5832eb4
Add StaticArray support and tweak elem_apply_while_loop to select cor…
ptiede Feb 26, 2026
228187c
Revert tracing.jl
ptiede Feb 26, 2026
a3c4ac3
Remove info debug
ptiede Feb 26, 2026
71cec58
Merge branch 'main' into ptiede-structarrays
ptiede Feb 26, 2026
19205c8
Remove get_ith
ptiede Feb 26, 2026
78ff71c
Add _copyto!
ptiede Mar 1, 2026
47f7918
Merge branch 'main' into ptiede-structarrays
ptiede Mar 2, 2026
b1b3fc3
format
ptiede Mar 2, 2026
cc65c11
Fix broken test and add new tests
ptiede Mar 2, 2026
611d095
format
ptiede Mar 2, 2026
3ac9140
add StaticArrays
ptiede Mar 2, 2026
399258e
Add LinearAlgebra
ptiede Mar 2, 2026
b29cc9a
Remove unused function
ptiede Mar 20, 2026
aa68406
Merge branch 'main' into ptiede-structarrays
ptiede Mar 20, 2026
9dce679
Reuse the known destination for while loop if possible
ptiede Mar 22, 2026
80f1031
Merge branch 'main' into ptiede-structarrays
ptiede Mar 23, 2026
a2ea5cf
Update ext/ReactantStructArraysExt.jl
ptiede Mar 23, 2026
957132b
Proposed improved support for SArrays
ptiede Mar 23, 2026
9e6bb90
fix dumb mistake
ptiede Mar 23, 2026
baf5177
Add additional changes for StaticArrays
ptiede Mar 24, 2026
358507b
Apply suggestions from code review
ptiede Mar 24, 2026
f754c6c
Cleanup
ptiede Mar 24, 2026
1456ac1
Update
ptiede Mar 25, 2026
9e95692
Update
ptiede Mar 25, 2026
56bcab6
Format
ptiede Mar 25, 2026
f28be27
Fix for code review
ptiede Mar 25, 2026
52e7b15
Add comments
ptiede Mar 25, 2026
79516cb
So dumb
ptiede Mar 25, 2026
1ece166
Merge branch 'main' into ptiede-structarrays
ptiede Mar 31, 2026
2212e49
Correct comment in overloaded_mul function
ptiede Mar 31, 2026
b3d3661
Update to remove anonymous functions
ptiede Apr 1, 2026
3ca3dc2
Update
ptiede Apr 1, 2026
3dc5d64
Add a complex test
ptiede Apr 1, 2026
6253791
Update
ptiede Apr 1, 2026
446d605
Merge branch 'main' into ptiede-structarrays
avik-pal Apr 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"
Expand Down Expand Up @@ -91,6 +92,7 @@ ReactantOffsetArraysExt = "OffsetArrays"
ReactantOneHotArraysExt = "OneHotArrays"
ReactantPythonCallExt = "PythonCall"
ReactantRandom123Ext = "Random123"
ReactantStaticArraysExt = "StaticArrays"
ReactantSparseArraysExt = "SparseArrays"
ReactantSpecialFunctionsExt = "SpecialFunctions"
ReactantStatisticsExt = "Statistics"
Expand Down Expand Up @@ -149,6 +151,7 @@ Setfield = "1.1.2"
Sockets = "1.10"
SparseArrays = "1.10"
SpecialFunctions = "2.4"
StaticArrays = "1"
StableRNGs = "1.0.4"
Statistics = "1.10"
StructArrays = "0.7.2"
Expand Down
41 changes: 41 additions & 0 deletions ext/ReactantStaticArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
module ReactantStaticArraysExt

using Reactant
import Reactant.TracedRArrayOverrides: overloaded_map, overloaded_mapreduce
import Reactant.TracedLinearAlgebra: overloaded_mul

using StaticArrays: SArray, StaticArray

const SAReact{Sz,T} = StaticArray{Sz,T} where {Sz<:Tuple,T<:Reactant.TracedRNumber}

Base.@nospecializeinfer function Reactant.traced_type_inner(
@nospecialize(FA::Type{SArray{S,T,N,L}}),
seen,
mode::Reactant.TraceMode,
@nospecialize(track_numbers::Type),
@nospecialize(ndevices),
@nospecialize(runtime)
) where {S,T,N,L}
T_traced = Reactant.traced_type_inner(T, seen, mode, track_numbers, ndevices, runtime)
return SArray{S,T_traced,N,L}
end

function Reactant.materialize_traced_array(x::SAReact)
return x
end

# We don't want to overload map on StaticArrays since it is likely better to just unroll things
overloaded_map(f, a::SAReact, rest::SAReact...) = f.(a, rest...)
overloaded_mapreduce(f, op, a::SAReact; kwargs...) = mapreduce(f, op, a, kwargs...)

function overloaded_mul(A::SAReact, B::SAReact, alpha::Number=true, beta::Number=false)
# beta is not supported since it is zero by default in Reactant
# (similar is zero'd automatically for TracedRArrays)
C = A * B
if !(alpha isa Reactant.TracedRNumber) && isone(alpha)
return C
end
return C .* alpha
end

end
112 changes: 75 additions & 37 deletions ext/ReactantStructArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ function Base.copy(
end

function Reactant.broadcast_to_size(arg::StructArray{T}, rsize) where {T}
new = [broadcast_to_size(c, rsize) for c in components(arg)]
return StructArray{T}(NamedTuple(Base.propertynames(arg) .=> new))
new = Tuple((broadcast_to_size(c, rsize) for c in components(arg)))
return StructArray{T}(new)
end

function Base.copyto!(
Expand All @@ -53,12 +53,48 @@ function Base.copyto!(
bc = Broadcast.preprocess(dest, bc)

args = (Reactant.broadcast_to_size(Base.materialize(a), size(bc)) for a in bc.args)

res = Reactant.TracedUtils.elem_apply_via_while_loop(bc.f, args...)
copyto!(dest, res)

return dest
end

function Reactant.TracedRArrayOverrides._copyto!(
dest::StructArray, bc::Base.Broadcast.Broadcasted{<:AbstractReactantArrayStyle}
)
return copyto!(dest, bc)
end

function Base.copyto!(
dest::Reactant.TracedRArray, bc::Broadcasted{StructArrayStyle{S,N}}
) where {S<:AbstractReactantArrayStyle,N}
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
isempty(dest) && return dest

bc = Broadcast.preprocess(dest, bc)
args = (Reactant.broadcast_to_size(Base.materialize(a), size(bc)) for a in bc.args)
res = Reactant.TracedUtils.elem_apply_via_while_loop(bc.f, args...)
return copyto!(dest, res)
end

function alloc_sarr(bc, T)
# Short circuit for Complex since in Reactant they are just a regular number
T <: Complex && return similar(bc, T)
if StructArrays.isnonemptystructtype(T)
return StructArrays.buildfromschema(x -> alloc_sarr(bc, x), T)
else
return similar(bc, T)
end
end

function Base.similar(
bc::Broadcasted{StructArrayStyle{S,N}}, ::Type{ElType}
) where {S<:AbstractReactantArrayStyle,N,ElType}
bc′ = convert(Broadcasted{S}, bc)
# It is possible that we have multiple broadcasted arguments
return alloc_sarr(bc′, ElType)
end

Base.@propagate_inbounds function StructArrays._getindex(
x::StructArray{T}, I::Vararg{TracedRNumber{<:Integer}}
) where {T}
Expand All @@ -67,14 +103,42 @@ Base.@propagate_inbounds function StructArrays._getindex(
return createinstance(T, get_ith(cols, I...)...)
end

setstruct(col, val, I) = @inbounds Reactant.@allowscalar col[I] = val
struct SetStruct{T}
I::T
end
(s::SetStruct)(col, val) = setstruct(col, val, s.I)
(s::SetStruct)(vals) = s(vals...)

Base.@propagate_inbounds function Base.setindex!(
s::StructArray{T,<:Any,<:Any,Int}, vals, I::TracedRNumber{TI}
) where {T,TI<:Integer}
valsT = maybe_convert_elt(T, vals)
foreachfield((col, val) -> (@inbounds col[I] = val), s, valsT)
setter = SetStruct(I)
foreachfield(setter, s, valsT)
return s
end

const MRarr = Union{Reactant.AnyTracedRArray,Reactant.RArray}
getstruct(col, n, I) = @inbounds Reactant.@allowscalar col[n][I...]
struct GetStruct{C,Idx}
cols::C
I::Idx
end
(g::GetStruct)(n) = getstruct(g.cols, n, g.I...)

function StructArrays.get_ith(cols::NamedTuple{N,<:NTuple{K,<:MRarr}}, I...) where {N,K}
getter = GetStruct(cols, I)
ith = ntuple(getter, Val(K))
return ith
end

function StructArrays.get_ith(cols::NTuple{K,<:MRarr}, I...) where {K}
getter = GetStruct(cols, I)
ith = ntuple(getter, Val(K))
return ith
end

Base.@nospecializeinfer function Reactant.traced_type_inner(
@nospecialize(prev::Type{StructArray{ET,N,C,I}}),
seen,
Expand All @@ -90,41 +154,15 @@ Base.@nospecializeinfer function Reactant.traced_type_inner(
return StructArray{ET_traced,N,C_traced,index_type(fieldtypes(C_traced))}
end

function Reactant.make_tracer(
seen,
@nospecialize(prev::StructArray{NT,N}),
@nospecialize(path),
mode;
track_numbers=false,
sharding=Reactant.Sharding.Sharding.NoSharding(),
runtime=nothing,
kwargs...,
) where {NT<:NamedTuple,N}
track_numbers isa Bool && (track_numbers = track_numbers ? Number : Union{})
components = getfield(prev, :components)
if mode == TracedToTypes
push!(path, typeof(prev))
for c in components
make_tracer(seen, c, path, mode; track_numbers, sharding, runtime, kwargs...)
end
return nothing
end
traced_components = make_tracer(
seen,
components,
append_path(path, 1),
mode;
track_numbers,
sharding,
runtime,
kwargs...,
)
T_traced = traced_type(typeof(prev), Val(mode), track_numbers, sharding, runtime)
return StructArray{first(T_traced.parameters)}(traced_components)
end

@inline function Reactant.traced_getfield(@nospecialize(obj::StructArray), field)
return Base.getfield(obj, field)
end

# This is to tell StructArrays to leave these array types alone.
StructArrays.staticschema(::Type{<:Reactant.AnyTracedRArray}) = NamedTuple{()}
StructArrays.staticschema(::Type{<:Reactant.RArray}) = NamedTuple{()}
StructArrays.staticschema(::Type{<:Reactant.RNumber}) = NamedTuple{()}
# # Even though RNumbers we have fields we want them to be threated as empty structs
StructArrays.isnonemptystructtype(::Type{<:Reactant.RNumber}) = false
StructArrays.isnonemptystructtype(::Type{<:Reactant.TracedRArray}) = false
end
1 change: 0 additions & 1 deletion src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,6 @@ end
function _copyto!(dest::AnyTracedRArray, bc::Broadcasted)
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
isempty(dest) && return dest

bc = Broadcast.preprocess(dest, bc)

args = (Reactant.broadcast_to_size(Base.materialize(a), size(bc)) for a in bc.args)
Expand Down
34 changes: 28 additions & 6 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1115,6 +1115,13 @@ function __elem_apply_loop_condition(idx_ref, fn_ref::F, res_ref, args_ref, L_re
return idx_ref[] < L_ref[]
end

struct RefFillVector{T}
data::T
end

Base.getindex(rv::RefFillVector, i) = rv.data[]
Base.broadcastable(x::RefFillVector) = x

function __elem_apply_loop_body(idx_ref, fn_ref::F, res_ref, args_ref, L_ref) where {F}
args = args_ref[]
fn = fn_ref[]
Expand All @@ -1129,14 +1136,24 @@ function __elem_apply_loop_body(idx_ref, fn_ref::F, res_ref, args_ref, L_ref) wh
return nothing
end

scalar_arg(arg) = arg isa Base.RefValue || !(arg isa AbstractArray)

flattenarg(arg) = ReactantCore.materialize_traced_array(vec(arg))
flattenarg(arg::Ref) = RefFillVector(arg)

function elem_apply_via_while_loop(f, args::Vararg{Any,Nargs}) where {Nargs}
@assert allequal(size.(args)) "All args must have the same size"
L = length(first(args))
non_ref_args = [arg for arg in args if !scalar_arg(arg)]
if !isempty(non_ref_args)
@assert allequal(size.(non_ref_args)) "All args must have the same size"
end
out_size = isempty(non_ref_args) ? () : size(first(non_ref_args))
L = isempty(non_ref_args) ? 1 : length(first(non_ref_args))
# flattening the tensors makes the auto-batching pass work nicer
flat_args = [ReactantCore.materialize_traced_array(vec(arg)) for arg in args]
flat_args = [flattenarg(arg) for arg in args]

# This wont be a mutating function so we can safely execute it once
res_tmp = @allowscalar(f([@allowscalar(arg[1]) for arg in flat_args]...))
scalar_seed_args = [@allowscalar(arg[1]) for arg in flat_args]
res_tmp = @allowscalar(f(scalar_seed_args...))

# TODO: perhaps instead of this logic, we should have
# `similar(::TracedRArray, TracedRNumber{T}) where T = similar(::TracedRArray, T)`
Expand All @@ -1146,7 +1163,12 @@ function elem_apply_via_while_loop(f, args::Vararg{Any,Nargs}) where {Nargs}
else
typeof(res_tmp)
end
result = similar(first(flat_args), T_res, L)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be reviewed closely. I had to change this line because now flat_args may have differing array types, e.g., one could be a StructArray. Before we decided the output entirely based on the first element, which could lead to errors, e.g., 2 .* sa would try to write a StructArray to a TracedRArray.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh this is clever!


# Before we selected the output container based on the first argument
# That doesn't work for cases when StructArrays are involved
# Since this is essentially a broadcast I'm reusing this machinery
bc = Base.Broadcast.Broadcasted(f, Tuple(args))
result = similar(bc, T_res)

ind_var = Ref(0)
f_ref = Ref(f)
Expand All @@ -1160,7 +1182,7 @@ function elem_apply_via_while_loop(f, args::Vararg{Any,Nargs}) where {Nargs}
(ind_var, f_ref, result_ref, args_ref, limit_ref),
)

return ReactantCore.materialize_traced_array(reshape(result, size(first(args))))
return ReactantCore.materialize_traced_array(reshape(result, out_size))
end

function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
Expand Down
27 changes: 26 additions & 1 deletion test/integration/structarrays.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using StructArrays, Reactant, Test
using StructArrays, StaticArrays, Reactant, LinearAlgebra, Test

@testset "StructArray to_rarray and make_tracer" begin
x = StructArray(;
Expand Down Expand Up @@ -69,3 +69,28 @@ end
@test component_ra ≈ component
end
end

@testset "structarray with static array broadcasting" begin
trel(x) = tr.(x)
s = StructArray{SMatrix{2,2,Float64,4}}((
fill(1.0, 4), fill(2.0, 4), fill(3.0, 4), fill(4.0, 4)
))
sr = Reactant.to_rarray(s)
out = @jit(trel(sr))
@test out ≈ trel(s)
@test out isa ConcreteRArray
@test @jit(sum(sr)) ≈ sum(s)
end

@testset "structarray with complex numbers" begin
s = randn(64)

elcom(x) = complex(x, x)
sr = Reactant.to_rarray(s)
out = @jit(elcom.(sr))
@test out ≈ elcom.(s)
@test out isa ConcreteRArray
@test @jit(sum(sr)) ≈ sum(s)
end


Comment thread
avik-pal marked this conversation as resolved.
Loading