-
Notifications
You must be signed in to change notification settings - Fork 64
Expand file tree
/
Copy pathReactantStaticArraysExt.jl
More file actions
41 lines (33 loc) · 1.32 KB
/
ReactantStaticArraysExt.jl
File metadata and controls
41 lines (33 loc) · 1.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
module ReactantStaticArraysExt
using Reactant
import Reactant.TracedRArrayOverrides: overloaded_map, overloaded_mapreduce
import Reactant.TracedLinearAlgebra: overloaded_mul
using StaticArrays: SArray, StaticArray
const SAReact{Sz,T} = StaticArray{Sz,T} where {Sz<:Tuple,T<:Reactant.TracedRNumber}
Base.@nospecializeinfer function Reactant.traced_type_inner(
@nospecialize(FA::Type{SArray{S,T,N,L}}),
seen,
mode::Reactant.TraceMode,
@nospecialize(track_numbers::Type),
@nospecialize(ndevices),
@nospecialize(runtime)
) where {S,T,N,L}
T_traced = Reactant.traced_type_inner(T, seen, mode, track_numbers, ndevices, runtime)
return SArray{S,T_traced,N,L}
end
function Reactant.materialize_traced_array(x::SAReact)
return x
end
# We don't want to overload map on StaticArrays since it is likely better to just unroll things
overloaded_map(f, a::SAReact, rest::SAReact...) = f.(a, rest...)
overloaded_mapreduce(f, op, a::SAReact; kwargs...) = mapreduce(f, op, a, kwargs...)
function overloaded_mul(A::SAReact, B::SAReact, alpha::Number=true, beta::Number=false)
# beta is not supported since it is zero by default in Reactant
# (similar is zero'd automatically for TracedRArrays)
C = A * B
if !(alpha isa Reactant.TracedRNumber) && isone(alpha)
return C
end
return C .* alpha
end
end