Skip to content

Schema adherence in recursive JSON with union types #2481

@bertrandkerres

Description

@bertrandkerres

Environment details

  • Programming language: python
  • OS: ubuntu-22.04
  • Language runtime version: 3.12.7
  • Package version: 1.75

Steps to reproduce

I'm using LLMs to build math expressions as trees. The schema is recursive, but the type of each leaf is a Union of several concrete types (e.g. operators +, -, *, ...).

For recursive types, I had filed an issue before (#2181 ), the the solution as far as I could see implemented a cycle guard that sets the schema of visited nodes to {} (i.e. any). It worked for simple recursive schemata without union types. But now that I use more complex schemata, now it seems Gemini doesn't know about the actual schema, and goes astray.

MWE

Install deps:

pip install google-genai pydantic

Run script:

GOOGLE_API_KEY=... python mwe_gemini_recursive_schema.py            # schema only
GOOGLE_API_KEY=... python mwe_gemini_recursive_schema.py --live     # + live calls

Script:
Note: It outputs the schema to /tmp/mwe_gemini_schema.json

from __future__ import annotations

import argparse
import copy
import json
import os
from typing import Annotated, List, Literal, Union

import google.genai as genai
import google.genai._transformers as _t
from pydantic import BaseModel, Field, RootModel, ValidationError


# --- Recursive expression-tree node union ---------------------------------

class UnaryOp(BaseModel):
    """A unary function: exp(x), log(x), ... — exactly one argument."""
    type: Literal["exp", "log", "sqrt", "abs"]
    args: Annotated[List["Node"], Field(max_length=1)]


class BinaryOp(BaseModel):
    """A binary function: x / y, x ^ y, ... — exactly two arguments."""
    type: Literal["/", "^", "atan"]
    args: Annotated[List["Node"], Field(max_length=2)]


class NAryOp(BaseModel):
    """An n-ary function: x + y + ..., x * y * ..., min/max — n arguments."""
    type: Literal["+", "-", "*", "min", "max"]
    args: List["Node"]


class Const(BaseModel):
    """A numeric constant."""
    type: Literal["real"]
    value: float


class VarRef(BaseModel):
    """A reference to a variable by name."""
    type: Literal["variable"]
    name: str


class SharedRef(BaseModel):
    """A 1-indexed pointer into the tree's `shared` subexpression list (a DAG)."""
    type: Literal["node"]
    index: Annotated[int, Field(ge=1)]


class Node(RootModel[Union[UnaryOp, BinaryOp, NAryOp, Const, VarRef, SharedRef]]):
    pass


# --- Outer union: a function may also be given in simpler closed forms -----

class Term(BaseModel):
    coefficient: float
    variable: str


class Variable(BaseModel):
    """The function is a single bare variable."""
    type: Literal["Variable"]
    name: str


class Polynomial(BaseModel):
    """The function is a linear combination of variables plus a constant."""
    type: Literal["Polynomial"]
    constant: float
    terms: List[Term]


class Tree(BaseModel):
    """A general expression DAG: `root` plus a `shared` list of subexpressions."""
    type: Literal["Tree"]
    root: Node
    shared: List[Node]


class FunctionExpr(BaseModel):
    """A named mathematical function in one of three representations."""
    name: str
    expr: Union[Variable, Polynomial, Tree]


class Functions(BaseModel):
    """A batch of mathematical functions."""
    items: List[FunctionExpr]


for _m in (UnaryOp, BinaryOp, NAryOp, Node):
    _m.model_rebuild()


PROMPT = (
    "Express these two mathematical functions as expression trees (use the "
    '"Tree" representation, with an operator tree in `root`):\n'
    "  1. f = exp(x) * y\n"
    "  2. g = exp(x)\n"
    'Use type="exp" (one argument) and type="*" (two arguments). '
    'Reference variables with {"type": "variable", "name": ...}.'
)


def build_schema() -> dict:
    """Schema as actually sent to Gemini (post process_schema)."""
    schema = copy.deepcopy(Functions.model_json_schema())
    _t.process_schema(schema, client=None)
    return schema


def show_schema() -> None:
    print(f"google-genai version: {genai.__version__}\n")
    schema = build_schema()
    text = json.dumps(schema)
    print(f"schema contains $ref: {'$ref' in text}   contains $defs: {'$defs' in text}")

    def find_recursive_args(node):
        # Only the operator nodes' `args` arrays (titled "Args") are recursive.
        if isinstance(node, dict):
            if node.get("title") == "Args" and node.get("type") == "array":
                yield node.get("items")
            for v in node.values():
                yield from find_recursive_args(v)
        elif isinstance(node, list):
            for v in node:
                yield from find_recursive_args(v)

    print("\nRecursive `args` array `items` as sent to Gemini:")
    seen = set()
    for items in find_recursive_args(schema):
        key = json.dumps(items, sort_keys=True)
        if key in seen:
            continue
        seen.add(key)
        marker = "  <-- collapsed to {} — NO structural guidance for operands" if items == {} else ""
        print(f"  items = {key}{marker}")

    with open("/tmp/mwe_gemini_schema.json", "w") as f:
        json.dump(schema, f, indent=2)
    print("\n(full schema written to /tmp/mwe_gemini_schema.json)")


def live(trials: int) -> None:
    api_key = os.environ.get("GOOGLE_API_KEY")
    if not api_key:
        print("\nGOOGLE_API_KEY not set — skipping live calls.")
        return

    client = genai.Client(api_key=api_key)
    schema = build_schema()
    model = "gemini-3.1-pro-preview"

    print(f"\nLive: {trials} calls to {model} (google-genai {genai.__version__})\n")
    ok = 0
    saved = False
    for i in range(1, trials + 1):
        resp = client.models.generate_content(
            model=model,
            contents=PROMPT,
            config={
                "response_mime_type": "application/json",
                "response_json_schema": schema,
            },
        )
        raw = resp.text
        try:
            Functions.model_validate_json(raw)
            ok += 1
            print(f"  trial {i:2d}: OK")
        except ValidationError as e:
            print(f"  trial {i:2d}: INVALID — {len(e.errors())} errors")
            if not saved:
                with open("/tmp/mwe_gemini_invalid.json", "w") as f:
                    f.write(raw)
                print("            raw payload written to /tmp/mwe_gemini_invalid.json")
                print(f"            raw: {raw[:600]}")
                saved = True

    print(f"\nSuccess rate: {ok}/{trials} ({ok / trials * 100:.0f}%) "
          f"on google-genai {genai.__version__}")


def main() -> None:
    ap = argparse.ArgumentParser(description=__doc__)
    ap.add_argument("--live", action="store_true", help="make real Gemini API calls")
    ap.add_argument("--trials", type=int, default=10)
    args = ap.parse_args()

    show_schema()
    if args.live:
        live(args.trials)


if __name__ == "__main__":
    main()

Results

  • Success rate: 0/10. Every response failed validation against the
    exact schema that was sent:

    Live: 10 calls to gemini-3.1-pro-preview (google-genai 1.75.0)
      trial  1: INVALID — 46 errors
      trial  2: INVALID — 36 errors
      trial  3: INVALID — 43 errors
      trial  4: INVALID — 46 errors
      trial  5: INVALID — 43 errors
      trial  6: INVALID — 46 errors
      trial  7: INVALID — 46 errors
      trial  8: INVALID — 1 errors
      trial  9: INVALID — 26 errors
      trial 10: INVALID — 36 errors
    
    Success rate: 0/10 (0%) on google-genai 1.75.0
    
  • Failures include decoding corruption, not just wrong structure. One raw
    payload contained a stray CJK token and an injected ```json fence
    in the middle of the structured output:

    {"items":[{"name":"f","expr":{"type":"Tree","root":{"type":"*",
     "args":[{"type":"exp","args敛```json{"   :"x"}]},"shared":[]}}]}
    

Metadata

Metadata

Labels

priority: p2Moderately-important priority. Fix may not be included in next release.type: bugError or flaw in code with unintended results or allowing sub-optimal usage patterns.

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions