@@ -250,6 +250,12 @@ def fp8_linear(self, input):
250250 if dtype not in [torch .float8_e4m3fn ]:
251251 return None
252252
253+ tensor_2d = False
254+ if len (input .shape ) == 2 :
255+ tensor_2d = True
256+ input = input .unsqueeze (1 )
257+
258+
253259 if len (input .shape ) == 3 :
254260 w , bias = cast_bias_weight (self , input , dtype = dtype , bias_dtype = input .dtype )
255261 w = w .t ()
@@ -272,7 +278,11 @@ def fp8_linear(self, input):
272278 if isinstance (o , tuple ):
273279 o = o [0 ]
274280
281+ if tensor_2d :
282+ return o .reshape (input .shape [0 ], - 1 )
283+
275284 return o .reshape ((- 1 , input .shape [1 ], self .weight .shape [0 ]))
285+
276286 return None
277287
278288class fp8_ops (manual_cast ):
@@ -316,7 +326,11 @@ def forward_comfy_cast_weights(self, input):
316326 return out
317327
318328 weight , bias = cast_bias_weight (self , input )
319- return torch .nn .functional .linear (input , weight * self .scale_weight .to (device = weight .device , dtype = weight .dtype ), bias )
329+
330+ if weight .numel () < input .numel (): #TODO: optimize
331+ return torch .nn .functional .linear (input , weight * self .scale_weight .to (device = weight .device , dtype = weight .dtype ), bias )
332+ else :
333+ return torch .nn .functional .linear (input * self .scale_weight .to (device = weight .device , dtype = weight .dtype ), weight , bias )
320334
321335 def convert_weight (self , weight , inplace = False , ** kwargs ):
322336 if inplace :
0 commit comments