Spaces:
Paused
Paused
# codingo/chatbot/chatbot.py | |
"""Interactive chatbot using Flan-T5 for dynamic responses""" | |
import os | |
import shutil | |
from typing import List | |
import torch | |
os.environ.setdefault("HF_HOME", "/tmp/huggingface") | |
os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/huggingface/transformers") | |
os.environ.setdefault("HUGGINGFACE_HUB_CACHE", "/tmp/huggingface/hub") | |
_model = None | |
_tokenizer = None | |
_chatbot_embedder = None | |
_chatbot_collection = None | |
_current_dir = os.path.dirname(os.path.abspath(__file__)) | |
_knowledge_base_path = os.path.join(_current_dir, "chatbot.txt") | |
_chroma_db_dir = "/tmp/chroma_db" | |
# Using Flan-T5 - it's small, fast, and great for Q&A | |
MODEL_NAME = "google/flan-t5-small" | |
def _init_model(): | |
global _model, _tokenizer | |
if _model is not None and _tokenizer is not None: | |
return | |
print("Loading Flan-T5 model...") | |
from transformers import T5ForConditionalGeneration, T5Tokenizer | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME) | |
model = T5ForConditionalGeneration.from_pretrained( | |
MODEL_NAME, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
low_cpu_mem_usage=True | |
) | |
model = model.to(device) | |
model.eval() | |
_model = model | |
_tokenizer = tokenizer | |
print("Model loaded successfully!") | |
def _init_vector_store(): | |
global _chatbot_embedder, _chatbot_collection | |
if _chatbot_embedder is not None and _chatbot_collection is not None: | |
return | |
print("Initializing vector store...") | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from sentence_transformers import SentenceTransformer | |
import chromadb | |
from chromadb.config import Settings | |
# Clean and create directory | |
shutil.rmtree(_chroma_db_dir, ignore_errors=True) | |
os.makedirs(_chroma_db_dir, exist_ok=True) | |
# Load knowledge base | |
try: | |
with open(_knowledge_base_path, encoding="utf-8") as f: | |
raw_text = f.read() | |
print(f"Loaded knowledge base: {len(raw_text)} characters") | |
except FileNotFoundError: | |
print("Knowledge base not found!") | |
raw_text = "Codingo is an AI recruitment platform." | |
# Split into chunks | |
splitter = RecursiveCharacterTextSplitter(chunk_size=400, chunk_overlap=50) | |
docs = [doc.strip() for doc in splitter.split_text(raw_text) if doc.strip()] | |
print(f"Created {len(docs)} document chunks") | |
# Create embeddings | |
embedder = SentenceTransformer("all-MiniLM-L6-v2") | |
embeddings = embedder.encode(docs, show_progress_bar=False) | |
# Create ChromaDB collection | |
client = chromadb.Client(Settings(anonymized_telemetry=False, is_persistent=False)) | |
try: | |
client.delete_collection("chatbot") | |
except: | |
pass | |
collection = client.create_collection("chatbot") | |
ids = [f"doc_{i}" for i in range(len(docs))] | |
collection.add(documents=docs, embeddings=embeddings.tolist(), ids=ids) | |
_chatbot_embedder = embedder | |
_chatbot_collection = collection | |
print("Vector store ready!") | |
def get_chatbot_response(query: str) -> str: | |
try: | |
if not query or not query.strip(): | |
return "Hi! I'm LUNA AI. Ask me anything about Codingo!" | |
print(f"\nProcessing: '{query}'") | |
# Clear GPU cache | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# Initialize | |
_init_vector_store() | |
_init_model() | |
# Search for relevant context | |
query_embedding = _chatbot_embedder.encode([query])[0] | |
results = _chatbot_collection.query( | |
query_embeddings=[query_embedding.tolist()], | |
n_results=3 | |
) | |
retrieved_docs = results.get("documents", [[]])[0] if results else [] | |
print(f"Found {len(retrieved_docs)} relevant chunks") | |
# Combine the most relevant information | |
context = " ".join(retrieved_docs[:2]) if retrieved_docs else "Codingo is an AI recruitment platform." | |
# Create a prompt for Flan-T5 | |
prompt = f"""Answer the question based on the context about Codingo. | |
Context: {context} | |
Question: {query} | |
Answer:""" | |
# Tokenize | |
inputs = _tokenizer( | |
prompt, | |
max_length=512, | |
truncation=True, | |
return_tensors="pt" | |
).to(_model.device) | |
# Generate response | |
with torch.no_grad(): | |
outputs = _model.generate( | |
**inputs, | |
max_new_tokens=150, | |
num_beams=4, | |
temperature=0.7, | |
do_sample=True, | |
top_p=0.9, | |
repetition_penalty=1.2 | |
) | |
# Decode response | |
response = _tokenizer.decode(outputs[0], skip_special_tokens=True) | |
print(f"Generated: '{response}'") | |
# Make sure we have a good response | |
if not response or len(response) < 5: | |
# Fallback: try a simpler prompt | |
simple_prompt = f"Question about Codingo: {query}\nAnswer:" | |
inputs = _tokenizer(simple_prompt, max_length=256, truncation=True, return_tensors="pt").to(_model.device) | |
with torch.no_grad(): | |
outputs = _model.generate(**inputs, max_new_tokens=100, temperature=0.8) | |
response = _tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Clean up the response | |
response = response.strip() | |
# If still too short, provide a helpful response | |
if len(response) < 10: | |
if "hello" in query.lower() or "hi" in query.lower(): | |
return "Hello! I'm LUNA AI, your Codingo assistant. I can help you with questions about our AI recruitment platform, job matching, CV tips, and more!" | |
else: | |
return f"I can help you with that! Based on what I know about Codingo: {retrieved_docs[0][:200] if retrieved_docs else 'Codingo is an AI-powered recruitment platform that helps match candidates with jobs.'}" | |
return response | |
except Exception as e: | |
print(f"Error: {e}") | |
import traceback | |
traceback.print_exc() | |
return "I'm having a technical issue. Please try asking your question again!" | |
# Test function | |
if __name__ == "__main__": | |
# Test the chatbot | |
test_queries = [ | |
"What is Codingo?", | |
"How does it work?", | |
"What makes Codingo special?", | |
"How can I improve my profile?", | |
"Is it free?" | |
] | |
print("Testing chatbot...") | |
for q in test_queries: | |
response = get_chatbot_response(q) | |
print(f"\nQ: {q}") | |
print(f"A: {response}") | |
print("-" * 50) |