-
Notifications
You must be signed in to change notification settings - Fork 64
Expand file tree
/
Copy pathTracedRArray.jl
More file actions
1390 lines (1187 loc) · 44.7 KB
/
TracedRArray.jl
File metadata and controls
1390 lines (1187 loc) · 44.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
module TracedRArrayOverrides
using Base: Broadcast
using Base.Broadcast: Broadcasted, AbstractArrayStyle, instantiate
using ..Reactant: Reactant, TracedRArray, TracedRNumber, AnyTracedRArray, AnyTracedRVector
using ..Reactant: MLIR, unwrapped_eltype
using ..Ops: @opcall
using ..TracedUtils: TracedUtils, get_mlir_data, set_mlir_data!, materialize_traced_array
using ReactantCore: ReactantCore
using GPUArraysCore: GPUArraysCore, @allowscalar
__lt(::Base.Order.ForwardOrdering, a, b) = isless.(a, b)
__lt(o::Base.Order.ReverseOrdering, a, b) = __lt(o.fwd, b, a)
__lt(o::Base.Order.By, a, b) = __lt(o.order, o.by.(a), o.by.(b))
__lt(o::Base.Order.Lt, a, b) = o.lt.(a, b)
ReactantCore.is_traced(::TracedRArray, _) = true
ReactantCore.is_traced(::TracedRArray) = true
Base.strides(x::TracedRArray) = Base.size_to_strides(1, size(x)...)
Base.IndexStyle(::Type{<:TracedRArray}) = Base.IndexLinear()
Base.elsize(::Type{TracedRArray{T,N}}) where {T,N} = sizeof(T)
# This is required otherwise we will copy a tracedrarray each time
# we use it
Base.convert(T::Type{<:TracedRArray}, x::AbstractArray) = Reactant.promote_to(T, x)
# Base.complex
Base.complex(x::TracedRArray{<:Real}) = complex.(x)
Base.complex(x::TracedRArray{<:Complex}) = x
function Base.deepcopy_internal(x::TracedRArray, stackdict::IdDict)
if haskey(stackdict, x)
return stackdict[x]::typeof(x)
end
y = copy(x)
stackdict[x] = y
return y
end
TracedRArray{T,N}(x::AbstractArray) where {T,N} = convert(TracedRArray{T,N}, x)
Base.Tuple(x::TracedRArray) = ntuple(Base.Fix1(getindex, x), length(x))
Base.size(x::TracedRArray) = x.shape
function Base.size(x::TracedRArray, i::Integer)
if i > ndims(x)
1
else
x.shape[i]
end
end
function Base.size(x::TracedRArray, i::TracedRNumber{<:Integer})
return @allowscalar ifelse(
i > ndims(x), 1, getindex(@opcall(constant([x.shape...])), i)
)
end
Base.collect(x::TracedRArray) = copy(x)
Base.copy(A::TracedRArray{T,N}) where {T,N} = TracedRArray{T,N}((), A.mlir_data, size(A))
function Base.similar(::TracedRArray, ::Type{T}, dims::Dims{N}) where {T,N}
return (@opcall fill(
zero(unwrapped_eltype(T)), dims
))::TracedRArray{unwrapped_eltype(T),N}
end
function Base.similar(::Type{<:TracedRArray{T}}, dims::Dims{N}) where {T,N}
return (@opcall fill(zero(T), dims))::TracedRArray{T,N}
end
function Base.show(io::IOty, X::AnyTracedRArray) where {IOty<:Union{IO,IOContext}}
print(io, Core.Typeof(X), "(")
parent(X) !== X && Base.show(io, parent(X))
return print(io, ")")
end
function Base.show(io::IOty, X::TracedRArray{T,N}) where {T,N,IOty<:Union{IO,IOContext}}
return print(io, "TracedRArray{", T, ",", N, "N}(", X.paths, ", size=", size(X), ")")
end
for (jlop, _, _, merge) in
((:(Base.:(==)), :compare, "EQ", :all), (:(Base.:(!=)), :compare, "NE", :any))
@eval function $jlop(
@nospecialize(lhs::TracedRArray{T,N}), @nospecialize(rhs::TracedRArray{T,N})
) where {T,N}
elems = $(jlop).(lhs, rhs)
return N == 0 ? elems : $(merge)(elems)
end
end
# Override _parentsmatch to avoid pointer comparisons during tracing
# Direct TracedRArray comparisons - they don't alias unless they're the same object
Base._parentsmatch(A::TracedRArray, B::TracedRArray) = A === B
# A TracedRArray and a regular Array can never share memory, so they never alias.
# Without this, the default DenseArray/StridedArray methods call pointer() which
# isn't defined for TracedRArray, causing "conversion to pointer not defined"
# errors when @views creates SubArray wrappers that trigger broadcast alias checking.
Base._parentsmatch(::TracedRArray, ::AbstractArray) = false
Base._parentsmatch(::AbstractArray, ::TracedRArray) = false
# Resolve method ambiguities with Base's DenseArray and StridedArray specializations
Base._parentsmatch(::TracedRArray, ::DenseArray) = false
Base._parentsmatch(::DenseArray, ::TracedRArray) = false
Base._parentsmatch(::TracedRArray, ::StridedArray) = false
Base._parentsmatch(::StridedArray, ::TracedRArray) = false
# ReshapedArray comparisons - check if they share the same parent (more specific than StridedArray)
function Base._parentsmatch(
A::Base.ReshapedArray{
<:TracedRNumber,
<:Any,
<:Union{TracedRArray,SubArray{<:TracedRNumber,<:Any,<:TracedRArray}},
},
B::Base.ReshapedArray{
<:TracedRNumber,
<:Any,
<:Union{TracedRArray,SubArray{<:TracedRNumber,<:Any,<:TracedRArray}},
},
)
return Base._parentsmatch(parent(A), parent(B))
end
function __default_init(
::Type{T}, ::Union{typeof(Base.min),typeof(Base.FastMath.min_fast)}
) where {T}
return Reactant.promote_to(TracedRNumber{T}, typemax(T))
end
function __default_init(
::Type{T}, ::Union{typeof(Base.max),typeof(Base.FastMath.max_fast)}
) where {T}
return Reactant.promote_to(TracedRNumber{T}, typemin(T))
end
function __default_init(::Type{T}, op::F) where {T,F}
return Reactant.promote_to(TracedRNumber, Base.reduce_empty(Base.BottomRF(op), T))
end
function __default_init(
T::Type{<:Reactant.ReactantFloat8},
::Union{typeof(Base.min),typeof(Base.FastMath.min_fast)},
)
return Reactant.promote_to(TracedRNumber{T}, typemax(Float16))
end
function __default_init(
T::Type{<:Reactant.ReactantFloat8},
::Union{typeof(Base.max),typeof(Base.FastMath.max_fast)},
)
return Reactant.promote_to(TracedRNumber{T}, typemin(Float16))
end
function __default_init(T::Type{<:Reactant.ReactantFloat8}, op::F) where {F}
return Reactant.promote_to(TracedRNumber{T}, __default_init(Float16, op))
end
function overloaded_mapreduce(
@nospecialize(f), @nospecialize(op), @nospecialize(A); dims=:, init=Base._InitialValue()
)
res, updated_dims, re = unwrapped_broadcast(f, A, dims)
# This means we are unable to use the optimized dispatches. For now we will
# unroll the mapreduce.
if typeof(res) == typeof(A)
@assert dims isa Colon "dims not supported for mapreduce currently."
return foldl(op, res; init)
end
return re(overloaded_mapreduce(identity, op, res; dims=updated_dims, init))
end
function overloaded_mapreduce(
@nospecialize(f),
@nospecialize(op),
@nospecialize(A::AbstractArray{<:Reactant.ReactantPrimitive});
kwargs...,
)
return overloaded_mapreduce(f, op, TracedUtils.promote_to(TracedRArray, A); kwargs...)
end
function overloaded_mapreduce(
@nospecialize(f), @nospecialize(op), @nospecialize(A::AnyTracedRArray); kwargs...
)
return overloaded_mapreduce(f, op, materialize_traced_array(A); kwargs...)
end
function overloaded_mapreduce(
@nospecialize(f),
@nospecialize(op),
@nospecialize(A::TracedRArray{T,N});
dims=:,
init=Base._InitialValue(),
) where {T,N}
original_dims = dims
dims isa Int && (dims = Int64[dims])
dims isa Colon && (dims = collect(Int64, 1:N))
dims isa Vector{Int64} || (dims = collect(Int64, dims))
dims = sort(dims)
op_in_T = unwrapped_eltype(Core.Compiler.return_type(f, Tuple{T}))
reduce_init = __default_init(op_in_T, op)
riT = unwrapped_eltype(typeof(reduce_init))
if riT != op_in_T
op_in_T = riT
A = riT.(A)
end
reduce_init = Reactant.promote_to(TracedRNumber{op_in_T}, reduce_init)
reduce_input = materialize_traced_array(TracedUtils.elem_apply(f, A))
res = @opcall reduce(reduce_input, reduce_init, dims, op)
(init isa Base._InitialValue || init === nothing) || (res = op.(res, init))
if original_dims isa Colon
@assert size(res) == () "expected size of result to be (), got $(size(res))"
return TracedRNumber{unwrapped_eltype(res)}((), res.mlir_data)
end
if res isa TracedRNumber
res = TracedRArray{unwrapped_eltype(res),0}((), res.mlir_data, ())
end
return @opcall reshape(res, [ifelse(i in dims, 1, size(A, i)) for i in 1:N])
end
function Base.mapreducedim!(
@nospecialize(f),
@nospecialize(op),
@nospecialize(R::AnyTracedRArray{T,N}),
A::Base.AbstractArrayOrBroadcasted,
) where {T,N}
@assert length(size(R)) == length(size(A))
dims = map(enumerate(zip(size(R), size(A)))) do (i, (sR, sA))
sR == sA && return nothing
@assert sR == 1
return i
end
tmp = mapreduce(f, op, A; dims=filter(!isnothing, dims))
copyto!(R, op.(R, tmp))
return R
end
function Base.fill!(A::AnyTracedRArray{T,N}, x) where {T,N}
bcast = Reactant.broadcast_to_size(T(x), size(A))
set_mlir_data!(A, get_mlir_data(bcast))
return A
end
function Base.fill!(A::AnyTracedRArray{T,N}, x::TracedRNumber{T2}) where {T,N,T2}
bcast = Reactant.broadcast_to_size(Reactant.promote_to(TracedRNumber{T}, x), size(A))
set_mlir_data!(A, get_mlir_data(bcast))
return A
end
function Base.fill!(A::Array{T,N}, x::TracedRNumber{T2}) where {T,N,T2}
throw(MethodError(fill!, (A, x)))
end
struct AbstractReactantArrayStyle{N} <: AbstractArrayStyle{N} end
AbstractReactantArrayStyle(::Val{N}) where {N} = AbstractReactantArrayStyle{N}()
AbstractReactantArrayStyle{M}(::Val{N}) where {N,M} = AbstractReactantArrayStyle{N}()
function Broadcast.BroadcastStyle(::Type{<:AnyTracedRArray{T,N}}) where {T,N}
return AbstractReactantArrayStyle{N}()
end
function Broadcast.BroadcastStyle(::Type{<:AbstractRange{<:TracedRNumber}})
return AbstractReactantArrayStyle{1}()
end
function Broadcast.BroadcastStyle(::Type{<:TracedRNumber})
return AbstractReactantArrayStyle{0}()
end
function Base.similar(
::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{T}, dims
) where {T<:Reactant.ReactantPrimitive,N}
@assert N isa Int
return (@opcall fill(zero(unwrapped_eltype(T)), dims))::TracedRArray{T,N}
end
function Base.similar(
::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{TracedRNumber{T}}, dims
) where {T<:Reactant.ReactantPrimitive,N}
@assert N isa Int
return (@opcall fill(zero(T), dims))::TracedRArray{T,N}
end
function Base.copy(bc::Broadcasted{<:AbstractReactantArrayStyle{0}})
ElType = Broadcast.combine_eltypes(bc.f, bc.args)
dest = copyto!(similar(bc, ElType), bc)
return dest[CartesianIndex()] # 0D broadcast needs to unwrap results
end
Base.eltype(::Broadcast.Extruded{T}) where {T} = eltype(T)
first_scalar(x) = @allowscalar first(x)
first_scalar(x::Broadcast.Extruded) = first_scalar(x.x)
# we need to override the outer copy method to make sure we never fall back to scalar
# iteration (see, e.g., CUDA.jl#145)
function _copy(bc)
fn = if bc.f isa Type && bc.f <: Reactant.ReactantPrimitive
TracedUtils.TypeCast{bc.f}()
else
bc.f
end
ElType = Broadcast.combine_eltypes(fn, bc.args)
# Special case a union{} return so we can see the better error message
if ElType === Union{} || ElType == Any || ElType == TracedRNumber
ElType = Core.Typeof(fn(map(first_scalar, bc.args)...))
end
if ElType == Any || ElType == Union{}
throw(AssertionError("Failed to deduce eltype of broadcast of $fn, found $ElType"))
end
sim = similar(bc, ElType)
return copyto!(sim, bc)
end
@noinline function Base.copy(bc::Broadcasted{<:AbstractReactantArrayStyle})
return _copy(bc)
end
function Base.materialize!(
::Style, dest, bc::Broadcasted
) where {Style<:AbstractReactantArrayStyle}
_copyto!(dest, instantiate(Broadcasted{Style}(bc.f, bc.args, axes(dest))))
return dest
end
Base.copyto!(dest::AnyTracedRArray, bc::Broadcasted{Nothing}) = _copyto!(dest, bc) # Keep it for ArrayConflict
function Base.copyto!(
dest::AnyTracedRArray{T},
dstart::Integer,
src::TracedRArray{T},
sstart::Integer,
n::Integer,
) where {T}
setindex!(dest, src[sstart:(sstart + n - 1)], dstart:(dstart + n - 1))
return dest
end
function Base.copyto!(
dest::AnyTracedRArray{T},
dstart::Integer,
src::TracedRArray,
sstart::Integer,
n::Integer,
) where {T}
return copyto!(
dest, dstart, @opcall(convert(TracedRArray{T,ndims(src)}, src)), sstart, n
)
end
function Base.copyto!(
dest::AnyTracedRArray,
dstart::Integer,
src::AnyTracedRArray,
sstart::Integer,
n::Integer,
)
return copyto!(dest, dstart, materialize_traced_array(src), sstart, n)
end
function Base.copyto!(dest::AnyTracedRArray{T}, src::AnyTracedRArray{T}) where {T}
TracedUtils.set_mlir_data!(
dest,
materialize_traced_array(
reshape(materialize_traced_array(src)[1:length(dest)], size(dest))
).mlir_data,
)
return dest
end
function Base.copyto!(dest::AnyTracedRArray{T}, src::AnyTracedRArray) where {T}
src = materialize_traced_array(src)
return copyto!(dest, @opcall(convert(TracedRArray{T,ndims(src)}, src)))
end
function Base.copyto!(dest::AnyTracedRArray, src::Array)
return copyto!(dest, Reactant.promote_to(TracedRArray, src))
end
function _copyto!(dest::AnyTracedRArray, bc::Broadcasted)
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
isempty(dest) && return dest
bc = Broadcast.preprocess(dest, bc)
args = (Reactant.broadcast_to_size(Base.materialize(a), size(bc)) for a in bc.args)
res = Reactant.promote_to(
TracedRArray{unwrapped_eltype(dest),ndims(dest)},
TracedUtils.elem_apply(bc.f, args...),
)
TracedUtils.set_mlir_data!(dest, res.mlir_data)
return dest
end
function _copyto!(dest::Array{<:TracedRNumber}, bc::Broadcasted)
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
isempty(dest) && return dest
bc = Broadcast.preprocess(dest, bc)
args = (Reactant.broadcast_to_size(Base.materialize(a), size(bc)) for a in bc.args)
res = TracedUtils.elem_apply(bc.f, args...)
for I in 1:length(dest)
dest[I] = @allowscalar res[I]
end
return dest
end
dispatch_val(x) = x
dispatch_val(::Val{D}) where {D} = D
@inline function Base._typed_vcat(
::Type{T}, X::Base.AbstractVecOrTuple{<:TracedRArray}
) where {T}
return Base._cat_t(Val(1), T, X...)
end
@inline function Base._typed_hcat(
::Type{T}, X::Base.AbstractVecOrTuple{<:TracedRArray}
) where {T}
return Base._cat_t(Val(2), T, X...)
end
# `Base.typed_hvcat` is overloaded for `AbstractVecOrMat` using `setindex!` that breaks Reactant
# generic implementation uses `typed_hcat` and `typed_vcat` which is alright
@inline function Base.typed_hvcat(
::Type{T}, rows::Tuple{Vararg{Int}}, as::TracedRArray...
) where {T}
return invoke(
Base.typed_hvcat, Tuple{Type{T},Tuple{Vararg{Int}},Vararg{Any}}, T, rows, as...
)
end
@inline function Base.typed_hvcat(
::Type{T}, rows::Tuple{Int}, as::TracedRArray...
) where {T}
return invoke(
Base.typed_hvcat, Tuple{Type{T},Tuple{Vararg{Int}},Vararg{Any}}, T, rows, as...
)
end
function Base._typed_hvncat(
::Type{T}, dims::NTuple{N,Int}, row_first::Bool, a::TracedRArray, as::TracedRArray...
) where {T,N}
return _typed_hvncat_internal(T, dims, row_first, a, as...)
end
function Base._typed_hvncat(
::Type{T}, dims::Tuple{Int}, ::Bool, a::TracedRArray, as::TracedRArray...
) where {T}
return Base._typed_hvncat_1d(T, dims[1], Val(false), a, as...)
end
function Base._typed_hvncat(
::Type{T}, ::Tuple{}, ::Bool, a::TracedRArray, as::TracedRArray...
) where {T}
return Base._typed_hvncat(T, Val(0), a, as...)
end
function _typed_hvncat_internal(
::Type{T}, dims::NTuple{N,Int}, row_first::Bool, as...
) where {T,N}
As = if row_first
perm = [2, 1, 3:N...]
dims = [dims[2], dims[1], dims[3:end]...]
permutedims(reshape(collect(as), dims...), perm)
else
reshape(collect(as), dims)
end
for d in 1:N
Bs = Array{Any,N - d}(undef, size(As)[2:end]...)
for (i, col) in
zip(eachindex(Bs), eachslice(As; dims=Tuple(2:ndims(As)), drop=true))
# TODO: row_first affects the flattening?
Bs[i] = Base._cat_t(d, T, col...)
end
As = Bs
end
return only(As)
end
function maybe_expand_dims(x::AbstractArray{T,N}, dims) where {T,N}
dims = dispatch_val(dims)
dims ≤ N && return x
return reshape(x, ntuple(i -> i ≤ N ? size(x, i) : 1, dims))
end
function Base._cat_t(dims, ::Type{T}, X::TracedRArray...) where {T}
dims = dispatch_val(dims)
@assert dims isa Integer "Support for non-integer dimensions is not implemented yet."
# MLIR expects the dimension `dims` to be ≤ the rank of the input tensors
X = maybe_expand_dims.(X, (dims,))
catdims = Base.dims2cat(dims)
shape = Base.cat_size_shape(catdims, X...)
RT = unwrapped_eltype(Base.promote_eltype(T, X...))
# convert to the target eltype
X = map(Base.Fix1(Reactant.promote_to, TracedRArray{RT,length(shape)}), X)
return TracedRArray{RT,length(shape)}(
(),
MLIR.IR.result(
# TODO: maybe we should do some conversion?
MLIR.Dialects.stablehlo.concatenate(
collect(TracedUtils.get_mlir_data.(X));
result_0=MLIR.IR.TensorType(collect(Int, shape), MLIR.IR.Type(RT)),
dimension=dims - 1, # stablehlo expects this to be zero-indexed
),
1,
),
shape,
)
end
for (minT, maxT) in Iterators.product((Number, TracedRNumber), (Number, TracedRNumber))
@eval function Base.clamp!(x::AnyTracedRArray, min::$(minT), max::$(maxT))
T = unwrapped_eltype(x)
min = Reactant.promote_to(TracedRNumber{T}, min)
max = Reactant.promote_to(TracedRNumber{T}, max)
y = @opcall clamp(min, materialize_traced_array(x), max)
TracedUtils.set_mlir_data!(x, y.mlir_data)
return x
end
end
# outer repeat
function repeat_outer_overloaded(x::AnyTracedRArray{T,N}, counts::Dims{N}) where {T,N}
# (d1, d2, ..., dP) -> (d1, 1, d2, 1, ..., dP, 1)
interleaved_size = ones(Int, 2N)
interleaved_size[1:2:(2N)] .= size(x)
x_interleaved = @opcall reshape(materialize_traced_array(x), interleaved_size...)
# (d1, 1, d2, 1, ..., dP, 1) -> (d1, r1, d2, r2, ..., dP, rP)
broadcast_target_size = interleaved_size
broadcast_target_size[2:2:(2N)] .= counts
x_broadcasted = Reactant.broadcast_to_size(x_interleaved, broadcast_target_size)
# (d1, r1, d2, r2, ..., dP, rP) -> (d1*r1, d2*r2, ..., dP*rP)
final_size = vec(prod(reshape(broadcast_target_size, 2, :); dims=1))
return materialize_traced_array(reshape(x_broadcasted, final_size...))
end
function Base._RepeatInnerOuter.repeat_outer(
x::AnyTracedRArray{T,1}, counts::Tuple{Any}
) where {T}
return repeat_outer_overloaded(x, counts)
end
function Base._RepeatInnerOuter.repeat_outer(
x::AnyTracedRArray{T,2}, counts::NTuple{2,Any}
) where {T}
return repeat_outer_overloaded(x, counts)
end
function Base._RepeatInnerOuter.repeat_outer(
x::AnyTracedRArray{T,N}, counts::NTuple{N,Any}
) where {T,N}
return repeat_outer_overloaded(x, counts)
end
# inner repeat
function Base._RepeatInnerOuter.repeat_inner(
x::AnyTracedRArray{T,N}, counts::NTuple{M,Any}
) where {T,N,M}
P = max(N, M) # potentially padded
# (d1, d2, ..., dP) -> (1, d1, 1, d2, 1, ..., 1, dP)
interleaved_size = ones(Int, 2P)
interleaved_size[2:2:(2N)] .= size(x)
x_interleaved = reshape(materialize_traced_array(x), interleaved_size...)
# (1, d1, 1, d2, 1, ..., 1, dP) -> (r1, d1, r2, d2, ..., rP, dP)
broadcast_target_size = interleaved_size
broadcast_target_size[1:2:(2N)] .= counts
x_broadcasted = Reactant.broadcast_to_size(x_interleaved, broadcast_target_size)
# (r1, d1, r2, d2, ..., rP, dP) -> (d1*r1, d2*r2, ..., dP*rP)
final_size = vec(prod(reshape(broadcast_target_size, 2, :); dims=1))
return materialize_traced_array(reshape(x_broadcasted, final_size...))
end
# stack
function overloaded_stack(dims::Union{Integer,Colon}, xs)
dims = dims isa Colon ? nothing : dims
res = []
prev_dims = nothing
for x in first(unwrapped_broadcast(identity, xs, Colon()))
cur_dims = ndims(x)
if prev_dims === nothing
prev_dims = cur_dims
else
@assert prev_dims == cur_dims "All arrays must have the same number of \
dimensions..."
end
dims === nothing && (dims = cur_dims + 1)
new_shape = ntuple(
i -> i == dims ? 1 : (i < dims ? size(x, i) : size(x, i - 1)), ndims(x) + 1
)
push!(res, materialize_traced_array(internal_stack_reshape(x, new_shape)))
end
return cat(res...; dims)
end
internal_stack_reshape(x, new_shape) = reshape(x, new_shape)
function internal_stack_reshape(x::TracedRNumber{T}, new_shape) where {T}
return internal_stack_reshape(TracedRArray{T,0}((), x.mlir_data, ()), new_shape)
end
# sort
function Base.sort(x::AnyTracedRArray; alg=missing, kwargs...)
return sort!(copy(x); alg, kwargs...)
end
function Base.sort(x::AnyTracedRVector; alg=missing, kwargs...)
return sort!(copy(x); alg, kwargs...)
end
function Base.sort!(
x::AnyTracedRVector;
lt=isless,
by=identity,
rev::Bool=false,
alg=missing,
order=Base.Order.Forward,
)
@assert alg === missing "Reactant doesn't support `alg` kwarg for `sort!`"
ordering = Base.ord(lt, by, rev, order)
comparator = (a, b) -> __lt(ordering, a, b)
res = only(@opcall(sort(materialize_traced_array(x); comparator, dimension=1)))
set_mlir_data!(x, get_mlir_data(res))
return x
end
function Base.sort!(
x::AnyTracedRArray;
dims::Integer,
lt=isless,
by=identity,
rev::Bool=false,
alg=missing,
order=Base.Order.Forward,
)
@assert alg === missing "Reactant doesn't support `alg` kwarg for `sort!`"
ordering = Base.ord(lt, by, rev, order)
comparator = (a, b) -> __lt(ordering, a, b)
res = only(@opcall(sort(materialize_traced_array(x); dimension=dims, comparator)))
set_mlir_data!(x, get_mlir_data(res))
return x
end
function Base.sortperm(x::AnyTracedRArray; alg=missing, kwargs...)
return sortperm!(similar(x, Int), x; alg, kwargs...)
end
function Base.sortperm(x::AnyTracedRVector; alg=missing, kwargs...)
return sortperm!(similar(x, Int), x; alg, dims=1, kwargs...)
end
function Base.sortperm!(
ix::AnyTracedRArray{Int,N},
x::AnyTracedRArray{<:Any,N};
dims::Union{Integer,Nothing}=nothing,
lt=isless,
by=identity,
rev::Bool=false,
alg=missing,
order=Base.Order.Forward,
) where {N}
if dims === nothing
@assert ndims(x) == 1
dims = 1
end
@assert alg === missing "Reactant doesn't support `alg` kwarg for `sortperm!`"
ordering = Base.ord(lt, by, rev, order)
comparator = (a, b, _, _) -> __lt(ordering, a, b)
idxs = @opcall constant(collect(LinearIndices(x)))
_, res = @opcall sort(materialize_traced_array(x), idxs; dimension=dims, comparator)
set_mlir_data!(ix, get_mlir_data(res))
return ix
end
function Base.partialsort(
x::AnyTracedRVector, k::Union{Integer,OrdinalRange}; rev=false, kwargs...
)
if rev
values, _ = overloaded_partialsort_descending(x, k; kwargs...)
else
values, _ = overloaded_partialsort_ascending(x, k; kwargs...)
end
k isa Integer && return @allowscalar(values[k])
return view(values, k)
end
function Base.partialsort!(
x::AnyTracedRVector, k::Union{Integer,OrdinalRange}; rev=false, kwargs...
)
if rev
values, _ = overloaded_partialsort_descending(x, k; kwargs...)
else
values, _ = overloaded_partialsort_ascending(x, k; kwargs...)
end
val = @allowscalar(values[k])
@allowscalar setindex!(x, val, k)
k isa Integer && return val
return view(x, k)
end
function Base.partialsortperm(
x::AnyTracedRVector, k::Union{Integer,OrdinalRange}; rev=false, kwargs...
)
if rev
_, idxs = overloaded_partialsort_descending(x, k; kwargs...)
else
_, idxs = overloaded_partialsort_ascending(x, k; kwargs...)
end
k isa Integer && return @allowscalar(idxs[k])
return view(idxs, k)
end
function Base.partialsortperm!(
ix::AnyTracedRVector{Int},
x::AnyTracedRVector,
k::Union{Integer,OrdinalRange};
rev=false,
kwargs...,
)
if rev
_, idxs = overloaded_partialsort_descending(x, k; kwargs...)
else
_, idxs = overloaded_partialsort_ascending(x, k; kwargs...)
end
val = @allowscalar(idxs[k])
@allowscalar setindex!(ix, val, k)
k isa Integer && return @allowscalar(ix[k])
return val
end
function overloaded_partialsort_descending(
x::AnyTracedRVector{T}, k::Union{Integer,OrdinalRange}; by=identity, lt=isless
) where {T}
if lt !== isless || by !== identity
sorted_x, sorted_idxs = @opcall sort(
materialize_traced_array(x),
@opcall(constant(collect(LinearIndices(x))));
dimension=1,
comparator=(a, b, _, _) -> !lt(by(a), by(b)),
)
return (getindex(sorted_x, 1:maximum(k)), getindex(sorted_idxs, 1:maximum(k)))
end
if Reactant.LOWER_PARTIALSORT_TO_APPROX_TOP_K[] && T <: Reactant.ReactantFloat
result = @opcall approx_top_k(
materialize_traced_array(x),
maximum(k);
comparator=(a, b, _, _) -> a > b,
dimension=1,
init_val=typemin(T),
)
return (
getindex(result.values, 1:maximum(k)), getindex(result.indices, 1:maximum(k))
)
end
(; values, indices) = @opcall top_k(materialize_traced_array(x), maximum(k))
return values, indices
end
function overloaded_partialsort_ascending(
x::AnyTracedRVector{T}, k::Union{Integer,OrdinalRange}; by=identity, lt=isless
) where {T}
if lt !== isless || by !== identity || T <: Unsigned
sorted_x, sorted_idxs = @opcall sort(
materialize_traced_array(x),
@opcall(constant(collect(LinearIndices(x))));
dimension=1,
comparator=(a, b, _, _) -> !lt(by(a), by(b)),
)
return (getindex(sorted_x, 1:maximum(k)), getindex(sorted_idxs, 1:maximum(k)))
end
if Reactant.LOWER_PARTIALSORT_TO_APPROX_TOP_K[] && T <: Reactant.ReactantFloat
result = @opcall approx_top_k(
materialize_traced_array(x),
maximum(k);
comparator=(a, b, _, _) -> a < b,
dimension=1,
init_val=typemax(T),
)
return (
getindex(result.values, 1:maximum(k)), getindex(result.indices, 1:maximum(k))
)
end
(; values, indices) = @opcall top_k(
@opcall(negate(materialize_traced_array(x))), maximum(k)
)
return @opcall(negate(values)), indices
end
# arg* functions
function Base.argmin(f::F, x::AnyTracedRArray) where {F}
idx = Reactant.TracedIndexing.scalar_index_to_cartesian(argmin(f.(x)), size(x))
return @allowscalar x[idx...]
end
function Base.argmax(f::F, x::AnyTracedRArray) where {F}
idx = Reactant.TracedIndexing.scalar_index_to_cartesian(argmax(f.(x)), size(x))
return @allowscalar x[idx...]
end
Base.argmin(x::AnyTracedRArray; kwargs...) = findmin(identity, x; kwargs...)[2]
Base.argmax(x::AnyTracedRArray; kwargs...) = findmax(identity, x; kwargs...)[2]
# find* functions
Base.findfirst(x::AnyTracedRArray) = findfirst(identity, x)
Base.findlast(x::AnyTracedRArray) = findlast(identity, x)
# FIXME(#2236): we need to conditionally return `nothing` here if idx < 0
function Base.findfirst(f::Function, x::AnyTracedRArray)
idx = @opcall findfirst(materialize_traced_array(vec(f.(x))))
return TracedRNumber{Int}((), idx.mlir_data)
end
# FIXME(#2236): we need to conditionally return `nothing` here if idx < 0
function Base.findlast(f::Function, x::AnyTracedRArray)
fA = @opcall reverse(materialize_traced_array(vec(f.(x))); dimensions=[1])
idx = @opcall findfirst(fA)
return length(x) + 1 - TracedRNumber{Int}((), idx.mlir_data)
end
Base.findmin(x::AnyTracedRVector) = findmin(identity, x; dims=1)
function Base.findmin(x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothing)
return findmin(identity, x; dims)
end
Base.findmax(x::AnyTracedRVector) = findmax(identity, x; dims=1)
function Base.findmax(x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothing)
return findmax(identity, x; dims)
end
## To avoid scalar indexing and constructing an array of tuples, we return the linear index
## instead of the cartesian index
function Base.findmin(f, x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothing)
if dims === nothing
if ndims(x) == 1
dims = 1
else
return findmin(f, vec(x); dims=1)
end
end
fx = @opcall negate(materialize_traced_array(f.(x)))
(; values, indices) = @opcall top_k(fx, 1; dimension=dims)
# Compute linear indices
strds = strides(x)
iotas = [@opcall(iota(Int64, [size(indices)...]; iota_dimension=i)) for i in 1:ndims(x)]
iotas[dims] = @opcall subtract(indices, @opcall(fill(Int64(1), size(indices))))
linear_indices = @opcall fill(Int64(1), size(indices))
for d in eachindex(iotas)
linear_indices = @opcall add(
linear_indices,
@opcall(multiply(iotas[d], @opcall(fill(Int64(strds[d]), size(iotas[d]))))),
)
end
values = @opcall negate(values)
ndims(x) == 1 && return @allowscalar (values[1], linear_indices[1])
return (values, linear_indices)
end
function Base.findmax(f, x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothing)
if dims === nothing
if ndims(x) == 1
dims = 1
else
return findmax(f, vec(x); dims=1)
end
end
fx = materialize_traced_array(f.(x))
(; values, indices) = @opcall top_k(fx, 1; dimension=dims)
# Compute linear indices
strds = strides(x)
iotas = [@opcall(iota(Int64, [size(indices)...]; iota_dimension=i)) for i in 1:ndims(x)]
iotas[dims] = @opcall subtract(indices, @opcall(fill(Int64(1), size(indices))))
linear_indices = @opcall fill(Int64(1), size(indices))
for d in eachindex(iotas)
linear_indices = @opcall add(
linear_indices,
@opcall(multiply(iotas[d], @opcall(fill(Int64(strds[d]), size(iotas[d]))))),
)
end
ndims(x) == 1 && return @allowscalar (values[1], linear_indices[1])
return (values, linear_indices)
end
function overloaded_map(f, x::AbstractArray, xs::AbstractArray...)
@assert allequal((axes(x), axes.(xs)...)) "Expected axes of all inputs to map to be \
equal"
needs_unrolling = falses(length(xs) + 1)
inputs = ()
for (i, input) in enumerate((x, xs...))
if input isa AnyTracedRArray
input = Reactant.materialize_traced_array(input)
elseif eltype(input) <: Reactant.ReactantPrimitive
input = Reactant.promote_to(TracedRArray{eltype(input),ndims(input)}, input)
else
needs_unrolling[i] = true
end
inputs = (inputs..., input)
end
@assert allequal(needs_unrolling) "All inputs to `overloaded_map` must be \
unrolled or none of them. Open an issue."
if needs_unrolling[1]
length(inputs) == 1 && return unrolled_map(f, only(inputs))
return unrolled_map(splat(f), zip(inputs...))
end
return TracedUtils.elem_apply(f, inputs...)
end
function overloaded_map!(f, y::AnyTracedRArray, x::AbstractArray, xs::AbstractArray...)
copyto!(y, overloaded_map(f, x, xs...))
return y
end
function Base.mapslices(f::F, A::AnyTracedRArray; dims) where {F}
return mapslices(f, materialize_traced_array(A); dims)
end
function Base.mapslices(f::F, A::TracedRArray; dims) where {F}
dims isa Integer && (dims = Int64[dims])
dims isa AbstractVector || (dims = collect(Int64, dims))
return @opcall batch(f, A, dims)
end
# accumulate interface
## Taken from https://github.com/JuliaGPU/CUDA.jl/blob/a4a7af45f54f0e57f5912bb52db48e2d27cf7b4f/src/accumulate.jl#L201
function Base.accumulate(
op, A::AnyTracedRArray; dims::Union{Integer,Nothing}=nothing, kwargs...
)
if dims === nothing && ndims(A) != 1
return reshape(accumulate(op, A[:]), size(A)...)
end
nt = values(kwargs)
# Base.promote_op was having issues
if isempty(kwargs)
zA = zero(unwrapped_eltype(A))
out = similar(A, TracedRNumber{unwrapped_eltype(op(zA, zA))})
elseif keys(nt) === (:init,)
zA = zero(unwrapped_eltype(A))
zI = zero(unwrapped_eltype(nt.init))
out = similar(A, TracedRNumber{unwrapped_eltype(op(zA, zI))})
else
throw(
ArgumentError(
"accumulate does not support the keyword arguments $(setdiff(keys(nt), (:init,)))",
),
)
end
return accumulate!(op, out, A; dims, kwargs...)
end
function Base.accumulate_pairwise!(op, A::AnyTracedRVector, B::AnyTracedRVector)
return accumulate!(op, A, B; dims=1)
end
@static if isdefined(Base, :_accumulate_promote_op)
function Base._accumulate_promote_op(op, A::AnyTracedRArray{T}; init=nothing) where {T}
if init !== nothing
init isa TracedRNumber && (init = zero(unwrapped_eltype(init)))
end
return TracedRNumber{
unwrapped_eltype(
Base._accumulate_promote_op(op, Array{T,ndims(A)}(undef, size(A)); init)
),