David Chu
commited on
feat: expand source metadata
Browse files- app/agent.py +51 -23
- app/main.py +1 -2
- app/tools/dailymed.py +3 -1
- app/tools/literature.py +15 -25
- main.py +18 -23
app/agent.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1 |
-
import json
|
2 |
import re
|
3 |
from pathlib import Path
|
4 |
|
5 |
from google import genai
|
6 |
from google.genai import types
|
|
|
7 |
|
|
|
8 |
from app.tools import dailymed, literature
|
9 |
|
10 |
CONFIG = types.GenerateContentConfig(
|
@@ -16,32 +17,59 @@ CONFIG = types.GenerateContentConfig(
|
|
16 |
system_instruction=(Path(__file__).parent / "system_instruction.txt").read_text(),
|
17 |
)
|
18 |
|
|
|
|
|
|
|
|
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
]
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
)
|
|
|
|
|
|
|
|
|
29 |
resp = client.models.generate_content(
|
30 |
model="gemini-2.5-flash-preview-04-17",
|
31 |
contents=query,
|
32 |
config=CONFIG,
|
33 |
)
|
34 |
-
|
35 |
-
output = ((resp.text) or "").strip()
|
36 |
-
|
37 |
-
if output.startswith("```"):
|
38 |
-
# Extract content inside the first markdown code block (``` or ```json)
|
39 |
-
match = re.match(r"^```(?:json)?\s*([\s\S]*?)\s*```", output)
|
40 |
-
if match:
|
41 |
-
output = match.group(1).strip()
|
42 |
-
|
43 |
-
try:
|
44 |
-
return json.loads(output)
|
45 |
-
except json.decoder.JSONDecodeError as err:
|
46 |
-
print(err)
|
47 |
-
return [{"text": output}]
|
|
|
|
|
1 |
import re
|
2 |
from pathlib import Path
|
3 |
|
4 |
from google import genai
|
5 |
from google.genai import types
|
6 |
+
from pydantic import ValidationError
|
7 |
|
8 |
+
from app import models
|
9 |
from app.tools import dailymed, literature
|
10 |
|
11 |
CONFIG = types.GenerateContentConfig(
|
|
|
17 |
system_instruction=(Path(__file__).parent / "system_instruction.txt").read_text(),
|
18 |
)
|
19 |
|
20 |
+
SOURCE_TOOL_NAMES = {
|
21 |
+
literature.search_medical_literature.__name__,
|
22 |
+
dailymed.find_drug_set_ids.__name__,
|
23 |
+
}
|
24 |
|
25 |
+
|
26 |
+
def hydrate_sources(
|
27 |
+
statements: models.Statements, calling_history: list[types.Content]
|
28 |
+
) -> models.Statements:
|
29 |
+
sources = {}
|
30 |
+
for call in calling_history:
|
31 |
+
for part in call.parts or []:
|
32 |
+
if (
|
33 |
+
(func := part.function_response)
|
34 |
+
and func.name in SOURCE_TOOL_NAMES
|
35 |
+
and func.response
|
36 |
+
):
|
37 |
+
for source in func.response["result"]:
|
38 |
+
sources[source["url"]] = source
|
39 |
+
|
40 |
+
for statement in statements.statements:
|
41 |
+
if statement.sources:
|
42 |
+
statement.sources = [
|
43 |
+
models.Source.model_validate(sources[source.url])
|
44 |
+
for source in statement.sources
|
45 |
+
]
|
46 |
+
|
47 |
+
return statements
|
48 |
+
|
49 |
+
|
50 |
+
def validate_response(response: types.GenerateContentResponse) -> models.Statements:
|
51 |
+
text = (response.text or "").strip()
|
52 |
+
|
53 |
+
# Extract content inside the first markdown code block (``` or ```json)
|
54 |
+
match = re.match(r"^```(?:json)?\s*([\s\S]*?)\s*```", text)
|
55 |
+
if match:
|
56 |
+
text = match.group(1).strip()
|
57 |
+
|
58 |
+
try:
|
59 |
+
statements = models.Statements.model_validate_json(f'{{"statements":{text}}}')
|
60 |
+
except ValidationError:
|
61 |
+
statements = models.Statements(statements=[models.Statement(text=text)])
|
62 |
+
|
63 |
+
statements = hydrate_sources(
|
64 |
+
statements, response.automatic_function_calling_history or []
|
65 |
)
|
66 |
+
return statements
|
67 |
+
|
68 |
+
|
69 |
+
def respond(client: genai.Client, query: str) -> models.Statements:
|
70 |
resp = client.models.generate_content(
|
71 |
model="gemini-2.5-flash-preview-04-17",
|
72 |
contents=query,
|
73 |
config=CONFIG,
|
74 |
)
|
75 |
+
return validate_response(resp)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/main.py
CHANGED
@@ -14,5 +14,4 @@ def health_check():
|
|
14 |
|
15 |
@app.get("/ask", response_model=models.Statements)
|
16 |
def ask(query: str):
|
17 |
-
|
18 |
-
return {"statements": output}
|
|
|
14 |
|
15 |
@app.get("/ask", response_model=models.Statements)
|
16 |
def ask(query: str):
|
17 |
+
return agent.respond(gemini, query)
|
|
app/tools/dailymed.py
CHANGED
@@ -18,8 +18,10 @@ def find_drug_set_ids(name: str) -> list[dict]:
|
|
18 |
)
|
19 |
return [
|
20 |
{
|
21 |
-
"
|
22 |
"set_id": row["setid"],
|
|
|
|
|
23 |
"url": f"https://dailymed.nlm.nih.gov/dailymed/drugInfo.cfm?setid={row['setid']}",
|
24 |
}
|
25 |
for row in resp.json()["data"]
|
|
|
18 |
)
|
19 |
return [
|
20 |
{
|
21 |
+
"title": row["title"],
|
22 |
"set_id": row["setid"],
|
23 |
+
"venue": "DailyMed",
|
24 |
+
"year": row["published_date"][-4:], # Original format: "May 05, 2025"
|
25 |
"url": f"https://dailymed.nlm.nih.gov/dailymed/drugInfo.cfm?setid={row['setid']}",
|
26 |
}
|
27 |
for row in resp.json()["data"]
|
app/tools/literature.py
CHANGED
@@ -49,26 +49,20 @@ def get_pubmed_abstracts(pmids: list[int]) -> dict[str, dict]:
|
|
49 |
return abstracts
|
50 |
|
51 |
|
52 |
-
def format_publication(publication: dict) ->
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
doi
|
61 |
-
|
62 |
-
return
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
f"<citationCount>{citations}</citationCount>\n<influentialCitationCount>{influential_citations}</influentialCitationCount>\n"
|
67 |
-
"</publication>"
|
68 |
-
)
|
69 |
-
|
70 |
-
|
71 |
-
def search_medical_literature(query: str) -> str:
|
72 |
"""Get medical literature related to the query.
|
73 |
|
74 |
Args:
|
@@ -98,8 +92,4 @@ def search_medical_literature(query: str) -> str:
|
|
98 |
|
99 |
outputs.append(format_publication(publication))
|
100 |
|
101 |
-
return
|
102 |
-
f"<publications>\n{'\n'.join(outputs)}\n</publications>"
|
103 |
-
if outputs
|
104 |
-
else "No literature found"
|
105 |
-
)
|
|
|
49 |
return abstracts
|
50 |
|
51 |
|
52 |
+
def format_publication(publication: dict) -> dict:
|
53 |
+
tldr = publication.pop("tldr") or {}
|
54 |
+
external_ids = publication.pop("externalIds")
|
55 |
+
doi = external_ids.get("DOI")
|
56 |
+
publication["summary"] = tldr.get("text", "")
|
57 |
+
publication["citations"] = publication.pop("citationCount")
|
58 |
+
publication["influential_citations"] = publication.pop("influentialCitationCount")
|
59 |
+
publication["doi"] = doi
|
60 |
+
if doi:
|
61 |
+
publication["url"] = f"https://doi.org/{doi}"
|
62 |
+
return publication
|
63 |
+
|
64 |
+
|
65 |
+
def search_medical_literature(query: str) -> list[dict]:
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
"""Get medical literature related to the query.
|
67 |
|
68 |
Args:
|
|
|
92 |
|
93 |
outputs.append(format_publication(publication))
|
94 |
|
95 |
+
return outputs
|
|
|
|
|
|
|
|
main.py
CHANGED
@@ -1,32 +1,27 @@
|
|
1 |
import streamlit as st
|
2 |
from google import genai
|
3 |
|
4 |
-
from app import agent, config
|
5 |
-
|
6 |
-
|
7 |
-
def format_output(
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
citation_ids = []
|
19 |
-
for source in statement.
|
20 |
-
|
21 |
-
if not (citation_id := citations.get(source_str)):
|
22 |
citation_id = len(citations) + 1
|
23 |
-
citations[
|
24 |
citation_ids.append(citation_id)
|
25 |
-
|
26 |
-
answer += " ".join(f"[^{i}]" for i in sorted(citation_ids))
|
27 |
-
except KeyError as err:
|
28 |
-
print(err)
|
29 |
-
return str(response), ""
|
30 |
|
31 |
footnotes = "\n".join(f"[^{id}]: {citation}" for citation, id in citations.items())
|
32 |
return answer, footnotes
|
|
|
1 |
import streamlit as st
|
2 |
from google import genai
|
3 |
|
4 |
+
from app import agent, config, models
|
5 |
+
|
6 |
+
|
7 |
+
def format_output(statements: models.Statements) -> tuple[str, str]:
|
8 |
+
answer = ""
|
9 |
+
citations = {}
|
10 |
+
|
11 |
+
for statement in statements.statements:
|
12 |
+
if statement.text.startswith(("*", "-")):
|
13 |
+
# Bullet points should be on a newline.
|
14 |
+
answer += "\n"
|
15 |
+
answer += statement.text
|
16 |
+
|
17 |
+
if statement.sources:
|
18 |
citation_ids = []
|
19 |
+
for source in statement.sources:
|
20 |
+
if not (citation_id := citations.get(source.citation)):
|
|
|
21 |
citation_id = len(citations) + 1
|
22 |
+
citations[source.citation] = citation_id
|
23 |
citation_ids.append(citation_id)
|
24 |
+
answer += " ".join(f"[^{i}]" for i in sorted(citation_ids))
|
|
|
|
|
|
|
|
|
25 |
|
26 |
footnotes = "\n".join(f"[^{id}]: {citation}" for citation, id in citations.items())
|
27 |
return answer, footnotes
|