@@ -949,7 +949,9 @@ def forward_core_with_concat(
949949 c_mod : torch .Tensor ,
950950 context : Optional [torch .Tensor ] = None ,
951951 control = None ,
952+ transformer_options = {},
952953 ) -> torch .Tensor :
954+ patches_replace = transformer_options .get ("patches_replace" , {})
953955 if self .register_length > 0 :
954956 context = torch .cat (
955957 (
@@ -961,14 +963,25 @@ def forward_core_with_concat(
961963
962964 # context is B, L', D
963965 # x is B, L, D
966+ blocks_replace = patches_replace .get ("dit" , {})
964967 blocks = len (self .joint_blocks )
965968 for i in range (blocks ):
966- context , x = self .joint_blocks [i ](
967- context ,
968- x ,
969- c = c_mod ,
970- use_checkpoint = self .use_checkpoint ,
971- )
969+ if ("double_block" , i ) in blocks_replace :
970+ def block_wrap (args ):
971+ out = {}
972+ out ["txt" ], out ["img" ] = self .joint_blocks [i ](args ["txt" ], args ["img" ], c = args ["vec" ])
973+ return out
974+
975+ out = blocks_replace [("double_block" , i )]({"img" : x , "txt" : context , "vec" : c_mod }, {"original_block" : block_wrap })
976+ context = out ["txt" ]
977+ x = out ["img" ]
978+ else :
979+ context , x = self .joint_blocks [i ](
980+ context ,
981+ x ,
982+ c = c_mod ,
983+ use_checkpoint = self .use_checkpoint ,
984+ )
972985 if control is not None :
973986 control_o = control .get ("output" )
974987 if i < len (control_o ):
@@ -986,6 +999,7 @@ def forward(
986999 y : Optional [torch .Tensor ] = None ,
9871000 context : Optional [torch .Tensor ] = None ,
9881001 control = None ,
1002+ transformer_options = {},
9891003 ) -> torch .Tensor :
9901004 """
9911005 Forward pass of DiT.
@@ -1007,7 +1021,7 @@ def forward(
10071021 if context is not None :
10081022 context = self .context_embedder (context )
10091023
1010- x = self .forward_core_with_concat (x , c , context , control )
1024+ x = self .forward_core_with_concat (x , c , context , control , transformer_options )
10111025
10121026 x = self .unpatchify (x , hw = hw ) # (N, out_channels, H, W)
10131027 return x [:,:,:hw [- 2 ],:hw [- 1 ]]
@@ -1021,7 +1035,8 @@ def forward(
10211035 context : Optional [torch .Tensor ] = None ,
10221036 y : Optional [torch .Tensor ] = None ,
10231037 control = None ,
1038+ transformer_options = {},
10241039 ** kwargs ,
10251040 ) -> torch .Tensor :
1026- return super ().forward (x , timesteps , context = context , y = y , control = control )
1041+ return super ().forward (x , timesteps , context = context , y = y , control = control , transformer_options = transformer_options )
10271042
0 commit comments