|

How to use Pydantic with Gemma 3N on OpenRouter

We use Gemma 3N 4B model for the task of extracting the NUMDAYS field when both VAX_DATE and ONSET_DATE are specified, but the NUMDAYS field is empty in the VAERS CSV file.

import os
import time
from datetime import date
from enum import Enum
from typing import Optional, List

import polars as pl
from dotenv import load_dotenv
from openai import OpenAI
from pydantic import BaseModel, Field, ValidationError

import json
import re


class CustomDate(BaseModel):
    year: int = Field(default=-1, description="Year of event, use -1 if not provided")
    month: int = Field(default=-1, description="Month number of event, use -1 if not provided")
    date_of_month: int = Field(default=-1, description="Date of month number between 1 and 31, use -1 if not provided")


class VaccineInfo(BaseModel):
    vaccination_date: CustomDate = Field(description="Date of vaccination")
    symptom_onset_date: CustomDate = Field(description="Date of onset of earliest symptom, including mild ones")
    number_of_days_upper_limit: int = Field(
        description="Upper limit of number of days to earliest symptom onset, which is the difference in number of days between vaccination_date and symptom_onset_date. If the number of days to symptom onset is not mentioned in the report, and date of month for symptom onset is not provided, use the last possible date of the symptom onset month and calculate the upper limit based on the maximum possible number of days to symptom onset. For example, the last date of May2021 is 31May2021, so the upper limit will be 10 (31-21) if the vaccination date is 21May2021.")
    number_of_days_upper_limit_explanation: str = Field(
        description="Verbatim sentence fragment from the input text which explains the extracted value")


class Message(BaseModel):
    role: str
    content: str


class ChatRequest(BaseModel):
    messages: List[Message]
    response_format: dict = Field(default={"type": "json_object"})
    temperature: float = 0


def contains_json(s, pydantic_model=None):
    # Regex pattern to match JSON objects or arrays, handling nesting
    json_pattern = r'\{(?:[^{}]|\{[^{}]*\})*\}|\[(?:[^[\]]|\[[^[\]]*])*]'

    # Find all potential JSON substrings
    matches = list(re.finditer(json_pattern, s, re.DOTALL))

    # Track valid JSON substrings and the largest one
    largest_json = None
    largest_json_length = 0
    pydantic_match = None

    # Check each match for valid JSON and Pydantic schema
    for match in matches:
        json_str = match.group(0)
        try:
            # Validate JSON
            parsed_json = json.loads(json_str)

            # Track the largest valid JSON
            if len(json_str) > largest_json_length:
                largest_json = json_str
                largest_json_length = len(json_str)

            # If Pydantic model is provided, check if JSON matches the schema
            if pydantic_model:
                try:
                    pydantic_model(**parsed_json)
                    return True, True, json_str  # Found Pydantic-compatible JSON
                except ValidationError:
                    continue  # Not a valid Pydantic schema, continue checking
            else:
                # If no Pydantic model, return first valid JSON
                return True, False, json_str

        except json.JSONDecodeError:
            continue  # Not valid JSON, skip

    # If Pydantic model was provided but no match found, return largest valid JSON
    if largest_json:
        return True, False, largest_json
    # If no valid JSON found at all
    return False, False, None


# Load environment variables
load_dotenv()
api_key = os.getenv("OPENROUTER_API_KEY")
MODEL_NAME = "google/gemma-3n-e4b-it"
NUM_ITEMS = 100
EXPERIMENT_NAME = 'null_numdays'

client = OpenAI(
    base_url="https://openrouter.ai/api/v1",
    api_key=api_key
)

df = pl.read_csv('csv/null_numdays/null_numdays_100.csv')

df_subset = df.head(NUM_ITEMS)

experiment = EXPERIMENT_NAME
model_name_str = MODEL_NAME.replace('/', '-').replace(':', '-')
cache_file = f'json/output_cache/{experiment}/{experiment}_{model_name_str}.json'

# Load existing cache if it exists
existing_data = {}
if os.path.exists(cache_file):
    with open(cache_file, 'r') as f:
        cached_items = json.load(f)
        for item in cached_items:
            existing_data[item['vaers_id']] = item
    print(f"Loaded {len(existing_data)} cached results")

all_data = []
total_rows = len(df_subset)
for i, row in enumerate(df_subset.iter_rows(named=True), 1):
    vaers_id = row['vaers_id']
    symptom_text = row['symptom_text']

    print(f"Progress: {i}/{total_rows} - VAERS ID: {vaers_id}")

    # Check if we already have this VAERS ID in cache
    if vaers_id in existing_data:
        print(f"Using cached data for VAERS ID: {vaers_id}")
        all_data.append(existing_data[vaers_id])
        continue

    print(f"Processing VAERS ID: {vaers_id}")
    # Initialize variables
    response_text = ""
    prompt = ""

    prompt = f"""Extract the requested information from this VAERS report and return it as JSON in the specified schema. 

    VAERS ID: {vaers_id}

    Report: {symptom_text}

    JSON Schema:
    {json.dumps(VaccineInfo.model_json_schema(), indent=2)}
    """
    before = time.time()
    # Create request
    request = ChatRequest(
        messages=[
            Message(
                role="user",
                content=prompt
            )
        ]
    )

    # Make API call - exclude reasoning parameter as it's not supported by OpenAI client
    request_data = request.model_dump(exclude_none=True)
    if 'reasoning' in request_data:
        del request_data['reasoning']

    response = client.chat.completions.create(
        model=MODEL_NAME,
        **request_data
    )
    after = time.time()
    elapsed = after - before
    inner_response_text = response.choices[0].message.content

    full_response_json = response.model_dump_json()
    inner_response_json = {}
    is_pure_json = False
    contains_valid_json = False
    is_valid_schema = False
    try:
        inner_response_json = json.loads(inner_response_text)
        is_pure_json = True
        contains_valid_json = True
        try:
            schema_json = VaccineInfo(**inner_response_json)
            is_valid_schema = True
        except ValidationError as ve:
            print(ve)
    except Exception as e:
        print('Error')
        contains_valid_json, is_valid_schema, json_str = contains_json(inner_response_text, VaccineInfo)
        if is_valid_schema:
            inner_response_json = json.loads(json_str)

    new_item = {
        "vaers_id": vaers_id,
        "inner_response_text": inner_response_text,
        "inner_response_json": inner_response_json,
        "full_response_json": full_response_json,
        "is_pure_json": is_pure_json,
        "contains_valid_json": contains_valid_json,
        "is_valid_schema": is_valid_schema,
        "time_elapsed": elapsed
    }
    all_data.append(new_item)

# Save updated cache
os.makedirs(f'json/output_cache/{experiment}', exist_ok=True)
with open(cache_file, 'w') as f:
    json.dump(all_data, f, indent=2)

print(f"Saved {len(all_data)} results to cache")

The code has some extra steps for caching previous results and validating schema and such, but you should be able to paste this code into an LLM to modify it according to your requirements.

Leave a Reply