-
Notifications
You must be signed in to change notification settings - Fork 90
feat(text-metrics): split text_score pair #647
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: feat/vlm-pr-3b-oneig-alignment
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,372 @@ | ||
| # Copyright 2025 - Pruna AI GmbH. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Text rendering via OCR: mean Levenshtein (``text_score`` / ``ocr_levenshtein``). | ||
|
|
||
| OneIG composite: ``oneig_text_score`` / ``ocr_text_score``. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from abc import abstractmethod | ||
| from typing import Any, Literal | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
| from pruna.engine.utils import set_to_best_available_device | ||
| from pruna.evaluation.metrics.metric_stateful import StatefulMetric | ||
| from pruna.evaluation.metrics.metric_text_score_utils import ( | ||
| levenshtein, | ||
| normalize_text_simple, | ||
| oneig_mean_text_score, | ||
| oneig_per_sample_contributions, | ||
| ) | ||
| from pruna.evaluation.metrics.registry import MetricRegistry | ||
| from pruna.evaluation.metrics.result import MetricResult | ||
| from pruna.evaluation.metrics.utils import ( | ||
| SINGLE, | ||
| get_call_type_for_single_metric, | ||
| metric_data_processor, | ||
| ) | ||
| from pruna.evaluation.metrics.vlm_base import BaseVLM, get_vlm | ||
| from pruna.evaluation.metrics.vlm_utils import TextOutput, _process_images, get_text_from_response | ||
|
|
||
| OCR_PROMPT = ( | ||
| "Extract all text visible in this image. Include logos, stylized fonts, handwritten text, " | ||
| "and non-standard typography. Return only the extracted text, exactly as it appears—no preamble, " | ||
| "explanation, or markdown. Preserve words, numbers, punctuation, and spacing. " | ||
| "IMPORTANT: Do NOT correct spelling errors or typos. If a word is misspelled in the image " | ||
| "(e.g. 'Teclhology' instead of 'Technology'), reproduce it exactly as it appears, including the misspelling. " | ||
| "If no text is recognized, reply with exactly: No text recognized" | ||
| ) | ||
|
|
||
|
|
||
| class _BaseVLMOCRTextMetric(StatefulMetric): | ||
| """ | ||
| Shared VLM OCR over rendered images with ground truth in ``text_content``. | ||
|
|
||
| Subclasses implement how OCR and GT strings are scored and aggregated. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| *args : Any | ||
| Additional positional arguments (unused; registry compatibility). | ||
| vlm : BaseVLM | None, optional | ||
| Custom VLM instance. If provided, ``vlm_type`` and ``model_name`` are ignored. | ||
| vlm_type : {'litellm', 'transformers'}, optional | ||
| VLM backend. Default is ``'litellm'``. | ||
| model_name : str | None, optional | ||
| Litellm model id or HuggingFace checkpoint id. **Required** when ``vlm`` is not | ||
| provided (e.g. ``openai/gpt-4o``). | ||
| vlm_kwargs : dict, optional | ||
| Forwarded by ``get_vlm`` to ``LitellmVLM`` or ``TransformersVLM``. For local models, | ||
| set ``model_load_kwargs`` for ``from_pretrained``; for litellm, pass extra API options. | ||
| structured_output : bool, optional | ||
| Use structured generation (litellm pydantic; transformers outlines when applicable). | ||
| Default is True. | ||
| device : str | torch.device | None, optional | ||
| Device for transformers VLM. | ||
| api_key : str | None, optional | ||
| API key for litellm. | ||
| call_type : str, optional | ||
| Call type for the metric. | ||
| **kwargs : Any | ||
| Additional arguments. | ||
|
|
||
| Examples | ||
| -------- | ||
| OCR metrics call ``get_vlm`` directly (not ``StatefulVLMMeanScoresMetric``). Same | ||
| ``hosted`` / ``local`` pattern as :func:`~pruna.evaluation.metrics.vlm_base.get_vlm`: | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| import torch | ||
|
|
||
| from pruna.evaluation.metrics import TextScoreMetric | ||
|
|
||
| hosted = TextScoreMetric(vlm_type="litellm", model_name="openai/gpt-4o") | ||
| local = TextScoreMetric( | ||
| vlm_type="transformers", | ||
| model_name="HuggingFaceTB/SmolVLM-256M-Instruct", | ||
| device="cpu", | ||
| vlm_kwargs={"model_load_kwargs": {"torch_dtype": torch.float32}}, | ||
| ) | ||
|
|
||
| Use ``OneIGTextScoreMetric`` the same way for ``oneig_text_score`` / ``ocr_text_score``. | ||
| """ | ||
|
|
||
| default_call_type: str = "y_gt" | ||
|
|
||
| def __init__( | ||
| self, | ||
| *args: Any, | ||
| vlm: BaseVLM | None = None, | ||
| vlm_type: Literal["litellm", "transformers"] = "litellm", | ||
| model_name: str | None = None, | ||
| vlm_kwargs: dict | None = None, | ||
| structured_output: bool = True, | ||
| device: str | torch.device | None = None, | ||
| api_key: str | None = None, | ||
| call_type: str = SINGLE, | ||
| **kwargs: Any, | ||
| ) -> None: | ||
| super().__init__(device=device) | ||
| self.device = set_to_best_available_device(device) | ||
|
|
||
| self.vlm = get_vlm( | ||
| vlm=vlm, | ||
| vlm_type=vlm_type, | ||
| model_name=model_name, | ||
| device=device, | ||
| api_key=api_key, | ||
| structured_output=structured_output, | ||
| **(vlm_kwargs or {}), | ||
| ) | ||
| self.response_format = TextOutput if structured_output else None | ||
|
|
||
| self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type) | ||
| self.higher_is_better = type(self).higher_is_better | ||
|
|
||
| @abstractmethod | ||
| def _accumulate_sample(self, text_gt: str, ocr_text: str) -> None: | ||
| """Update metric state from one ground-truth / OCR pair.""" | ||
|
|
||
| @abstractmethod | ||
| def _compute_result_value(self) -> float: | ||
| """Return the scalar reported as ``MetricResult.result``.""" | ||
|
|
||
| def update(self, x: list[Any] | torch.Tensor, gt: list[str], outputs: torch.Tensor) -> None: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I really like that call to VLM for text metrics now live in a shared update function! |
||
| """ | ||
| Run OCR on outputs and score against ``text_content`` (or string list) auxiliaries. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| x : List[Any] | torch.Tensor | ||
| Batch prompts or metadata. | ||
| gt : list of dict or list of str | ||
| Auxiliaries with ``'text_content'`` as a string, a list of strings (joined with | ||
| newlines), or plain strings per batch item. | ||
| outputs : torch.Tensor | ||
| Rendered images. | ||
| """ | ||
| inputs = metric_data_processor(x, gt, outputs, self.call_type) | ||
| images = _process_images(inputs[0]) | ||
| auxiliaries = inputs[1] if len(inputs) > 1 and isinstance(inputs[1], (list, tuple)) else [{}] * len(images) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for the call type assigned, i think the len(inputs) check would always be longer than 1! |
||
| for i, image in enumerate(images): | ||
| responses = self.vlm.generate([image], [OCR_PROMPT], response_format=self.response_format) | ||
| raw = responses[0] if responses else "" | ||
| ocr_text = get_text_from_response(raw) | ||
| aux = auxiliaries[i] if i < len(auxiliaries) else {} | ||
| text_gt = aux.get("text_content") if isinstance(aux, dict) else (aux if isinstance(aux, str) else None) | ||
| if isinstance(text_gt, list): | ||
| text_gt = "\n".join(str(x) for x in text_gt) | ||
| if text_gt is None: | ||
| raise ValueError( | ||
| f"{self.metric_name} requires 'text_content' in auxiliaries. " | ||
| "Use a benchmark that provides it (e.g. LongTextBench, OneIG)." | ||
| ) | ||
| self._accumulate_sample(text_gt, ocr_text) | ||
|
|
||
| def compute(self) -> MetricResult: | ||
| """ | ||
| Aggregate batched contributions into a single metric value. | ||
|
|
||
| Returns | ||
| ------- | ||
| MetricResult | ||
| Named result with ``higher_is_better`` taken from the class. | ||
| """ | ||
| value = self._compute_result_value() | ||
| return MetricResult(self.metric_name, self.__dict__, float(value)) | ||
|
|
||
|
|
||
| @MetricRegistry.register("ocr_levenshtein") | ||
| @MetricRegistry.register("text_score") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. a bit of a stupid question but why do we need this metric? It implements the edit score, which already exists in OneIG-text score? |
||
| class TextScoreMetric(_BaseVLMOCRTextMetric): | ||
| """ | ||
| OCR then mean normalized character accuracy in [0, 1] (higher is better). | ||
|
|
||
| Registry: ``ocr_levenshtein`` (descriptive) and ``text_score`` (legacy). | ||
|
|
||
| Uses light normalization only (not the full OneIG preprocess). See | ||
| :class:`OneIGTextScoreMetric` for the OneIG composite ``ocr_text_score``. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| *args : Any | ||
| Additional positional arguments (unused; registry compatibility). | ||
| vlm : BaseVLM | None, optional | ||
| Custom VLM instance. If provided, ``vlm_type`` and ``model_name`` are ignored. | ||
| vlm_type : {'litellm', 'transformers'}, optional | ||
| VLM backend. Default is ``'litellm'``. | ||
| model_name : str | None, optional | ||
| Litellm model id or HuggingFace checkpoint id. **Required** when ``vlm`` is not | ||
| provided (e.g. ``openai/gpt-4o``). | ||
| vlm_kwargs : dict, optional | ||
| Forwarded by ``get_vlm`` to ``LitellmVLM`` or ``TransformersVLM``. For local models, | ||
| set ``model_load_kwargs`` for ``from_pretrained``; for litellm, pass extra API options. | ||
| structured_output : bool, optional | ||
| Use structured generation (litellm pydantic; transformers outlines when applicable). | ||
| Default is True. | ||
| device : str | torch.device | None, optional | ||
| Device for transformers VLM. | ||
| api_key : str | None, optional | ||
| API key for litellm. | ||
| call_type : str, optional | ||
| Call type for the metric. | ||
| **kwargs : Any | ||
| Additional keyword arguments forwarded to :class:`_BaseVLMOCRTextMetric`. | ||
| """ | ||
|
|
||
| scores: list[float] | ||
| higher_is_better: bool = True | ||
| metric_name: str = "text_score" | ||
|
|
||
| def __init__( | ||
| self, | ||
| *args: Any, | ||
| vlm: BaseVLM | None = None, | ||
| vlm_type: Literal["litellm", "transformers"] = "litellm", | ||
| model_name: str | None = None, | ||
| vlm_kwargs: dict[str, Any] | None = None, | ||
| structured_output: bool = True, | ||
| device: str | torch.device | None = None, | ||
| api_key: str | None = None, | ||
| call_type: str = SINGLE, | ||
| **kwargs: Any, | ||
| ) -> None: | ||
| super().__init__( | ||
| *args, | ||
| vlm=vlm, | ||
| vlm_type=vlm_type, | ||
| model_name=model_name, | ||
| vlm_kwargs=vlm_kwargs, | ||
| structured_output=structured_output, | ||
| device=device, | ||
| api_key=api_key, | ||
| call_type=call_type, | ||
| **kwargs, | ||
| ) | ||
| self.add_state("scores", []) | ||
|
|
||
| def _accumulate_sample(self, text_gt: str, ocr_text: str) -> None: | ||
| norm_gt = normalize_text_simple(text_gt) | ||
| norm_ocr = normalize_text_simple(ocr_text) | ||
| ed = levenshtein(norm_ocr, norm_gt) | ||
| denom = max(float(len(norm_gt)), 1.0) | ||
| self.scores.append(1.0 - min(1.0, ed / denom)) | ||
|
|
||
| def _compute_result_value(self) -> float: | ||
| if not self.scores: | ||
| return 0.0 | ||
| return float(np.mean(self.scores)) | ||
|
|
||
|
|
||
| @MetricRegistry.register("ocr_text_score") | ||
| @MetricRegistry.register("oneig_text_score") | ||
| class OneIGTextScoreMetric(_BaseVLMOCRTextMetric): | ||
| """ | ||
| OCR then OneIG-style composite text score (higher is better). | ||
|
|
||
| Registry: ``ocr_text_score`` (descriptive) and ``oneig_text_score`` (protocol). | ||
|
|
||
| Aggregates edit distance, completion rate, and word/char accuracy like | ||
| ``OneIG-Benchmark/scripts/text/text_score.py``. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| *args : Any | ||
| Additional positional arguments (forwarded to :class:`_BaseVLMOCRTextMetric`). | ||
| language_mode : {'EN', 'ZH'}, optional | ||
| Selects ``MAX_EDIT_DISTANCE`` (100 vs 50) for the composite. | ||
| vlm : BaseVLM | None, optional | ||
| Custom VLM instance. If provided, ``vlm_type`` and ``model_name`` are ignored. | ||
| vlm_type : {'litellm', 'transformers'}, optional | ||
| VLM backend. Default is ``'litellm'``. | ||
| model_name : str | None, optional | ||
| Litellm model id or HuggingFace checkpoint id. **Required** when ``vlm`` is not | ||
| provided (e.g. ``openai/gpt-4o``). | ||
| vlm_kwargs : dict, optional | ||
| Forwarded by ``get_vlm`` to ``LitellmVLM`` or ``TransformersVLM``. For local models, | ||
| set ``model_load_kwargs`` for ``from_pretrained``; for litellm, pass extra API options. | ||
| structured_output : bool, optional | ||
| Use structured generation (litellm pydantic; transformers outlines when applicable). | ||
| Default is True. | ||
| device : str | torch.device | None, optional | ||
| Device for transformers VLM. | ||
| api_key : str | None, optional | ||
| API key for litellm. | ||
| call_type : str, optional | ||
| Call type for the metric. | ||
| **kwargs : Any | ||
| Additional keyword arguments forwarded to :class:`_BaseVLMOCRTextMetric`. | ||
| """ | ||
|
|
||
| edit_distances: list[float] | ||
| completion_ratios: list[float] | ||
| match_counts: list[int] | ||
| gt_totals: list[int] | ||
|
|
||
| higher_is_better: bool = True | ||
| metric_name: str = "oneig_text_score" | ||
|
|
||
| def __init__( | ||
| self, | ||
| *args: Any, | ||
| language_mode: Literal["EN", "ZH"] = "EN", | ||
| vlm: BaseVLM | None = None, | ||
| vlm_type: Literal["litellm", "transformers"] = "litellm", | ||
| model_name: str | None = None, | ||
| vlm_kwargs: dict[str, Any] | None = None, | ||
| structured_output: bool = True, | ||
| device: str | torch.device | None = None, | ||
| api_key: str | None = None, | ||
| call_type: str = SINGLE, | ||
| **kwargs: Any, | ||
| ) -> None: | ||
| super().__init__( | ||
| *args, | ||
| vlm=vlm, | ||
| vlm_type=vlm_type, | ||
| model_name=model_name, | ||
| vlm_kwargs=vlm_kwargs, | ||
| structured_output=structured_output, | ||
| device=device, | ||
| api_key=api_key, | ||
| call_type=call_type, | ||
| **kwargs, | ||
| ) | ||
| self.language_mode = language_mode | ||
| self.add_state("edit_distances", []) | ||
| self.add_state("completion_ratios", []) | ||
| self.add_state("match_counts", []) | ||
| self.add_state("gt_totals", []) | ||
|
|
||
| def _accumulate_sample(self, text_gt: str, ocr_text: str) -> None: | ||
| ed, cr, mcount, gtot = oneig_per_sample_contributions(text_gt, ocr_text) | ||
| self.edit_distances.append(ed) | ||
| self.completion_ratios.append(cr) | ||
| self.match_counts.append(mcount) | ||
| self.gt_totals.append(gtot) | ||
|
|
||
| def _compute_result_value(self) -> float: | ||
| *_, text_score = oneig_mean_text_score( | ||
| self.edit_distances, | ||
| self.completion_ratios, | ||
| self.match_counts, | ||
| self.gt_totals, | ||
| self.language_mode, | ||
| ) | ||
| return text_score | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see now where we are using this metric, I am not so sure if the Long Text Bench is using an edit distance based metric though, if I am not wrong 🥹 perhaps we should remove this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/X-Omni-Team/X-Omni/blob/main/textbench/summary_scores.py They are using word accuracy metric rather than a character accuracy one (like the edit distance)