@@ -1115,28 +1115,50 @@ function __elem_apply_loop_condition(idx_ref, fn_ref::F, res_ref, args_ref, L_re
11151115 return idx_ref[] < L_ref[]
11161116end
11171117
1118+ struct RefFillVector{T}
1119+ data:: T
1120+ end
1121+
1122+ Base. getindex (rv:: RefFillVector , i) = rv. data[]
1123+ Base. broadcastable (x:: RefFillVector ) = x
1124+
11181125function __elem_apply_loop_body (idx_ref, fn_ref:: F , res_ref, args_ref, L_ref) where {F}
11191126 args = args_ref[]
11201127 fn = fn_ref[]
11211128 res = res_ref[]
11221129 idx = idx_ref[] + 1
11231130
1124- scalar_args = [@allowscalar (arg[idx]) for arg in args]
1131+ scalar_args = ntuple (length (args)) do i
1132+ arg = args[i]
1133+ return @allowscalar (arg[idx])
1134+ end
11251135 @allowscalar res[idx] = fn (scalar_args... )
11261136
11271137 idx_ref[] = idx
11281138 res_ref[] = res
11291139 return nothing
11301140end
11311141
1132- function elem_apply_via_while_loop (f, args:: Vararg{Any,Nargs} ; dest= nothing ) where {Nargs}
1133- @assert allequal (size .(args)) " All args must have the same size"
1134- L = length (first (args))
1142+ function elem_apply_via_while_loop (f, args:: Vararg{Any,Nargs} ) where {Nargs}
1143+ scalar_arg (arg) = arg isa Base. RefValue || ! (arg isa AbstractArray)
1144+ non_ref_args = Tuple (arg for arg in args if ! scalar_arg (arg))
1145+ if ! isempty (non_ref_args)
1146+ @assert allequal (size .(non_ref_args)) " All args must have the same size"
1147+ end
1148+ out_size = isempty (non_ref_args) ? () : size (first (non_ref_args))
1149+ L = isempty (non_ref_args) ? 1 : length (first (non_ref_args))
11351150 # flattening the tensors makes the auto-batching pass work nicer
1136- flat_args = [ReactantCore. materialize_traced_array (vec (arg)) for arg in args]
1151+ flat_args = ntuple (Val (Nargs)) do i
1152+ arg = args[i]
1153+ scalar_arg (arg) ? RefFillVector (arg) : ReactantCore. materialize_traced_array (vec (arg))
1154+ end
11371155
11381156 # This wont be a mutating function so we can safely execute it once
1139- res_tmp = @allowscalar (f ([@allowscalar (arg[1 ]) for arg in flat_args]. .. ))
1157+ scalar_seed_args = ntuple (Val (Nargs)) do i
1158+ arg = flat_args[i]
1159+ @allowscalar (arg[1 ])
1160+ end
1161+ res_tmp = @allowscalar (f (scalar_seed_args... ))
11401162
11411163 # TODO : perhaps instead of this logic, we should have
11421164 # `similar(::TracedRArray, TracedRNumber{T}) where T = similar(::TracedRArray, T)`
@@ -1150,13 +1172,8 @@ function elem_apply_via_while_loop(f, args::Vararg{Any,Nargs}; dest=nothing) whe
11501172 # Before we selected the output container based on the first argument
11511173 # That doesn't work for cases when StructArrays are involved
11521174 # Since this is essentially a broadcast I'm reusing this machinery
1153- if isnothing (dest)
1154- bc = Base. Broadcast. Broadcasted (f, Tuple (flat_args))
1155- result = similar (bc, T_res)
1156- else
1157- @assert size (dest) == size (first (args)) " dest must have the same size as the input args"
1158- result = dest
1159- end
1175+ bc = Base. Broadcast. Broadcasted (f, Tuple (args))
1176+ result = similar (bc, T_res)
11601177
11611178 ind_var = Ref (0 )
11621179 f_ref = Ref (f)
@@ -1170,7 +1187,7 @@ function elem_apply_via_while_loop(f, args::Vararg{Any,Nargs}; dest=nothing) whe
11701187 (ind_var, f_ref, result_ref, args_ref, limit_ref),
11711188 )
11721189
1173- return ReactantCore. materialize_traced_array (reshape (result, size ( first (args)) ))
1190+ return ReactantCore. materialize_traced_array (reshape (result, out_size ))
11741191end
11751192
11761193function elem_apply (f, args:: Vararg{Any,Nargs} ) where {Nargs}
0 commit comments