Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/gemini_webapi/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,9 @@ async def _fetch_user_status(self) -> None:
tier_flags, capability_flags
)

id_name_mapping = AvailableModel.build_model_id_name_mapping()
id_name_mapping = AvailableModel.build_model_id_name_mapping(
capacity, capacity_field
)

for model_data in models_list:
if isinstance(model_data, list):
Expand Down
84 changes: 61 additions & 23 deletions src/gemini_webapi/types/availablemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,33 +110,71 @@ def compute_capacity(tier_flags: list, capability_flags: list) -> tuple[int, int
return 1, 12 # Free accounts

@staticmethod
def build_model_id_name_mapping() -> dict[str, str]:
def build_model_id_name_mapping(
capacity: int = 1,
capacity_field: int = 12,
) -> dict[str, str]:
"""
Build a mapping from `model_id` to `model_name` for all registered models.
Build a mapping from `model_id` to `model_name` for all registered models,
picking the canonical name that matches the caller's account tier.

This uses the :class:`Model` enum to resolve hex identifiers to their
canonical names (e.g., "gemini-3-pro").
PLUS and ADVANCED tiers share `model_id` values with each other
(differentiated only by the `capacity` header), so the target tier is
chosen by the supplied `(capacity, capacity_field)` and the enum is
walked in tier-priority order. Any `model_id` not found in the primary
tier falls through to the next tier so the mapping stays complete.

Parameters
----------
capacity : `int`, optional
Account capacity as returned by :meth:`compute_capacity`. Defaults
to ``1`` (free tier) for backwards compatibility with callers that
do not yet pass tier info.
capacity_field : `int`, optional
Account capacity proto field. Defaults to ``12``.

Returns
-------
`dict[str, str]`
Mapping of internal hex `model_id` to canonical `model_name`
(e.g. ``"gemini-3-pro-plus"`` for a Plus-tier account).
"""

# Tier priority order — which `Model` family's name we prefer for a
# given (capacity, capacity_field). The primary tier matches the enum
# family whose header was built with this same capacity value, so the
# returned name is consistent with the header the account actually sends.
# Fallback tiers fill any `model_id` the primary does not cover (e.g.
# the `BASIC_*` ids that PLUS/ADVANCED do not share).
if capacity == 4 and capacity_field == 12:
tier_order = ("PLUS", "ADVANCED", "BASIC")
elif capacity == 2 and capacity_field in (12, 13):
tier_order = ("ADVANCED", "PLUS", "BASIC")
elif capacity == 1 and capacity_field == 13:
tier_order = ("PLUS", "ADVANCED", "BASIC")
else:
# capacity=1/field=12 — free tier — keeps the existing behaviour.
tier_order = ("BASIC", "PLUS", "ADVANCED")

result: dict[str, str] = {}
for member in Model:
if member is Model.UNSPECIFIED:
continue

header_value = member.model_header.get(MODEL_HEADER_KEY, "")
if not header_value:
continue

try:
parsed = json.loads(header_value)
model_id = get_nested_value(parsed, [4])
except json.JSONDecodeError:
continue

if model_id and model_id not in result:
# Use basic model name without tier suffix regardless of the actual tier
base_key = "BASIC_" + member.name.split("_", 1)[-1]
base_member = getattr(Model, base_key, member)
result[model_id] = base_member.model_name
for tier_prefix in tier_order:
for member in Model:
if member is Model.UNSPECIFIED:
continue
if not member.name.startswith(f"{tier_prefix}_"):
continue

header_value = member.model_header.get(MODEL_HEADER_KEY, "")
if not header_value:
continue

try:
parsed = json.loads(header_value)
model_id = get_nested_value(parsed, [4])
except json.JSONDecodeError:
continue

if model_id and model_id not in result:
result[model_id] = member.model_name

return result
81 changes: 81 additions & 0 deletions tests/test_available_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import unittest

from gemini_webapi.constants import Model, MODEL_HEADER_KEY
from gemini_webapi.types import AvailableModel
from gemini_webapi.utils import get_nested_value

import orjson as json


def _id_for(member: Model) -> str:
header = member.model_header.get(MODEL_HEADER_KEY, "")
return get_nested_value(json.loads(header), [4], "")


class TestBuildModelIdNameMapping(unittest.TestCase):
"""The mapping must return names whose tier matches the account's capacity,
because the caller uses the returned name to construct headers, and using a
wrong-tier name produces requests Google may reject or silently re-tier."""

def test_free_tier_primary_ids_resolve_to_basic_names(self):
mapping = AvailableModel.build_model_id_name_mapping(
capacity=1, capacity_field=12
)
self.assertEqual(mapping[_id_for(Model.BASIC_PRO)], "gemini-3-pro")
self.assertEqual(mapping[_id_for(Model.BASIC_FLASH)], "gemini-3-flash")
self.assertEqual(
mapping[_id_for(Model.BASIC_THINKING)], "gemini-3-flash-thinking"
)

def test_plus_tier_primary_ids_resolve_to_plus_names(self):
mapping = AvailableModel.build_model_id_name_mapping(
capacity=4, capacity_field=12
)
self.assertEqual(mapping[_id_for(Model.PLUS_PRO)], "gemini-3-pro-plus")
self.assertEqual(mapping[_id_for(Model.PLUS_FLASH)], "gemini-3-flash-plus")
self.assertEqual(
mapping[_id_for(Model.PLUS_THINKING)], "gemini-3-flash-thinking-plus"
)

def test_advanced_tier_primary_ids_resolve_to_advanced_names(self):
# Capacity 2 (field 12 or 13) is the "Advanced" capability — its
# model_ids happen to be shared with the Plus tier but the account's
# header uses capacity=2, so names must reflect that.
for field in (12, 13):
with self.subTest(capacity_field=field):
mapping = AvailableModel.build_model_id_name_mapping(
capacity=2, capacity_field=field
)
self.assertEqual(
mapping[_id_for(Model.ADVANCED_PRO)], "gemini-3-pro-advanced"
)
self.assertEqual(
mapping[_id_for(Model.ADVANCED_FLASH)], "gemini-3-flash-advanced"
)
self.assertEqual(
mapping[_id_for(Model.ADVANCED_THINKING)],
"gemini-3-flash-thinking-advanced",
)

def test_basic_only_ids_still_resolve_on_higher_tiers(self):
# BASIC_* have unique model_ids — if Google surfaces them to a Plus
# account (defensive case), the mapping must still cover them.
for capacity in (2, 4):
with self.subTest(capacity=capacity):
mapping = AvailableModel.build_model_id_name_mapping(
capacity=capacity, capacity_field=12
)
self.assertIn(_id_for(Model.BASIC_PRO), mapping)
self.assertIn(_id_for(Model.BASIC_FLASH), mapping)
self.assertIn(_id_for(Model.BASIC_THINKING), mapping)

def test_default_args_preserve_legacy_basic_mapping(self):
default = AvailableModel.build_model_id_name_mapping()
explicit = AvailableModel.build_model_id_name_mapping(
capacity=1, capacity_field=12
)
self.assertEqual(default, explicit)


if __name__ == "__main__":
unittest.main()