Codingo / chatbot /chatbot.py
husseinelsaadi's picture
updated
7814b36
raw
history blame
6.92 kB
# codingo/chatbot/chatbot.py
"""Interactive chatbot using Flan-T5 for dynamic responses"""
import os
import shutil
from typing import List
import torch
os.environ.setdefault("HF_HOME", "/tmp/huggingface")
os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/huggingface/transformers")
os.environ.setdefault("HUGGINGFACE_HUB_CACHE", "/tmp/huggingface/hub")
_model = None
_tokenizer = None
_chatbot_embedder = None
_chatbot_collection = None
_current_dir = os.path.dirname(os.path.abspath(__file__))
_knowledge_base_path = os.path.join(_current_dir, "chatbot.txt")
_chroma_db_dir = "/tmp/chroma_db"
# Using Flan-T5 - it's small, fast, and great for Q&A
MODEL_NAME = "google/flan-t5-small"
def _init_model():
global _model, _tokenizer
if _model is not None and _tokenizer is not None:
return
print("Loading Flan-T5 model...")
from transformers import T5ForConditionalGeneration, T5Tokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
model = T5ForConditionalGeneration.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
low_cpu_mem_usage=True
)
model = model.to(device)
model.eval()
_model = model
_tokenizer = tokenizer
print("Model loaded successfully!")
def _init_vector_store():
global _chatbot_embedder, _chatbot_collection
if _chatbot_embedder is not None and _chatbot_collection is not None:
return
print("Initializing vector store...")
from langchain.text_splitter import RecursiveCharacterTextSplitter
from sentence_transformers import SentenceTransformer
import chromadb
from chromadb.config import Settings
# Clean and create directory
shutil.rmtree(_chroma_db_dir, ignore_errors=True)
os.makedirs(_chroma_db_dir, exist_ok=True)
# Load knowledge base
try:
with open(_knowledge_base_path, encoding="utf-8") as f:
raw_text = f.read()
print(f"Loaded knowledge base: {len(raw_text)} characters")
except FileNotFoundError:
print("Knowledge base not found!")
raw_text = "Codingo is an AI recruitment platform."
# Split into chunks
splitter = RecursiveCharacterTextSplitter(chunk_size=400, chunk_overlap=50)
docs = [doc.strip() for doc in splitter.split_text(raw_text) if doc.strip()]
print(f"Created {len(docs)} document chunks")
# Create embeddings
embedder = SentenceTransformer("all-MiniLM-L6-v2")
embeddings = embedder.encode(docs, show_progress_bar=False)
# Create ChromaDB collection
client = chromadb.Client(Settings(anonymized_telemetry=False, is_persistent=False))
try:
client.delete_collection("chatbot")
except:
pass
collection = client.create_collection("chatbot")
ids = [f"doc_{i}" for i in range(len(docs))]
collection.add(documents=docs, embeddings=embeddings.tolist(), ids=ids)
_chatbot_embedder = embedder
_chatbot_collection = collection
print("Vector store ready!")
def get_chatbot_response(query: str) -> str:
try:
if not query or not query.strip():
return "Hi! I'm LUNA AI. Ask me anything about Codingo!"
print(f"\nProcessing: '{query}'")
# Clear GPU cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Initialize
_init_vector_store()
_init_model()
# Search for relevant context
query_embedding = _chatbot_embedder.encode([query])[0]
results = _chatbot_collection.query(
query_embeddings=[query_embedding.tolist()],
n_results=3
)
retrieved_docs = results.get("documents", [[]])[0] if results else []
print(f"Found {len(retrieved_docs)} relevant chunks")
# Combine the most relevant information
context = " ".join(retrieved_docs[:2]) if retrieved_docs else "Codingo is an AI recruitment platform."
# Create a prompt for Flan-T5
prompt = f"""Answer the question based on the context about Codingo.
Context: {context}
Question: {query}
Answer:"""
# Tokenize
inputs = _tokenizer(
prompt,
max_length=512,
truncation=True,
return_tensors="pt"
).to(_model.device)
# Generate response
with torch.no_grad():
outputs = _model.generate(
**inputs,
max_new_tokens=150,
num_beams=4,
temperature=0.7,
do_sample=True,
top_p=0.9,
repetition_penalty=1.2
)
# Decode response
response = _tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Generated: '{response}'")
# Make sure we have a good response
if not response or len(response) < 5:
# Fallback: try a simpler prompt
simple_prompt = f"Question about Codingo: {query}\nAnswer:"
inputs = _tokenizer(simple_prompt, max_length=256, truncation=True, return_tensors="pt").to(_model.device)
with torch.no_grad():
outputs = _model.generate(**inputs, max_new_tokens=100, temperature=0.8)
response = _tokenizer.decode(outputs[0], skip_special_tokens=True)
# Clean up the response
response = response.strip()
# If still too short, provide a helpful response
if len(response) < 10:
if "hello" in query.lower() or "hi" in query.lower():
return "Hello! I'm LUNA AI, your Codingo assistant. I can help you with questions about our AI recruitment platform, job matching, CV tips, and more!"
else:
return f"I can help you with that! Based on what I know about Codingo: {retrieved_docs[0][:200] if retrieved_docs else 'Codingo is an AI-powered recruitment platform that helps match candidates with jobs.'}"
return response
except Exception as e:
print(f"Error: {e}")
import traceback
traceback.print_exc()
return "I'm having a technical issue. Please try asking your question again!"
# Test function
if __name__ == "__main__":
# Test the chatbot
test_queries = [
"What is Codingo?",
"How does it work?",
"What makes Codingo special?",
"How can I improve my profile?",
"Is it free?"
]
print("Testing chatbot...")
for q in test_queries:
response = get_chatbot_response(q)
print(f"\nQ: {q}")
print(f"A: {response}")
print("-" * 50)