diff --git a/src/gemini_webapi/client.py b/src/gemini_webapi/client.py index 7e86abe..5a56e76 100644 --- a/src/gemini_webapi/client.py +++ b/src/gemini_webapi/client.py @@ -426,7 +426,9 @@ async def _fetch_user_status(self) -> None: tier_flags, capability_flags ) - id_name_mapping = AvailableModel.build_model_id_name_mapping() + id_name_mapping = AvailableModel.build_model_id_name_mapping( + capacity, capacity_field + ) for model_data in models_list: if isinstance(model_data, list): diff --git a/src/gemini_webapi/types/availablemodel.py b/src/gemini_webapi/types/availablemodel.py index 7cd01d1..6084e78 100644 --- a/src/gemini_webapi/types/availablemodel.py +++ b/src/gemini_webapi/types/availablemodel.py @@ -110,33 +110,71 @@ def compute_capacity(tier_flags: list, capability_flags: list) -> tuple[int, int return 1, 12 # Free accounts @staticmethod - def build_model_id_name_mapping() -> dict[str, str]: + def build_model_id_name_mapping( + capacity: int = 1, + capacity_field: int = 12, + ) -> dict[str, str]: """ - Build a mapping from `model_id` to `model_name` for all registered models. + Build a mapping from `model_id` to `model_name` for all registered models, + picking the canonical name that matches the caller's account tier. - This uses the :class:`Model` enum to resolve hex identifiers to their - canonical names (e.g., "gemini-3-pro"). + PLUS and ADVANCED tiers share `model_id` values with each other + (differentiated only by the `capacity` header), so the target tier is + chosen by the supplied `(capacity, capacity_field)` and the enum is + walked in tier-priority order. Any `model_id` not found in the primary + tier falls through to the next tier so the mapping stays complete. + + Parameters + ---------- + capacity : `int`, optional + Account capacity as returned by :meth:`compute_capacity`. Defaults + to ``1`` (free tier) for backwards compatibility with callers that + do not yet pass tier info. + capacity_field : `int`, optional + Account capacity proto field. Defaults to ``12``. + + Returns + ------- + `dict[str, str]` + Mapping of internal hex `model_id` to canonical `model_name` + (e.g. ``"gemini-3-pro-plus"`` for a Plus-tier account). """ + # Tier priority order — which `Model` family's name we prefer for a + # given (capacity, capacity_field). The primary tier matches the enum + # family whose header was built with this same capacity value, so the + # returned name is consistent with the header the account actually sends. + # Fallback tiers fill any `model_id` the primary does not cover (e.g. + # the `BASIC_*` ids that PLUS/ADVANCED do not share). + if capacity == 4 and capacity_field == 12: + tier_order = ("PLUS", "ADVANCED", "BASIC") + elif capacity == 2 and capacity_field in (12, 13): + tier_order = ("ADVANCED", "PLUS", "BASIC") + elif capacity == 1 and capacity_field == 13: + tier_order = ("PLUS", "ADVANCED", "BASIC") + else: + # capacity=1/field=12 — free tier — keeps the existing behaviour. + tier_order = ("BASIC", "PLUS", "ADVANCED") + result: dict[str, str] = {} - for member in Model: - if member is Model.UNSPECIFIED: - continue - - header_value = member.model_header.get(MODEL_HEADER_KEY, "") - if not header_value: - continue - - try: - parsed = json.loads(header_value) - model_id = get_nested_value(parsed, [4]) - except json.JSONDecodeError: - continue - - if model_id and model_id not in result: - # Use basic model name without tier suffix regardless of the actual tier - base_key = "BASIC_" + member.name.split("_", 1)[-1] - base_member = getattr(Model, base_key, member) - result[model_id] = base_member.model_name + for tier_prefix in tier_order: + for member in Model: + if member is Model.UNSPECIFIED: + continue + if not member.name.startswith(f"{tier_prefix}_"): + continue + + header_value = member.model_header.get(MODEL_HEADER_KEY, "") + if not header_value: + continue + + try: + parsed = json.loads(header_value) + model_id = get_nested_value(parsed, [4]) + except json.JSONDecodeError: + continue + + if model_id and model_id not in result: + result[model_id] = member.model_name return result diff --git a/tests/test_available_model.py b/tests/test_available_model.py new file mode 100644 index 0000000..82204f3 --- /dev/null +++ b/tests/test_available_model.py @@ -0,0 +1,81 @@ +import unittest + +from gemini_webapi.constants import Model, MODEL_HEADER_KEY +from gemini_webapi.types import AvailableModel +from gemini_webapi.utils import get_nested_value + +import orjson as json + + +def _id_for(member: Model) -> str: + header = member.model_header.get(MODEL_HEADER_KEY, "") + return get_nested_value(json.loads(header), [4], "") + + +class TestBuildModelIdNameMapping(unittest.TestCase): + """The mapping must return names whose tier matches the account's capacity, + because the caller uses the returned name to construct headers, and using a + wrong-tier name produces requests Google may reject or silently re-tier.""" + + def test_free_tier_primary_ids_resolve_to_basic_names(self): + mapping = AvailableModel.build_model_id_name_mapping( + capacity=1, capacity_field=12 + ) + self.assertEqual(mapping[_id_for(Model.BASIC_PRO)], "gemini-3-pro") + self.assertEqual(mapping[_id_for(Model.BASIC_FLASH)], "gemini-3-flash") + self.assertEqual( + mapping[_id_for(Model.BASIC_THINKING)], "gemini-3-flash-thinking" + ) + + def test_plus_tier_primary_ids_resolve_to_plus_names(self): + mapping = AvailableModel.build_model_id_name_mapping( + capacity=4, capacity_field=12 + ) + self.assertEqual(mapping[_id_for(Model.PLUS_PRO)], "gemini-3-pro-plus") + self.assertEqual(mapping[_id_for(Model.PLUS_FLASH)], "gemini-3-flash-plus") + self.assertEqual( + mapping[_id_for(Model.PLUS_THINKING)], "gemini-3-flash-thinking-plus" + ) + + def test_advanced_tier_primary_ids_resolve_to_advanced_names(self): + # Capacity 2 (field 12 or 13) is the "Advanced" capability — its + # model_ids happen to be shared with the Plus tier but the account's + # header uses capacity=2, so names must reflect that. + for field in (12, 13): + with self.subTest(capacity_field=field): + mapping = AvailableModel.build_model_id_name_mapping( + capacity=2, capacity_field=field + ) + self.assertEqual( + mapping[_id_for(Model.ADVANCED_PRO)], "gemini-3-pro-advanced" + ) + self.assertEqual( + mapping[_id_for(Model.ADVANCED_FLASH)], "gemini-3-flash-advanced" + ) + self.assertEqual( + mapping[_id_for(Model.ADVANCED_THINKING)], + "gemini-3-flash-thinking-advanced", + ) + + def test_basic_only_ids_still_resolve_on_higher_tiers(self): + # BASIC_* have unique model_ids — if Google surfaces them to a Plus + # account (defensive case), the mapping must still cover them. + for capacity in (2, 4): + with self.subTest(capacity=capacity): + mapping = AvailableModel.build_model_id_name_mapping( + capacity=capacity, capacity_field=12 + ) + self.assertIn(_id_for(Model.BASIC_PRO), mapping) + self.assertIn(_id_for(Model.BASIC_FLASH), mapping) + self.assertIn(_id_for(Model.BASIC_THINKING), mapping) + + def test_default_args_preserve_legacy_basic_mapping(self): + default = AvailableModel.build_model_id_name_mapping() + explicit = AvailableModel.build_model_id_name_mapping( + capacity=1, capacity_field=12 + ) + self.assertEqual(default, explicit) + + +if __name__ == "__main__": + unittest.main()