David Chu
commited on
feat: mitigate citation hallucination
Browse filesUsing PydanticAI's output function and raising a ModelRetry will
make the LLM regenerate the output (by default 1 retry)
if there are nonexistent source IDs.
- app/agent.py +26 -26
- app/main.py +2 -2
- main.py +6 -8
app/agent.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
| 1 |
from pathlib import Path
|
| 2 |
|
| 3 |
-
import logfire
|
| 4 |
from pydantic import BaseModel, Field
|
| 5 |
-
from pydantic_ai import Agent
|
| 6 |
from pydantic_ai.messages import (
|
| 7 |
ModelMessage,
|
| 8 |
ModelRequest,
|
|
@@ -28,24 +27,6 @@ class Statement(BaseModel):
|
|
| 28 |
)
|
| 29 |
|
| 30 |
|
| 31 |
-
model = GoogleModel("gemini-2.5-flash-preview-05-20")
|
| 32 |
-
settings = GoogleModelSettings(
|
| 33 |
-
google_thinking_config={"thinking_budget": 2048, "include_thoughts": True},
|
| 34 |
-
)
|
| 35 |
-
agent = Agent(
|
| 36 |
-
model=model,
|
| 37 |
-
name="elna",
|
| 38 |
-
model_settings=settings,
|
| 39 |
-
output_type=list[Statement],
|
| 40 |
-
system_prompt=(Path(__file__).parent / "system_instruction.txt").read_text(),
|
| 41 |
-
tools=[
|
| 42 |
-
dailymed.find_drug_set_ids,
|
| 43 |
-
dailymed.find_drug_instruction,
|
| 44 |
-
literature.search_medical_literature,
|
| 45 |
-
],
|
| 46 |
-
)
|
| 47 |
-
|
| 48 |
-
|
| 49 |
def get_context(messages: list[ModelMessage]) -> Context:
|
| 50 |
thoughts: list[str] = []
|
| 51 |
sources: dict[str, dict] = {}
|
|
@@ -67,18 +48,19 @@ def get_context(messages: list[ModelMessage]) -> Context:
|
|
| 67 |
return Context(thoughts=thoughts, sources=sources)
|
| 68 |
|
| 69 |
|
| 70 |
-
def
|
| 71 |
-
|
| 72 |
-
context = get_context(result.all_messages())
|
| 73 |
|
| 74 |
statements = []
|
| 75 |
-
for statement in
|
| 76 |
sources = []
|
| 77 |
for source_id in statement.sources or []:
|
| 78 |
try:
|
| 79 |
sources.append(context.sources[source_id])
|
| 80 |
-
except KeyError:
|
| 81 |
-
|
|
|
|
|
|
|
| 82 |
statements.append({"text": statement.text, "sources": sources})
|
| 83 |
|
| 84 |
return models.Statements.model_validate(
|
|
@@ -87,3 +69,21 @@ def respond(query: str) -> models.Statements:
|
|
| 87 |
"thoughts": "\n\n".join(context.thoughts),
|
| 88 |
}
|
| 89 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from pathlib import Path
|
| 2 |
|
|
|
|
| 3 |
from pydantic import BaseModel, Field
|
| 4 |
+
from pydantic_ai import Agent, ModelRetry, RunContext
|
| 5 |
from pydantic_ai.messages import (
|
| 6 |
ModelMessage,
|
| 7 |
ModelRequest,
|
|
|
|
| 27 |
)
|
| 28 |
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
def get_context(messages: list[ModelMessage]) -> Context:
|
| 31 |
thoughts: list[str] = []
|
| 32 |
sources: dict[str, dict] = {}
|
|
|
|
| 48 |
return Context(thoughts=thoughts, sources=sources)
|
| 49 |
|
| 50 |
|
| 51 |
+
def create_response(ctx: RunContext, output: list[Statement]) -> models.Statements:
|
| 52 |
+
context = get_context(ctx.messages)
|
|
|
|
| 53 |
|
| 54 |
statements = []
|
| 55 |
+
for statement in output:
|
| 56 |
sources = []
|
| 57 |
for source_id in statement.sources or []:
|
| 58 |
try:
|
| 59 |
sources.append(context.sources[source_id])
|
| 60 |
+
except KeyError as err:
|
| 61 |
+
raise ModelRetry(
|
| 62 |
+
f"Source ID '{source_id}' not found in literature."
|
| 63 |
+
) from err
|
| 64 |
statements.append({"text": statement.text, "sources": sources})
|
| 65 |
|
| 66 |
return models.Statements.model_validate(
|
|
|
|
| 69 |
"thoughts": "\n\n".join(context.thoughts),
|
| 70 |
}
|
| 71 |
)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
model = GoogleModel("gemini-2.5-flash-preview-05-20")
|
| 75 |
+
settings = GoogleModelSettings(
|
| 76 |
+
google_thinking_config={"thinking_budget": 2048, "include_thoughts": True},
|
| 77 |
+
)
|
| 78 |
+
agent = Agent(
|
| 79 |
+
model=model,
|
| 80 |
+
name="elna",
|
| 81 |
+
model_settings=settings,
|
| 82 |
+
output_type=create_response,
|
| 83 |
+
system_prompt=(Path(__file__).parent / "system_instruction.txt").read_text(),
|
| 84 |
+
tools=[
|
| 85 |
+
dailymed.find_drug_set_ids,
|
| 86 |
+
dailymed.find_drug_instruction,
|
| 87 |
+
literature.search_medical_literature,
|
| 88 |
+
],
|
| 89 |
+
)
|
app/main.py
CHANGED
|
@@ -19,5 +19,5 @@ def health_check():
|
|
| 19 |
|
| 20 |
|
| 21 |
@app.get("/ask", response_model=models.Statements)
|
| 22 |
-
def ask(query: str):
|
| 23 |
-
return agent.
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
@app.get("/ask", response_model=models.Statements)
|
| 22 |
+
async def ask(query: str):
|
| 23 |
+
return await agent.agent.run(query)
|
main.py
CHANGED
|
@@ -9,11 +9,11 @@ logfire.configure(
|
|
| 9 |
logfire.instrument_pydantic_ai()
|
| 10 |
|
| 11 |
|
| 12 |
-
def format_output(statements: models.
|
| 13 |
sentences = []
|
| 14 |
citations = {}
|
| 15 |
|
| 16 |
-
for statement in statements
|
| 17 |
sentence = statement.text
|
| 18 |
|
| 19 |
if sentence.startswith(("*", "-")):
|
|
@@ -33,9 +33,7 @@ def format_output(statements: models.Statements) -> tuple[str, str]:
|
|
| 33 |
|
| 34 |
answer = " ".join(sentences)
|
| 35 |
footnotes = "\n".join(f"[^{id}]: {citation}" for citation, id in citations.items())
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
return f"{answer}\n\n{footnotes}", thought
|
| 39 |
|
| 40 |
|
| 41 |
def main():
|
|
@@ -46,10 +44,10 @@ def main():
|
|
| 46 |
|
| 47 |
if submit:
|
| 48 |
with st.spinner("Thinking...", show_time=True):
|
| 49 |
-
output = agent.
|
| 50 |
-
|
| 51 |
with st.expander("Thinking Process"):
|
| 52 |
-
st.markdown(thoughts)
|
| 53 |
st.markdown(answer)
|
| 54 |
|
| 55 |
|
|
|
|
| 9 |
logfire.instrument_pydantic_ai()
|
| 10 |
|
| 11 |
|
| 12 |
+
def format_output(statements: list[models.Statement]) -> str:
|
| 13 |
sentences = []
|
| 14 |
citations = {}
|
| 15 |
|
| 16 |
+
for statement in statements:
|
| 17 |
sentence = statement.text
|
| 18 |
|
| 19 |
if sentence.startswith(("*", "-")):
|
|
|
|
| 33 |
|
| 34 |
answer = " ".join(sentences)
|
| 35 |
footnotes = "\n".join(f"[^{id}]: {citation}" for citation, id in citations.items())
|
| 36 |
+
return f"{answer}\n\n{footnotes}"
|
|
|
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
def main():
|
|
|
|
| 44 |
|
| 45 |
if submit:
|
| 46 |
with st.spinner("Thinking...", show_time=True):
|
| 47 |
+
output = agent.agent.run_sync(query).output
|
| 48 |
+
answer = format_output(output.statements)
|
| 49 |
with st.expander("Thinking Process"):
|
| 50 |
+
st.markdown(output.thoughts)
|
| 51 |
st.markdown(answer)
|
| 52 |
|
| 53 |
|