Skip to content

Commit 1456ac1

Browse files
committed
Update
1 parent f754c6c commit 1456ac1

2 files changed

Lines changed: 38 additions & 26 deletions

File tree

ext/ReactantStructArraysExt.jl

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ function Base.copyto!(
5353
bc = Broadcast.preprocess(dest, bc)
5454

5555
args = (Reactant.broadcast_to_size(Base.materialize(a), size(bc)) for a in bc.args)
56-
@info "HERE"
57-
Reactant.TracedUtils.elem_apply_via_while_loop(bc.f, args...; dest=dest)
56+
res = Reactant.TracedUtils.elem_apply_via_while_loop(bc.f, args...)
57+
copyto!(dest, res)
5858

5959
return dest
6060
end
@@ -81,11 +81,11 @@ function Base.similar(
8181
bc::Broadcasted{StructArrayStyle{S,N}}, ::Type{ElType}
8282
) where {S<:AbstractReactantArrayStyle,N,ElType}
8383
bc′ = convert(Broadcasted{S}, bc)
84-
if StructArrays.isnonemptystructtype(ElType)
85-
StructArrays.buildfromschema(T -> similar(bc′, T), ElType)
86-
else
87-
similar(bc′, ElType)
88-
end
84+
alloc(::Type{T}) where {T} = (T <: Complex) ? similar(bc′, T) :
85+
(StructArrays.isnonemptystructtype(T) ?
86+
StructArrays.buildfromschema(alloc, T) :
87+
similar(bc′, T))
88+
return alloc(ElType)
8989
end
9090

9191
Base.@propagate_inbounds function StructArrays._getindex(
@@ -134,11 +134,6 @@ end
134134
return Base.getfield(obj, field)
135135
end
136136

137-
function Base.similar(
138-
::Base.Broadcast.Broadcasted{AbstractReactantArrayStyle}, ::Type{Eltype}, dims
139-
) where {Eltype}
140-
return similar(TracedRArray{Eltype}, dims)
141-
end
142137

143138
# This is to tell StructArrays to leave these array types alone.
144139
StructArrays.staticschema(::Type{<:Reactant.AnyTracedRArray}) = NamedTuple{()}

src/TracedUtils.jl

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,28 +1115,50 @@ function __elem_apply_loop_condition(idx_ref, fn_ref::F, res_ref, args_ref, L_re
11151115
return idx_ref[] < L_ref[]
11161116
end
11171117

1118+
struct RefFillVector{T}
1119+
data::T
1120+
end
1121+
1122+
Base.getindex(rv::RefFillVector, i) = rv.data[]
1123+
Base.broadcastable(x::RefFillVector) = x
1124+
11181125
function __elem_apply_loop_body(idx_ref, fn_ref::F, res_ref, args_ref, L_ref) where {F}
11191126
args = args_ref[]
11201127
fn = fn_ref[]
11211128
res = res_ref[]
11221129
idx = idx_ref[] + 1
11231130

1124-
scalar_args = [@allowscalar(arg[idx]) for arg in args]
1131+
scalar_args = ntuple(length(args)) do i
1132+
arg = args[i]
1133+
return @allowscalar(arg[idx])
1134+
end
11251135
@allowscalar res[idx] = fn(scalar_args...)
11261136

11271137
idx_ref[] = idx
11281138
res_ref[] = res
11291139
return nothing
11301140
end
11311141

1132-
function elem_apply_via_while_loop(f, args::Vararg{Any,Nargs}; dest=nothing) where {Nargs}
1133-
@assert allequal(size.(args)) "All args must have the same size"
1134-
L = length(first(args))
1142+
function elem_apply_via_while_loop(f, args::Vararg{Any,Nargs}) where {Nargs}
1143+
scalar_arg(arg) = arg isa Base.RefValue || !(arg isa AbstractArray)
1144+
non_ref_args = Tuple(arg for arg in args if !scalar_arg(arg))
1145+
if !isempty(non_ref_args)
1146+
@assert allequal(size.(non_ref_args)) "All args must have the same size"
1147+
end
1148+
out_size = isempty(non_ref_args) ? () : size(first(non_ref_args))
1149+
L = isempty(non_ref_args) ? 1 : length(first(non_ref_args))
11351150
# flattening the tensors makes the auto-batching pass work nicer
1136-
flat_args = [ReactantCore.materialize_traced_array(vec(arg)) for arg in args]
1151+
flat_args = ntuple(Val(Nargs)) do i
1152+
arg = args[i]
1153+
scalar_arg(arg) ? RefFillVector(arg) : ReactantCore.materialize_traced_array(vec(arg))
1154+
end
11371155

11381156
# This wont be a mutating function so we can safely execute it once
1139-
res_tmp = @allowscalar(f([@allowscalar(arg[1]) for arg in flat_args]...))
1157+
scalar_seed_args = ntuple(Val(Nargs)) do i
1158+
arg = flat_args[i]
1159+
@allowscalar(arg[1])
1160+
end
1161+
res_tmp = @allowscalar(f(scalar_seed_args...))
11401162

11411163
# TODO: perhaps instead of this logic, we should have
11421164
# `similar(::TracedRArray, TracedRNumber{T}) where T = similar(::TracedRArray, T)`
@@ -1150,13 +1172,8 @@ function elem_apply_via_while_loop(f, args::Vararg{Any,Nargs}; dest=nothing) whe
11501172
# Before we selected the output container based on the first argument
11511173
# That doesn't work for cases when StructArrays are involved
11521174
# Since this is essentially a broadcast I'm reusing this machinery
1153-
if isnothing(dest)
1154-
bc = Base.Broadcast.Broadcasted(f, Tuple(flat_args))
1155-
result = similar(bc, T_res)
1156-
else
1157-
@assert size(dest) == size(first(args)) "dest must have the same size as the input args"
1158-
result = dest
1159-
end
1175+
bc = Base.Broadcast.Broadcasted(f, Tuple(args))
1176+
result = similar(bc, T_res)
11601177

11611178
ind_var = Ref(0)
11621179
f_ref = Ref(f)
@@ -1170,7 +1187,7 @@ function elem_apply_via_while_loop(f, args::Vararg{Any,Nargs}; dest=nothing) whe
11701187
(ind_var, f_ref, result_ref, args_ref, limit_ref),
11711188
)
11721189

1173-
return ReactantCore.materialize_traced_array(reshape(result, size(first(args))))
1190+
return ReactantCore.materialize_traced_array(reshape(result, out_size))
11741191
end
11751192

11761193
function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}

0 commit comments

Comments
 (0)