David Chu commited on
Commit
21d76a7
·
unverified ·
1 Parent(s): d2f1b05

feat: mitigate citation hallucination

Browse files

Using PydanticAI's output function and raising a ModelRetry will
make the LLM regenerate the output (by default 1 retry)
if there are nonexistent source IDs.

Files changed (3) hide show
  1. app/agent.py +26 -26
  2. app/main.py +2 -2
  3. main.py +6 -8
app/agent.py CHANGED
@@ -1,8 +1,7 @@
1
  from pathlib import Path
2
 
3
- import logfire
4
  from pydantic import BaseModel, Field
5
- from pydantic_ai import Agent
6
  from pydantic_ai.messages import (
7
  ModelMessage,
8
  ModelRequest,
@@ -28,24 +27,6 @@ class Statement(BaseModel):
28
  )
29
 
30
 
31
- model = GoogleModel("gemini-2.5-flash-preview-05-20")
32
- settings = GoogleModelSettings(
33
- google_thinking_config={"thinking_budget": 2048, "include_thoughts": True},
34
- )
35
- agent = Agent(
36
- model=model,
37
- name="elna",
38
- model_settings=settings,
39
- output_type=list[Statement],
40
- system_prompt=(Path(__file__).parent / "system_instruction.txt").read_text(),
41
- tools=[
42
- dailymed.find_drug_set_ids,
43
- dailymed.find_drug_instruction,
44
- literature.search_medical_literature,
45
- ],
46
- )
47
-
48
-
49
  def get_context(messages: list[ModelMessage]) -> Context:
50
  thoughts: list[str] = []
51
  sources: dict[str, dict] = {}
@@ -67,18 +48,19 @@ def get_context(messages: list[ModelMessage]) -> Context:
67
  return Context(thoughts=thoughts, sources=sources)
68
 
69
 
70
- def respond(query: str) -> models.Statements:
71
- result = agent.run_sync(query)
72
- context = get_context(result.all_messages())
73
 
74
  statements = []
75
- for statement in result.output:
76
  sources = []
77
  for source_id in statement.sources or []:
78
  try:
79
  sources.append(context.sources[source_id])
80
- except KeyError:
81
- logfire.warning(f"citation hallucination '{source_id}'")
 
 
82
  statements.append({"text": statement.text, "sources": sources})
83
 
84
  return models.Statements.model_validate(
@@ -87,3 +69,21 @@ def respond(query: str) -> models.Statements:
87
  "thoughts": "\n\n".join(context.thoughts),
88
  }
89
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from pathlib import Path
2
 
 
3
  from pydantic import BaseModel, Field
4
+ from pydantic_ai import Agent, ModelRetry, RunContext
5
  from pydantic_ai.messages import (
6
  ModelMessage,
7
  ModelRequest,
 
27
  )
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def get_context(messages: list[ModelMessage]) -> Context:
31
  thoughts: list[str] = []
32
  sources: dict[str, dict] = {}
 
48
  return Context(thoughts=thoughts, sources=sources)
49
 
50
 
51
+ def create_response(ctx: RunContext, output: list[Statement]) -> models.Statements:
52
+ context = get_context(ctx.messages)
 
53
 
54
  statements = []
55
+ for statement in output:
56
  sources = []
57
  for source_id in statement.sources or []:
58
  try:
59
  sources.append(context.sources[source_id])
60
+ except KeyError as err:
61
+ raise ModelRetry(
62
+ f"Source ID '{source_id}' not found in literature."
63
+ ) from err
64
  statements.append({"text": statement.text, "sources": sources})
65
 
66
  return models.Statements.model_validate(
 
69
  "thoughts": "\n\n".join(context.thoughts),
70
  }
71
  )
72
+
73
+
74
+ model = GoogleModel("gemini-2.5-flash-preview-05-20")
75
+ settings = GoogleModelSettings(
76
+ google_thinking_config={"thinking_budget": 2048, "include_thoughts": True},
77
+ )
78
+ agent = Agent(
79
+ model=model,
80
+ name="elna",
81
+ model_settings=settings,
82
+ output_type=create_response,
83
+ system_prompt=(Path(__file__).parent / "system_instruction.txt").read_text(),
84
+ tools=[
85
+ dailymed.find_drug_set_ids,
86
+ dailymed.find_drug_instruction,
87
+ literature.search_medical_literature,
88
+ ],
89
+ )
app/main.py CHANGED
@@ -19,5 +19,5 @@ def health_check():
19
 
20
 
21
  @app.get("/ask", response_model=models.Statements)
22
- def ask(query: str):
23
- return agent.respond(query)
 
19
 
20
 
21
  @app.get("/ask", response_model=models.Statements)
22
+ async def ask(query: str):
23
+ return await agent.agent.run(query)
main.py CHANGED
@@ -9,11 +9,11 @@ logfire.configure(
9
  logfire.instrument_pydantic_ai()
10
 
11
 
12
- def format_output(statements: models.Statements) -> tuple[str, str]:
13
  sentences = []
14
  citations = {}
15
 
16
- for statement in statements.statements:
17
  sentence = statement.text
18
 
19
  if sentence.startswith(("*", "-")):
@@ -33,9 +33,7 @@ def format_output(statements: models.Statements) -> tuple[str, str]:
33
 
34
  answer = " ".join(sentences)
35
  footnotes = "\n".join(f"[^{id}]: {citation}" for citation, id in citations.items())
36
- thought = statements.thoughts or ""
37
-
38
- return f"{answer}\n\n{footnotes}", thought
39
 
40
 
41
  def main():
@@ -46,10 +44,10 @@ def main():
46
 
47
  if submit:
48
  with st.spinner("Thinking...", show_time=True):
49
- output = agent.respond(query)
50
- answer, thoughts = format_output(output)
51
  with st.expander("Thinking Process"):
52
- st.markdown(thoughts)
53
  st.markdown(answer)
54
 
55
 
 
9
  logfire.instrument_pydantic_ai()
10
 
11
 
12
+ def format_output(statements: list[models.Statement]) -> str:
13
  sentences = []
14
  citations = {}
15
 
16
+ for statement in statements:
17
  sentence = statement.text
18
 
19
  if sentence.startswith(("*", "-")):
 
33
 
34
  answer = " ".join(sentences)
35
  footnotes = "\n".join(f"[^{id}]: {citation}" for citation, id in citations.items())
36
+ return f"{answer}\n\n{footnotes}"
 
 
37
 
38
 
39
  def main():
 
44
 
45
  if submit:
46
  with st.spinner("Thinking...", show_time=True):
47
+ output = agent.agent.run_sync(query).output
48
+ answer = format_output(output.statements)
49
  with st.expander("Thinking Process"):
50
+ st.markdown(output.thoughts)
51
  st.markdown(answer)
52
 
53