Environment details
- Programming language: python
- OS: ubuntu-22.04
- Language runtime version: 3.12.7
- Package version: 1.75
Steps to reproduce
I'm using LLMs to build math expressions as trees. The schema is recursive, but the type of each leaf is a Union of several concrete types (e.g. operators +, -, *, ...).
For recursive types, I had filed an issue before (#2181 ), the the solution as far as I could see implemented a cycle guard that sets the schema of visited nodes to {} (i.e. any). It worked for simple recursive schemata without union types. But now that I use more complex schemata, now it seems Gemini doesn't know about the actual schema, and goes astray.
MWE
Install deps:
pip install google-genai pydantic
Run script:
GOOGLE_API_KEY=... python mwe_gemini_recursive_schema.py # schema only
GOOGLE_API_KEY=... python mwe_gemini_recursive_schema.py --live # + live calls
Script:
Note: It outputs the schema to /tmp/mwe_gemini_schema.json
from __future__ import annotations
import argparse
import copy
import json
import os
from typing import Annotated, List, Literal, Union
import google.genai as genai
import google.genai._transformers as _t
from pydantic import BaseModel, Field, RootModel, ValidationError
# --- Recursive expression-tree node union ---------------------------------
class UnaryOp(BaseModel):
"""A unary function: exp(x), log(x), ... — exactly one argument."""
type: Literal["exp", "log", "sqrt", "abs"]
args: Annotated[List["Node"], Field(max_length=1)]
class BinaryOp(BaseModel):
"""A binary function: x / y, x ^ y, ... — exactly two arguments."""
type: Literal["/", "^", "atan"]
args: Annotated[List["Node"], Field(max_length=2)]
class NAryOp(BaseModel):
"""An n-ary function: x + y + ..., x * y * ..., min/max — n arguments."""
type: Literal["+", "-", "*", "min", "max"]
args: List["Node"]
class Const(BaseModel):
"""A numeric constant."""
type: Literal["real"]
value: float
class VarRef(BaseModel):
"""A reference to a variable by name."""
type: Literal["variable"]
name: str
class SharedRef(BaseModel):
"""A 1-indexed pointer into the tree's `shared` subexpression list (a DAG)."""
type: Literal["node"]
index: Annotated[int, Field(ge=1)]
class Node(RootModel[Union[UnaryOp, BinaryOp, NAryOp, Const, VarRef, SharedRef]]):
pass
# --- Outer union: a function may also be given in simpler closed forms -----
class Term(BaseModel):
coefficient: float
variable: str
class Variable(BaseModel):
"""The function is a single bare variable."""
type: Literal["Variable"]
name: str
class Polynomial(BaseModel):
"""The function is a linear combination of variables plus a constant."""
type: Literal["Polynomial"]
constant: float
terms: List[Term]
class Tree(BaseModel):
"""A general expression DAG: `root` plus a `shared` list of subexpressions."""
type: Literal["Tree"]
root: Node
shared: List[Node]
class FunctionExpr(BaseModel):
"""A named mathematical function in one of three representations."""
name: str
expr: Union[Variable, Polynomial, Tree]
class Functions(BaseModel):
"""A batch of mathematical functions."""
items: List[FunctionExpr]
for _m in (UnaryOp, BinaryOp, NAryOp, Node):
_m.model_rebuild()
PROMPT = (
"Express these two mathematical functions as expression trees (use the "
'"Tree" representation, with an operator tree in `root`):\n'
" 1. f = exp(x) * y\n"
" 2. g = exp(x)\n"
'Use type="exp" (one argument) and type="*" (two arguments). '
'Reference variables with {"type": "variable", "name": ...}.'
)
def build_schema() -> dict:
"""Schema as actually sent to Gemini (post process_schema)."""
schema = copy.deepcopy(Functions.model_json_schema())
_t.process_schema(schema, client=None)
return schema
def show_schema() -> None:
print(f"google-genai version: {genai.__version__}\n")
schema = build_schema()
text = json.dumps(schema)
print(f"schema contains $ref: {'$ref' in text} contains $defs: {'$defs' in text}")
def find_recursive_args(node):
# Only the operator nodes' `args` arrays (titled "Args") are recursive.
if isinstance(node, dict):
if node.get("title") == "Args" and node.get("type") == "array":
yield node.get("items")
for v in node.values():
yield from find_recursive_args(v)
elif isinstance(node, list):
for v in node:
yield from find_recursive_args(v)
print("\nRecursive `args` array `items` as sent to Gemini:")
seen = set()
for items in find_recursive_args(schema):
key = json.dumps(items, sort_keys=True)
if key in seen:
continue
seen.add(key)
marker = " <-- collapsed to {} — NO structural guidance for operands" if items == {} else ""
print(f" items = {key}{marker}")
with open("/tmp/mwe_gemini_schema.json", "w") as f:
json.dump(schema, f, indent=2)
print("\n(full schema written to /tmp/mwe_gemini_schema.json)")
def live(trials: int) -> None:
api_key = os.environ.get("GOOGLE_API_KEY")
if not api_key:
print("\nGOOGLE_API_KEY not set — skipping live calls.")
return
client = genai.Client(api_key=api_key)
schema = build_schema()
model = "gemini-3.1-pro-preview"
print(f"\nLive: {trials} calls to {model} (google-genai {genai.__version__})\n")
ok = 0
saved = False
for i in range(1, trials + 1):
resp = client.models.generate_content(
model=model,
contents=PROMPT,
config={
"response_mime_type": "application/json",
"response_json_schema": schema,
},
)
raw = resp.text
try:
Functions.model_validate_json(raw)
ok += 1
print(f" trial {i:2d}: OK")
except ValidationError as e:
print(f" trial {i:2d}: INVALID — {len(e.errors())} errors")
if not saved:
with open("/tmp/mwe_gemini_invalid.json", "w") as f:
f.write(raw)
print(" raw payload written to /tmp/mwe_gemini_invalid.json")
print(f" raw: {raw[:600]}")
saved = True
print(f"\nSuccess rate: {ok}/{trials} ({ok / trials * 100:.0f}%) "
f"on google-genai {genai.__version__}")
def main() -> None:
ap = argparse.ArgumentParser(description=__doc__)
ap.add_argument("--live", action="store_true", help="make real Gemini API calls")
ap.add_argument("--trials", type=int, default=10)
args = ap.parse_args()
show_schema()
if args.live:
live(args.trials)
if __name__ == "__main__":
main()
Results
-
Success rate: 0/10. Every response failed validation against the
exact schema that was sent:
Live: 10 calls to gemini-3.1-pro-preview (google-genai 1.75.0)
trial 1: INVALID — 46 errors
trial 2: INVALID — 36 errors
trial 3: INVALID — 43 errors
trial 4: INVALID — 46 errors
trial 5: INVALID — 43 errors
trial 6: INVALID — 46 errors
trial 7: INVALID — 46 errors
trial 8: INVALID — 1 errors
trial 9: INVALID — 26 errors
trial 10: INVALID — 36 errors
Success rate: 0/10 (0%) on google-genai 1.75.0
-
Failures include decoding corruption, not just wrong structure. One raw
payload contained a stray CJK token and an injected ```json fence
in the middle of the structured output:
{"items":[{"name":"f","expr":{"type":"Tree","root":{"type":"*",
"args":[{"type":"exp","args敛```json{" :"x"}]},"shared":[]}}]}
Environment details
Steps to reproduce
I'm using LLMs to build math expressions as trees. The schema is recursive, but the type of each leaf is a
Unionof several concrete types (e.g. operators+,-,*, ...).For recursive types, I had filed an issue before (#2181 ), the the solution as far as I could see implemented a cycle guard that sets the schema of visited nodes to
{}(i.e.any). It worked for simple recursive schemata without union types. But now that I use more complex schemata, now it seems Gemini doesn't know about the actual schema, and goes astray.MWE
Install deps:
Run script:
Script:
Note: It outputs the schema to
/tmp/mwe_gemini_schema.jsonResults
Success rate: 0/10. Every response failed validation against the
exact schema that was sent:
Failures include decoding corruption, not just wrong structure. One raw
payload contained a stray CJK token and an injected
```jsonfencein the middle of the structured output: