File size: 2,654 Bytes
8a3374d 7be08b4 21d76a7 7be08b4 8a3374d d9a5339 8a3374d 7be08b4 21d76a7 7be08b4 21d76a7 7be08b4 21d76a7 7be08b4 8a3374d 21d76a7 5d93afc 21d76a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
from pathlib import Path
from pydantic import BaseModel, Field
from pydantic_ai import Agent, ModelRetry, RunContext
from pydantic_ai.messages import (
ModelMessage,
ModelRequest,
ModelResponse,
TextPart,
ToolReturnPart,
)
from pydantic_ai.models.google import GoogleModel, GoogleModelSettings
from app import models
from app.tools import dailymed, literature
class Context(BaseModel):
thoughts: list[str]
sources: dict[str, dict]
class Statement(BaseModel):
text: str
sources: list[str] | None = Field(
default=None, description="ID of the sources that support this statement."
)
def get_context(messages: list[ModelMessage]) -> Context:
thoughts: list[str] = []
sources: dict[str, dict] = {}
for message in messages:
if isinstance(message, ModelResponse):
for part in message.parts:
if isinstance(part, TextPart):
thoughts.append(part.content)
elif isinstance(message, ModelRequest):
for part in message.parts:
if isinstance(part, ToolReturnPart) and part.tool_name in {
"search_medical_literature",
"find_drug_set_ids",
}:
for item in part.content:
sources[item["id"]] = item
return Context(thoughts=thoughts, sources=sources)
def create_response(ctx: RunContext, output: list[Statement]) -> models.Statements:
context = get_context(ctx.messages)
statements = []
for statement in output:
sources = []
for source_id in statement.sources or []:
try:
sources.append(context.sources[source_id])
except KeyError as err:
raise ModelRetry(
f"Source ID '{source_id}' not found in literature."
) from err
statements.append({"text": statement.text, "sources": sources})
return models.Statements.model_validate(
{
"statements": statements,
"thoughts": "\n\n".join(context.thoughts),
}
)
model = GoogleModel("gemini-2.5-flash-preview-05-20")
settings = GoogleModelSettings(
# temperature=0.1,
google_thinking_config={"thinking_budget": 1024, "include_thoughts": True},
)
agent = Agent(
model=model,
name="elna",
model_settings=settings,
output_type=create_response,
system_prompt=(Path(__file__).parent / "system_instruction.txt").read_text(),
tools=[
dailymed.find_drug_set_ids,
dailymed.find_drug_instruction,
literature.search_medical_literature,
],
)
|