David Chu
commited on
feat: improve citation instruction
Browse files
main.py
CHANGED
@@ -12,15 +12,14 @@ class Article(BaseModel):
|
|
12 |
title: str
|
13 |
summary: str | None
|
14 |
abstract: str | None
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
source_id: str
|
19 |
|
20 |
|
21 |
class Statement(BaseModel):
|
22 |
text: str
|
23 |
-
|
24 |
|
25 |
|
26 |
def improve_prompt(client: genai.Client, prompt: str) -> str:
|
@@ -44,12 +43,33 @@ def format_sources(articles: list[Article]) -> str:
|
|
44 |
return "\n".join(sources)
|
45 |
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
def generate_answer(
|
48 |
client: genai.Client, query: str, articles: list[Article]
|
49 |
) -> list[Statement]:
|
50 |
response = client.models.generate_content(
|
51 |
model="gemini-2.5-flash-preview-04-17",
|
52 |
-
contents=
|
|
|
|
|
53 |
config={
|
54 |
"response_mime_type": "application/json",
|
55 |
"response_schema": list[Statement],
|
@@ -93,6 +113,9 @@ def semantic_scholar(
|
|
93 |
title=article["title"],
|
94 |
summary=article["tldr"]["text"] if article["tldr"] else "",
|
95 |
abstract=article["abstract"],
|
|
|
|
|
|
|
96 |
)
|
97 |
articles.append(article)
|
98 |
return articles
|
@@ -122,7 +145,7 @@ def pubmed(query: str, top_k: int = 10, db: str = "pubmed"):
|
|
122 |
|
123 |
|
124 |
def main():
|
125 |
-
semantic_scholar_client = httpx.Client(transport=httpx.HTTPTransport(retries=
|
126 |
gemini_client = genai.Client(api_key=os.environ["GOOGLE_API_KEY"])
|
127 |
|
128 |
st.title("Ask ~~Jeeves~~ Elna")
|
@@ -142,22 +165,18 @@ def main():
|
|
142 |
citations = {}
|
143 |
statements = generate_answer(gemini_client, query, papers)
|
144 |
for statement in statements:
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
statement.citation.source_id
|
149 |
-
)
|
150 |
-
):
|
151 |
citation_id = len(citations) + 1
|
152 |
-
citations[
|
153 |
-
|
154 |
-
|
155 |
-
sentences.append(statement.text)
|
156 |
answer = " ".join(sentences)
|
157 |
footnotes = ""
|
158 |
if citations:
|
159 |
footnotes = "\n".join(
|
160 |
-
f"[^{v}]: [{paper_map[k].title}](https://doi.org/{paper_map[k].id})"
|
161 |
for k, v in citations.items()
|
162 |
)
|
163 |
|
|
|
12 |
title: str
|
13 |
summary: str | None
|
14 |
abstract: str | None
|
15 |
+
venue: str
|
16 |
+
year: int
|
17 |
+
citations: int | None
|
|
|
18 |
|
19 |
|
20 |
class Statement(BaseModel):
|
21 |
text: str
|
22 |
+
source_ids: list[str] | None
|
23 |
|
24 |
|
25 |
def improve_prompt(client: genai.Client, prompt: str) -> str:
|
|
|
43 |
return "\n".join(sources)
|
44 |
|
45 |
|
46 |
+
ANSWER_INSTRUCTION = """\
|
47 |
+
You are a medical research expert.
|
48 |
+
|
49 |
+
Please answer the user's query clearly and concisely, using no more than 100 words.
|
50 |
+
|
51 |
+
Base every claim or statement strictly on the provided sources. For each claim, include a citation referencing the source's ID (do not include the citation in the `text` field). A claim may be supported by one or multiple sources, but only cite sources that directly support the claim. Do not add unnecessary citations.
|
52 |
+
|
53 |
+
If none of the sources contain relevant information to answer the query, politely inform the user that an answer cannot be provided, and do not use any citations.
|
54 |
+
|
55 |
+
If the query is not related to medicine, politely decline to answer.
|
56 |
+
|
57 |
+
<query>{query}</query>
|
58 |
+
|
59 |
+
<sources>
|
60 |
+
{sources}
|
61 |
+
</sources>
|
62 |
+
"""
|
63 |
+
|
64 |
+
|
65 |
def generate_answer(
|
66 |
client: genai.Client, query: str, articles: list[Article]
|
67 |
) -> list[Statement]:
|
68 |
response = client.models.generate_content(
|
69 |
model="gemini-2.5-flash-preview-04-17",
|
70 |
+
contents=ANSWER_INSTRUCTION.format(
|
71 |
+
query=query, sources=format_sources(articles)
|
72 |
+
),
|
73 |
config={
|
74 |
"response_mime_type": "application/json",
|
75 |
"response_schema": list[Statement],
|
|
|
113 |
title=article["title"],
|
114 |
summary=article["tldr"]["text"] if article["tldr"] else "",
|
115 |
abstract=article["abstract"],
|
116 |
+
venue=article["venue"],
|
117 |
+
year=article["year"],
|
118 |
+
citations=article["citationCount"],
|
119 |
)
|
120 |
articles.append(article)
|
121 |
return articles
|
|
|
145 |
|
146 |
|
147 |
def main():
|
148 |
+
semantic_scholar_client = httpx.Client(transport=httpx.HTTPTransport(retries=3))
|
149 |
gemini_client = genai.Client(api_key=os.environ["GOOGLE_API_KEY"])
|
150 |
|
151 |
st.title("Ask ~~Jeeves~~ Elna")
|
|
|
165 |
citations = {}
|
166 |
statements = generate_answer(gemini_client, query, papers)
|
167 |
for statement in statements:
|
168 |
+
sentence = statement.text
|
169 |
+
for source_id in statement.source_ids or []:
|
170 |
+
if not (citation_id := citations.get(source_id)):
|
|
|
|
|
|
|
171 |
citation_id = len(citations) + 1
|
172 |
+
citations[source_id] = citation_id
|
173 |
+
sentence += f"[^{citation_id}] "
|
174 |
+
sentences.append(sentence.strip())
|
|
|
175 |
answer = " ".join(sentences)
|
176 |
footnotes = ""
|
177 |
if citations:
|
178 |
footnotes = "\n".join(
|
179 |
+
f"[^{v}]: :grey-badge[:material/attribution: {paper_map[k].citations}] [{paper_map[k].title}](https://doi.org/{paper_map[k].id}). _{paper_map[k].venue}_ {paper_map[k].year}."
|
180 |
for k, v in citations.items()
|
181 |
)
|
182 |
|