We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
TiledFusedLogitsLoss
1 parent 3292e07 commit c4b1a8cCopy full SHA for c4b1a8c
1 file changed
deepspeed/runtime/sequence_parallel/ulysses_sp.py
@@ -1012,7 +1012,7 @@ def forward(
1012
with torch.enable_grad():
1013
args = (self, x_shard, y_shard)
1014
if mask is not None:
1015
- args.append(mask_shards[i])
+ args += (mask_shards[i], )
1016
output = fn(*args)
1017
output_shards.append(output)
1018
torch.autograd.backward(output, incoming_grad)
0 commit comments