diff --git a/CHANGELOG.md b/CHANGELOG.md index 4414d5c65..43f5b468d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,13 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ## [Unreleased] +## [0.3.16] - 2026-05-12 + +### Changed +- Extend supported dependency ranges to allow `torch<2.12.0`, `peft<0.20.0`, and `pillow<12.3.0`. + +## [0.3.15] - 2026-03-31 + ### Added - Add ColQwen3.5 and BiQwen3.5 support (model + processor). Pretrained checkpoint: [athrael-soju/colqwen3.5-4.5B-v3](https://huggingface.co/athrael-soju/colqwen3.5-4.5B-v3). diff --git a/colpali_engine/models/__init__.py b/colpali_engine/models/__init__.py index e4845bf79..7d6e1232d 100644 --- a/colpali_engine/models/__init__.py +++ b/colpali_engine/models/__init__.py @@ -6,4 +6,11 @@ from .qwen2_5 import BiQwen2_5, BiQwen2_5_Processor, ColQwen2_5, ColQwen2_5_Processor from .qwen3 import BiQwen3, BiQwen3Processor, ColQwen3, ColQwen3Processor from .qwen3_5 import BiQwen3_5, BiQwen3_5Processor, ColQwen3_5, ColQwen3_5Processor -from .qwen_omni import ColQwen2_5Omni, ColQwen2_5OmniProcessor +from .qwen_omni import ( + BiQwen3Omni, + BiQwen3OmniProcessor, + ColQwen2_5Omni, + ColQwen2_5OmniProcessor, + ColQwen3Omni, + ColQwen3OmniProcessor, +) diff --git a/colpali_engine/models/qwen_omni/__init__.py b/colpali_engine/models/qwen_omni/__init__.py index 7dd081290..fba4e7887 100644 --- a/colpali_engine/models/qwen_omni/__init__.py +++ b/colpali_engine/models/qwen_omni/__init__.py @@ -1 +1,2 @@ +from .colqwen3_omni import BiQwen3Omni, BiQwen3OmniProcessor, ColQwen3Omni, ColQwen3OmniProcessor from .colqwen_omni import ColQwen2_5Omni, ColQwen2_5OmniProcessor diff --git a/colpali_engine/models/qwen_omni/colqwen3_omni/__init__.py b/colpali_engine/models/qwen_omni/colqwen3_omni/__init__.py new file mode 100644 index 000000000..b2653a1c2 --- /dev/null +++ b/colpali_engine/models/qwen_omni/colqwen3_omni/__init__.py @@ -0,0 +1,10 @@ +from .configuration_bidirlm_omni import ( + BidirLMOmniAudioConfig, + BidirLMOmniConfig, + BidirLMOmniTextConfig, + BidirLMOmniVisionConfig, +) +from .modeling_biqwen3_omni import BiQwen3Omni +from .modeling_colqwen3_omni import ColQwen3Omni +from .processing_biqwen3_omni import BiQwen3OmniProcessor +from .processing_colqwen3_omni import ColQwen3OmniProcessor diff --git a/colpali_engine/models/qwen_omni/colqwen3_omni/configuration_bidirlm_omni.py b/colpali_engine/models/qwen_omni/colqwen3_omni/configuration_bidirlm_omni.py new file mode 100644 index 000000000..e4a52a752 --- /dev/null +++ b/colpali_engine/models/qwen_omni/colqwen3_omni/configuration_bidirlm_omni.py @@ -0,0 +1,236 @@ +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +# ── Audio encoder config ────────────────────────────── + + +class BidirLMOmniAudioConfig(PretrainedConfig): + model_type = "bidirlm_omni_audio" + + def __init__( + self, + num_mel_bins=128, + encoder_layers=32, + encoder_attention_heads=20, + encoder_ffn_dim=5120, + d_model=1280, + dropout=0, + attention_dropout=0, + activation_function="gelu", + activation_dropout=0, + scale_embedding=False, + initializer_range=0.02, + max_source_positions=1500, + n_window=100, + output_dim=3584, + n_window_infer=400, + conv_chunksize=500, + downsample_hidden_size=480, + **kwargs, + ): + super().__init__(**kwargs) + self.num_mel_bins = num_mel_bins + self.d_model = d_model + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.encoder_ffn_dim = encoder_ffn_dim + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_function = activation_function + self.activation_dropout = activation_dropout + self.num_hidden_layers = encoder_layers + self.initializer_range = initializer_range + self.scale_embedding = scale_embedding + self.max_source_positions = max_source_positions + self.n_window = n_window + self.output_dim = output_dim + self.n_window_infer = n_window_infer + self.conv_chunksize = conv_chunksize + self.downsample_hidden_size = downsample_hidden_size + + +# ── Vision encoder config ───────────────────────────── + + +class BidirLMOmniVisionConfig(PretrainedConfig): + model_type = "bidirlm_omni_vision" + base_config_key = "vision_config" + + def __init__( + self, + depth=27, + hidden_size=1152, + hidden_act="gelu_pytorch_tanh", + intermediate_size=4304, + num_heads=16, + in_channels=3, + patch_size=16, + spatial_merge_size=2, + temporal_patch_size=2, + out_hidden_size=3584, + num_position_embeddings=2304, + deepstack_visual_indexes=None, + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + if deepstack_visual_indexes is None: + deepstack_visual_indexes = [8, 16, 24] + self.depth = depth + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.out_hidden_size = out_hidden_size + self.num_position_embeddings = num_position_embeddings + self.initializer_range = initializer_range + self.deepstack_visual_indexes = deepstack_visual_indexes + + +# ── Shared text encoder config ────────────────────────────────────────────── + + +class BidirLMOmniTextConfig(PretrainedConfig): + model_type = "bidirlm_omni_text" + base_config_key = "text_config" + # mrope_section/mrope_interleaved are model-specific rope_scaling keys. + # Without this, validate_rope() called by huggingface_hub warns about them. + ignore_keys_at_rope_validation = {"mrope_section", "mrope_interleaved"} + + def __init__( + self, + vocab_size=151936, + hidden_size=4096, + intermediate_size=22016, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + head_dim=128, + hidden_act="silu", + max_position_embeddings=128000, + initializer_range=0.02, + rms_norm_eps=1e-6, + tie_word_embeddings=False, + rope_theta=5000000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + clf_pooling="late", + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + + self.head_dim = head_dim + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.clf_pooling = clf_pooling + self.is_causal = False + + # In tf5, super().__init__() calls convert_rope_params_to_dict() + validate_rope() + # automatically via huggingface_hub. ignore_keys_at_rope_validation (class attr above) + # tells validate_rope() to skip mrope_section/mrope_interleaved warnings. + # The old rope_config_validation() call is not needed and emits a FutureWarning. + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +# ── Top-level omni config ────────────────────────────────────────────────── + + +class BidirLMOmniConfig(PretrainedConfig): + model_type = "bidirlm_omni" + ignore_keys_at_rope_validation = {"mrope_section", "mrope_interleaved"} + sub_configs = { + "audio_config": BidirLMOmniAudioConfig, + "vision_config": BidirLMOmniVisionConfig, + "text_config": BidirLMOmniTextConfig, + } + + def __init__( + self, + text_config=None, + audio_config=None, + vision_config=None, + # Audio special tokens + audio_token_id=151676, + audio_start_token_id=151669, + audio_end_token_id=151670, + # Vision special tokens + image_token_id=151655, + video_token_id=151656, + vision_start_token_id=151652, + vision_end_token_id=151653, + tie_word_embeddings=True, + text_weights_source="visual", + # Classification / fine-tuning + num_labels=1, + problem_type=None, + clf_pooling="late", + **kwargs, + ): + if isinstance(audio_config, dict): + self.audio_config = BidirLMOmniAudioConfig(**audio_config) + elif audio_config is None: + self.audio_config = BidirLMOmniAudioConfig() + else: + self.audio_config = audio_config + + if isinstance(vision_config, dict): + self.vision_config = BidirLMOmniVisionConfig(**vision_config) + elif vision_config is None: + self.vision_config = BidirLMOmniVisionConfig() + else: + self.vision_config = vision_config + + if isinstance(text_config, dict): + self.text_config = BidirLMOmniTextConfig(**text_config) + elif text_config is None: + self.text_config = BidirLMOmniTextConfig() + else: + self.text_config = text_config + + self.audio_token_id = audio_token_id + self.audio_start_token_id = audio_start_token_id + self.audio_end_token_id = audio_end_token_id + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.vision_start_token_id = vision_start_token_id + self.vision_end_token_id = vision_end_token_id + self.text_weights_source = text_weights_source + self.clf_pooling = clf_pooling + + # num_labels / problem_type must be set AFTER super().__init__() because + # PretrainedConfig.num_labels is a property that accesses id2label, which + # is only initialised by super().__init__(). + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + self.num_labels = num_labels + self.problem_type = problem_type + + +__all__ = [ + "BidirLMOmniConfig", + "BidirLMOmniTextConfig", + "BidirLMOmniAudioConfig", + "BidirLMOmniVisionConfig", +] diff --git a/colpali_engine/models/qwen_omni/colqwen3_omni/modeling_bidirlm_omni.py b/colpali_engine/models/qwen_omni/colqwen3_omni/modeling_bidirlm_omni.py new file mode 100644 index 000000000..ed2465833 --- /dev/null +++ b/colpali_engine/models/qwen_omni/colqwen3_omni/modeling_bidirlm_omni.py @@ -0,0 +1,1564 @@ +import math +from typing import Callable, Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F # noqa: N812 +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.activations import ACT2FN +from transformers.integrations import use_kernel_forward_from_hub +from transformers.modeling_layers import GradientCheckpointingLayer +from transformers.modeling_outputs import ( + BaseModelOutput, + MaskedLMOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.processing_utils import Unpack +from transformers.utils import auto_docstring +from transformers.utils.generic import TransformersKwargs + +try: + from .configuration_bidirlm_omni import ( + BidirLMOmniAudioConfig, + BidirLMOmniConfig, + BidirLMOmniTextConfig, + BidirLMOmniVisionConfig, + ) +except ImportError: + from configuration_bidirlm_omni import ( + BidirLMOmniAudioConfig, + BidirLMOmniConfig, + BidirLMOmniTextConfig, + BidirLMOmniVisionConfig, + ) + + +# ═══════════════════════════════════════════════════════════════════════════ +# Shared utilities +# ═══════════════════════════════════════════════════════════════════════════ + + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def _get_feat_extract_output_lengths(input_lengths): + # Three Conv2d layers each with kernel=3, stride=2, padding=1. + # Per-layer formula: floor((L - 1) / 2) + 1 + length = (input_lengths - 1) // 2 + 1 + length = (length - 1) // 2 + 1 + length = (length - 1) // 2 + 1 + return length + + +@use_kernel_forward_from_hub("RMSNorm") +class BidirLMOmniRMSNorm(nn.Module): + def __init__(self, hidden_size, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +# ═══════════════════════════════════════════════════════════════════════════ +# Audio encoder (copied from BiQwen3-SP) +# ═══════════════════════════════════════════════════════════════════════════ + + +class BidirLMOmniAudioAttention(nn.Module): + def __init__(self, config: BidirLMOmniAudioConfig): + super().__init__() + self.embed_dim = config.d_model + self.num_heads = config.encoder_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.num_key_value_groups = 1 + self.config = config + self.scaling = self.head_dim**-0.5 + self.attention_dropout = 0.0 + self.is_causal = False + + if (self.head_dim * self.num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads " + f"(got embed_dim={self.embed_dim}, num_heads={self.num_heads})." + ) + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + seq_length, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).reshape(seq_length, self.num_heads, -1) + key_states = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1) + value_states = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1) + + query_states = query_states.transpose(0, 1).unsqueeze(0) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + cu_seq_lens_q=cu_seqlens, + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) + + attn_output = attn_output.reshape(seq_length, -1).contiguous() + return self.out_proj(attn_output) + + +class BidirLMOmniAudioEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: BidirLMOmniAudioConfig): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = BidirLMOmniAudioAttention(config) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> tuple[torch.Tensor]: + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + cu_seqlens=cu_seqlens, + attention_mask=attention_mask, + **kwargs, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16: + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + return (hidden_states,) + + +class SinusoidsPositionEmbedding(nn.Module): + def __init__(self, length, channels, max_timescale=10000): + super().__init__() + if channels % 2 != 0: + raise ValueError("SinusoidsPositionEmbedding needs even channels input") + # Store scalars so forward can recompute in float32. + # tf5 casts persistent=False buffers to the model dtype (bfloat16), + # which degrades sinusoidal precision and diverges from tf4 numerics. + self.length = length + self.channels = channels + self.max_timescale = max_timescale + # Register a dummy buffer only for device tracking. + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2).float()) + scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + self.register_buffer( + "positional_embedding", + torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1), + persistent=False, + ) + + def _recompute(self, seqlen: int, device) -> torch.Tensor: + log_timescale_increment = np.log(self.max_timescale) / (self.channels // 2 - 1) + inv_timescales = torch.exp( + -log_timescale_increment * torch.arange(self.channels // 2, dtype=torch.float32, device=device) + ) + scaled_time = torch.arange(seqlen, dtype=torch.float32, device=device)[:, None] * inv_timescales[None, :] + return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) + + def forward(self, seqlen: int): + # Recompute in float32 every call — do NOT use the stored buffer whose + # values may have been cast to bfloat16 by from_pretrained. + return self._recompute(seqlen, self.positional_embedding.device) + + +class BidirLMOmniAudioEncoder(PreTrainedModel): + config: BidirLMOmniAudioConfig + main_input_name = "input_features" + _no_split_modules = ["BidirLMOmniAudioEncoderLayer"] + _supports_sdpa = True + + def __init__(self, config: BidirLMOmniAudioConfig): + super().__init__(config) + self.dropout = config.dropout + + embed_dim = config.d_model + self.num_mel_bins = config.num_mel_bins + self.max_source_positions = config.max_source_positions + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + self.n_window = config.n_window + self.positional_embedding = SinusoidsPositionEmbedding(self.max_source_positions, embed_dim) + self.layers = nn.ModuleList([BidirLMOmniAudioEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.ln_post = nn.LayerNorm(config.d_model) + self.gradient_checkpointing = False + self.conv2d1 = nn.Conv2d(1, config.downsample_hidden_size, 3, 2, padding=1) + self.conv2d2 = nn.Conv2d(config.downsample_hidden_size, config.downsample_hidden_size, 3, 2, padding=1) + self.conv2d3 = nn.Conv2d(config.downsample_hidden_size, config.downsample_hidden_size, 3, 2, padding=1) + self.conv_out = nn.Linear( + config.downsample_hidden_size * ((((config.num_mel_bins + 1) // 2 + 1) // 2 + 1) // 2), + config.d_model, + bias=False, + ) + self.proj1 = nn.Linear(config.d_model, config.d_model) + self.act = ACT2FN[config.activation_function] + self.proj2 = nn.Linear(config.d_model, config.output_dim) + self.n_window_infer = self.config.n_window_infer + self.conv_chunksize = self.config.conv_chunksize + self.post_init() + + def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> Optional[torch.Tensor]: + if self.config._attn_implementation == "flash_attention_2": + return None + + seq_length = inputs_tensor.shape[0] + attention_mask = torch.full( + [1, 1, seq_length, seq_length], + torch.finfo(inputs_tensor.dtype).min, + device=inputs_tensor.device, + dtype=inputs_tensor.dtype, + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 + return attention_mask + + def forward(self, input_features, feature_lens=None, aftercnn_lens=None): + aftercnn_lens = _get_feat_extract_output_lengths(feature_lens) + chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long() + + chunk_lengths = torch.tensor( + [self.n_window * 2] * chunk_num.sum(), + dtype=torch.long, + device=feature_lens.device, + ) + tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:] + chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2) + chunk_lengths[chunk_lengths == 0] = self.n_window * 2 + + chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0) + padded_feature = nn.utils.rnn.pad_sequence(chunk_list, batch_first=True).transpose(1, 2) + feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths) + padded_mask_after_cnn = nn.utils.rnn.pad_sequence( + [torch.ones(length, dtype=torch.bool, device=padded_feature.device) for length in feature_lens_after_cnn], + batch_first=True, + ) + padded_feature = padded_feature.unsqueeze(1) + padded_embeds = [] + for chunk in padded_feature.split(self.conv_chunksize, dim=0): + padded_embed = F.gelu(self.conv2d1(chunk)) + padded_embed = F.gelu(self.conv2d2(padded_embed)) + padded_embed = F.gelu(self.conv2d3(padded_embed)) + padded_embeds.append(padded_embed) + padded_embed = torch.cat(padded_embeds, dim=0) + b, c, f, t = padded_embed.size() + padded_embed = self.conv_out(padded_embed.permute(0, 3, 1, 2).contiguous().view(b, t, c * f)) + + # Call forward() which recomputes sinusoids in float32 (avoids bfloat16 buffer precision loss). + positional_embedding = self.positional_embedding(padded_embed.shape[1]).unsqueeze(0).to(padded_embed.dtype) + padded_embed = padded_embed + positional_embedding + hidden_states = padded_embed[padded_mask_after_cnn] + + cu_chunk_lens = [0] + window_aftercnn = padded_mask_after_cnn.shape[-1] * (self.n_window_infer // (self.n_window * 2)) + for cnn_len in aftercnn_lens: + cu_chunk_lens += [window_aftercnn] * (cnn_len // window_aftercnn) + remainder = cnn_len % window_aftercnn + if remainder != 0: + cu_chunk_lens += [remainder] + cu_seqlens = torch.tensor(cu_chunk_lens, device=aftercnn_lens.device).cumsum(-1, dtype=torch.int32) + + for encoder_layer in self.layers: + layer_outputs = encoder_layer(hidden_states, cu_seqlens) + hidden_states = layer_outputs[0] + + hidden_states = self.ln_post(hidden_states) + hidden_states = self.proj1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.proj2(hidden_states) + return BaseModelOutput(last_hidden_state=hidden_states) + + +# ═══════════════════════════════════════════════════════════════════════════ +# Vision encoder (copied from BiQwen3-VL) +# ═══════════════════════════════════════════════════════════════════════════ + + +class BidirLMOmniVisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True) + self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state))) + + +class BidirLMOmniVisionPatchEmbed(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.patch_size = config.patch_size + self.temporal_patch_size = config.temporal_patch_size + self.in_channels = config.in_channels + self.embed_dim = config.hidden_size + + kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size] + self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view( + -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + return hidden_states + + +class BidirLMOmniVisionRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + # Store theta/dim so forward can recompute inv_freq in float32. + # We still register a buffer for device tracking, but don't use its values. + self.dim = dim + self.theta = theta + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + # Recompute inv_freq in float32 — matches training precision. + inv_freq = 1.0 / ( + self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=self.inv_freq.device) / self.dim) + ) + seq = torch.arange(seqlen, dtype=torch.float32, device=self.inv_freq.device) + return torch.outer(seq, inv_freq) + + +class BidirLMOmniVisionPatchMerger(nn.Module): + def __init__(self, config: BidirLMOmniVisionConfig, use_postshuffle_norm=False) -> None: + super().__init__() + self.hidden_size = config.hidden_size * (config.spatial_merge_size**2) + self.use_postshuffle_norm = use_postshuffle_norm + self.norm = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size, eps=1e-6) + self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size) + self.act_fn = nn.GELU() + self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size) + x = self.linear_fc2(self.act_fn(self.linear_fc1(x))) + return x + + +def apply_rotary_pos_emb_vision( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + orig_q_dtype = q.dtype + orig_k_dtype = k.dtype + q, k = q.float(), k.float() + cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed.to(orig_q_dtype), k_embed.to(orig_k_dtype) + + +class BidirLMOmniVisionAttention(nn.Module): + def __init__(self, config: BidirLMOmniVisionConfig) -> None: + super().__init__() + self.dim = config.hidden_size + self.num_heads = config.num_heads + self.head_dim = self.dim // self.num_heads + self.num_key_value_groups = 1 + self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True) + self.proj = nn.Linear(self.dim, self.dim) + self.scaling = self.head_dim**-0.5 + self.config = config + self.attention_dropout = 0.0 + self.is_causal = False + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + query_states, key_states, value_states = ( + self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + ) + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) + + query_states = query_states.transpose(0, 1).unsqueeze(0) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + if self.config._attn_implementation == "flash_attention_2": + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + cu_seq_lens_q=cu_seqlens, + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=self.is_causal, + **kwargs, + ) + else: + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + ] + attn_outputs = [ + attention_interface( + self, + q, + k, + v, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + is_causal=self.is_causal, + **kwargs, + )[0] + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) + + attn_output = attn_output.reshape(seq_length, -1).contiguous() + attn_output = self.proj(attn_output) + return attn_output + + +class BidirLMOmniVisionBlock(GradientCheckpointingLayer): + def __init__(self, config, attn_implementation: str = "sdpa") -> None: + super().__init__() + self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6) + self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6) + self.attn = BidirLMOmniVisionAttention(config=config) + self.mlp = BidirLMOmniVisionMLP(config=config) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class BidirLMOmniVisionModel(PreTrainedModel): + config: BidirLMOmniVisionConfig + _no_split_modules = ["BidirLMOmniVisionBlock"] + + def __init__(self, config, *inputs, **kwargs) -> None: + super().__init__(config, *inputs, **kwargs) + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = config.patch_size + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + self.patch_embed = BidirLMOmniVisionPatchEmbed(config=config) + self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size) + self.num_grid_per_side = int(config.num_position_embeddings**0.5) + + head_dim = config.hidden_size // config.num_heads + self.rotary_pos_emb = BidirLMOmniVisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList([BidirLMOmniVisionBlock(config) for _ in range(config.depth)]) + self.merger = BidirLMOmniVisionPatchMerger(config=config, use_postshuffle_norm=False) + + self.deepstack_visual_indexes = config.deepstack_visual_indexes + self.deepstack_merger_list = nn.ModuleList( + [ + BidirLMOmniVisionPatchMerger(config=config, use_postshuffle_norm=True) + for _ in range(len(config.deepstack_visual_indexes)) + ] + ) + + self.gradient_checkpointing = False + + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + merge_size = self.spatial_merge_size + max_hw = int(grid_thw[:, 1:].max().item()) + freq_table = self.rotary_pos_emb(max_hw) + device = freq_table.device + + total_tokens = int(torch.prod(grid_thw, dim=1).sum().item()) + pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) + + offset = 0 + for num_frames, height, width in grid_thw: + merged_h, merged_w = height // merge_size, width // merge_size + block_rows = torch.arange(merged_h, device=device) + block_cols = torch.arange(merged_w, device=device) + intra_row = torch.arange(merge_size, device=device) + intra_col = torch.arange(merge_size, device=device) + + row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] + col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] + row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + + coords = torch.stack((row_idx, col_idx), dim=-1) + if num_frames > 1: + coords = coords.repeat(num_frames, 1) + + num_tokens = coords.shape[0] + pos_ids[offset : offset + num_tokens] = coords + offset += num_tokens + + embeddings = freq_table[pos_ids] + return embeddings.flatten(1) + + def fast_pos_embed_interpolate(self, grid_thw): + grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] + idx_list = [[] for _ in range(4)] + weight_list = [[] for _ in range(4)] + + for t, h, w in zip(grid_ts, grid_hs, grid_ws): + h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) + w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) + h_idxs_floor = h_idxs.int() + w_idxs_floor = w_idxs.int() + h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + base_h = h_idxs_floor * self.num_grid_per_side + base_h_ceil = h_idxs_ceil * self.num_grid_per_side + indices = [ + (base_h[None].T + w_idxs_floor[None]).flatten(), + (base_h[None].T + w_idxs_ceil[None]).flatten(), + (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), + (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), + ] + weights = [ + ((1 - dh)[None].T * (1 - dw)[None]).flatten(), + ((1 - dh)[None].T * dw[None]).flatten(), + (dh[None].T * (1 - dw)[None]).flatten(), + (dh[None].T * dw[None]).flatten(), + ] + for i in range(4): + idx_list[i].extend(indices[i].tolist()) + weight_list[i].extend(weights[i].tolist()) + + idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=self.pos_embed.weight.device) + weight_tensor = torch.tensor( + weight_list, dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device + ) + pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None] + patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] + patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) + + merge_size = self.config.spatial_merge_size + patch_pos_embeds_permute = [] + for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): + pos_embed = pos_embed.repeat(t, 1) + pos_embed = ( + pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) + .permute(0, 1, 3, 2, 4, 5) + .flatten(0, 4) + ) + patch_pos_embeds_permute.append(pos_embed) + return torch.cat(patch_pos_embeds_permute) + + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: + hidden_states = self.patch_embed(hidden_states) + hidden_states = hidden_states + self.fast_pos_embed_interpolate(grid_thw) + + rotary_pos_emb = self.rot_pos_emb(grid_thw) + seq_len = hidden_states.shape[0] + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + deepstack_feature_lists = [] + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, + **kwargs, + ) + if layer_num in self.deepstack_visual_indexes: + deepstack_feature = self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)]( + hidden_states + ) + deepstack_feature_lists.append(deepstack_feature) + + hidden_states = self.merger(hidden_states) + return hidden_states, deepstack_feature_lists + + +# ═══════════════════════════════════════════════════════════════════════════ +# Shared text encoder (supports both audio injection + DeepStack visual) +# ═══════════════════════════════════════════════════════════════════════════ + + +class BidirLMOmniTextRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor + + def __init__(self, config: BidirLMOmniTextConfig, device=None): + super().__init__() + self.rope_type = ( + config.rope_scaling.get("rope_type", "default") + if hasattr(config, "rope_scaling") and config.rope_scaling is not None + else "default" + ) + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + self.config = config + # transformers 5.x removed "default" from ROPE_INIT_FUNCTIONS. + # For rope_type="default" (standard inv_freq, no scaling) we compute directly. + if self.rope_type == "default" or self.rope_type not in ROPE_INIT_FUNCTIONS: + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + rope_theta = getattr(config, "rope_theta", 10000.0) + inv_freq = 1.0 / ( + rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) / head_dim) + ) + self.attention_scaling = 1.0 + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + self.mrope_section = (config.rope_scaling or {}).get("mrope_section", [24, 20, 20]) + + def compute_default_rope_parameters(self, config=None): + """Required by transformers 5.x _init_weights when rope_type='default'.""" + config = config or self.config + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + rope_theta = getattr(config, "rope_theta", 10000.0) + inv_freq = 1.0 / (rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim)) + return inv_freq, 1.0 + + def apply_interleaved_mrope(self, freqs, mrope_section): + freqs_t = freqs[0] + for dim, offset in enumerate((1, 2), start=1): + length = mrope_section[dim] * 3 + idx = slice(offset, length, 3) + freqs_t[..., idx] = freqs[dim, ..., idx] + return freqs_t + + @torch.no_grad() + def forward(self, x, position_ids): + if position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + freqs = self.apply_interleaved_mrope(freqs, self.mrope_section) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class BidirLMOmniTextAttention(nn.Module): + """Bidirectional multi-head attention (no causal mask, no KV cache).""" + + def __init__(self, config: BidirLMOmniTextConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = False + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.q_norm = BidirLMOmniRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = BidirLMOmniRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.sliding_window = None + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + return self.o_proj(attn_output) + + +class BidirLMOmniTextMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class BidirLMOmniTextEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: BidirLMOmniTextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = BidirLMOmniTextAttention(config=config, layer_idx=layer_idx) + self.mlp = BidirLMOmniTextMLP(config) + self.input_layernorm = BidirLMOmniRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = BidirLMOmniRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.self_attn( + hidden_states=self.input_layernorm(hidden_states), + attention_mask=attention_mask, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.mlp(self.post_attention_layernorm(hidden_states)) + hidden_states = residual + hidden_states + return hidden_states + + +# ═══════════════════════════════════════════════════════════════════════════ +# PreTrainedModel base +# ═══════════════════════════════════════════════════════════════════════════ + + +@auto_docstring +class BidirLMOmniPreTrainedModel(PreTrainedModel): + config: BidirLMOmniConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = [ + "BidirLMOmniTextEncoderLayer", + "BidirLMOmniAudioEncoderLayer", + "BidirLMOmniVisionBlock", + ] + _supports_flash_attn = True + _supports_sdpa = True + _supports_attention_backend = True + + +# ═══════════════════════════════════════════════════════════════════════════ +# Text encoder model (with DeepStack visual injection support) +# ═══════════════════════════════════════════════════════════════════════════ + + +class BidirLMOmniTextModel(BidirLMOmniPreTrainedModel): + """ + Bidirectional text encoder. Supports: + - audio feature injection via ``masked_scatter`` + - DeepStack visual feature injection at intermediate layers + """ + + config: BidirLMOmniTextConfig + _no_split_modules = ["BidirLMOmniTextEncoderLayer"] + + def __init__(self, config: BidirLMOmniTextConfig): + super().__init__(config) + self.padding_idx = getattr(config, "pad_token_id", None) + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [BidirLMOmniTextEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = BidirLMOmniRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = BidirLMOmniTextRotaryEmbedding(config) + self.gradient_checkpointing = False + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + # DeepStack visual injection args + visual_pos_masks: Optional[torch.Tensor] = None, + deepstack_visual_embeds: Optional[list[torch.Tensor]] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutput: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("Specify exactly one of input_ids or inputs_embeds.") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_len = inputs_embeds.shape[:2] + + if position_ids is None: + position_ids = torch.arange(seq_len, device=inputs_embeds.device) + position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) + if position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, batch_size, -1) + + extended_attention_mask: Optional[torch.Tensor] = None + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + # Flash attention computes cu_seqlens from a 2D mask internally; + # passing a 4D mask breaks the varlen path. + extended_attention_mask = attention_mask + else: + # Convert 1/0 mask to additive float mask (0.0 = attend, -inf = ignore). + # The old boolean expand (True→+1, False→+0) added to attn_weights was NOT + # a real mask: padding tokens still participated in softmax, corrupting + # embeddings of shorter sequences when batched with longer ones. + float_mask = attention_mask.to(dtype=inputs_embeds.dtype) + extended_attention_mask = (1.0 - float_mask)[:, None, None, :] * torch.finfo(inputs_embeds.dtype).min + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for layer_idx, encoder_layer in enumerate(self.layers): + hidden_states = encoder_layer( + hidden_states, + attention_mask=extended_attention_mask, + position_embeddings=position_embeddings, + **kwargs, + ) + + # DeepStack: add visual features at intermediate layers + if deepstack_visual_embeds is not None and layer_idx < len(deepstack_visual_embeds): + hidden_states = self._deepstack_process( + hidden_states, + visual_pos_masks, + deepstack_visual_embeds[layer_idx], + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutput(last_hidden_state=hidden_states) + + def _deepstack_process( + self, + hidden_states: torch.Tensor, + visual_pos_masks: torch.Tensor, + visual_embeds: torch.Tensor, + ) -> torch.Tensor: + visual_pos_masks = visual_pos_masks.to(hidden_states.device) + visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype) + hidden_states[visual_pos_masks, :] = hidden_states[visual_pos_masks, :].clone() + visual_embeds + return hidden_states + + +# ═══════════════════════════════════════════════════════════════════════════ +# Top-level Omni model +# ═══════════════════════════════════════════════════════════════════════════ + + +@auto_docstring( + custom_intro="Multimodal encoder combining audio tower, vision tower, and shared bidirectional text encoder." +) +class BidirLMOmniModel(BidirLMOmniPreTrainedModel): + """ + Audio + Vision + Text omni encoder. + Accepts any combination of modalities: text-only, text+audio, text+vision, text+audio+vision. + """ + + config: BidirLMOmniConfig + + def __init__(self, config: BidirLMOmniConfig): + super().__init__(config) + # Flash/SDPA attention only applies to the text encoder; + # audio and vision towers always run eager (no causal masking or varlen path needed). + config.audio_config._attn_implementation = "eager" + config.vision_config._attn_implementation = "eager" + config.text_config._attn_implementation = config._attn_implementation + self.audio_tower = BidirLMOmniAudioEncoder._from_config(config.audio_config) + self.visual = BidirLMOmniVisionModel._from_config(config.vision_config) + self.language_model = BidirLMOmniTextModel._from_config(config.text_config) + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + # ── Audio helpers ────────────────────────────────────────────────── + + def get_audio_features( + self, + input_features: torch.FloatTensor, + feature_attention_mask: Optional[torch.LongTensor] = None, + audio_feature_lengths: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + if feature_attention_mask is not None: + audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) + + feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) + + audio_features = [] + for input_feature, feature_len in zip(input_features, feature_lens): + audio_output = self.audio_tower( + input_feature[:, :feature_len], + feature_lens=feature_len.unsqueeze(0), + ) + audio_features.append(audio_output.last_hidden_state) + return torch.cat(audio_features, dim=0) + + def get_audio_placeholder_mask( + self, + input_ids: Optional[torch.LongTensor], + inputs_embeds: torch.FloatTensor, + ) -> torch.Tensor: + if input_ids is None: + special_audio_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + ).all(-1) + else: + special_audio_mask = input_ids == self.config.audio_token_id + return special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + + # ── Vision helpers ───────────────────────────────────────────────── + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_grid_thw: Optional[torch.LongTensor] = None, + ): + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds, deepstack_image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() + image_embeds = torch.split(image_embeds, split_sizes) + return image_embeds, deepstack_image_embeds + + def get_video_features( + self, + pixel_values_videos: torch.FloatTensor, + video_grid_thw: Optional[torch.LongTensor] = None, + ): + return self.get_image_features(pixel_values_videos, video_grid_thw) + + def get_vision_placeholder_mask( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: Optional[torch.FloatTensor] = None, + video_features: Optional[torch.FloatTensor] = None, + ): + if input_ids is None: + img_embed = self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + vid_embed = self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = (inputs_embeds == img_embed).all(-1) + special_video_mask = (inputs_embeds == vid_embed).all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + special_video_mask = input_ids == self.config.video_token_id + + special_image_mask_expanded = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_video_mask_expanded = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + + if image_features is not None and inputs_embeds[special_image_mask_expanded].numel() != image_features.numel(): + n_image_tokens = special_image_mask.sum() + raise ValueError( + f"Image features and image tokens do not match: tokens {n_image_tokens}, " + f"features {image_features.shape[0]}" + ) + if video_features is not None and inputs_embeds[special_video_mask_expanded].numel() != video_features.numel(): + n_video_tokens = special_video_mask.sum() + raise ValueError( + f"Video features and video tokens do not match: tokens {n_video_tokens}, " + f"features {video_features.shape[0]}" + ) + + return special_image_mask_expanded, special_video_mask_expanded + + # ── MRoPE position ids ───────────────────────────────────────────── + + def get_rope_index( + self, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Build 3-D MRoPE position ids. Returns (3, batch, seq_len).""" + if video_grid_thw is not None: + video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0) + video_grid_thw[:, 0] = 1 + + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + for i, ids in enumerate(total_input_ids): + ids = ids[attention_mask[i] == 1] + vision_start_indices = torch.argwhere(ids == vision_start_token_id).squeeze(1) + vision_tokens = ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + ed_image = ( + input_tokens.index(image_token_id, st) + if image_token_id in input_tokens and remain_images > 0 + else len(input_tokens) + 1 + ) + ed_video = ( + input_tokens.index(video_token_id, st) + if video_token_id in input_tokens and remain_videos > 0 + else len(input_tokens) + 1 + ) + if ed_image < ed_video: + t, h, w = image_grid_thw[image_index] + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = video_grid_thw[video_index] + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t = t.item() + llm_grid_h = h.item() // spatial_merge_size + llm_grid_w = w.item() // spatial_merge_size + text_len = ed - st + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + llm_pos_ids_list.append(torch.arange(len(input_tokens) - st).view(1, -1).expand(3, -1) + st_idx) + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + return position_ids + + # Text-only / audio-only path (no spatial position structure) + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device).unsqueeze(0).expand(input_ids.shape[0], -1) + ) + return position_ids.unsqueeze(0).expand(3, -1, -1) + + # ── Forward ──────────────────────────────────────────────────────── + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + # Audio inputs + input_features: Optional[torch.FloatTensor] = None, + feature_attention_mask: Optional[torch.LongTensor] = None, + audio_feature_lengths: Optional[torch.LongTensor] = None, + # Vision inputs + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutput: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("Specify exactly one of input_ids or inputs_embeds.") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + # ── Audio injection ──────────────────────────────────────────── + if input_features is not None: + audio_features = self.get_audio_features( + input_features, + feature_attention_mask=feature_attention_mask, + audio_feature_lengths=audio_feature_lengths, + ).to(inputs_embeds.device, inputs_embeds.dtype) + audio_mask = self.get_audio_placeholder_mask(input_ids, inputs_embeds) + inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features) + + # ── Vision injection ─────────────────────────────────────────── + image_mask = video_mask = None + + if pixel_values is not None: + image_embeds, deepstack_image_embeds = self.get_image_features(pixel_values, image_grid_thw) + image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + image_mask, _ = self.get_vision_placeholder_mask(input_ids, inputs_embeds, image_features=image_embeds) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + video_embeds, deepstack_video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) + video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + _, video_mask = self.get_vision_placeholder_mask(input_ids, inputs_embeds, video_features=video_embeds) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + # ── Assemble DeepStack masks / embeds ────────────────────────── + visual_pos_masks = deepstack_visual_embeds = None + if image_mask is not None and video_mask is not None: + im = image_mask[..., 0] + vm = video_mask[..., 0] + visual_pos_masks = im | vm + image_mask_joint = im[visual_pos_masks] + video_mask_joint = vm[visual_pos_masks] + deepstack_visual_embeds = [] + for img_e, vid_e in zip(deepstack_image_embeds, deepstack_video_embeds): + joint = img_e.new_zeros(visual_pos_masks.sum(), img_e.shape[-1]).to(img_e.device) + joint[image_mask_joint] = img_e + joint[video_mask_joint] = vid_e + deepstack_visual_embeds.append(joint) + elif image_mask is not None: + visual_pos_masks = image_mask[..., 0] + deepstack_visual_embeds = deepstack_image_embeds + elif video_mask is not None: + visual_pos_masks = video_mask[..., 0] + deepstack_visual_embeds = deepstack_video_embeds + + # ── Build position ids ───────────────────────────────────────── + if position_ids is None: + position_ids = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask) + + return self.language_model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + visual_pos_masks=visual_pos_masks, + deepstack_visual_embeds=deepstack_visual_embeds, + **kwargs, + ) + + +# ═══════════════════════════════════════════════════════════════════════════ +# Masked language model head +# ═══════════════════════════════════════════════════════════════════════════ + + +@auto_docstring +class BidirLMOmniForMaskedLM(BidirLMOmniPreTrainedModel): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config: BidirLMOmniConfig): + super().__init__(config) + self.model = BidirLMOmniModel(config) + self.lm_head = nn.Linear( + config.text_config.hidden_size, + config.text_config.vocab_size, + bias=False, + ) + self.post_init() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + pixel_values_videos: Optional[torch.Tensor] = None, + input_features: Optional[torch.FloatTensor] = None, + feature_attention_mask: Optional[torch.LongTensor] = None, + audio_feature_lengths: Optional[torch.LongTensor] = None, + **kwargs, + ) -> MaskedLMOutput: + encoder_output = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + pixel_values_videos=pixel_values_videos, + input_features=input_features, + feature_attention_mask=feature_attention_mask, + audio_feature_lengths=audio_feature_lengths, + **kwargs, + ) + logits = self.lm_head(encoder_output.last_hidden_state) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, vocab_size=self.config.text_config.vocab_size) + + return MaskedLMOutput( + loss=loss, + logits=logits, + hidden_states=encoder_output.hidden_states, + attentions=encoder_output.attentions, + ) + + +# ═══════════════════════════════════════════════════════════════════════════ +# Sequence classification head +# ═══════════════════════════════════════════════════════════════════════════ + + +@auto_docstring +class BidirLMOmniForSequenceClassification(BidirLMOmniPreTrainedModel): + def __init__(self, config: BidirLMOmniConfig): + super().__init__(config) + self.num_labels = config.num_labels + self.clf_pooling = config.clf_pooling + + self.model = BidirLMOmniModel(config) + self.dense = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size) + self.activation = nn.GELU() + self.classifier = nn.Linear(config.text_config.hidden_size, self.num_labels) + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + pixel_values_videos: Optional[torch.Tensor] = None, + input_features: Optional[torch.FloatTensor] = None, + feature_attention_mask: Optional[torch.LongTensor] = None, + audio_feature_lengths: Optional[torch.LongTensor] = None, + **kwargs, + ) -> SequenceClassifierOutput: + encoder_output = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + pixel_values_videos=pixel_values_videos, + input_features=input_features, + feature_attention_mask=feature_attention_mask, + audio_feature_lengths=audio_feature_lengths, + **kwargs, + ) + last_hidden_state = encoder_output.last_hidden_state + + if self.clf_pooling == "bos": + pooled = last_hidden_state[:, 0] + elif self.clf_pooling == "mean": + if attention_mask is None: + pooled = last_hidden_state.mean(dim=1) + else: + pooled = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) + pooled = pooled / attention_mask.sum(dim=1, keepdim=True) + else: # "late" — project each token then mean-pool + pooled = last_hidden_state + + pooled = self.dense(pooled) + pooled = self.activation(pooled) + logits = self.classifier(pooled) + + if self.clf_pooling == "late": + if attention_mask is None: + logits = logits.mean(dim=1) + else: + logits = (logits * attention_mask.unsqueeze(-1)).sum(dim=1) + logits = logits / attention_mask.sum(dim=1, keepdim=True) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and labels.dtype in (torch.long, torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + loss = loss_fct(logits.squeeze(), labels.squeeze()) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=encoder_output.hidden_states, + attentions=encoder_output.attentions, + ) + + +# ═══════════════════════════════════════════════════════════════════════════ +# Token classification head +# ═══════════════════════════════════════════════════════════════════════════ + + +@auto_docstring +class BidirLMOmniForTokenClassification(BidirLMOmniPreTrainedModel): + def __init__(self, config: BidirLMOmniConfig): + super().__init__(config) + self.num_labels = config.num_labels + + self.model = BidirLMOmniModel(config) + self.classifier = nn.Linear(config.text_config.hidden_size, self.num_labels) + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + pixel_values_videos: Optional[torch.Tensor] = None, + input_features: Optional[torch.FloatTensor] = None, + feature_attention_mask: Optional[torch.LongTensor] = None, + audio_feature_lengths: Optional[torch.LongTensor] = None, + **kwargs, + ) -> TokenClassifierOutput: + encoder_output = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + pixel_values_videos=pixel_values_videos, + input_features=input_features, + feature_attention_mask=feature_attention_mask, + audio_feature_lengths=audio_feature_lengths, + **kwargs, + ) + logits = self.classifier(encoder_output.last_hidden_state) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=encoder_output.hidden_states, + attentions=encoder_output.attentions, + ) + + +__all__ = [ + "BidirLMOmniPreTrainedModel", + "BidirLMOmniAudioEncoder", + "BidirLMOmniVisionModel", + "BidirLMOmniTextModel", + "BidirLMOmniModel", + "BidirLMOmniForMaskedLM", + "BidirLMOmniForSequenceClassification", + "BidirLMOmniForTokenClassification", +] diff --git a/colpali_engine/models/qwen_omni/colqwen3_omni/modeling_biqwen3_omni.py b/colpali_engine/models/qwen_omni/colqwen3_omni/modeling_biqwen3_omni.py new file mode 100644 index 000000000..b8175270e --- /dev/null +++ b/colpali_engine/models/qwen_omni/colqwen3_omni/modeling_biqwen3_omni.py @@ -0,0 +1,101 @@ +from typing import ClassVar, Literal + +import torch + +from .configuration_bidirlm_omni import BidirLMOmniConfig +from .modeling_bidirlm_omni import BidirLMOmniModel + + +class BiQwen3Omni(BidirLMOmniModel): + """ + BiQwen3-Omni model wrapper for BidirLM-Omni checkpoints. + + The backbone is the BidirLM-Omni bidirectional encoder: Qwen3-style text and vision towers with an audio tower. + Representations are pooled to obtain a single vector representation. + """ + + config_class = BidirLMOmniConfig + main_input_name: ClassVar[str] = "doc_input_ids" + _checkpoint_conversion_mapping = { + r"^model\.audio_tower": "audio_tower", + r"^model\.visual": "visual", + r"^model\.language_model": "language_model", + r"^model\.": "", + } + + def __init__(self, config: BidirLMOmniConfig, **kwargs): + dtype = kwargs.pop("dtype", kwargs.pop("torch_dtype", None)) + attn_impl = kwargs.pop("attn_implementation", None) + use_cache = kwargs.pop("use_cache", None) + + super().__init__(config=config) + self.padding_side = "left" + self.post_init() + + if dtype is not None: + self.to(dtype=dtype) + if use_cache is not None: + self.config.use_cache = use_cache + if attn_impl is not None and hasattr(self, "set_attn_implementation"): + self.set_attn_implementation(attn_impl) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + key_mapping = kwargs.pop("key_mapping", None) + if key_mapping is None: + key_mapping = getattr(cls, "_checkpoint_conversion_mapping", None) + return super().from_pretrained(*args, **kwargs, key_mapping=key_mapping) + + def forward( + self, + pooling_strategy: Literal["cls", "last", "mean"] = "mean", + *args, + **kwargs, + ) -> torch.Tensor: + if "pixel_values" in kwargs and kwargs["pixel_values"].ndim == 3: + offsets = kwargs["image_grid_thw"].prod(dim=1) + kwargs["pixel_values"] = torch.cat( + [pixel_sequence[:offset] for pixel_sequence, offset in zip(kwargs["pixel_values"], offsets)], + dim=0, + ) + + if "pixel_values_videos" in kwargs and kwargs["pixel_values_videos"].ndim == 3: + offsets = kwargs["video_grid_thw"].prod(dim=1) + kwargs["pixel_values_videos"] = torch.cat( + [pixel_sequence[:offset] for pixel_sequence, offset in zip(kwargs["pixel_values_videos"], offsets)], + dim=0, + ) + + model_dtype = next(self.parameters()).dtype + for key in ("pixel_values", "pixel_values_videos", "input_features"): + if key in kwargs and kwargs[key].is_floating_point() and kwargs[key].dtype != model_dtype: + kwargs[key] = kwargs[key].to(dtype=model_dtype) + + kwargs.pop("return_dict", True) + kwargs.pop("output_hidden_states", None) + kwargs.pop("use_cache", None) + last_hidden_states = ( + super() + .forward(*args, **kwargs, use_cache=False, output_hidden_states=True, return_dict=True) + .last_hidden_state + ) + + if pooling_strategy == "cls": + pooled_output = last_hidden_states[:, 0] + elif pooling_strategy == "last": + pooled_output = last_hidden_states[:, -1] + elif pooling_strategy == "mean": + mask = kwargs["attention_mask"].unsqueeze(-1) + pooled_output = (last_hidden_states * mask).sum(dim=1) / mask.sum(dim=1) + else: + raise ValueError(f"Invalid pooling strategy: {pooling_strategy}") + + return pooled_output / pooled_output.norm(dim=-1, keepdim=True) + + @property + def patch_size(self) -> int: + return self.visual.config.patch_size + + @property + def spatial_merge_size(self) -> int: + return self.visual.config.spatial_merge_size diff --git a/colpali_engine/models/qwen_omni/colqwen3_omni/modeling_colqwen3_omni.py b/colpali_engine/models/qwen_omni/colqwen3_omni/modeling_colqwen3_omni.py new file mode 100644 index 000000000..2419eaa67 --- /dev/null +++ b/colpali_engine/models/qwen_omni/colqwen3_omni/modeling_colqwen3_omni.py @@ -0,0 +1,112 @@ +from typing import ClassVar + +import torch +from torch import nn + +from .configuration_bidirlm_omni import BidirLMOmniConfig +from .modeling_bidirlm_omni import BidirLMOmniModel + + +class ColQwen3Omni(BidirLMOmniModel): + """ + ColQwen3-Omni model wrapper for BidirLM-Omni checkpoints. + + The backbone is the BidirLM-Omni bidirectional encoder: Qwen3-style text and vision towers with an audio tower. + This class adds the Col-style projection head used for multi-vector retrieval. + """ + + config_class = BidirLMOmniConfig + main_input_name: ClassVar[str] = "doc_input_ids" + _checkpoint_conversion_mapping = { + r"^base_model\.model\.custom_text_proj": "custom_text_proj", + r"^model\.audio_tower": "audio_tower", + r"^model\.visual": "visual", + r"^model\.language_model": "language_model", + r"^model\.": "", + } + + def __init__( + self, + config: BidirLMOmniConfig, + mask_non_image_embeddings: bool = False, + **kwargs, + ): + dtype = kwargs.pop("dtype", kwargs.pop("torch_dtype", None)) + attn_impl = kwargs.pop("attn_implementation", None) + use_cache = kwargs.pop("use_cache", None) + + super().__init__(config=config) + + hidden_size = getattr(self.config, "hidden_size", None) + if hidden_size is None and hasattr(self.config, "text_config"): + hidden_size = getattr(self.config.text_config, "hidden_size", None) + if hidden_size is None: + raise ValueError(f"Unable to determine text hidden size for {type(self.config).__name__}.") + + self.dim = 128 + self.custom_text_proj = nn.Linear(hidden_size, self.dim) + self.padding_side = "left" + self.mask_non_image_embeddings = mask_non_image_embeddings + self.post_init() + + if dtype is not None: + self.to(dtype=dtype) + if use_cache is not None: + self.config.use_cache = use_cache + if attn_impl is not None and hasattr(self, "set_attn_implementation"): + self.set_attn_implementation(attn_impl) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + key_mapping = kwargs.pop("key_mapping", None) + if key_mapping is None: + key_mapping = dict(getattr(super(), "_checkpoint_conversion_mapping", {})) + key_mapping.update(cls._checkpoint_conversion_mapping) + return super().from_pretrained(*args, **kwargs, key_mapping=key_mapping) + + def forward(self, *args, **kwargs) -> torch.Tensor: + if "pixel_values" in kwargs and kwargs["pixel_values"].ndim == 3: + offsets = kwargs["image_grid_thw"].prod(dim=1) + kwargs["pixel_values"] = torch.cat( + [pixel_sequence[:offset] for pixel_sequence, offset in zip(kwargs["pixel_values"], offsets)], + dim=0, + ) + + if "pixel_values_videos" in kwargs and kwargs["pixel_values_videos"].ndim == 3: + offsets = kwargs["video_grid_thw"].prod(dim=1) + kwargs["pixel_values_videos"] = torch.cat( + [pixel_sequence[:offset] for pixel_sequence, offset in zip(kwargs["pixel_values_videos"], offsets)], + dim=0, + ) + + model_dtype = next(self.parameters()).dtype + for key in ("pixel_values", "pixel_values_videos", "input_features"): + if key in kwargs and kwargs[key].is_floating_point() and kwargs[key].dtype != model_dtype: + kwargs[key] = kwargs[key].to(dtype=model_dtype) + + kwargs.pop("return_dict", True) + kwargs.pop("output_hidden_states", None) + kwargs.pop("use_cache", None) + last_hidden_states = ( + super() + .forward(*args, **kwargs, use_cache=False, output_hidden_states=True, return_dict=True) + .last_hidden_state + ) + + proj = self.custom_text_proj(last_hidden_states) + proj = proj / proj.norm(dim=-1, keepdim=True) + proj = proj * kwargs["attention_mask"].unsqueeze(-1) + + if "pixel_values" in kwargs and self.mask_non_image_embeddings: + image_mask = (kwargs["input_ids"] == self.config.image_token_id).unsqueeze(-1) + proj = proj * image_mask + + return proj + + @property + def patch_size(self) -> int: + return self.visual.config.patch_size + + @property + def spatial_merge_size(self) -> int: + return self.visual.config.spatial_merge_size diff --git a/colpali_engine/models/qwen_omni/colqwen3_omni/processing_bidirlm_omni.py b/colpali_engine/models/qwen_omni/colqwen3_omni/processing_bidirlm_omni.py new file mode 100644 index 000000000..6a4938328 --- /dev/null +++ b/colpali_engine/models/qwen_omni/colqwen3_omni/processing_bidirlm_omni.py @@ -0,0 +1,405 @@ +from typing import Optional, Union + +import numpy as np +from transformers.audio_utils import AudioInput + +try: + from torchcodec.decoders import AudioDecoder +except ImportError: + AudioDecoder = None + +from transformers.feature_extraction_utils import BatchFeature +from transformers.image_utils import ImageInput +from transformers.processing_utils import ( + ImagesKwargs, + ProcessingKwargs, + ProcessorMixin, + Unpack, + VideosKwargs, +) +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput +from transformers.utils import logging +from transformers.video_utils import VideoInput + +logger = logging.get_logger(__name__) + + +def _resample_audio(arr: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray: + try: + import librosa + except ImportError as exc: + raise ImportError( + "`librosa` is required to resample audio inputs. Install it or pass audio with the processor's native " + f"sampling rate ({target_sr})." + ) from exc + return librosa.resample(arr, orig_sr=orig_sr, target_sr=target_sr) + + +# ── Kwargs classes ───────────────────────────────────────────────────────── + + +class BidirLMOmniVideosKwargs(VideosKwargs, total=False): + pass + + +class BidirLMOmniImagesKwargs(ImagesKwargs): + min_pixels: Optional[int] + max_pixels: Optional[int] + patch_size: Optional[int] + temporal_patch_size: Optional[int] + merge_size: Optional[int] + + +class BidirLMOmniProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: BidirLMOmniImagesKwargs + videos_kwargs: BidirLMOmniVideosKwargs + _defaults = { + "text_kwargs": { + "padding": False, + "padding_side": "right", + "return_token_type_ids": False, + "return_mm_token_type_ids": False, + }, + "audio_kwargs": { + "sampling_rate": 16000, + "padding": True, + "return_attention_mask": True, + }, + "videos_kwargs": {"return_metadata": True}, + } + + +# ── Audio helpers ────────────────────────────────────────────────────────── + + +def _get_feat_extract_output_lengths(input_lengths): + """Computes the output length of the audio encoder's convolutional layers. + Three Conv2d layers each with kernel=3, stride=2, padding=1. + Per-layer formula: floor((L - 1) / 2) + 1 + """ + length = (input_lengths - 1) // 2 + 1 + length = (length - 1) // 2 + 1 + length = (length - 1) // 2 + 1 + return length + + +# ── Processor ────────────────────────────────────────────────────────────── + + +class BidirLMOmniProcessor(ProcessorMixin): + attributes = ["image_processor", "video_processor", "feature_extractor", "tokenizer"] + tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") + + def __init__( + self, + image_processor=None, + video_processor=None, + feature_extractor=None, + tokenizer=None, + chat_template=None, + max_image_size: Optional[int] = None, + ): + super().__init__( + image_processor, + video_processor, + feature_extractor, + tokenizer, + chat_template=chat_template, + ) + + if max_image_size is not None and image_processor is not None: + max_pixels = max_image_size * max_image_size + image_processor.size["longest_edge"] = max_pixels + if image_processor.size["shortest_edge"] > max_pixels: + image_processor.size["shortest_edge"] = max_pixels + + # ── Vision tokens (from Qwen3VLProcessor) ───────────────────── + self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token + self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token + self.image_token_id = ( + tokenizer.image_token_id + if getattr(tokenizer, "image_token_id", None) is not None + else tokenizer.convert_tokens_to_ids(self.image_token) + ) + self.video_token_id = ( + tokenizer.video_token_id + if getattr(tokenizer, "video_token_id", None) is not None + else tokenizer.convert_tokens_to_ids(self.video_token) + ) + self.vision_start_token = ( + "<|vision_start|>" if not hasattr(tokenizer, "vision_start_token") else tokenizer.vision_start_token + ) + self.vision_end_token = ( + "<|vision_end|>" if not hasattr(tokenizer, "vision_end_token") else tokenizer.vision_end_token + ) + + # ── Audio tokens (from Qwen3ASRProcessor) ───────────────────── + self.audio_token = "<|audio_pad|>" if not hasattr(tokenizer, "audio_token") else tokenizer.audio_token + self.audio_bos_token = ( + "<|audio_start|>" if not hasattr(tokenizer, "audio_bos_token") else tokenizer.audio_bos_token + ) + self.audio_eos_token = ( + "<|audio_end|>" if not hasattr(tokenizer, "audio_eos_token") else tokenizer.audio_eos_token + ) + + self.sampling_rate = self.feature_extractor.sampling_rate + + # ── __call__ ─────────────────────────────────────────────────────── + + def __call__( + self, + images: ImageInput = None, + text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, + videos: VideoInput = None, + audio: AudioInput = None, + **kwargs: Unpack[BidirLMOmniProcessorKwargs], + ) -> BatchFeature: + """ + Prepare inputs for the model. Processes text with the tokenizer, images with + the image processor, videos with the video processor, and audio with the + WhisperFeatureExtractor. + + Args: + images: PIL images, numpy arrays, or tensors. + text: Text sequences to encode. + videos: Video arrays (4D) or nested lists of frames. + audio: Audio numpy arrays. + """ + if text is None: + raise ValueError("You need to specify a `text` input to process.") + + output_kwargs = self._merge_kwargs( + BidirLMOmniProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + # ── Image processing (from Qwen3VLProcessor) ────────────────── + if images is not None: + image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) + image_grid_thw = image_inputs["image_grid_thw"] + else: + image_inputs = {} + image_grid_thw = None + + # ── Video processing (from Qwen3VLProcessor) ────────────────── + if videos is not None: + videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"]) + video_grid_thw = videos_inputs["video_grid_thw"] + if "return_metadata" not in kwargs: + video_metadata = videos_inputs.pop("video_metadata") + else: + video_metadata = videos_inputs["video_metadata"] + else: + videos_inputs = {} + video_grid_thw = None + + # ── Audio processing (from Qwen3ASRProcessor) ───────────────── + if audio is not None: + pipeline_sr = output_kwargs["audio_kwargs"].get("sampling_rate", self.sampling_rate) + if not isinstance(audio, (list, tuple)): + audio = [audio] + audio = [self._normalize_audio(a, pipeline_sr) for a in audio] + + output_kwargs["audio_kwargs"]["sampling_rate"] = self.sampling_rate + output_kwargs["audio_kwargs"]["padding"] = True + output_kwargs["audio_kwargs"]["truncation"] = False + audio_inputs = self.feature_extractor(audio, **output_kwargs["audio_kwargs"]) + audio_inputs["feature_attention_mask"] = audio_inputs.pop("attention_mask") + audio_lengths = iter(_get_feat_extract_output_lengths(audio_inputs["feature_attention_mask"].sum(-1))) + else: + audio_inputs = {} + audio_lengths = iter([]) + + # ── Token expansion ──────────────────────────────────────────── + if not isinstance(text, list): + text = [text] + + text = text.copy() + + # Image placeholder expansion + if image_grid_thw is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.image_token in text[i]: + num_image_tokens = image_grid_thw[index].prod() // merge_length + text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.image_token) + + # Video placeholder expansion + if video_grid_thw is not None: + merge_length = self.video_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.video_token in text[i]: + metadata = video_metadata[index] + if metadata.fps is None: + logger.warning_once( + "BiQwen3VL requires frame timestamps to construct prompts, but the `fps` of the input " + "video could not be inferred. Defaulting to `fps=24`." + ) + metadata.fps = 24 + curr_timestamp = self._calculate_timestamps( + metadata.frames_indices, + metadata.fps, + self.video_processor.merge_size, + ) + video_placeholder = "" + frame_seqlen = video_grid_thw[index][1:].prod() // merge_length + for frame_idx in range(video_grid_thw[index][0]): + curr_time = curr_timestamp[frame_idx] + video_placeholder += f"<{curr_time:.1f} seconds>" + video_placeholder += ( + self.vision_start_token + "<|placeholder|>" * frame_seqlen + self.vision_end_token + ) + if f"{self.vision_start_token}{self.video_token}{self.vision_end_token}" in text[i]: + text[i] = text[i].replace( + f"{self.vision_start_token}{self.video_token}{self.vision_end_token}", + video_placeholder, + 1, + ) + else: + text[i] = text[i].replace(self.video_token, video_placeholder, 1) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.video_token) + + # Audio placeholder expansion + text = self._replace_audio_special_tokens(text, audio_lengths) + + # ── Tokenize ────────────────────────────────────────────────── + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None) + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) + self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"]) + + if return_mm_token_type_ids: + array_ids = np.array(text_inputs["input_ids"]) + mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) + mm_token_type_ids[array_ids == self.image_token_id] = 1 + text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist() + + return BatchFeature( + data={**text_inputs, **image_inputs, **videos_inputs, **audio_inputs}, + tensor_type=return_tensors, + ) + + # ── Audio token expansion ────────────────────────────────────────── + + def _replace_audio_special_tokens(self, text, audio_lengths): + """Replace audio placeholder tokens with the correct number of pad tokens.""" + processed_text = [] + for sample in text: + while self.audio_token in sample: + sample = sample.replace( + self.audio_token, + "<|audio_placeholder|>" * next(audio_lengths), + 1, + ) + sample = sample.replace("<|audio_placeholder|>", self.audio_token) + processed_text.append(sample) + return processed_text + + # ── Video timestamp calculation (from Qwen3VLProcessor) ──────────── + + def _calculate_timestamps(self, indices: Union[list[int], np.ndarray], video_fps: float, merge_size: int = 2): + if not isinstance(indices, list): + indices = indices.tolist() + if len(indices) % merge_size != 0: + indices.extend(indices[-1] for _ in range(merge_size - len(indices) % merge_size)) + timestamps = [idx / video_fps for idx in indices] + timestamps = [ + (timestamps[i] + timestamps[i + merge_size - 1]) / 2 for i in range(0, len(timestamps), merge_size) + ] + return timestamps + + # ── Audio chunking helper (from Qwen3ASRProcessor) ───────────────── + + def get_chunked_index(self, token_indices: np.ndarray, tokens_per_chunk: int) -> list[tuple[int, int]]: + """Splits token index list into chunks based on token value ranges.""" + + def _iter(): + i, start_idx = 0, 0 + current_chunk = 1 + while i < len(token_indices): + if token_indices[i] >= current_chunk * tokens_per_chunk: + yield (start_idx, i) + start_idx = i + current_chunk += 1 + i += 1 + yield (start_idx, len(token_indices)) + + return list(_iter()) + + # ── Post processing ──────────────────────────────────────────────── + + def post_process_image_text_to_text( + self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs + ): + return self.tokenizer.batch_decode( + generated_outputs, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + def _normalize_audio(self, a, pipeline_sr=None): + """Normalize a single audio item to a float32 numpy array at self.sampling_rate. + + Accepts: + - list[float] — plain Python list of samples + - np.ndarray — raw samples (resampled via pipeline_sr if it differs from self.sampling_rate) + - dict — HuggingFace datasets Audio dict {"array": ..., "sampling_rate": ...} + - AudioDecoder — datasets 4.x lazy decoder (torchcodec); None if torchcodec is not installed + """ + if isinstance(a, (list, np.ndarray)): + arr = np.asarray(a, dtype=np.float32) + if pipeline_sr and pipeline_sr != self.sampling_rate: + arr = _resample_audio(arr, orig_sr=pipeline_sr, target_sr=self.sampling_rate) + elif isinstance(a, dict): # HuggingFace datasets Audio dict: {"array": ..., "sampling_rate": ...} + arr = np.asarray(a["array"], dtype=np.float32) + src_sr = a.get("sampling_rate") + if src_sr and src_sr != self.sampling_rate: + arr = _resample_audio(arr, orig_sr=src_sr, target_sr=self.sampling_rate) + elif AudioDecoder is not None and isinstance(a, AudioDecoder): + samples = a.get_all_samples() + arr = samples.data.float().mean(dim=0).cpu().numpy() + src_sr = samples.sample_rate + if src_sr and src_sr != self.sampling_rate: + arr = _resample_audio(arr, orig_sr=src_sr, target_sr=self.sampling_rate) + else: + raise TypeError( + f"Unsupported audio type: {type(a).__name__}. " + "Expected a plain list, a numpy array, a HuggingFace datasets Audio dict, " + "or a torchcodec AudioDecoder." + ) + return arr + + def apply_chat_template(self, conversations, chat_template=None, **kwargs): + # Normalize audio in user turn content items to numpy arrays before the base + # class processes them. Accepts lists, numpy arrays, HF Audio dicts, or + # AudioDecoder objects. Only user turns can contain audio. + # Accept both single conversation (List[Dict]) and batch (List[List[Dict]]). + # Build a local batch view for safe traversal WITHOUT changing what is passed + # to super(), so super's return type (str for single, List[str] for batch) + # is preserved exactly as the caller expects. + batch = [conversations] if (conversations and isinstance(conversations[0], dict)) else conversations + for conv in batch: + for turn in conv: + if turn.get("role") != "user": + continue + for item in turn.get("content", []): + audio_val = item.get("audio") + if item.get("type") == "audio" and not isinstance(audio_val, np.ndarray): + item["audio"] = self._normalize_audio(audio_val) + return super().apply_chat_template(conversations, chat_template, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + feature_extractor_input_names = self.feature_extractor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + feature_extractor_input_names + ["feature_attention_mask"])) + + +__all__ = ["BidirLMOmniProcessor"] diff --git a/colpali_engine/models/qwen_omni/colqwen3_omni/processing_biqwen3_omni.py b/colpali_engine/models/qwen_omni/colqwen3_omni/processing_biqwen3_omni.py new file mode 100644 index 000000000..1532b22ec --- /dev/null +++ b/colpali_engine/models/qwen_omni/colqwen3_omni/processing_biqwen3_omni.py @@ -0,0 +1,28 @@ +from typing import List, Optional, Union + +import torch +from transformers import BatchEncoding, BatchFeature + +from .processing_colqwen3_omni import ColQwen3OmniProcessor + + +class BiQwen3OmniProcessor(ColQwen3OmniProcessor): + """ + Processor for BiQwen3-Omni / BidirLM-Omni checkpoints. + """ + + def process_texts(self, texts: List[str]) -> Union[BatchFeature, BatchEncoding]: + return self( + text=texts, + return_tensors="pt", + padding="longest", + ) + + def score( + self, + qs: List[torch.Tensor], + ps: List[torch.Tensor], + device: Optional[Union[str, torch.device]] = None, + **kwargs, + ) -> torch.Tensor: + return self.score_single_vector(qs, ps, device=device) diff --git a/colpali_engine/models/qwen_omni/colqwen3_omni/processing_colqwen3_omni.py b/colpali_engine/models/qwen_omni/colqwen3_omni/processing_colqwen3_omni.py new file mode 100644 index 000000000..8b13615ba --- /dev/null +++ b/colpali_engine/models/qwen_omni/colqwen3_omni/processing_colqwen3_omni.py @@ -0,0 +1,166 @@ +from typing import ClassVar, List, Optional, Tuple, Union + +import torch +from PIL import Image +from transformers import BatchEncoding, BatchFeature +from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize + +from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor + +from .processing_bidirlm_omni import BidirLMOmniProcessor + + +class ColQwen3OmniProcessor(BaseVisualRetrieverProcessor, BidirLMOmniProcessor): + """ + Processor for ColQwen3-Omni / BidirLM-Omni checkpoints. + """ + + visual_prompt_prefix: ClassVar[str] = ( + "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|><|endoftext|>" + ) + audio_prompt_prefix: ClassVar[str] = ( + "<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|>Describe the sound.<|im_end|><|endoftext|>" + ) + query_augmentation_token: ClassVar[str] = "<|endoftext|>" + image_token: ClassVar[str] = "<|image_pad|>" + + def __init__( + self, + image_processor=None, + video_processor=None, + feature_extractor=None, + tokenizer=None, + chat_template=None, + max_image_size: Optional[int] = None, + ): + super().__init__( + image_processor=image_processor, + video_processor=video_processor, + feature_extractor=feature_extractor, + tokenizer=tokenizer, + chat_template=chat_template, + max_image_size=max_image_size, + ) + self.tokenizer.padding_side = "left" + + @classmethod + def from_pretrained( + cls, + *args, + device_map: Optional[str] = None, + **kwargs, + ): + max_num_visual_tokens = kwargs.pop("max_num_visual_tokens", None) + instance = super().from_pretrained( + *args, + device_map=device_map, + **kwargs, + ) + + if max_num_visual_tokens is not None: + patch_size = getattr(instance.image_processor, "patch_size", None) + merge_size = getattr(instance.image_processor, "merge_size", None) + if patch_size is None or merge_size is None: + raise ValueError("BidirLM-Omni image processor is missing `patch_size` or `merge_size`.") + tile = patch_size * merge_size + instance.image_processor.max_pixels = max_num_visual_tokens * tile * tile + instance.image_processor.size["longest_edge"] = instance.image_processor.max_pixels + + return instance + + def process_images( + self, + images: List[Image.Image], + ) -> Union[BatchFeature, BatchEncoding]: + images = [image.convert("RGB") for image in images] + + batch_doc = self( + text=[self.visual_prompt_prefix] * len(images), + images=images, + padding="longest", + return_tensors="pt", + ) + + offsets = batch_doc["image_grid_thw"].prod(dim=1) + pixel_values = list(torch.split(batch_doc["pixel_values"], offsets.tolist())) + batch_doc["pixel_values"] = torch.nn.utils.rnn.pad_sequence(pixel_values, batch_first=True) + + return batch_doc + + def process_audios(self, audios) -> Union[BatchFeature, BatchEncoding]: + return self( + text=[self.audio_prompt_prefix] * len(audios), + audio=audios, + padding="longest", + return_tensors="pt", + ) + + def process_videos(self, videos) -> Union[BatchFeature, BatchEncoding]: + conversations = [ + [ + { + "role": "user", + "content": [ + {"type": "video", "video": video}, + {"type": "text", "text": "Describe the video."}, + ], + } + ] + for video in videos + ] + text = [ + self.apply_chat_template(conversation, tokenize=False, add_generation_prompt=False) + for conversation in conversations + ] + + batch_doc = self( + text=text, + videos=videos, + padding="longest", + return_tensors="pt", + ) + + offsets = batch_doc["video_grid_thw"].prod(dim=1) + pixel_values_videos = list(torch.split(batch_doc["pixel_values_videos"], offsets.tolist())) + batch_doc["pixel_values_videos"] = torch.nn.utils.rnn.pad_sequence(pixel_values_videos, batch_first=True) + + return batch_doc + + def process_texts(self, texts: List[str]) -> Union[BatchFeature, BatchEncoding]: + return self( + text=texts, + return_tensors="pt", + padding="longest", + ) + + def score( + self, + qs: List[torch.Tensor], + ps: List[torch.Tensor], + device: Optional[Union[str, torch.device]] = None, + **kwargs, + ) -> torch.Tensor: + return self.score_multi_vector(qs, ps, device=device, **kwargs) + + def get_n_patches( + self, + image_size: Tuple[int, int], + spatial_merge_size: int, + ) -> Tuple[int, int]: + patch_size = self.image_processor.patch_size + + height_new, width_new = smart_resize( + width=image_size[0], + height=image_size[1], + factor=patch_size * self.image_processor.merge_size, + min_pixels=self.image_processor.size["shortest_edge"], + max_pixels=self.image_processor.size["longest_edge"], + ) + + n_patches_x = width_new // patch_size // spatial_merge_size + n_patches_y = height_new // patch_size // spatial_merge_size + + return n_patches_x, n_patches_y + + def get_image_mask(self, batch_images: BatchFeature) -> torch.Tensor: + return batch_images.input_ids == self.image_token_id diff --git a/colpali_engine/trainer/contrastive_trainer.py b/colpali_engine/trainer/contrastive_trainer.py index 330520c36..47fc784e9 100644 --- a/colpali_engine/trainer/contrastive_trainer.py +++ b/colpali_engine/trainer/contrastive_trainer.py @@ -52,6 +52,9 @@ def __init__(self, loss_func, is_vision_model, compute_symetric_loss=False, *arg self.train_dataset_list = train_dataset_list self.eval_dataset_list = eval_dataset_list self.compute_symetric_loss = compute_symetric_loss + self.query_prefix = self.data_collator.query_prefix + self.pos_prefix = self.data_collator.pos_doc_prefix + self.neg_prefix = self.data_collator.neg_doc_prefix def get_train_dataloader(self) -> DataLoader: """ @@ -81,10 +84,6 @@ def get_train_dataloader(self) -> DataLoader: else: data_collator = self._get_collator_with_removed_columns(self.data_collator, description=description) - self.query_prefix = data_collator.query_prefix - self.pos_prefix = data_collator.pos_doc_prefix - self.neg_prefix = data_collator.neg_doc_prefix - dataloader_params = { ######### don't set batch size, mutually exclusive from batch sampler ###### "collate_fn": data_collator, @@ -116,9 +115,9 @@ def get_train_dataloader(self) -> DataLoader: return self.accelerator.prepare(dataloader) - def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: + def _get_train_sampler(self, train_dataset: Optional[Dataset] = None) -> Optional[torch.utils.data.Sampler]: if self.train_dataset_list is None: - return super()._get_train_sampler() + return super()._get_train_sampler(train_dataset) # Use SingleDatasetBatchSampler to ensure that each dataset in the list is sampled independently # Note: Surely breaks in distributed training @@ -214,10 +213,16 @@ def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=True) with torch.no_grad(): # feed only kwargs with 'doc_' prefix - doc_outputs = model(**{k[4:]: v for k, v in inputs.items() if k.startswith("doc")}) - query_outputs = model(input_ids=inputs["query_input_ids"], attention_mask=inputs["query_attention_mask"]) - if "neg_doc_input_ids" in inputs: - neg_doc_outputs = model(**{k[8:]: v for k, v in inputs.items() if k.startswith("neg_doc")}) + doc_outputs = model( + **{k[len(self.pos_prefix) :]: v for k, v in inputs.items() if k.startswith(self.pos_prefix)} + ) + query_outputs = model( + **{k[len(self.query_prefix) :]: v for k, v in inputs.items() if k.startswith(self.query_prefix)} + ) + if f"{self.neg_prefix}input_ids" in inputs: + neg_doc_outputs = model( + **{k[len(self.neg_prefix) :]: v for k, v in inputs.items() if k.startswith(self.neg_prefix)} + ) loss = self.loss_func(query_outputs, doc_outputs, neg_doc_outputs) return loss, None, None diff --git a/pyproject.toml b/pyproject.toml index c6b6681a6..04c67fb94 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ classifiers = [ dependencies = [ "numpy", - "peft>=0.18.0,<0.20.0", + "peft>=0.19.1,<0.20.0", "pillow>=10.0.0", "requests", "scipy", diff --git a/scripts/configs/qwen3/train_colqwen3omni_model.py b/scripts/configs/qwen3/train_colqwen3omni_model.py new file mode 100644 index 000000000..d2825d6c5 --- /dev/null +++ b/scripts/configs/qwen3/train_colqwen3omni_model.py @@ -0,0 +1,99 @@ +import argparse +import shutil +from pathlib import Path + +import torch +from datasets import load_dataset +from peft import LoraConfig +from transformers import TrainingArguments + +from colpali_engine.data.dataset import ColPaliEngineDataset +from colpali_engine.loss.late_interaction_losses import ColbertLoss, ColbertPairwiseCELoss +from colpali_engine.models import ColQwen3Omni, ColQwen3OmniProcessor +from colpali_engine.trainer.colmodel_torch_training import ColModelTorchTraining +from colpali_engine.trainer.colmodel_training import ColModelTraining, ColModelTrainingConfig +from colpali_engine.utils.dataset_transformation import load_train_set + + +def parse_args(): + p = argparse.ArgumentParser() + p.add_argument("--output-dir", type=str, required=True, help="where to write model + script copy") + p.add_argument("--lr", type=float, default=2e-4, help="learning rate") + p.add_argument("--tau", type=float, default=0.02, help="temperature for loss function") + p.add_argument("--trainer", type=str, default="hf", choices=["torch", "hf"], help="trainer to use") + p.add_argument("--loss", type=str, default="ce", choices=["ce", "pairwise"], help="loss function to use") + p.add_argument("--peft", action="store_true", help="use PEFT for training") + return p.parse_args() + + +if __name__ == "__main__": + args = parse_args() + + if args.loss == "ce": + loss_func = ColbertLoss( + temperature=args.tau, + normalize_scores=True, + use_smooth_max=False, + pos_aware_negative_filtering=False, + ) + elif args.loss == "pairwise": + loss_func = ColbertPairwiseCELoss( + normalize_scores=False, + ) + else: + raise ValueError(f"Unknown loss function: {args.loss}") + + config = ColModelTrainingConfig( + output_dir=args.output_dir, + processor=ColQwen3OmniProcessor.from_pretrained( + pretrained_model_name_or_path="manu/colqwen3omni-base", + max_num_visual_tokens=1024, + ), + model=ColQwen3Omni.from_pretrained( + pretrained_model_name_or_path="manu/colqwen3omni-base", + torch_dtype=torch.bfloat16, + use_cache=False, + attn_implementation="flash_attention_2", + ), + train_dataset=load_train_set(), + eval_dataset=ColPaliEngineDataset( + load_dataset("vidore/colpali_train_set", split="test"), pos_target_column_name="image" + ), + run_eval=True, + loss_func=loss_func, + tr_args=TrainingArguments( + output_dir=None, + overwrite_output_dir=True, + num_train_epochs=5, + per_device_train_batch_size=64, + gradient_checkpointing=True, + gradient_checkpointing_kwargs={"use_reentrant": False}, + per_device_eval_batch_size=16, + eval_strategy="steps", + dataloader_num_workers=8, + save_steps=500, + logging_steps=10, + eval_steps=100, + warmup_steps=100, + learning_rate=args.lr, + save_total_limit=1, + ), + peft_config=LoraConfig( + r=32, + lora_alpha=32, + lora_dropout=0.1, + init_lora_weights="gaussian", + bias="none", + task_type="FEATURE_EXTRACTION", + target_modules="(.*language_model.*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*custom_text_proj.*$)", + ) + if args.peft + else None, + ) + + Path(config.output_dir).mkdir(parents=True, exist_ok=True) + shutil.copy(Path(__file__), Path(config.output_dir) / Path(__file__).name) + + trainer = ColModelTraining(config) if args.trainer == "hf" else ColModelTorchTraining(config) + trainer.train() + trainer.save() diff --git a/scripts/quickstart_colmodernvbert.py b/scripts/quickstart_colmodernvbert.py new file mode 100644 index 000000000..82453a51f --- /dev/null +++ b/scripts/quickstart_colmodernvbert.py @@ -0,0 +1,116 @@ +import argparse +import statistics +from time import perf_counter + +import torch +from huggingface_hub import hf_hub_download +from PIL import Image + +from colpali_engine.models import ColModernVBert, ColModernVBertProcessor + + +def synchronize(device: torch.device) -> None: + if device.type == "cuda": + torch.cuda.synchronize(device) + + +def benchmark_forward( + model: ColModernVBert, + inputs: dict[str, torch.Tensor], + device: torch.device, + *, + iterations: int, + warmup: int, +) -> tuple[torch.Tensor, list[float]]: + output = None + + with torch.inference_mode(): + for _ in range(warmup): + output = model(**inputs) + synchronize(device) + + timings = [] + for _ in range(iterations): + start = perf_counter() + output = model(**inputs) + synchronize(device) + timings.append(perf_counter() - start) + + if output is None: + raise RuntimeError("Benchmark did not run; use at least one warmup or iteration.") + return output, timings + + +def print_stats(label: str, timings: list[float]) -> None: + timings_ms = [elapsed * 1_000 for elapsed in timings] + print( + f"{label}: " + f"mean={statistics.mean(timings_ms):.2f} ms, " + f"median={statistics.median(timings_ms):.2f} ms, " + f"min={min(timings_ms):.2f} ms, " + f"max={max(timings_ms):.2f} ms" + ) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Benchmark ColModernVBERT text and image embedding latency.") + parser.add_argument("--iterations", type=int, default=50, help="Number of timed forward passes per input type.") + parser.add_argument("--warmup", type=int, default=5, help="Number of untimed warmup forward passes per input type.") + args = parser.parse_args() + + if args.iterations < 1: + raise ValueError("--iterations must be at least 1") + if args.warmup < 0: + raise ValueError("--warmup must be non-negative") + + model_id = "ModernVBERT/colmodernvbert" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + processor = ColModernVBertProcessor.from_pretrained(model_id) + model = ColModernVBert.from_pretrained( + model_id, + torch_dtype=torch.float32, + trust_remote_code=True, + ).to(device) + model.eval() + + image = Image.open( + hf_hub_download( + "HuggingFaceTB/SmolVLM", + "example_images/rococo.jpg", + repo_type="space", + ) + ) + text = "This is a text" + + text_inputs = processor.process_texts([text]).to(device) + image_inputs = processor.process_images([image]).to(device) + + q_embeddings, text_timings = benchmark_forward( + model, + text_inputs, + device, + iterations=args.iterations, + warmup=args.warmup, + ) + corpus_embeddings, image_timings = benchmark_forward( + model, + image_inputs, + device, + iterations=args.iterations, + warmup=args.warmup, + ) + + scores = processor.score(q_embeddings, corpus_embeddings) + + print(f"Device: {device}") + print(f"Iterations: {args.iterations} timed, {args.warmup} warmup") + print(f"Query embeddings shape: {tuple(q_embeddings.shape)}") + print(f"Image embeddings shape: {tuple(corpus_embeddings.shape)}") + print("Similarity scores:", scores) + print_stats("Text embedding", text_timings) + print_stats("Image embedding", image_timings) + + +if __name__ == "__main__": + main() diff --git a/scripts/test_colqwen3_omni_download.py b/scripts/test_colqwen3_omni_download.py new file mode 100644 index 000000000..95fce66f9 --- /dev/null +++ b/scripts/test_colqwen3_omni_download.py @@ -0,0 +1,22 @@ +from colpali_engine.models.qwen_omni import ColQwen3Omni, ColQwen3OmniProcessor + + +MODEL_NAME = "BidirLM/BidirLM-Omni-2.5B-Embedding" + + +def main() -> None: + processor = ColQwen3OmniProcessor.from_pretrained(MODEL_NAME) + model = ColQwen3Omni.from_pretrained( + MODEL_NAME, + device_map="cpu", + torch_dtype="auto", + ) + + print(f"Processor: {type(processor).__name__}") + print(f"Model: {type(model).__name__}") + print(f"Model type: {model.config.model_type}") + print(f"Device: {model.device}") + + +if __name__ == "__main__": + main()