David Chu commited on
Commit
a874450
·
unverified ·
1 Parent(s): 176a472

feat: improve citation instruction

Browse files
Files changed (1) hide show
  1. main.py +37 -18
main.py CHANGED
@@ -12,15 +12,14 @@ class Article(BaseModel):
12
  title: str
13
  summary: str | None
14
  abstract: str | None
15
-
16
-
17
- class Citation(BaseModel):
18
- source_id: str
19
 
20
 
21
  class Statement(BaseModel):
22
  text: str
23
- citation: Citation | None
24
 
25
 
26
  def improve_prompt(client: genai.Client, prompt: str) -> str:
@@ -44,12 +43,33 @@ def format_sources(articles: list[Article]) -> str:
44
  return "\n".join(sources)
45
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  def generate_answer(
48
  client: genai.Client, query: str, articles: list[Article]
49
  ) -> list[Statement]:
50
  response = client.models.generate_content(
51
  model="gemini-2.5-flash-preview-04-17",
52
- contents=f"Answer the query based solely on the provided sources. The answer should be less than 100 words. Justify the answer by citing from the sources. Refuse to answer non-medical related query.\n\n<query>{query}</query>\n\n<sources>{format_sources(articles)}</sources>",
 
 
53
  config={
54
  "response_mime_type": "application/json",
55
  "response_schema": list[Statement],
@@ -93,6 +113,9 @@ def semantic_scholar(
93
  title=article["title"],
94
  summary=article["tldr"]["text"] if article["tldr"] else "",
95
  abstract=article["abstract"],
 
 
 
96
  )
97
  articles.append(article)
98
  return articles
@@ -122,7 +145,7 @@ def pubmed(query: str, top_k: int = 10, db: str = "pubmed"):
122
 
123
 
124
  def main():
125
- semantic_scholar_client = httpx.Client(transport=httpx.HTTPTransport(retries=1))
126
  gemini_client = genai.Client(api_key=os.environ["GOOGLE_API_KEY"])
127
 
128
  st.title("Ask ~~Jeeves~~ Elna")
@@ -142,22 +165,18 @@ def main():
142
  citations = {}
143
  statements = generate_answer(gemini_client, query, papers)
144
  for statement in statements:
145
- if statement.citation:
146
- if not (
147
- citation_id := citations.get(
148
- statement.citation.source_id
149
- )
150
- ):
151
  citation_id = len(citations) + 1
152
- citations[statement.citation.source_id] = citation_id
153
- sentences.append(f"{statement.text}[^{citation_id}]")
154
- else:
155
- sentences.append(statement.text)
156
  answer = " ".join(sentences)
157
  footnotes = ""
158
  if citations:
159
  footnotes = "\n".join(
160
- f"[^{v}]: [{paper_map[k].title}](https://doi.org/{paper_map[k].id})"
161
  for k, v in citations.items()
162
  )
163
 
 
12
  title: str
13
  summary: str | None
14
  abstract: str | None
15
+ venue: str
16
+ year: int
17
+ citations: int | None
 
18
 
19
 
20
  class Statement(BaseModel):
21
  text: str
22
+ source_ids: list[str] | None
23
 
24
 
25
  def improve_prompt(client: genai.Client, prompt: str) -> str:
 
43
  return "\n".join(sources)
44
 
45
 
46
+ ANSWER_INSTRUCTION = """\
47
+ You are a medical research expert.
48
+
49
+ Please answer the user's query clearly and concisely, using no more than 100 words.
50
+
51
+ Base every claim or statement strictly on the provided sources. For each claim, include a citation referencing the source's ID (do not include the citation in the `text` field). A claim may be supported by one or multiple sources, but only cite sources that directly support the claim. Do not add unnecessary citations.
52
+
53
+ If none of the sources contain relevant information to answer the query, politely inform the user that an answer cannot be provided, and do not use any citations.
54
+
55
+ If the query is not related to medicine, politely decline to answer.
56
+
57
+ <query>{query}</query>
58
+
59
+ <sources>
60
+ {sources}
61
+ </sources>
62
+ """
63
+
64
+
65
  def generate_answer(
66
  client: genai.Client, query: str, articles: list[Article]
67
  ) -> list[Statement]:
68
  response = client.models.generate_content(
69
  model="gemini-2.5-flash-preview-04-17",
70
+ contents=ANSWER_INSTRUCTION.format(
71
+ query=query, sources=format_sources(articles)
72
+ ),
73
  config={
74
  "response_mime_type": "application/json",
75
  "response_schema": list[Statement],
 
113
  title=article["title"],
114
  summary=article["tldr"]["text"] if article["tldr"] else "",
115
  abstract=article["abstract"],
116
+ venue=article["venue"],
117
+ year=article["year"],
118
+ citations=article["citationCount"],
119
  )
120
  articles.append(article)
121
  return articles
 
145
 
146
 
147
  def main():
148
+ semantic_scholar_client = httpx.Client(transport=httpx.HTTPTransport(retries=3))
149
  gemini_client = genai.Client(api_key=os.environ["GOOGLE_API_KEY"])
150
 
151
  st.title("Ask ~~Jeeves~~ Elna")
 
165
  citations = {}
166
  statements = generate_answer(gemini_client, query, papers)
167
  for statement in statements:
168
+ sentence = statement.text
169
+ for source_id in statement.source_ids or []:
170
+ if not (citation_id := citations.get(source_id)):
 
 
 
171
  citation_id = len(citations) + 1
172
+ citations[source_id] = citation_id
173
+ sentence += f"[^{citation_id}] "
174
+ sentences.append(sentence.strip())
 
175
  answer = " ".join(sentences)
176
  footnotes = ""
177
  if citations:
178
  footnotes = "\n".join(
179
+ f"[^{v}]: :grey-badge[:material/attribution: {paper_map[k].citations}] [{paper_map[k].title}](https://doi.org/{paper_map[k].id}). _{paper_map[k].venue}_ {paper_map[k].year}."
180
  for k, v in citations.items()
181
  )
182