Skip to content

Commit 475254f

Browse files
committed
avoid allocating new array when reduce / repeat is a no-op
1 parent 53a17f6 commit 475254f

5 files changed

Lines changed: 27 additions & 7 deletions

File tree

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Einops"
22
uuid = "e3ce28c8-8bfb-4704-8add-e3e7f14b55c9"
33
authors = ["Anton Oresten <antonoresten@gmail.com> and contributors"]
4-
version = "0.1.5"
4+
version = "0.1.6"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -14,4 +14,4 @@ ChainRulesCore = "1.7"
1414
EllipsisNotation = "1"
1515
OMEinsum = "0.8"
1616
TupleTools = "1"
17-
julia = "1.9"
17+
julia = "1.9"

src/reduce.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,16 @@ function Base.reduce(f::Function, x::AbstractArray, (left, right)::ArrowPattern;
4444
allunique(extract(Symbol, right)) || throw(ArgumentError("Right names $(right) are not unique"))
4545
left_names, right_names = extract(Symbol, left), extract(Symbol, right)
4646
expanded = reshape_in(x, left; context...)
47-
reduced_dims::NTuple{length(left_names)-length(right_names),Int}, permutation = @ignore_derivatives begin
47+
dims::NTuple{length(left_names)-length(right_names),Int}, permutation = @ignore_derivatives begin
4848
isempty(setdiff(right_names, left_names)) || throw(ArgumentError("All dimension names on right side of pattern must be present on left side: $(setdiff(right_names, left_names))"))
4949
reduced_dim_names = setdiff(left_names, right_names)
5050
reduced_dims = ntuple(i -> findfirst(isequal(reduced_dim_names[i]), left_names)::Int, length(left_names) - length(right_names))
5151
reduced_left_names = intersect(left_names, right_names)
5252
permutation = permutation_mapping(ntuple(i -> reduced_left_names[i], length(right_names)), right_names)
5353
reduced_dims, permutation
5454
end
55-
reduced = f(expanded, dims=reduced_dims)
56-
dropped = dropdims(reduced, dims=reduced_dims)
55+
reduced = isempty(dims) ? expanded : f(expanded; dims)
56+
dropped = dropdims(reduced; dims)
5757
permuted = _permutedims(dropped, permutation)
5858
collapsed = reshape_out(permuted, right)
5959
return collapsed

src/repeat.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ function Base.repeat(x::AbstractArray, (left, right)::ArrowPattern; context...)
5151
expanded = reshape_in(x, left; context_info...)
5252
permuted = _permutedims(expanded, permutation)
5353
reshaped = reshape(permuted, prerepeat_shape(size(expanded), left_names, right_names))
54-
repeated = repeat(reshaped, repeats...)
54+
repeated = all(isone, repeats) ? reshaped : repeat(reshaped, repeats...)
5555
collapsed = reshape_out(repeated, right)
5656
return collapsed
5757
end

test/reduce.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,4 +139,14 @@ using Test, Statistics
139139
# Can't drop dimensions in rearrange - use reduce instead
140140
@test reduce(sum, x, (:a, :b, :c, :d, :e) --> (:b, :d)) |> size == (100, 50)
141141
end
142+
143+
@testset "no-op allocation optimization" begin
144+
x = rand(2, 3, 4)
145+
146+
y = reduce(sum, x, (:a, :b, :c) --> (:a, :b, :c))
147+
@test pointer(x) == pointer(y)
148+
149+
y = reduce(sum, x, (:a, :b, :c) --> ((:a, :b), :c))
150+
@test pointer(x) == pointer(y)
151+
end
142152
end

test/repeat.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using Einops
2-
using Test
2+
using Test, Statistics
33

44
@testset "Repeat Operations" begin
55
@testset "basic repetitions" begin
@@ -91,4 +91,14 @@ using Test
9191
x = ["a" "b"; "c" "d"]
9292
@test repeat(x, (:a, :b) --> (:a, :b, :c), c=2) == cat(x, x, dims=3)
9393
end
94+
95+
@testset "no-op allocation optimization" begin
96+
x = rand(2, 3, 4)
97+
98+
y = repeat(x, (:a, :b, :c) --> (:a, :b, :c, 1))
99+
@test pointer(x) == pointer(y)
100+
101+
y = repeat(x, (:a, :b, :c) --> ((:a, :b), :c, 1))
102+
@test pointer(x) == pointer(y)
103+
end
94104
end

0 commit comments

Comments
 (0)