File size: 3,135 Bytes
f4cbf67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import streamlit as st
from src.file_loader import load_file
from src.rag_pipeline import build_rag_pipeline, get_relevant_docs
from src.model_utils import load_hf_model, generate_answer
from src.utils import get_font_css

st.set_page_config(page_title="AI Chatbot", page_icon=":robot_face:", layout="wide")
st.markdown(get_font_css(), unsafe_allow_html=True)

st.sidebar.image("assets/logo.png", width=180)
st.sidebar.title("AI Chatbot")
st.sidebar.markdown("Upload a file to get started:")

uploaded_file = st.sidebar.file_uploader(
    "Upload PDF, CSV, or XLSX", type=["pdf", "csv", "xlsx"]
)

model_name = st.sidebar.text_input(
    "HuggingFace Model (text-generation)", value="amiguel/GM_Qwen1.8B_Finetune"
)
embedding_model = st.sidebar.text_input(
    "Embedding Model", value="sentence-transformers/all-MiniLM-L6-v2"
)

st.sidebar.markdown("---")
st.sidebar.markdown("Powered by [Your Company]")

st.markdown(
    """
    <div style="display: flex; align-items: center; margin-bottom: 1rem;">
        <img src="app/assets/logo.png" width="60" style="margin-right: 1rem;">
        <h1 style="font-family: 'Tw Cen MT', sans-serif; margin: 0;">AI Chatbot</h1>
    </div>
    """,
    unsafe_allow_html=True,
)

if uploaded_file:
    with st.spinner("Processing file..."):
        text = load_file(uploaded_file)
        docs = [{"page_content": chunk, "metadata": {}} for chunk in text]
        retriever = build_rag_pipeline(docs, embedding_model)
        st.success("File processed and indexed!")

    with st.spinner("Loading model..."):
        text_gen = load_hf_model(model_name)
        st.success("Model loaded!")

    if "chat_history" not in st.session_state:
        st.session_state.chat_history = []

    user_input = st.text_input("Ask a question about your document:", key="user_input")
    if st.button("Send", use_container_width=True) and user_input:
        with st.spinner("Generating answer..."):
            context_docs = get_relevant_docs(retriever, user_input)
            context = " ".join([doc["page_content"] for doc in context_docs])
            answer = generate_answer(text_gen, user_input, context)
            st.session_state.chat_history.append(("user", user_input))
            st.session_state.chat_history.append(("bot", answer))

    for sender, msg in st.session_state.chat_history:
        if sender == "user":
            st.markdown(
                f"""
                <div style="background: #e6f0fa; border-radius: 10px; padding: 10px; margin-bottom: 5px; text-align: right; font-family: 'Tw Cen MT', sans-serif;">
                    <b>You:</b> {msg}
                </div>
                """,
                unsafe_allow_html=True,
            )
        else:
            st.markdown(
                f"""
                <div style="background: #f4f4f4; border-radius: 10px; padding: 10px; margin-bottom: 10px; text-align: left; font-family: 'Tw Cen MT', sans-serif;">
                    <b>AI:</b> {msg}
                </div>
                """,
                unsafe_allow_html=True,
            )
else:
    st.info("Please upload a PDF, CSV, or XLSX file to begin.")