Skip to content

Commit a5404bf

Browse files
authored
Merge pull request #52 from lucidrains/pope
pope
2 parents d63ebb6 + de1926e commit a5404bf

5 files changed

Lines changed: 124 additions & 20 deletions

File tree

.github/workflows/test.yml

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
name: test
2+
on: [push, pull_request]
3+
4+
jobs:
5+
test:
6+
runs-on: ubuntu-latest
7+
strategy:
8+
matrix:
9+
python-version: ["3.10"]
10+
11+
steps:
12+
- uses: actions/checkout@v4
13+
14+
- name: Set up Python ${{ matrix.python-version }}
15+
uses: actions/setup-python@v5
16+
with:
17+
python-version: ${{ matrix.python-version }}
18+
19+
- name: Install dependencies
20+
run: |
21+
python -m pip install --upgrade pip
22+
pip install .
23+
pip install pytest
24+
25+
- name: Test with pytest
26+
run: |
27+
pytest tests/test_roformer.py

bs_roformer/bs_roformer.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from beartype import beartype
1313

1414
from rotary_embedding_torch import RotaryEmbedding
15+
from PoPE_pytorch import PoPE, flash_attn_with_pope
1516

1617
from einops import rearrange, pack, unpack
1718

@@ -73,6 +74,7 @@ def __init__(
7374
dim_head = 64,
7475
dropout = 0.,
7576
rotary_embed = None,
77+
pope_embed = None,
7678
flash = True,
7779
learned_value_residual_mix = False
7880
):
@@ -82,6 +84,9 @@ def __init__(
8284
dim_inner = heads * dim_head
8385

8486
self.rotary_embed = rotary_embed
87+
self.pope_embed = pope_embed
88+
89+
assert not (exists(rotary_embed) and exists(pope_embed)), 'cannot have both rotary and pope embeddings'
8590

8691
self.attend = Attend(flash = flash, dropout = dropout)
8792

@@ -111,11 +116,14 @@ def forward(self, x, value_residual = None):
111116
assert exists(value_residual)
112117
v = v.lerp(value_residual, mix)
113118

114-
if exists(self.rotary_embed):
119+
if exists(self.pope_embed):
120+
out = flash_attn_with_pope(q, k, v, pos_emb = self.pope_embed(q.shape[-2]), softmax_scale = self.scale)
121+
elif exists(self.rotary_embed):
115122
q = self.rotary_embed.rotate_queries_or_keys(q)
116123
k = self.rotary_embed.rotate_queries_or_keys(k)
117-
118-
out = self.attend(q, k, v)
124+
out = self.attend(q, k, v)
125+
else:
126+
out = self.attend(q, k, v)
119127

120128
gates = self.to_gates(x)
121129
out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
@@ -137,6 +145,7 @@ def __init__(
137145
ff_mult = 4,
138146
norm_output = True,
139147
rotary_embed = None,
148+
pope_embed = None,
140149
flash_attn = True,
141150
add_value_residual = False,
142151
num_residual_streams = 1,
@@ -150,7 +159,7 @@ def __init__(
150159

151160
for _ in range(depth):
152161
self.layers.append(ModuleList([
153-
init_hyper_conn(dim = dim, branch = Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_embed = rotary_embed, flash = flash_attn, learned_value_residual_mix = add_value_residual)),
162+
init_hyper_conn(dim = dim, branch = Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_embed = rotary_embed, pope_embed = pope_embed, flash = flash_attn, learned_value_residual_mix = add_value_residual)),
154163
init_hyper_conn(dim = dim, branch = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout))
155164
]))
156165

@@ -306,7 +315,8 @@ def __init__(
306315
multi_stft_resolutions_window_sizes: tuple[int, ...] = (4096, 2048, 1024, 512, 256),
307316
multi_stft_hop_size = 147,
308317
multi_stft_normalized = False,
309-
multi_stft_window_fn: Callable = torch.hann_window
318+
multi_stft_window_fn: Callable = torch.hann_window,
319+
use_pope = False
310320
):
311321
super().__init__()
312322

@@ -328,18 +338,24 @@ def __init__(
328338
num_residual_streams = num_residual_streams,
329339
num_residual_fracs = num_residual_fracs,
330340
mc_hyper_conn_sinkhorn_iters = mc_hyper_conn_sinkhorn_iters,
331-
norm_output = False,
341+
norm_output = False
332342
)
333343

334-
time_rotary_embed = RotaryEmbedding(dim = dim_head)
335-
freq_rotary_embed = RotaryEmbedding(dim = dim_head)
344+
if use_pope:
345+
time_pope_embed = PoPE(dim = dim_head, heads = heads)
346+
freq_pope_embed = PoPE(dim = dim_head, heads = heads)
347+
time_rotary_embed = freq_rotary_embed = None
348+
else:
349+
time_rotary_embed = RotaryEmbedding(dim = dim_head)
350+
freq_rotary_embed = RotaryEmbedding(dim = dim_head)
351+
time_pope_embed = freq_pope_embed = None
336352

337353
for layer_index in range(depth):
338354
is_first = layer_index == 0
339355

340356
self.layers.append(nn.ModuleList([
341-
Transformer(depth = time_transformer_depth, rotary_embed = time_rotary_embed, add_value_residual = not is_first, **transformer_kwargs),
342-
Transformer(depth = freq_transformer_depth, rotary_embed = freq_rotary_embed, add_value_residual = not is_first, **transformer_kwargs)
357+
Transformer(depth = time_transformer_depth, rotary_embed = time_rotary_embed, pope_embed = time_pope_embed, add_value_residual = not is_first, **transformer_kwargs),
358+
Transformer(depth = freq_transformer_depth, rotary_embed = freq_rotary_embed, pope_embed = freq_pope_embed, add_value_residual = not is_first, **transformer_kwargs)
343359
]))
344360

345361
self.final_norm = RMSNorm(dim)

bs_roformer/mel_band_roformer.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from beartype import beartype
1313

1414
from rotary_embedding_torch import RotaryEmbedding
15+
from PoPE_pytorch import PoPE, flash_attn_with_pope
1516

1617
from einops import rearrange, pack, unpack, reduce, repeat
1718
from einops.layers.torch import Rearrange
@@ -84,6 +85,7 @@ def __init__(
8485
dim_head = 64,
8586
dropout = 0.,
8687
rotary_embed = None,
88+
pope_embed = None,
8789
flash = True,
8890
add_value_residual = False
8991
):
@@ -93,6 +95,9 @@ def __init__(
9395
dim_inner = heads * dim_head
9496

9597
self.rotary_embed = rotary_embed
98+
self.pope_embed = pope_embed
99+
100+
assert not (exists(rotary_embed) and exists(pope_embed)), 'cannot have both rotary and pope embeddings'
96101

97102
self.attend = Attend(flash = flash, dropout = dropout)
98103

@@ -124,11 +129,14 @@ def forward(self, x, value_residual = None):
124129
assert exists(value_residual)
125130
v = v.lerp(mix, value_residual)
126131

127-
if exists(self.rotary_embed):
132+
if exists(self.pope_embed):
133+
out = flash_attn_with_pope(q, k, v, pos_emb = self.pope_embed(q.shape[-2]), softmax_scale = self.scale)
134+
elif exists(self.rotary_embed):
128135
q = self.rotary_embed.rotate_queries_or_keys(q)
129136
k = self.rotary_embed.rotate_queries_or_keys(k)
130-
131-
out = self.attend(q, k, v)
137+
out = self.attend(q, k, v)
138+
else:
139+
out = self.attend(q, k, v)
132140

133141
gates = self.to_gates(x)
134142
out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
@@ -217,6 +225,7 @@ def __init__(
217225
ff_mult = 4,
218226
norm_output = True,
219227
rotary_embed = None,
228+
pope_embed = None,
220229
flash_attn = True,
221230
linear_attn = False,
222231
add_value_residual = False,
@@ -234,7 +243,7 @@ def __init__(
234243
if linear_attn:
235244
attn = LinearAttention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, flash = flash_attn, add_value_residual = add_value_residual)
236245
else:
237-
attn = Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_embed = rotary_embed, flash = flash_attn, add_value_residual = add_value_residual)
246+
attn = Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_embed = rotary_embed, pope_embed = pope_embed, flash = flash_attn, add_value_residual = add_value_residual)
238247

239248
ff = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
240249

@@ -387,7 +396,8 @@ def __init__(
387396
match_input_audio_length = False, # if True, pad output tensor to match length of input tensor
388397
add_value_residual = True,
389398
num_residual_streams = 4,
390-
num_residual_fracs = 1
399+
num_residual_fracs = 1,
400+
use_pope = False
391401
):
392402
super().__init__()
393403

@@ -407,8 +417,14 @@ def __init__(
407417
num_residual_fracs = num_residual_fracs
408418
)
409419

410-
time_rotary_embed = RotaryEmbedding(dim = dim_head)
411-
freq_rotary_embed = RotaryEmbedding(dim = dim_head)
420+
if use_pope:
421+
time_pope_embed = PoPE(dim = dim_head, heads = heads)
422+
freq_pope_embed = PoPE(dim = dim_head, heads = heads)
423+
time_rotary_embed = freq_rotary_embed = None
424+
else:
425+
time_rotary_embed = RotaryEmbedding(dim = dim_head)
426+
freq_rotary_embed = RotaryEmbedding(dim = dim_head)
427+
time_pope_embed = freq_pope_embed = None
412428

413429
linear_flash_attn = default(linear_flash_attn, flash_attn)
414430

@@ -421,8 +437,8 @@ def __init__(
421437

422438
self.layers.append(nn.ModuleList([
423439
Transformer(depth = linear_transformer_depth, linear_attn = True, flash_attn = linear_flash_attn, add_value_residual = add_value_residual and not is_first, **transformer_kwargs) if linear_transformer_depth > 0 else None,
424-
Transformer(depth = time_transformer_depth, rotary_embed = time_rotary_embed, flash_attn = flash_attn, add_value_residual = add_value_residual and not is_first, **transformer_kwargs),
425-
Transformer(depth = freq_transformer_depth, rotary_embed = freq_rotary_embed, flash_attn = flash_attn, add_value_residual = add_value_residual and not is_first, **transformer_kwargs)
440+
Transformer(depth = time_transformer_depth, rotary_embed = time_rotary_embed, pope_embed = time_pope_embed, flash_attn = flash_attn, add_value_residual = add_value_residual and not is_first, **transformer_kwargs),
441+
Transformer(depth = freq_transformer_depth, rotary_embed = freq_rotary_embed, pope_embed = freq_pope_embed, flash_attn = flash_attn, add_value_residual = add_value_residual and not is_first, **transformer_kwargs)
426442
]))
427443

428444
self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'BS-RoFormer',
55
packages = find_packages(exclude=[]),
6-
version = '1.0.6',
6+
version = '1.1.0',
77
license='MIT',
88
description = 'BS-RoFormer - Band-Split Rotary Transformer for SOTA Music Source Separation',
99
author = 'Phil Wang',
@@ -22,6 +22,7 @@
2222
'einops>=0.8.0',
2323
'hyper-connections>=0.4.4',
2424
'librosa',
25+
'PoPE-pytorch>=0.0.15',
2526
'rotary-embedding-torch>=0.3.6',
2627
'torch>=2.0',
2728
],

tests/test_roformer.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import torch
2+
import pytest
3+
from bs_roformer import BSRoformer, MelBandRoformer
4+
from PoPE_pytorch import PoPE
5+
6+
@pytest.mark.parametrize('use_pope', [True, False])
7+
def test_bs_roformer(use_pope):
8+
model = BSRoformer(
9+
dim = 512,
10+
depth = 1,
11+
time_transformer_depth = 1,
12+
freq_transformer_depth = 1,
13+
use_pope = use_pope
14+
)
15+
16+
dummy_audio = torch.randn(1, 1, 44100)
17+
out = model(dummy_audio)
18+
19+
assert out.shape[0] == dummy_audio.shape[0]
20+
assert abs(out.shape[-1] - dummy_audio.shape[-1]) < 1024
21+
22+
# verify pope presence
23+
has_pope = any(isinstance(m, PoPE) for m in model.modules())
24+
assert has_pope == use_pope
25+
26+
@pytest.mark.parametrize('use_pope', [True, False])
27+
def test_mel_band_roformer(use_pope):
28+
model = MelBandRoformer(
29+
dim = 512,
30+
depth = 1,
31+
time_transformer_depth = 1,
32+
freq_transformer_depth = 1,
33+
use_pope = use_pope
34+
)
35+
36+
dummy_audio = torch.randn(1, 1, 44100)
37+
out = model(dummy_audio)
38+
39+
assert out.shape[0] == dummy_audio.shape[0]
40+
assert abs(out.shape[-1] - dummy_audio.shape[-1]) < 1024
41+
42+
# verify pope presence
43+
has_pope = any(isinstance(m, PoPE) for m in model.modules())
44+
assert has_pope == use_pope

0 commit comments

Comments
 (0)