David Chu commited on
Commit
d9a5339
·
unverified ·
1 Parent(s): d38794d

feat: expand source metadata

Browse files
Files changed (5) hide show
  1. app/agent.py +51 -23
  2. app/main.py +1 -2
  3. app/tools/dailymed.py +3 -1
  4. app/tools/literature.py +15 -25
  5. main.py +18 -23
app/agent.py CHANGED
@@ -1,10 +1,11 @@
1
- import json
2
  import re
3
  from pathlib import Path
4
 
5
  from google import genai
6
  from google.genai import types
 
7
 
 
8
  from app.tools import dailymed, literature
9
 
10
  CONFIG = types.GenerateContentConfig(
@@ -16,32 +17,59 @@ CONFIG = types.GenerateContentConfig(
16
  system_instruction=(Path(__file__).parent / "system_instruction.txt").read_text(),
17
  )
18
 
 
 
 
 
19
 
20
- def respond(client: genai.Client, query: str) -> list[dict]:
21
- config = types.GenerateContentConfig(
22
- tools=[
23
- dailymed.find_drug_set_ids,
24
- dailymed.find_drug_instruction,
25
- literature.search_medical_literature,
26
- ],
27
- system_instruction=SYSTEM_INSTRUCTION,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  )
 
 
 
 
29
  resp = client.models.generate_content(
30
  model="gemini-2.5-flash-preview-04-17",
31
  contents=query,
32
  config=CONFIG,
33
  )
34
-
35
- output = ((resp.text) or "").strip()
36
-
37
- if output.startswith("```"):
38
- # Extract content inside the first markdown code block (``` or ```json)
39
- match = re.match(r"^```(?:json)?\s*([\s\S]*?)\s*```", output)
40
- if match:
41
- output = match.group(1).strip()
42
-
43
- try:
44
- return json.loads(output)
45
- except json.decoder.JSONDecodeError as err:
46
- print(err)
47
- return [{"text": output}]
 
 
1
  import re
2
  from pathlib import Path
3
 
4
  from google import genai
5
  from google.genai import types
6
+ from pydantic import ValidationError
7
 
8
+ from app import models
9
  from app.tools import dailymed, literature
10
 
11
  CONFIG = types.GenerateContentConfig(
 
17
  system_instruction=(Path(__file__).parent / "system_instruction.txt").read_text(),
18
  )
19
 
20
+ SOURCE_TOOL_NAMES = {
21
+ literature.search_medical_literature.__name__,
22
+ dailymed.find_drug_set_ids.__name__,
23
+ }
24
 
25
+
26
+ def hydrate_sources(
27
+ statements: models.Statements, calling_history: list[types.Content]
28
+ ) -> models.Statements:
29
+ sources = {}
30
+ for call in calling_history:
31
+ for part in call.parts or []:
32
+ if (
33
+ (func := part.function_response)
34
+ and func.name in SOURCE_TOOL_NAMES
35
+ and func.response
36
+ ):
37
+ for source in func.response["result"]:
38
+ sources[source["url"]] = source
39
+
40
+ for statement in statements.statements:
41
+ if statement.sources:
42
+ statement.sources = [
43
+ models.Source.model_validate(sources[source.url])
44
+ for source in statement.sources
45
+ ]
46
+
47
+ return statements
48
+
49
+
50
+ def validate_response(response: types.GenerateContentResponse) -> models.Statements:
51
+ text = (response.text or "").strip()
52
+
53
+ # Extract content inside the first markdown code block (``` or ```json)
54
+ match = re.match(r"^```(?:json)?\s*([\s\S]*?)\s*```", text)
55
+ if match:
56
+ text = match.group(1).strip()
57
+
58
+ try:
59
+ statements = models.Statements.model_validate_json(f'{{"statements":{text}}}')
60
+ except ValidationError:
61
+ statements = models.Statements(statements=[models.Statement(text=text)])
62
+
63
+ statements = hydrate_sources(
64
+ statements, response.automatic_function_calling_history or []
65
  )
66
+ return statements
67
+
68
+
69
+ def respond(client: genai.Client, query: str) -> models.Statements:
70
  resp = client.models.generate_content(
71
  model="gemini-2.5-flash-preview-04-17",
72
  contents=query,
73
  config=CONFIG,
74
  )
75
+ return validate_response(resp)
 
 
 
 
 
 
 
 
 
 
 
 
 
app/main.py CHANGED
@@ -14,5 +14,4 @@ def health_check():
14
 
15
  @app.get("/ask", response_model=models.Statements)
16
  def ask(query: str):
17
- output = agent.respond(gemini, query)
18
- return {"statements": output}
 
14
 
15
  @app.get("/ask", response_model=models.Statements)
16
  def ask(query: str):
17
+ return agent.respond(gemini, query)
 
app/tools/dailymed.py CHANGED
@@ -18,8 +18,10 @@ def find_drug_set_ids(name: str) -> list[dict]:
18
  )
19
  return [
20
  {
21
- "name": row["title"],
22
  "set_id": row["setid"],
 
 
23
  "url": f"https://dailymed.nlm.nih.gov/dailymed/drugInfo.cfm?setid={row['setid']}",
24
  }
25
  for row in resp.json()["data"]
 
18
  )
19
  return [
20
  {
21
+ "title": row["title"],
22
  "set_id": row["setid"],
23
+ "venue": "DailyMed",
24
+ "year": row["published_date"][-4:], # Original format: "May 05, 2025"
25
  "url": f"https://dailymed.nlm.nih.gov/dailymed/drugInfo.cfm?setid={row['setid']}",
26
  }
27
  for row in resp.json()["data"]
app/tools/literature.py CHANGED
@@ -49,26 +49,20 @@ def get_pubmed_abstracts(pmids: list[int]) -> dict[str, dict]:
49
  return abstracts
50
 
51
 
52
- def format_publication(publication: dict) -> str:
53
- title = publication["title"]
54
- summary = (publication["tldr"] or {}).get("text", "")
55
- abstract = publication["abstract"]
56
- venue = publication["venue"]
57
- year = publication["year"]
58
- citations = publication["citationCount"]
59
- influential_citations = publication["influentialCitationCount"]
60
- doi = publication["externalIds"].get("DOI")
61
- url = f"https://doi.org/{doi}" if doi else publication["url"]
62
- return (
63
- f"<publication title={title}>\n<url>{url}</url>\n"
64
- f"<summary>{summary}</summary>\n<abstract>{abstract}</abstract>\n"
65
- f"<venue>{venue}</venue>\n<year>{year}</year>\n"
66
- f"<citationCount>{citations}</citationCount>\n<influentialCitationCount>{influential_citations}</influentialCitationCount>\n"
67
- "</publication>"
68
- )
69
-
70
-
71
- def search_medical_literature(query: str) -> str:
72
  """Get medical literature related to the query.
73
 
74
  Args:
@@ -98,8 +92,4 @@ def search_medical_literature(query: str) -> str:
98
 
99
  outputs.append(format_publication(publication))
100
 
101
- return (
102
- f"<publications>\n{'\n'.join(outputs)}\n</publications>"
103
- if outputs
104
- else "No literature found"
105
- )
 
49
  return abstracts
50
 
51
 
52
+ def format_publication(publication: dict) -> dict:
53
+ tldr = publication.pop("tldr") or {}
54
+ external_ids = publication.pop("externalIds")
55
+ doi = external_ids.get("DOI")
56
+ publication["summary"] = tldr.get("text", "")
57
+ publication["citations"] = publication.pop("citationCount")
58
+ publication["influential_citations"] = publication.pop("influentialCitationCount")
59
+ publication["doi"] = doi
60
+ if doi:
61
+ publication["url"] = f"https://doi.org/{doi}"
62
+ return publication
63
+
64
+
65
+ def search_medical_literature(query: str) -> list[dict]:
 
 
 
 
 
 
66
  """Get medical literature related to the query.
67
 
68
  Args:
 
92
 
93
  outputs.append(format_publication(publication))
94
 
95
+ return outputs
 
 
 
 
main.py CHANGED
@@ -1,32 +1,27 @@
1
  import streamlit as st
2
  from google import genai
3
 
4
- from app import agent, config
5
-
6
-
7
- def format_output(response: list[dict]) -> tuple[str, str]:
8
- try:
9
- answer = ""
10
- citations = {}
11
- for statement in response:
12
- text = statement["text"].strip()
13
- answer = (
14
- answer + f"\n{text}"
15
- if text.startswith("*") or text.startswith("-")
16
- else answer + f" {text}"
17
- )
18
  citation_ids = []
19
- for source in statement.get("sources", []):
20
- source_str = f"[{source['title']}]({source['url']})"
21
- if not (citation_id := citations.get(source_str)):
22
  citation_id = len(citations) + 1
23
- citations[source_str] = citation_id
24
  citation_ids.append(citation_id)
25
- if citation_ids:
26
- answer += " ".join(f"[^{i}]" for i in sorted(citation_ids))
27
- except KeyError as err:
28
- print(err)
29
- return str(response), ""
30
 
31
  footnotes = "\n".join(f"[^{id}]: {citation}" for citation, id in citations.items())
32
  return answer, footnotes
 
1
  import streamlit as st
2
  from google import genai
3
 
4
+ from app import agent, config, models
5
+
6
+
7
+ def format_output(statements: models.Statements) -> tuple[str, str]:
8
+ answer = ""
9
+ citations = {}
10
+
11
+ for statement in statements.statements:
12
+ if statement.text.startswith(("*", "-")):
13
+ # Bullet points should be on a newline.
14
+ answer += "\n"
15
+ answer += statement.text
16
+
17
+ if statement.sources:
18
  citation_ids = []
19
+ for source in statement.sources:
20
+ if not (citation_id := citations.get(source.citation)):
 
21
  citation_id = len(citations) + 1
22
+ citations[source.citation] = citation_id
23
  citation_ids.append(citation_id)
24
+ answer += " ".join(f"[^{i}]" for i in sorted(citation_ids))
 
 
 
 
25
 
26
  footnotes = "\n".join(f"[^{id}]: {citation}" for citation, id in citations.items())
27
  return answer, footnotes