David Chu
commited on
feat: mitigate citation hallucination
Browse filesUsing PydanticAI's output function and raising a ModelRetry will
make the LLM regenerate the output (by default 1 retry)
if there are nonexistent source IDs.
- app/agent.py +26 -26
- app/main.py +2 -2
- main.py +6 -8
app/agent.py
CHANGED
@@ -1,8 +1,7 @@
|
|
1 |
from pathlib import Path
|
2 |
|
3 |
-
import logfire
|
4 |
from pydantic import BaseModel, Field
|
5 |
-
from pydantic_ai import Agent
|
6 |
from pydantic_ai.messages import (
|
7 |
ModelMessage,
|
8 |
ModelRequest,
|
@@ -28,24 +27,6 @@ class Statement(BaseModel):
|
|
28 |
)
|
29 |
|
30 |
|
31 |
-
model = GoogleModel("gemini-2.5-flash-preview-05-20")
|
32 |
-
settings = GoogleModelSettings(
|
33 |
-
google_thinking_config={"thinking_budget": 2048, "include_thoughts": True},
|
34 |
-
)
|
35 |
-
agent = Agent(
|
36 |
-
model=model,
|
37 |
-
name="elna",
|
38 |
-
model_settings=settings,
|
39 |
-
output_type=list[Statement],
|
40 |
-
system_prompt=(Path(__file__).parent / "system_instruction.txt").read_text(),
|
41 |
-
tools=[
|
42 |
-
dailymed.find_drug_set_ids,
|
43 |
-
dailymed.find_drug_instruction,
|
44 |
-
literature.search_medical_literature,
|
45 |
-
],
|
46 |
-
)
|
47 |
-
|
48 |
-
|
49 |
def get_context(messages: list[ModelMessage]) -> Context:
|
50 |
thoughts: list[str] = []
|
51 |
sources: dict[str, dict] = {}
|
@@ -67,18 +48,19 @@ def get_context(messages: list[ModelMessage]) -> Context:
|
|
67 |
return Context(thoughts=thoughts, sources=sources)
|
68 |
|
69 |
|
70 |
-
def
|
71 |
-
|
72 |
-
context = get_context(result.all_messages())
|
73 |
|
74 |
statements = []
|
75 |
-
for statement in
|
76 |
sources = []
|
77 |
for source_id in statement.sources or []:
|
78 |
try:
|
79 |
sources.append(context.sources[source_id])
|
80 |
-
except KeyError:
|
81 |
-
|
|
|
|
|
82 |
statements.append({"text": statement.text, "sources": sources})
|
83 |
|
84 |
return models.Statements.model_validate(
|
@@ -87,3 +69,21 @@ def respond(query: str) -> models.Statements:
|
|
87 |
"thoughts": "\n\n".join(context.thoughts),
|
88 |
}
|
89 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from pathlib import Path
|
2 |
|
|
|
3 |
from pydantic import BaseModel, Field
|
4 |
+
from pydantic_ai import Agent, ModelRetry, RunContext
|
5 |
from pydantic_ai.messages import (
|
6 |
ModelMessage,
|
7 |
ModelRequest,
|
|
|
27 |
)
|
28 |
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
def get_context(messages: list[ModelMessage]) -> Context:
|
31 |
thoughts: list[str] = []
|
32 |
sources: dict[str, dict] = {}
|
|
|
48 |
return Context(thoughts=thoughts, sources=sources)
|
49 |
|
50 |
|
51 |
+
def create_response(ctx: RunContext, output: list[Statement]) -> models.Statements:
|
52 |
+
context = get_context(ctx.messages)
|
|
|
53 |
|
54 |
statements = []
|
55 |
+
for statement in output:
|
56 |
sources = []
|
57 |
for source_id in statement.sources or []:
|
58 |
try:
|
59 |
sources.append(context.sources[source_id])
|
60 |
+
except KeyError as err:
|
61 |
+
raise ModelRetry(
|
62 |
+
f"Source ID '{source_id}' not found in literature."
|
63 |
+
) from err
|
64 |
statements.append({"text": statement.text, "sources": sources})
|
65 |
|
66 |
return models.Statements.model_validate(
|
|
|
69 |
"thoughts": "\n\n".join(context.thoughts),
|
70 |
}
|
71 |
)
|
72 |
+
|
73 |
+
|
74 |
+
model = GoogleModel("gemini-2.5-flash-preview-05-20")
|
75 |
+
settings = GoogleModelSettings(
|
76 |
+
google_thinking_config={"thinking_budget": 2048, "include_thoughts": True},
|
77 |
+
)
|
78 |
+
agent = Agent(
|
79 |
+
model=model,
|
80 |
+
name="elna",
|
81 |
+
model_settings=settings,
|
82 |
+
output_type=create_response,
|
83 |
+
system_prompt=(Path(__file__).parent / "system_instruction.txt").read_text(),
|
84 |
+
tools=[
|
85 |
+
dailymed.find_drug_set_ids,
|
86 |
+
dailymed.find_drug_instruction,
|
87 |
+
literature.search_medical_literature,
|
88 |
+
],
|
89 |
+
)
|
app/main.py
CHANGED
@@ -19,5 +19,5 @@ def health_check():
|
|
19 |
|
20 |
|
21 |
@app.get("/ask", response_model=models.Statements)
|
22 |
-
def ask(query: str):
|
23 |
-
return agent.
|
|
|
19 |
|
20 |
|
21 |
@app.get("/ask", response_model=models.Statements)
|
22 |
+
async def ask(query: str):
|
23 |
+
return await agent.agent.run(query)
|
main.py
CHANGED
@@ -9,11 +9,11 @@ logfire.configure(
|
|
9 |
logfire.instrument_pydantic_ai()
|
10 |
|
11 |
|
12 |
-
def format_output(statements: models.
|
13 |
sentences = []
|
14 |
citations = {}
|
15 |
|
16 |
-
for statement in statements
|
17 |
sentence = statement.text
|
18 |
|
19 |
if sentence.startswith(("*", "-")):
|
@@ -33,9 +33,7 @@ def format_output(statements: models.Statements) -> tuple[str, str]:
|
|
33 |
|
34 |
answer = " ".join(sentences)
|
35 |
footnotes = "\n".join(f"[^{id}]: {citation}" for citation, id in citations.items())
|
36 |
-
|
37 |
-
|
38 |
-
return f"{answer}\n\n{footnotes}", thought
|
39 |
|
40 |
|
41 |
def main():
|
@@ -46,10 +44,10 @@ def main():
|
|
46 |
|
47 |
if submit:
|
48 |
with st.spinner("Thinking...", show_time=True):
|
49 |
-
output = agent.
|
50 |
-
|
51 |
with st.expander("Thinking Process"):
|
52 |
-
st.markdown(thoughts)
|
53 |
st.markdown(answer)
|
54 |
|
55 |
|
|
|
9 |
logfire.instrument_pydantic_ai()
|
10 |
|
11 |
|
12 |
+
def format_output(statements: list[models.Statement]) -> str:
|
13 |
sentences = []
|
14 |
citations = {}
|
15 |
|
16 |
+
for statement in statements:
|
17 |
sentence = statement.text
|
18 |
|
19 |
if sentence.startswith(("*", "-")):
|
|
|
33 |
|
34 |
answer = " ".join(sentences)
|
35 |
footnotes = "\n".join(f"[^{id}]: {citation}" for citation, id in citations.items())
|
36 |
+
return f"{answer}\n\n{footnotes}"
|
|
|
|
|
37 |
|
38 |
|
39 |
def main():
|
|
|
44 |
|
45 |
if submit:
|
46 |
with st.spinner("Thinking...", show_time=True):
|
47 |
+
output = agent.agent.run_sync(query).output
|
48 |
+
answer = format_output(output.statements)
|
49 |
with st.expander("Thinking Process"):
|
50 |
+
st.markdown(output.thoughts)
|
51 |
st.markdown(answer)
|
52 |
|
53 |
|