Spaces:
Paused
Paused
# codingo/chatbot/chatbot.py | |
"""Chatbot module for Codingo … | |
Default model changed to blenderbot-400M-distill; generation uses max_new_tokens; fallback between causal and seq2seq models.""" | |
import os | |
import shutil | |
from typing import List | |
os.environ.setdefault("HF_HOME", "/tmp/huggingface") | |
os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/huggingface/transformers") | |
os.environ.setdefault("HUGGINGFACE_HUB_CACHE", "/tmp/huggingface/hub") | |
_hf_model = None | |
_hf_tokenizer = None | |
_chatbot_embedder = None | |
_chatbot_collection = None | |
_current_dir = os.path.dirname(os.path.abspath(__file__)) | |
_knowledge_base_path = os.path.join(_current_dir, "chatbot.txt") | |
_chroma_db_dir = "/tmp/chroma_db" | |
DEFAULT_MODEL_NAME = "facebook/blenderbot-400M-distill" | |
def _init_hf_model() -> None: | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoModelForSeq2SeqLM, | |
AutoTokenizer, | |
) | |
import torch | |
global _hf_model, _hf_tokenizer | |
if _hf_model is not None and _hf_tokenizer is not None: | |
return | |
model_name = os.getenv("HF_CHATBOT_MODEL", DEFAULT_MODEL_NAME) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
try: | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
except Exception: | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
model = model.to(device) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
_hf_model = model | |
_hf_tokenizer = tokenizer | |
def _init_vector_store() -> None: | |
global _chatbot_embedder, _chatbot_collection | |
if _chatbot_embedder is not None and _chatbot_collection is not None: | |
return | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from sentence_transformers import SentenceTransformer | |
import chromadb | |
from chromadb.config import Settings | |
shutil.rmtree("/app/chatbot/chroma_db", ignore_errors=True) | |
os.makedirs(_chroma_db_dir, exist_ok=True) | |
try: | |
with open(_knowledge_base_path, encoding="utf-8") as f: | |
raw_text = f.read() | |
except FileNotFoundError: | |
raw_text = ( | |
"Codingo is an AI-powered recruitment platform designed to " | |
"streamline job applications, candidate screening, and hiring. " | |
"We make hiring smarter, faster, and fairer through automation " | |
"and intelligent recommendations." | |
) | |
splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=100) | |
docs: List[str] = [doc.strip() for doc in splitter.split_text(raw_text) if doc.strip()] | |
embedder = SentenceTransformer("all-MiniLM-L6-v2") | |
embeddings = embedder.encode(docs, show_progress_bar=False, batch_size=32) | |
client = chromadb.Client(Settings( | |
persist_directory=_chroma_db_dir, | |
anonymized_telemetry=False, | |
is_persistent=True, | |
)) | |
collection = client.get_or_create_collection("chatbot") | |
try: | |
existing = collection.get(limit=1) | |
if not existing.get("documents"): | |
raise ValueError("Empty Chroma DB") | |
except Exception: | |
ids = [f"doc_{i}" for i in range(len(docs))] | |
collection.add(documents=docs, embeddings=embeddings.tolist(), ids=ids) | |
_chatbot_embedder = embedder | |
_chatbot_collection = collection | |
def get_chatbot_response(query: str) -> str: | |
if not query or not query.strip(): | |
return "Please type a question about the Codingo platform." | |
_init_vector_store() | |
_init_hf_model() | |
embedder = _chatbot_embedder | |
collection = _chatbot_collection | |
model = _hf_model | |
tokenizer = _hf_tokenizer | |
import torch | |
query_embedding = embedder.encode([query])[0] | |
results = collection.query(query_embeddings=[query_embedding.tolist()], n_results=3) | |
retrieved_docs = results.get("documents", [[]])[0] if results else [] | |
context = "\n".join(retrieved_docs[:3]) | |
system_instruction = ( | |
"You are LUNA AI, a helpful assistant for the Codingo recruitment " | |
"platform. Use the provided context to answer questions about " | |
"Codingo. If the question is not related to Codingo, politely " | |
"redirect the conversation. Keep responses concise and friendly." | |
) | |
prompt = f"{system_instruction}\n\nContext:\n{context}\n\nUser: {query}\nLUNA AI:" | |
inputs = tokenizer.encode( | |
prompt, return_tensors="pt", truncation=True, max_length=512, padding=True | |
).to(model.device) | |
with torch.no_grad(): | |
output_ids = model.generate( | |
inputs, | |
max_new_tokens=150, | |
num_beams=3, | |
do_sample=True, | |
temperature=0.7, | |
pad_token_id=tokenizer.eos_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
early_stopping=True, | |
) | |
response = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
if "LUNA AI:" in response: | |
response = response.split("LUNA AI:")[-1].strip() | |
elif prompt in response: | |
response = response.replace(prompt, "").strip() | |
return ( | |
response | |
if response | |
else "I'm here to help you with questions about the Codingo platform. What would you like to know?" | |
) | |