-
Notifications
You must be signed in to change notification settings - Fork 64
Better StructArray & StaticArray support #2546
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 37 commits
Commits
Show all changes
39 commits
Select commit
Hold shift + click to select a range
e03f422
Draft to figure out better StructArray support
ptiede 7c425b8
Simplify and generalize structarray type conversion
wsmoses bb28db8
Merge remote-tracking branch 'upstream/satc' into ptiede-structarrays
ptiede f3080b6
Start adding StaticArray support
ptiede 5832eb4
Add StaticArray support and tweak elem_apply_while_loop to select cor…
ptiede 228187c
Revert tracing.jl
ptiede a3c4ac3
Remove info debug
ptiede 71cec58
Merge branch 'main' into ptiede-structarrays
ptiede 19205c8
Remove get_ith
ptiede 78ff71c
Add _copyto!
ptiede 47f7918
Merge branch 'main' into ptiede-structarrays
ptiede b1b3fc3
format
ptiede cc65c11
Fix broken test and add new tests
ptiede 611d095
format
ptiede 3ac9140
add StaticArrays
ptiede 399258e
Add LinearAlgebra
ptiede b29cc9a
Remove unused function
ptiede aa68406
Merge branch 'main' into ptiede-structarrays
ptiede 9dce679
Reuse the known destination for while loop if possible
ptiede 80f1031
Merge branch 'main' into ptiede-structarrays
ptiede a2ea5cf
Update ext/ReactantStructArraysExt.jl
ptiede 957132b
Proposed improved support for SArrays
ptiede 9e6bb90
fix dumb mistake
ptiede baf5177
Add additional changes for StaticArrays
ptiede 358507b
Apply suggestions from code review
ptiede f754c6c
Cleanup
ptiede 1456ac1
Update
ptiede 9e95692
Update
ptiede 56bcab6
Format
ptiede f28be27
Fix for code review
ptiede 52e7b15
Add comments
ptiede 79516cb
So dumb
ptiede 1ece166
Merge branch 'main' into ptiede-structarrays
ptiede 2212e49
Correct comment in overloaded_mul function
ptiede b3d3661
Update to remove anonymous functions
ptiede 3ca3dc2
Update
ptiede 3dc5d64
Add a complex test
ptiede 6253791
Update
ptiede 446d605
Merge branch 'main' into ptiede-structarrays
avik-pal File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,41 @@ | ||
| module ReactantStaticArraysExt | ||
|
|
||
| using Reactant | ||
| import Reactant.TracedRArrayOverrides: overloaded_map, overloaded_mapreduce | ||
| import Reactant.TracedLinearAlgebra: overloaded_mul | ||
|
|
||
| using StaticArrays: SArray, StaticArray | ||
|
|
||
| const SAReact{Sz,T} = StaticArray{Sz,T} where {Sz<:Tuple,T<:Reactant.TracedRNumber} | ||
|
|
||
| Base.@nospecializeinfer function Reactant.traced_type_inner( | ||
| @nospecialize(FA::Type{SArray{S,T,N,L}}), | ||
| seen, | ||
| mode::Reactant.TraceMode, | ||
| @nospecialize(track_numbers::Type), | ||
| @nospecialize(ndevices), | ||
| @nospecialize(runtime) | ||
| ) where {S,T,N,L} | ||
| T_traced = Reactant.traced_type_inner(T, seen, mode, track_numbers, ndevices, runtime) | ||
| return SArray{S,T_traced,N,L} | ||
| end | ||
|
|
||
| function Reactant.materialize_traced_array(x::SAReact) | ||
| return x | ||
| end | ||
|
|
||
| # We don't want to overload map on StaticArrays since it is likely better to just unroll things | ||
| overloaded_map(f, a::SAReact, rest::SAReact...) = f.(a, rest...) | ||
| overloaded_mapreduce(f, op, a::SAReact; kwargs...) = mapreduce(f, op, a, kwargs...) | ||
|
|
||
| function overloaded_mul(A::SAReact, B::SAReact, alpha::Number=true, beta::Number=false) | ||
| # beta is not supported since it is zero by default in Reactant | ||
| # (similar is zero'd automatically for TracedRArrays) | ||
| C = A * B | ||
| if !(alpha isa Reactant.TracedRNumber) && isone(alpha) | ||
| return C | ||
| end | ||
| return C .* alpha | ||
| end | ||
|
|
||
| end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be reviewed closely. I had to change this line because now
flat_argsmay have differing array types, e.g., one could be aStructArray. Before we decided the output entirely based on the first element, which could lead to errors, e.g.,2 .* sawould try to write aStructArrayto aTracedRArray.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh this is clever!