Spaces:
Paused
Paused
import os | |
import logging | |
import streamlit as st | |
from huggingface_hub import hf_hub_download | |
# ✅ Nouveau moteur RAG (Ollama) | |
from rag_model_ollama_v1 import RAGEngine | |
# --- Config & logs --- | |
os.environ.setdefault("NLTK_DATA", "/home/appuser/nltk_data") | |
logger = logging.getLogger("Streamlit") | |
logger.setLevel(logging.INFO) | |
handler = logging.StreamHandler() | |
formatter = logging.Formatter("[%(asctime)s] %(levelname)s - %(message)s") | |
handler.setFormatter(formatter) | |
if not logger.handlers: | |
logger.addHandler(handler) | |
st.set_page_config(page_title="Chatbot RAG (Ollama)", page_icon="🤖") | |
# --- ENV --- | |
ENV = os.getenv("ENV", "local") # "local" ou "space" | |
logger.info(f"ENV: {ENV}") | |
# --- Chemins FAISS & chunks --- | |
if ENV == "local": | |
# Adapte ces chemins à ton filesystem local | |
faiss_index_path = "chatbot-models/vectordb_docling/index.faiss" | |
vectors_path = "chatbot-models/vectordb_docling/chunks.pkl" | |
else: | |
# Télécharge depuis Hugging Face (dataset privé/public selon tes réglages) | |
faiss_index_path = hf_hub_download( | |
repo_id="rkonan/chatbot-models", | |
filename="chatbot-models/vectordb_docling/index.faiss", | |
repo_type="dataset" | |
) | |
vectors_path = hf_hub_download( | |
repo_id="rkonan/chatbot-models", | |
filename="chatbot-models/vectordb_docling/chunks.pkl", | |
repo_type="dataset" | |
) | |
# --- UI Sidebar --- | |
st.sidebar.header("⚙️ Paramètres") | |
default_host = os.getenv("OLLAMA_HOST", "http://localhost:11434") | |
ollama_host = st.sidebar.text_input("Ollama host", value=default_host, help="Ex: http://localhost:11434") | |
# Propose des modèles déjà présents ou courants | |
suggested_models = [ | |
"qwen2.5:3b-instruct-q4_K_M", | |
"noushermes_rag", | |
"mistral", # présent chez toi | |
"gemma3", # présent chez toi | |
"deepseek-r1", # présent chez toi (raisonnement long, plus lent) | |
"granite3.3", # présent chez toi | |
"llama3.1:8b-instruct-q4_K_M", | |
"nous-hermes2:Q4_K_M", | |
] | |
model_name = st.sidebar.selectbox("Modèle Ollama", options=suggested_models, index=0) | |
num_threads = st.sidebar.slider("Threads (hint)", min_value=2, max_value=16, value=6, step=1) | |
temperature = st.sidebar.slider("Température", min_value=0.0, max_value=1.5, value=0.1, step=0.1) | |
st.title("🤖 Chatbot RAG Local (Ollama)") | |
# --- Cache du moteur --- | |
# def load_rag_engine(_model_name: str, _host: str, _threads: int, _temp: float,_version: int =1): | |
# # Options pour Ollama | |
# ollama_opts = { | |
# "num_thread": int(_threads), | |
# "temperature": float(_temp), | |
# "num_ctx": 256, | |
# "num_batch": 16, | |
# } | |
# rag = RAGEngine( | |
# model_name=_model_name, | |
# vector_path=vectors_path, | |
# index_path=faiss_index_path, | |
# model_threads=_threads, | |
# ollama_host=_host, | |
# ollama_opts=ollama_opts | |
# ) | |
# # Warmup léger (évite la latence au 1er token) | |
# try: | |
# gen = rag._complete_stream("Bonjour", max_tokens=1) | |
# next(gen,"") | |
# except Exception as e: | |
# logger.warning(f"Warmup Ollama échoué: {e}") | |
# return rag | |
def load_rag_engine(_model_name: str, _host: str, _threads: int, _temp: float, _version: int = 1): | |
# Applique KEEP_ALIVE pour garder le modèle en mémoire après usage | |
os.environ["OLLAMA_KEEP_ALIVE"] = "15m" | |
ollama_opts = { | |
"num_thread": int(_threads), | |
"temperature": float(_temp), | |
"num_ctx": 512, # identique au CLI | |
"num_batch": 16, | |
} | |
rag = RAGEngine( | |
model_name=_model_name, | |
vector_path=vectors_path, | |
index_path=faiss_index_path, | |
model_threads=_threads, | |
ollama_host=_host, | |
ollama_opts=ollama_opts | |
) | |
# Warmup proche du CLI (plus de 1 token pour remplir le cache) | |
try: | |
list(rag._complete_stream("Bonjour", max_tokens=8)) | |
except Exception as e: | |
logger.warning(f"Warmup Ollama échoué: {e}") | |
return rag | |
rag = load_rag_engine(model_name, ollama_host, num_threads, temperature,_version=2) | |
# --- Chat simple --- | |
user_input = st.text_area("Posez votre question :", height=120, placeholder="Ex: Quels sont les traitements appliqués aux images ?") | |
col1, col2 = st.columns([1,1]) | |
if col1.button("Envoyer"): | |
if user_input.strip(): | |
with st.spinner("Génération en cours..."): | |
try: | |
response = rag.ask(user_input) | |
st.markdown("**Réponse :**") | |
st.success(response) | |
except Exception as e: | |
st.error(f"Erreur pendant la génération: {e}") | |
else: | |
st.info("Saisissez une question.") | |
if col2.button("Envoyer (stream)"): | |
if user_input.strip(): | |
with st.spinner("Génération en cours (stream)..."): | |
try: | |
# Affichage token-par-token | |
ph = st.empty() | |
acc = "" | |
for token in rag.ask_stream(user_input): | |
acc += token | |
ph.markdown(acc) | |
st.balloons() | |
except Exception as e: | |
st.error(f"Erreur pendant la génération (stream): {e}") | |
else: | |
st.info("Saisissez une question.") | |