Skip to content

Commit 22535d0

Browse files
Skip layer guidance now works on stable audio model.
1 parent 8986151 commit 22535d0

1 file changed

Lines changed: 14 additions & 3 deletions

File tree

comfy/ldm/audio/dit.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,9 @@ def forward(
612612
return_info = False,
613613
**kwargs
614614
):
615+
patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {})
615616
batch, seq, device = *x.shape[:2], x.device
617+
context = kwargs["context"]
616618

617619
info = {
618620
"hidden_states": [],
@@ -643,9 +645,19 @@ def forward(
643645
if self.use_sinusoidal_emb or self.use_abs_pos_emb:
644646
x = x + self.pos_emb(x)
645647

648+
blocks_replace = patches_replace.get("dit", {})
646649
# Iterate over the transformer layers
647-
for layer in self.layers:
648-
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
650+
for i, layer in enumerate(self.layers):
651+
if ("double_block", i) in blocks_replace:
652+
def block_wrap(args):
653+
out = {}
654+
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"])
655+
return out
656+
657+
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap})
658+
x = out["img"]
659+
else:
660+
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context)
649661
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
650662

651663
if return_info:
@@ -874,7 +886,6 @@ def forward(
874886
mask=None,
875887
return_info=False,
876888
control=None,
877-
transformer_options={},
878889
**kwargs):
879890
return self._forward(
880891
x,

0 commit comments

Comments
 (0)