diff --git a/Project.toml b/Project.toml index 3108a1095f..4721c93a46 100644 --- a/Project.toml +++ b/Project.toml @@ -64,6 +64,7 @@ PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" Random123 = "74087812-796a-5b5d-8853-05524746bad3" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df" @@ -91,6 +92,7 @@ ReactantOffsetArraysExt = "OffsetArrays" ReactantOneHotArraysExt = "OneHotArrays" ReactantPythonCallExt = "PythonCall" ReactantRandom123Ext = "Random123" +ReactantStaticArraysExt = "StaticArrays" ReactantSparseArraysExt = "SparseArrays" ReactantSpecialFunctionsExt = "SpecialFunctions" ReactantStatisticsExt = "Statistics" @@ -149,6 +151,7 @@ Setfield = "1.1.2" Sockets = "1.10" SparseArrays = "1.10" SpecialFunctions = "2.4" +StaticArrays = "1" StableRNGs = "1.0.4" Statistics = "1.10" StructArrays = "0.7.2" diff --git a/ext/ReactantStaticArraysExt.jl b/ext/ReactantStaticArraysExt.jl new file mode 100644 index 0000000000..6568a66f05 --- /dev/null +++ b/ext/ReactantStaticArraysExt.jl @@ -0,0 +1,41 @@ +module ReactantStaticArraysExt + +using Reactant +import Reactant.TracedRArrayOverrides: overloaded_map, overloaded_mapreduce +import Reactant.TracedLinearAlgebra: overloaded_mul + +using StaticArrays: SArray, StaticArray + +const SAReact{Sz,T} = StaticArray{Sz,T} where {Sz<:Tuple,T<:Reactant.TracedRNumber} + +Base.@nospecializeinfer function Reactant.traced_type_inner( + @nospecialize(FA::Type{SArray{S,T,N,L}}), + seen, + mode::Reactant.TraceMode, + @nospecialize(track_numbers::Type), + @nospecialize(ndevices), + @nospecialize(runtime) +) where {S,T,N,L} + T_traced = Reactant.traced_type_inner(T, seen, mode, track_numbers, ndevices, runtime) + return SArray{S,T_traced,N,L} +end + +function Reactant.materialize_traced_array(x::SAReact) + return x +end + +# We don't want to overload map on StaticArrays since it is likely better to just unroll things +overloaded_map(f, a::SAReact, rest::SAReact...) = f.(a, rest...) +overloaded_mapreduce(f, op, a::SAReact; kwargs...) = mapreduce(f, op, a, kwargs...) + +function overloaded_mul(A::SAReact, B::SAReact, alpha::Number=true, beta::Number=false) + # beta is not supported since it is zero by default in Reactant + # (similar is zero'd automatically for TracedRArrays) + C = A * B + if !(alpha isa Reactant.TracedRNumber) && isone(alpha) + return C + end + return C .* alpha +end + +end diff --git a/ext/ReactantStructArraysExt.jl b/ext/ReactantStructArraysExt.jl index 29cff1c39e..5923d60d4b 100644 --- a/ext/ReactantStructArraysExt.jl +++ b/ext/ReactantStructArraysExt.jl @@ -40,8 +40,8 @@ function Base.copy( end function Reactant.broadcast_to_size(arg::StructArray{T}, rsize) where {T} - new = [broadcast_to_size(c, rsize) for c in components(arg)] - return StructArray{T}(NamedTuple(Base.propertynames(arg) .=> new)) + new = Tuple((broadcast_to_size(c, rsize) for c in components(arg))) + return StructArray{T}(new) end function Base.copyto!( @@ -53,12 +53,49 @@ function Base.copyto!( bc = Broadcast.preprocess(dest, bc) args = (Reactant.broadcast_to_size(Base.materialize(a), size(bc)) for a in bc.args) - res = Reactant.TracedUtils.elem_apply_via_while_loop(bc.f, args...) + copyto!(dest, res) + + return dest +end +function Reactant.TracedRArrayOverrides._copyto!( + dest::StructArray, bc::Base.Broadcast.Broadcasted{<:AbstractReactantArrayStyle} +) + return copyto!(dest, bc) +end + +function Base.copyto!( + dest::Reactant.TracedRArray, bc::Broadcasted{StructArrayStyle{S,N}} +) where {S<:AbstractReactantArrayStyle,N} + axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc)) + isempty(dest) && return dest + + bc = Broadcast.preprocess(dest, bc) + args = (Reactant.broadcast_to_size(Base.materialize(a), size(bc)) for a in bc.args) + res = Reactant.TracedUtils.elem_apply_via_while_loop(bc.f, args...) return copyto!(dest, res) end +function alloc_sarr(bc, T) + # Short circuit for Complex since in Reactant they are just a regular number + T <: Complex && return similar(bc, T) + asa = Base.Fix1(alloc_sarr, bc) + if StructArrays.isnonemptystructtype(T) + return StructArrays.buildfromschema(asa, T) + else + return similar(bc, T) + end +end + +function Base.similar( + bc::Broadcasted{StructArrayStyle{S,N}}, ::Type{ElType} +) where {S<:AbstractReactantArrayStyle,N,ElType} + bc′ = convert(Broadcasted{S}, bc) + # It is possible that we have multiple broadcasted arguments + return alloc_sarr(bc′, ElType) +end + Base.@propagate_inbounds function StructArrays._getindex( x::StructArray{T}, I::Vararg{TracedRNumber{<:Integer}} ) where {T} @@ -67,14 +104,42 @@ Base.@propagate_inbounds function StructArrays._getindex( return createinstance(T, get_ith(cols, I...)...) end +setstruct(col, val, I) = @inbounds Reactant.@allowscalar col[I] = val +struct SetStruct{T} + I::T +end +(s::SetStruct)(col, val) = setstruct(col, val, s.I) +(s::SetStruct)(vals) = s(vals...) + Base.@propagate_inbounds function Base.setindex!( s::StructArray{T,<:Any,<:Any,Int}, vals, I::TracedRNumber{TI} ) where {T,TI<:Integer} valsT = maybe_convert_elt(T, vals) - foreachfield((col, val) -> (@inbounds col[I] = val), s, valsT) + setter = SetStruct(I) + foreachfield(setter, s, valsT) return s end +const MRarr = Union{Reactant.AnyTracedRArray,Reactant.RArray} +getstruct(col, n, I) = @inbounds Reactant.@allowscalar col[n][I...] +struct GetStruct{C,Idx} + cols::C + I::Idx +end +(g::GetStruct)(n) = getstruct(g.cols, n, g.I...) + +function StructArrays.get_ith(cols::NamedTuple{N,<:NTuple{K,<:MRarr}}, I...) where {N,K} + getter = GetStruct(cols, I) + ith = ntuple(getter, Val(K)) + return ith +end + +function StructArrays.get_ith(cols::NTuple{K,<:MRarr}, I...) where {K} + getter = GetStruct(cols, I) + ith = ntuple(getter, Val(K)) + return ith +end + Base.@nospecializeinfer function Reactant.traced_type_inner( @nospecialize(prev::Type{StructArray{ET,N,C,I}}), seen, @@ -90,41 +155,15 @@ Base.@nospecializeinfer function Reactant.traced_type_inner( return StructArray{ET_traced,N,C_traced,index_type(fieldtypes(C_traced))} end -function Reactant.make_tracer( - seen, - @nospecialize(prev::StructArray{NT,N}), - @nospecialize(path), - mode; - track_numbers=false, - sharding=Reactant.Sharding.Sharding.NoSharding(), - runtime=nothing, - kwargs..., -) where {NT<:NamedTuple,N} - track_numbers isa Bool && (track_numbers = track_numbers ? Number : Union{}) - components = getfield(prev, :components) - if mode == TracedToTypes - push!(path, typeof(prev)) - for c in components - make_tracer(seen, c, path, mode; track_numbers, sharding, runtime, kwargs...) - end - return nothing - end - traced_components = make_tracer( - seen, - components, - append_path(path, 1), - mode; - track_numbers, - sharding, - runtime, - kwargs..., - ) - T_traced = traced_type(typeof(prev), Val(mode), track_numbers, sharding, runtime) - return StructArray{first(T_traced.parameters)}(traced_components) -end - @inline function Reactant.traced_getfield(@nospecialize(obj::StructArray), field) return Base.getfield(obj, field) end +# This is to tell StructArrays to leave these array types alone. +StructArrays.staticschema(::Type{<:Reactant.AnyTracedRArray}) = NamedTuple{()} +StructArrays.staticschema(::Type{<:Reactant.RArray}) = NamedTuple{()} +StructArrays.staticschema(::Type{<:Reactant.RNumber}) = NamedTuple{()} +# # Even though RNumbers we have fields we want them to be threated as empty structs +StructArrays.isnonemptystructtype(::Type{<:Reactant.RNumber}) = false +StructArrays.isnonemptystructtype(::Type{<:Reactant.TracedRArray}) = false end diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 8ccf84d541..67cebabd51 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -383,7 +383,6 @@ end function _copyto!(dest::AnyTracedRArray, bc::Broadcasted) axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc)) isempty(dest) && return dest - bc = Broadcast.preprocess(dest, bc) args = (Reactant.broadcast_to_size(Base.materialize(a), size(bc)) for a in bc.args) diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 6d4b83d33c..27f0221507 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -1115,6 +1115,13 @@ function __elem_apply_loop_condition(idx_ref, fn_ref::F, res_ref, args_ref, L_re return idx_ref[] < L_ref[] end +struct RefFillVector{T} + data::T +end + +Base.getindex(rv::RefFillVector, i) = rv.data[] +Base.broadcastable(x::RefFillVector) = x + function __elem_apply_loop_body(idx_ref, fn_ref::F, res_ref, args_ref, L_ref) where {F} args = args_ref[] fn = fn_ref[] @@ -1129,14 +1136,24 @@ function __elem_apply_loop_body(idx_ref, fn_ref::F, res_ref, args_ref, L_ref) wh return nothing end +scalar_arg(arg) = arg isa Base.RefValue || !(arg isa AbstractArray) + +flattenarg(arg) = ReactantCore.materialize_traced_array(vec(arg)) +flattenarg(arg::Ref) = RefFillVector(arg) + function elem_apply_via_while_loop(f, args::Vararg{Any,Nargs}) where {Nargs} - @assert allequal(size.(args)) "All args must have the same size" - L = length(first(args)) + non_ref_args = [arg for arg in args if !scalar_arg(arg)] + if !isempty(non_ref_args) + @assert allequal(size.(non_ref_args)) "All args must have the same size" + end + out_size = isempty(non_ref_args) ? () : size(first(non_ref_args)) + L = isempty(non_ref_args) ? 1 : length(first(non_ref_args)) # flattening the tensors makes the auto-batching pass work nicer - flat_args = [ReactantCore.materialize_traced_array(vec(arg)) for arg in args] + flat_args = [flattenarg(arg) for arg in args] # This wont be a mutating function so we can safely execute it once - res_tmp = @allowscalar(f([@allowscalar(arg[1]) for arg in flat_args]...)) + scalar_seed_args = [@allowscalar(arg[1]) for arg in flat_args] + res_tmp = @allowscalar(f(scalar_seed_args...)) # TODO: perhaps instead of this logic, we should have # `similar(::TracedRArray, TracedRNumber{T}) where T = similar(::TracedRArray, T)` @@ -1146,7 +1163,12 @@ function elem_apply_via_while_loop(f, args::Vararg{Any,Nargs}) where {Nargs} else typeof(res_tmp) end - result = similar(first(flat_args), T_res, L) + + # Before we selected the output container based on the first argument + # That doesn't work for cases when StructArrays are involved + # Since this is essentially a broadcast I'm reusing this machinery + bc = Base.Broadcast.Broadcasted(f, Tuple(args)) + result = similar(bc, T_res) ind_var = Ref(0) f_ref = Ref(f) @@ -1160,7 +1182,7 @@ function elem_apply_via_while_loop(f, args::Vararg{Any,Nargs}) where {Nargs} (ind_var, f_ref, result_ref, args_ref, limit_ref), ) - return ReactantCore.materialize_traced_array(reshape(result, size(first(args)))) + return ReactantCore.materialize_traced_array(reshape(result, out_size)) end function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} diff --git a/test/Project.toml b/test/Project.toml index 71e63147b1..55d12e5a2e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -40,6 +40,7 @@ Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" diff --git a/test/integration/structarrays.jl b/test/integration/structarrays.jl index eeb7798b03..d63b4cb365 100644 --- a/test/integration/structarrays.jl +++ b/test/integration/structarrays.jl @@ -1,4 +1,4 @@ -using StructArrays, Reactant, Test +using StructArrays, StaticArrays, Reactant, LinearAlgebra, Test @testset "StructArray to_rarray and make_tracer" begin x = StructArray(; @@ -69,3 +69,28 @@ end @test component_ra ≈ component end end + +@testset "structarray with static array broadcasting" begin + trel(x) = tr.(x) + s = StructArray{SMatrix{2,2,Float64,4}}(( + fill(1.0, 4), fill(2.0, 4), fill(3.0, 4), fill(4.0, 4) + )) + sr = Reactant.to_rarray(s) + out = @jit(trel(sr)) + @test out ≈ trel(s) + @test out isa ConcreteRArray + @test @jit(sum(sr)) ≈ sum(s) +end + +@testset "structarray with complex numbers" begin + s = randn(64) + + elcom(x) = complex(x, x) + sr = Reactant.to_rarray(s) + out = @jit(elcom.(sr)) + @test out ≈ elcom.(s) + @test out isa ConcreteRArray + @test @jit(sum(sr)) ≈ sum(s) +end + +