File size: 1,666 Bytes
7be08b4
3e78ada
 
7be08b4
 
 
 
 
 
d9a5339
 
21d76a7
f045eec
d9a5339
 
21d76a7
f045eec
 
 
d9a5339
f045eec
d9a5339
 
ab2c1b8
d9a5339
 
5bb1986
d9a5339
ab2c1b8
f045eec
 
 
5bb1986
f045eec
5bb1986
21d76a7
3e78ada
 
 
ab2c1b8
3e78ada
 
 
 
5543da4
5bb1986
21d76a7
 
b868906
21d76a7
b868906
3e78ada
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import logfire
import streamlit as st

from app import agent, models

logfire.configure(
    service_name="streamlit", scrubbing=False, send_to_logfire="if-token-present"
)
logfire.instrument_pydantic_ai()


def format_output(statements: list[models.Statement]) -> str:
    sentences = []
    citations = {}

    for statement in statements:
        sentence = statement.text

        if sentence.startswith(("*", "-")):
            # Bullet points should be on a newline.
            sentence = f"\n{sentence}"

        if statement.sources:
            citation_ids = []
            for source in statement.sources:
                if not (citation_id := citations.get(source.citation)):
                    citation_id = len(citations) + 1
                    citations[source.citation] = citation_id
                citation_ids.append(citation_id)
            sentence += " ".join(f"[^{i}]" for i in sorted(citation_ids))

        sentences.append(sentence)

    answer = " ".join(sentences)
    footnotes = "\n".join(f"[^{id}]: {citation}" for citation, id in citations.items())
    return f"{answer}\n\n{footnotes}"


def main():
    st.title("Elna")
    with st.form("search", border=False):
        query = st.text_input("Your medical question")
        submit = st.form_submit_button("Ask")

        if submit:
            with st.spinner("Thinking...", show_time=True):
                output = agent.agent.run_sync(query).output
                answer = format_output(output.statements)
            with st.expander("Thinking Process"):
                st.markdown(output.thoughts)
            st.markdown(answer)


if __name__ == "__main__":
    main()