File size: 3,650 Bytes
236b637
d8f78dc
236b637
d8f78dc
 
236b637
 
1e33a04
236b637
 
d8f78dc
 
 
 
 
 
236b637
 
d8f78dc
236b637
d8f78dc
236b637
 
 
d8f78dc
236b637
d8f78dc
236b637
 
d8f78dc
 
 
 
 
 
 
 
 
 
 
 
236b637
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8765003
236b637
8765003
 
236b637
 
 
8765003
6f01060
6a4f05e
236b637
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import logging
import streamlit as st
from huggingface_hub import hf_hub_download

# ✅ Nouveau moteur RAG (sans ollama_opts)
from rag_model_ollama_v1 import RAGEngine

# --- Config & logs ---
os.environ.setdefault("NLTK_DATA", "/home/appuser/nltk_data")

logger = logging.getLogger("Streamlit")
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter("[%(asctime)s] %(levelname)s - %(message)s")
handler.setFormatter(formatter)
if not logger.handlers:
    logger.addHandler(handler)

st.set_page_config(page_title="Chatbot RAG (Ollama)", page_icon="🤖")

# --- ENV ---
ENV = os.getenv("ENV", "local")
logger.info(f"ENV: {ENV}")

# --- Chemins FAISS & chunks ---
if ENV == "local":
    faiss_index_path = "chatbot-models/vectordb_docling/index.faiss"
    vectors_path = "chatbot-models/vectordb_docling/chunks.pkl"
else:
    faiss_index_path = hf_hub_download(
        repo_id="rkonan/chatbot-models",
        filename="chatbot-models/vectordb_docling/index.faiss",
        repo_type="dataset"
    )
    vectors_path = hf_hub_download(
        repo_id="rkonan/chatbot-models",
        filename="chatbot-models/vectordb_docling/chunks.pkl",
        repo_type="dataset"
    )

# --- UI Sidebar ---
st.sidebar.header("⚙️ Paramètres")
default_host = os.getenv("OLLAMA_HOST", "http://localhost:11435")
ollama_host = st.sidebar.text_input("Ollama host", value=default_host)
suggested_models = [
    "qwen2.5:3b-instruct-q4_K_M",
    "noushermes_rag",
    "mistral",
    "gemma3",
    "deepseek-r1",
    "granite3.3",
    "llama3.1:8b-instruct-q4_K_M",
    "nous-hermes2:Q4_K_M",
]
model_name = st.sidebar.selectbox("Modèle Ollama", options=suggested_models, index=0)
num_threads = st.sidebar.slider("Threads (hint)", min_value=2, max_value=16, value=6, step=1)
temperature = st.sidebar.slider("Température", min_value=0.0, max_value=1.5, value=0.1, step=0.1)

st.title("🤖 Chatbot RAG Local (Ollama)")

# --- Cache du moteur ---
@st.cache_resource(show_spinner=True)
def load_rag_engine(_model_name: str, _host: str, _threads: int, _temp: float):
    os.environ["OLLAMA_KEEP_ALIVE"] = "15m"
    rag = RAGEngine(
        model_name=_model_name,
        vector_path=vectors_path,
        index_path=faiss_index_path,
        model_threads=_threads,
        ollama_host=_host
        # ❌ pas d'ollama_opts → Ollama choisit les defaults
    )
    return rag

rag = load_rag_engine(model_name, ollama_host, num_threads, temperature)

# --- Chat simple ---
user_input = st.text_area("Posez votre question :", height=120,
                          placeholder="Ex: Quels sont les traitements appliqués aux images ?")
col1, col2 = st.columns([1, 1])

# if col1.button("Envoyer"):
#     if user_input.strip():
#         with st.spinner("Génération en cours..."):
#             try:
#                 response = rag.ask(user_input)
#                 st.markdown("**Réponse :**")
#                 st.success(response)
#             except Exception as e:
#                 st.error(f"Erreur pendant la génération: {e}")
#     else:
#         st.info("Saisissez une question.")

if col2.button("Envoyer (stream)"):
    if user_input.strip():
        with st.spinner("Génération en cours (stream)..."):
            try:
                ph = st.empty()
                acc = ""
                for token in rag.ask_stream(user_input):
                    acc += token
                    ph.markdown(acc)
                st.balloons()
            except Exception as e:
                st.error(f"Erreur pendant la génération (stream): {e}")
    else:
        st.info("Saisissez une question.")