File size: 1,711 Bytes
3e78ada
 
 
00d1644
3e78ada
5bb1986
8a3374d
5bb1986
 
 
8a3374d
ab2c1b8
 
 
 
 
 
 
5bb1986
 
 
 
 
ab2c1b8
 
 
5bb1986
 
8a3374d
5bb1986
 
 
3e78ada
 
 
00d1644
3e78ada
ab2c1b8
3e78ada
 
 
8d10d04
3e78ada
5543da4
5bb1986
00d1644
8a3374d
 
5bb1986
3e78ada
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import streamlit as st
from google import genai

from app import agent, config


def format_output(response: list[dict]) -> tuple[str, str]:
    try:
        answer = ""
        citations = {}
        for statement in response:
            text = statement["text"].strip()
            answer = (
                answer + f"\n{text}"
                if text.startswith("*") or text.startswith("-")
                else answer + f" {text}"
            )
            citation_ids = []
            for source in statement.get("sources", []):
                source_str = f"[{source['title']}]({source['url']})"
                if not (citation_id := citations.get(source_str)):
                    citation_id = len(citations) + 1
                    citations[source_str] = citation_id
                citation_ids.append(citation_id)
            if citation_ids:
                answer += " ".join(f"[^{i}]" for i in sorted(citation_ids))
    except KeyError as err:
        print(err)
        return str(response), ""

    footnotes = "\n".join(f"[^{id}]: {citation}" for citation, id in citations.items())
    return answer, footnotes


def main():
    gemini_client = genai.Client(api_key=config.settings.google_api_key)

    st.title("Elna")
    with st.form("search", border=False):
        query = st.text_input("Your medical question")
        submit = st.form_submit_button("Ask")
        response = st.empty()

        if submit:
            with st.spinner("Thinking...", show_time=True):
                output = agent.respond(gemini_client, query)

            answer, footnotes = format_output(output)
            response.markdown(f"{answer}\n\n{footnotes}")


if __name__ == "__main__":
    main()