elna / main.py
David Chu
feat: mitigate citation hallucination
21d76a7 unverified
import logfire
import streamlit as st
from app import agent, models
logfire.configure(
service_name="streamlit", scrubbing=False, send_to_logfire="if-token-present"
)
logfire.instrument_pydantic_ai()
def format_output(statements: list[models.Statement]) -> str:
sentences = []
citations = {}
for statement in statements:
sentence = statement.text
if sentence.startswith(("*", "-")):
# Bullet points should be on a newline.
sentence = f"\n{sentence}"
if statement.sources:
citation_ids = []
for source in statement.sources:
if not (citation_id := citations.get(source.citation)):
citation_id = len(citations) + 1
citations[source.citation] = citation_id
citation_ids.append(citation_id)
sentence += " ".join(f"[^{i}]" for i in sorted(citation_ids))
sentences.append(sentence)
answer = " ".join(sentences)
footnotes = "\n".join(f"[^{id}]: {citation}" for citation, id in citations.items())
return f"{answer}\n\n{footnotes}"
def main():
st.title("Elna")
with st.form("search", border=False):
query = st.text_input("Your medical question")
submit = st.form_submit_button("Ask")
if submit:
with st.spinner("Thinking...", show_time=True):
output = agent.agent.run_sync(query).output
answer = format_output(output.statements)
with st.expander("Thinking Process"):
st.markdown(output.thoughts)
st.markdown(answer)
if __name__ == "__main__":
main()