Spaces:
Paused
Paused
import os | |
import logging | |
import streamlit as st | |
from huggingface_hub import hf_hub_download | |
# ✅ Nouveau moteur RAG (sans ollama_opts) | |
from rag_model_ollama_v1 import RAGEngine | |
# --- Config & logs --- | |
os.environ.setdefault("NLTK_DATA", "/home/appuser/nltk_data") | |
logger = logging.getLogger("Streamlit") | |
logger.setLevel(logging.INFO) | |
handler = logging.StreamHandler() | |
formatter = logging.Formatter("[%(asctime)s] %(levelname)s - %(message)s") | |
handler.setFormatter(formatter) | |
if not logger.handlers: | |
logger.addHandler(handler) | |
st.set_page_config(page_title="Chatbot RAG (Ollama)", page_icon="🤖") | |
# --- ENV --- | |
ENV = os.getenv("ENV", "local") | |
logger.info(f"ENV: {ENV}") | |
# --- Chemins FAISS & chunks --- | |
if ENV == "local": | |
faiss_index_path = "chatbot-models/vectordb_docling/index.faiss" | |
vectors_path = "chatbot-models/vectordb_docling/chunks.pkl" | |
else: | |
faiss_index_path = hf_hub_download( | |
repo_id="rkonan/chatbot-models", | |
filename="chatbot-models/vectordb_docling/index.faiss", | |
repo_type="dataset" | |
) | |
vectors_path = hf_hub_download( | |
repo_id="rkonan/chatbot-models", | |
filename="chatbot-models/vectordb_docling/chunks.pkl", | |
repo_type="dataset" | |
) | |
# --- UI Sidebar --- | |
st.sidebar.header("⚙️ Paramètres") | |
default_host = os.getenv("OLLAMA_HOST", "http://localhost:11435") | |
ollama_host = st.sidebar.text_input("Ollama host", value=default_host) | |
suggested_models = [ | |
"qwen2.5:3b-instruct-q4_K_M", | |
"noushermes_rag", | |
"mistral", | |
"gemma3", | |
"deepseek-r1", | |
"granite3.3", | |
"llama3.1:8b-instruct-q4_K_M", | |
"nous-hermes2:Q4_K_M", | |
] | |
model_name = st.sidebar.selectbox("Modèle Ollama", options=suggested_models, index=0) | |
num_threads = st.sidebar.slider("Threads (hint)", min_value=2, max_value=16, value=6, step=1) | |
temperature = st.sidebar.slider("Température", min_value=0.0, max_value=1.5, value=0.1, step=0.1) | |
st.title("🤖 Chatbot RAG Local (Ollama)") | |
# --- Cache du moteur --- | |
def load_rag_engine(_model_name: str, _host: str, _threads: int, _temp: float): | |
os.environ["OLLAMA_KEEP_ALIVE"] = "15m" | |
rag = RAGEngine( | |
model_name=_model_name, | |
vector_path=vectors_path, | |
index_path=faiss_index_path, | |
model_threads=_threads, | |
ollama_host=_host | |
# ❌ pas d'ollama_opts → Ollama choisit les defaults | |
) | |
return rag | |
rag = load_rag_engine(model_name, ollama_host, num_threads, temperature) | |
# --- Chat simple --- | |
user_input = st.text_area("Posez votre question :", height=120, | |
placeholder="Ex: Quels sont les traitements appliqués aux images ?") | |
col1, col2 = st.columns([1, 1]) | |
# if col1.button("Envoyer"): | |
# if user_input.strip(): | |
# with st.spinner("Génération en cours..."): | |
# try: | |
# response = rag.ask(user_input) | |
# st.markdown("**Réponse :**") | |
# st.success(response) | |
# except Exception as e: | |
# st.error(f"Erreur pendant la génération: {e}") | |
# else: | |
# st.info("Saisissez une question.") | |
if col2.button("Envoyer (stream)"): | |
if user_input.strip(): | |
with st.spinner("Génération en cours (stream)..."): | |
try: | |
ph = st.empty() | |
acc = "" | |
for token in rag.ask_stream(user_input): | |
acc += token | |
ph.markdown(acc) | |
st.balloons() | |
except Exception as e: | |
st.error(f"Erreur pendant la génération (stream): {e}") | |
else: | |
st.info("Saisissez une question.") | |