@@ -612,7 +612,9 @@ def forward(
612612 return_info = False ,
613613 ** kwargs
614614 ):
615+ patches_replace = kwargs .get ("transformer_options" , {}).get ("patches_replace" , {})
615616 batch , seq , device = * x .shape [:2 ], x .device
617+ context = kwargs ["context" ]
616618
617619 info = {
618620 "hidden_states" : [],
@@ -643,9 +645,19 @@ def forward(
643645 if self .use_sinusoidal_emb or self .use_abs_pos_emb :
644646 x = x + self .pos_emb (x )
645647
648+ blocks_replace = patches_replace .get ("dit" , {})
646649 # Iterate over the transformer layers
647- for layer in self .layers :
648- x = layer (x , rotary_pos_emb = rotary_pos_emb , global_cond = global_cond , ** kwargs )
650+ for i , layer in enumerate (self .layers ):
651+ if ("double_block" , i ) in blocks_replace :
652+ def block_wrap (args ):
653+ out = {}
654+ out ["img" ] = layer (args ["img" ], rotary_pos_emb = args ["pe" ], global_cond = args ["vec" ], context = args ["txt" ])
655+ return out
656+
657+ out = blocks_replace [("double_block" , i )]({"img" : x , "txt" : context , "vec" : global_cond , "pe" : rotary_pos_emb }, {"original_block" : block_wrap })
658+ x = out ["img" ]
659+ else :
660+ x = layer (x , rotary_pos_emb = rotary_pos_emb , global_cond = global_cond , context = context )
649661 # x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
650662
651663 if return_info :
@@ -874,7 +886,6 @@ def forward(
874886 mask = None ,
875887 return_info = False ,
876888 control = None ,
877- transformer_options = {},
878889 ** kwargs ):
879890 return self ._forward (
880891 x ,
0 commit comments