Skip to content

Commit 8ce2a10

Browse files
Optimizations to --fast and scaled fp8.
1 parent f82314f commit 8ce2a10

1 file changed

Lines changed: 15 additions & 1 deletion

File tree

comfy/ops.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,12 @@ def fp8_linear(self, input):
250250
if dtype not in [torch.float8_e4m3fn]:
251251
return None
252252

253+
tensor_2d = False
254+
if len(input.shape) == 2:
255+
tensor_2d = True
256+
input = input.unsqueeze(1)
257+
258+
253259
if len(input.shape) == 3:
254260
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
255261
w = w.t()
@@ -272,7 +278,11 @@ def fp8_linear(self, input):
272278
if isinstance(o, tuple):
273279
o = o[0]
274280

281+
if tensor_2d:
282+
return o.reshape(input.shape[0], -1)
283+
275284
return o.reshape((-1, input.shape[1], self.weight.shape[0]))
285+
276286
return None
277287

278288
class fp8_ops(manual_cast):
@@ -316,7 +326,11 @@ def forward_comfy_cast_weights(self, input):
316326
return out
317327

318328
weight, bias = cast_bias_weight(self, input)
319-
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
329+
330+
if weight.numel() < input.numel(): #TODO: optimize
331+
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
332+
else:
333+
return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
320334

321335
def convert_weight(self, weight, inplace=False, **kwargs):
322336
if inplace:

0 commit comments

Comments
 (0)