Skip to content

[Bug] AssertionError: Weight too large to fit in bucket during update_weights in vllm_rollout (Qwen3-VL-8B) #5952

@Silas-11

Description

@Silas-11

System Info

  • verl version: 0.7.1
  • Model: Qwen3-VL-8B-Instruct
  • Rollout engine: vllm
  • Hardware: NPU (Ascend)
  • Script: examples/grpo_trainer/run_qwen3_vl-8b_npu.sh

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Running the official example script for Qwen3-VL-8B on NPU:

bash examples/grpo_trainer/run_qwen3_vl-8b_npu.sh

Relevant ROLLOUT_CONFIG in the script (default, no update_weights_bucket_megabytes set):

ROLLOUT_CONFIG="
    actor_rollout_ref.rollout.name=${ENGINE} \
    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
    actor_rollout_ref.rollout.n=5 \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=20 \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
    actor_rollout_ref.rollout.max_num_batched_tokens=20000 \
    actor_rollout_ref.rollout.enable_chunked_prefill=True \
    actor_rollout_ref.rollout.enforce_eager=False \
    actor_rollout_ref.rollout.free_cache_engine=True \
    actor_rollout_ref.rollout.calculate_log_probs=True \
    +actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True"

Full stack trace:

File ".../verl/single_controller/base/decorator.py", line 433, in async_inner
return await func(*args, **kwargs)
File ".../verl/workers/fsdp_workers.py", line 1738, in update_weights
await self.rollout_mode()
File ".../verl/workers/fsdp_workers.py", line 847, in rollout_mode
await self.rollout.update_weights(per_tensor_param, peft_config=peft_config, base_sync_done=self.base_sync_done)
File ".../verl/workers/rollout/vllm_rollout/vllm_rollout.py", line 172, in update_weights
await sender.async_send_weights(weights)
File ".../verl/workers/rollout/vllm_rollout/bucketed_weight_transfer.py", line 136, in async_send_weights
assert offset + weight.nbytes <= self.bucket_size, (
AssertionError: Weight model.language_model.embed_tokens.weight(torch.Size([151936, 4096])), torch.float32) is too large to fit in the bucket. Please increase rollout.update_weights_bucket_megabytes(2048 MB).

Expected behavior

The official example script run_qwen3_vl-8b_npu.sh should run without error out of the box.

The default value of rollout.update_weights_bucket_megabytes (2048 MB) is too small to accommodate
the embed_tokens.weight tensor of Qwen3-VL-8B, which has shape [151936, 4096] in float32
(≈ 2.38 GB). The script should either:

  1. Set a sufficiently large default for update_weights_bucket_megabytes in the example script
    (e.g. 4096 MB), or
  2. Dynamically size the bucket based on the largest weight tensor in the model.

Workaround: manually adding the following line to ROLLOUT_CONFIG resolves the issue:
actor_rollout_ref.rollout.update_weights_bucket_megabytes=4096

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions