1212from beartype import beartype
1313
1414from rotary_embedding_torch import RotaryEmbedding
15+ from PoPE_pytorch import PoPE , flash_attn_with_pope
1516
1617from einops import rearrange , pack , unpack , reduce , repeat
1718from einops .layers .torch import Rearrange
@@ -84,6 +85,7 @@ def __init__(
8485 dim_head = 64 ,
8586 dropout = 0. ,
8687 rotary_embed = None ,
88+ pope_embed = None ,
8789 flash = True ,
8890 add_value_residual = False
8991 ):
@@ -93,6 +95,9 @@ def __init__(
9395 dim_inner = heads * dim_head
9496
9597 self .rotary_embed = rotary_embed
98+ self .pope_embed = pope_embed
99+
100+ assert not (exists (rotary_embed ) and exists (pope_embed )), 'cannot have both rotary and pope embeddings'
96101
97102 self .attend = Attend (flash = flash , dropout = dropout )
98103
@@ -124,11 +129,14 @@ def forward(self, x, value_residual = None):
124129 assert exists (value_residual )
125130 v = v .lerp (mix , value_residual )
126131
127- if exists (self .rotary_embed ):
132+ if exists (self .pope_embed ):
133+ out = flash_attn_with_pope (q , k , v , pos_emb = self .pope_embed (q .shape [- 2 ]), softmax_scale = self .scale )
134+ elif exists (self .rotary_embed ):
128135 q = self .rotary_embed .rotate_queries_or_keys (q )
129136 k = self .rotary_embed .rotate_queries_or_keys (k )
130-
131- out = self .attend (q , k , v )
137+ out = self .attend (q , k , v )
138+ else :
139+ out = self .attend (q , k , v )
132140
133141 gates = self .to_gates (x )
134142 out = out * rearrange (gates , 'b n h -> b h n 1' ).sigmoid ()
@@ -217,6 +225,7 @@ def __init__(
217225 ff_mult = 4 ,
218226 norm_output = True ,
219227 rotary_embed = None ,
228+ pope_embed = None ,
220229 flash_attn = True ,
221230 linear_attn = False ,
222231 add_value_residual = False ,
@@ -234,7 +243,7 @@ def __init__(
234243 if linear_attn :
235244 attn = LinearAttention (dim = dim , dim_head = dim_head , heads = heads , dropout = attn_dropout , flash = flash_attn , add_value_residual = add_value_residual )
236245 else :
237- attn = Attention (dim = dim , dim_head = dim_head , heads = heads , dropout = attn_dropout , rotary_embed = rotary_embed , flash = flash_attn , add_value_residual = add_value_residual )
246+ attn = Attention (dim = dim , dim_head = dim_head , heads = heads , dropout = attn_dropout , rotary_embed = rotary_embed , pope_embed = pope_embed , flash = flash_attn , add_value_residual = add_value_residual )
238247
239248 ff = FeedForward (dim = dim , mult = ff_mult , dropout = ff_dropout )
240249
@@ -387,7 +396,8 @@ def __init__(
387396 match_input_audio_length = False , # if True, pad output tensor to match length of input tensor
388397 add_value_residual = True ,
389398 num_residual_streams = 4 ,
390- num_residual_fracs = 1
399+ num_residual_fracs = 1 ,
400+ use_pope = False
391401 ):
392402 super ().__init__ ()
393403
@@ -407,8 +417,14 @@ def __init__(
407417 num_residual_fracs = num_residual_fracs
408418 )
409419
410- time_rotary_embed = RotaryEmbedding (dim = dim_head )
411- freq_rotary_embed = RotaryEmbedding (dim = dim_head )
420+ if use_pope :
421+ time_pope_embed = PoPE (dim = dim_head , heads = heads )
422+ freq_pope_embed = PoPE (dim = dim_head , heads = heads )
423+ time_rotary_embed = freq_rotary_embed = None
424+ else :
425+ time_rotary_embed = RotaryEmbedding (dim = dim_head )
426+ freq_rotary_embed = RotaryEmbedding (dim = dim_head )
427+ time_pope_embed = freq_pope_embed = None
412428
413429 linear_flash_attn = default (linear_flash_attn , flash_attn )
414430
@@ -421,8 +437,8 @@ def __init__(
421437
422438 self .layers .append (nn .ModuleList ([
423439 Transformer (depth = linear_transformer_depth , linear_attn = True , flash_attn = linear_flash_attn , add_value_residual = add_value_residual and not is_first , ** transformer_kwargs ) if linear_transformer_depth > 0 else None ,
424- Transformer (depth = time_transformer_depth , rotary_embed = time_rotary_embed , flash_attn = flash_attn , add_value_residual = add_value_residual and not is_first , ** transformer_kwargs ),
425- Transformer (depth = freq_transformer_depth , rotary_embed = freq_rotary_embed , flash_attn = flash_attn , add_value_residual = add_value_residual and not is_first , ** transformer_kwargs )
440+ Transformer (depth = time_transformer_depth , rotary_embed = time_rotary_embed , pope_embed = time_pope_embed , flash_attn = flash_attn , add_value_residual = add_value_residual and not is_first , ** transformer_kwargs ),
441+ Transformer (depth = freq_transformer_depth , rotary_embed = freq_rotary_embed , pope_embed = freq_pope_embed , flash_attn = flash_attn , add_value_residual = add_value_residual and not is_first , ** transformer_kwargs )
426442 ]))
427443
428444 self .stft_window_fn = partial (default (stft_window_fn , torch .hann_window ), stft_win_length )
0 commit comments