File size: 2,654 Bytes
8a3374d
 
7be08b4
21d76a7
7be08b4
 
 
 
 
 
 
 
8a3374d
d9a5339
8a3374d
 
7be08b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21d76a7
 
7be08b4
 
21d76a7
7be08b4
 
 
 
21d76a7
 
 
 
7be08b4
 
 
 
 
 
 
8a3374d
21d76a7
 
 
 
5d93afc
 
21d76a7
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from pathlib import Path

from pydantic import BaseModel, Field
from pydantic_ai import Agent, ModelRetry, RunContext
from pydantic_ai.messages import (
    ModelMessage,
    ModelRequest,
    ModelResponse,
    TextPart,
    ToolReturnPart,
)
from pydantic_ai.models.google import GoogleModel, GoogleModelSettings

from app import models
from app.tools import dailymed, literature


class Context(BaseModel):
    thoughts: list[str]
    sources: dict[str, dict]


class Statement(BaseModel):
    text: str
    sources: list[str] | None = Field(
        default=None, description="ID of the sources that support this statement."
    )


def get_context(messages: list[ModelMessage]) -> Context:
    thoughts: list[str] = []
    sources: dict[str, dict] = {}

    for message in messages:
        if isinstance(message, ModelResponse):
            for part in message.parts:
                if isinstance(part, TextPart):
                    thoughts.append(part.content)
        elif isinstance(message, ModelRequest):
            for part in message.parts:
                if isinstance(part, ToolReturnPart) and part.tool_name in {
                    "search_medical_literature",
                    "find_drug_set_ids",
                }:
                    for item in part.content:
                        sources[item["id"]] = item

    return Context(thoughts=thoughts, sources=sources)


def create_response(ctx: RunContext, output: list[Statement]) -> models.Statements:
    context = get_context(ctx.messages)

    statements = []
    for statement in output:
        sources = []
        for source_id in statement.sources or []:
            try:
                sources.append(context.sources[source_id])
            except KeyError as err:
                raise ModelRetry(
                    f"Source ID '{source_id}' not found in literature."
                ) from err
        statements.append({"text": statement.text, "sources": sources})

    return models.Statements.model_validate(
        {
            "statements": statements,
            "thoughts": "\n\n".join(context.thoughts),
        }
    )


model = GoogleModel("gemini-2.5-flash-preview-05-20")
settings = GoogleModelSettings(
    # temperature=0.1,
    google_thinking_config={"thinking_budget": 1024, "include_thoughts": True},
)
agent = Agent(
    model=model,
    name="elna",
    model_settings=settings,
    output_type=create_response,
    system_prompt=(Path(__file__).parent / "system_instruction.txt").read_text(),
    tools=[
        dailymed.find_drug_set_ids,
        dailymed.find_drug_instruction,
        literature.search_medical_literature,
    ],
)