Codingo / chatbot /chatbot.py
husseinelsaadi's picture
update
72f831c
raw
history blame
5.27 kB
# codingo/chatbot/chatbot.py
"""Chatbot module for Codingo …
Default model changed to blenderbot-400M-distill; generation uses max_new_tokens; fallback between causal and seq2seq models."""
import os
import shutil
from typing import List
os.environ.setdefault("HF_HOME", "/tmp/huggingface")
os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/huggingface/transformers")
os.environ.setdefault("HUGGINGFACE_HUB_CACHE", "/tmp/huggingface/hub")
_hf_model = None
_hf_tokenizer = None
_chatbot_embedder = None
_chatbot_collection = None
_current_dir = os.path.dirname(os.path.abspath(__file__))
_knowledge_base_path = os.path.join(_current_dir, "chatbot.txt")
_chroma_db_dir = "/tmp/chroma_db"
DEFAULT_MODEL_NAME = "facebook/blenderbot-400M-distill"
def _init_hf_model() -> None:
from transformers import (
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoTokenizer,
)
import torch
global _hf_model, _hf_tokenizer
if _hf_model is not None and _hf_tokenizer is not None:
return
model_name = os.getenv("HF_CHATBOT_MODEL", DEFAULT_MODEL_NAME)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(model_name)
try:
model = AutoModelForCausalLM.from_pretrained(model_name)
except Exception:
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
model = model.to(device)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
_hf_model = model
_hf_tokenizer = tokenizer
def _init_vector_store() -> None:
global _chatbot_embedder, _chatbot_collection
if _chatbot_embedder is not None and _chatbot_collection is not None:
return
from langchain.text_splitter import RecursiveCharacterTextSplitter
from sentence_transformers import SentenceTransformer
import chromadb
from chromadb.config import Settings
shutil.rmtree("/app/chatbot/chroma_db", ignore_errors=True)
os.makedirs(_chroma_db_dir, exist_ok=True)
try:
with open(_knowledge_base_path, encoding="utf-8") as f:
raw_text = f.read()
except FileNotFoundError:
raw_text = (
"Codingo is an AI-powered recruitment platform designed to "
"streamline job applications, candidate screening, and hiring. "
"We make hiring smarter, faster, and fairer through automation "
"and intelligent recommendations."
)
splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=100)
docs: List[str] = [doc.strip() for doc in splitter.split_text(raw_text) if doc.strip()]
embedder = SentenceTransformer("all-MiniLM-L6-v2")
embeddings = embedder.encode(docs, show_progress_bar=False, batch_size=32)
client = chromadb.Client(Settings(
persist_directory=_chroma_db_dir,
anonymized_telemetry=False,
is_persistent=True,
))
collection = client.get_or_create_collection("chatbot")
try:
existing = collection.get(limit=1)
if not existing.get("documents"):
raise ValueError("Empty Chroma DB")
except Exception:
ids = [f"doc_{i}" for i in range(len(docs))]
collection.add(documents=docs, embeddings=embeddings.tolist(), ids=ids)
_chatbot_embedder = embedder
_chatbot_collection = collection
def get_chatbot_response(query: str) -> str:
if not query or not query.strip():
return "Please type a question about the Codingo platform."
_init_vector_store()
_init_hf_model()
embedder = _chatbot_embedder
collection = _chatbot_collection
model = _hf_model
tokenizer = _hf_tokenizer
import torch
query_embedding = embedder.encode([query])[0]
results = collection.query(query_embeddings=[query_embedding.tolist()], n_results=3)
retrieved_docs = results.get("documents", [[]])[0] if results else []
context = "\n".join(retrieved_docs[:3])
system_instruction = (
"You are LUNA AI, a helpful assistant for the Codingo recruitment "
"platform. Use the provided context to answer questions about "
"Codingo. If the question is not related to Codingo, politely "
"redirect the conversation. Keep responses concise and friendly."
)
prompt = f"{system_instruction}\n\nContext:\n{context}\n\nUser: {query}\nLUNA AI:"
inputs = tokenizer.encode(
prompt, return_tensors="pt", truncation=True, max_length=512, padding=True
).to(model.device)
with torch.no_grad():
output_ids = model.generate(
inputs,
max_new_tokens=150,
num_beams=3,
do_sample=True,
temperature=0.7,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
early_stopping=True,
)
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
if "LUNA AI:" in response:
response = response.split("LUNA AI:")[-1].strip()
elif prompt in response:
response = response.replace(prompt, "").strip()
return (
response
if response
else "I'm here to help you with questions about the Codingo platform. What would you like to know?"
)