Spaces:
Sleeping
Sleeping
import os | |
import re | |
import requests | |
import sqlite3 | |
import threading | |
import time | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
from sentence_transformers import SentenceTransformer | |
from tenacity import retry, stop_after_attempt, wait_exponential | |
import mwparserfromhell | |
import logging | |
import chromadb | |
from collections import deque | |
from huggingface_hub import login | |
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from fastapi.concurrency import run_in_threadpool | |
# --- Configuration --- | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
# Use /tmp for all persistent storage (always writable) | |
CACHE_DIR = "/tmp/one_piece_cache" | |
DB_PATH = os.path.join(CACHE_DIR, "one_piece_data.db") | |
CHROMA_DB_PATH = os.path.join(CACHE_DIR, "chroma_db") | |
LLM_MODEL = "google/gemma-2-2b-it" | |
EMBED_MODEL = "sentence-transformers/paraphrase-MiniLM-L6-v2" | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
WIKI_CATEGORIES = { | |
"Characters": ["Straw_Hat_Pirates", "Marines", "Yonko", "Seven_Warlords", "Worst_Generation"], | |
"Devil_Fruits": ["Paramecia", "Zoan", "Logia"], | |
"Locations": ["Islands", "Seas", "Grand_Line", "New_World"], | |
"Story": ["Story_Arcs", "Sagas", "Events"], | |
"Organizations": ["Pirates", "Crews", "Marines", "World_Government"], | |
"Concepts": ["Haki", "Void_Century", "Ancient_Weapons", "Will_of_D"] | |
} | |
CRUCIAL_PAGES = [ | |
"Monkey_D._Luffy", "Straw_Hat_Pirates", "One_Piece_(Manga)", "Eiichiro_Oda", | |
"Devil_Fruit", "Haki", "Void_Century", "Gol_D._Roger", "Marines", "Yonko", | |
"World_Government", "Grand_Line", "New_World", "One_Piece", "Will_of_D", | |
"Poneglyphs", "Ancient_Weapons", "Roger_Pirates", "God_Valley_Incident", | |
"Joy_Boy", "Sun_God_Nika", "Laugh_Tale", "Rocks_Pirates", "Revolutionary_Army", | |
"Hito_Hito_no_Mi,_Model:_Nika", "Gomu_Gomu_no_Mi", "Five_Elders", "Im", | |
"Marshall_D._Teach", "Blackbeard_Pirates", "Gura_Gura_no_Mi", "Yami_Yami_no_Mi" | |
] | |
CHUNK_SIZE_TOKENS = 300 | |
CHUNK_OVERLAP = 2 | |
MAX_CONTEXT_CHUNKS = 10 | |
SIMILARITY_THRESHOLD = 0.35 | |
REFRESH_INTERVAL = 7 * 24 * 3600 | |
CONVERSATION_HISTORY_LENGTH = 6 | |
class OnePieceChatbot: | |
def __init__(self): | |
os.makedirs(CACHE_DIR, exist_ok=True) | |
if HF_TOKEN: | |
try: | |
login(token=HF_TOKEN) | |
logging.info("Successfully logged into Hugging Face.") | |
except Exception as e: | |
logging.warning(f"Hugging Face login failed: {e}. Proceeding without explicit login.") | |
self.db_conn = self._init_db() | |
self.chroma_client, self.chroma_collection = self._init_chroma() | |
self.data_lock = threading.Lock() | |
self.processing_pages = set() | |
self.initial_processing_done = threading.Event() | |
try: | |
self.embedder = SentenceTransformer(EMBED_MODEL) | |
logging.info(f"Loaded SentenceTransformer model: {EMBED_MODEL}") | |
self.tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL) | |
logging.info(f"Loaded Tokenizer: {LLM_MODEL}") | |
self.model = AutoModelForCausalLM.from_pretrained( | |
LLM_MODEL, | |
device_map="auto", | |
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
) | |
logging.info(f"Loaded LLM Model: {LLM_MODEL}") | |
self.generator = pipeline( | |
"text-generation", | |
model=self.model, | |
tokenizer=self.tokenizer, | |
max_new_tokens=500, | |
temperature=0.2, | |
do_sample=True, | |
repetition_penalty=1.2 | |
) | |
logging.info("Initialized text generation pipeline.") | |
except Exception as e: | |
logging.error(f"Failed to load models: {e}") | |
raise SystemExit("Failed to load models, exiting.") from e | |
self.conversation_history = deque(maxlen=CONVERSATION_HISTORY_LENGTH) | |
logging.info("Starting background data processing thread...") | |
thread = threading.Thread(target=self._process_wiki_data, daemon=True) | |
thread.start() | |
logging.info("Background data processing thread started.") | |
def _init_db(self): | |
conn = sqlite3.connect(DB_PATH, check_same_thread=False) | |
conn.execute(""" | |
CREATE TABLE IF NOT EXISTS wiki_data ( | |
title TEXT PRIMARY KEY, | |
content TEXT, | |
category TEXT, | |
last_fetched REAL, | |
page_links TEXT | |
) | |
""") | |
conn.execute("CREATE INDEX IF NOT EXISTS idx_category ON wiki_data (category)") | |
conn.commit() | |
logging.info(f"SQLite database initialized at {DB_PATH}") | |
return conn | |
def _init_chroma(self): | |
client = chromadb.PersistentClient(path=CHROMA_DB_PATH) | |
collection_name = "one_piece_knowledge" | |
try: | |
collection = client.get_collection(name=collection_name) | |
logging.info(f"Connected to existing ChromaDB collection: {collection_name}") | |
except Exception: | |
collection = client.create_collection( | |
name=collection_name, | |
metadata={"hnsw:space": "cosine"} | |
) | |
logging.info(f"Created new ChromaDB collection: {collection_name}") | |
logging.info(f"ChromaDB initialized at {CHROMA_DB_PATH}") | |
return client, collection | |
def _fetch_wiki_page(self, title): | |
url = f"https://onepiece.fandom.com/api.php?action=parse&page={title}&format=json&prop=wikitext|categories" | |
try: | |
response = requests.get(url, timeout=15) | |
response.raise_for_status() | |
data = response.json() | |
if "parse" not in data: | |
logging.warning(f"Could not parse wiki data for {title}") | |
return None, [], None | |
wikitext = data["parse"]["wikitext"]["*"] | |
parsed = mwparserfromhell.parse(wikitext) | |
for node in parsed.ifilter_templates(): | |
template_name = str(node.name).strip().lower() | |
if template_name.startswith('infobox') or 'sidebar' in template_name: | |
try: | |
parsed.remove(node) | |
except ValueError: | |
pass | |
links = [] | |
for link in parsed.ifilter_wikilinks(): | |
link_title = str(link.title).split("#")[0].strip() | |
if ":" not in link_title and len(link_title) > 1 and not link_title.startswith(('File:', 'Category:', 'Template:')): | |
links.append(link_title) | |
category = "Other" | |
if "categories" in data["parse"]: | |
categories = [cat["*"] for cat in data["parse"]["categories"]] | |
for cat_type, cat_list in WIKI_CATEGORIES.items(): | |
if any(cat.replace(' ', '_') in [c.replace(' ', '_') for c in categories] for cat in cat_list): | |
category = cat_type | |
break | |
text = parsed.strip_code().strip() | |
text = re.sub(r'https?://\S+', '', text) | |
text = re.sub(r'\[\[[^\]]+\]\]', '', text) | |
text = re.sub(r'\s+', ' ', text).strip() | |
text = re.sub(r'\n{2,}', '\n\n', text) | |
return text, links, category | |
except requests.exceptions.RequestException as e: | |
logging.error(f"Request error fetching {title}: {e}") | |
raise | |
except Exception as e: | |
logging.error(f"Error processing wiki data for {title}: {e}") | |
return None, [], None | |
def _fetch_category_pages(self, category): | |
url = f"https://onepiece.fandom.com/api.php?action=query&list=categorymembers&cmtitle=Category:{category}&cmlimit=500&format=json" | |
try: | |
response = requests.get(url, timeout=20) | |
response.raise_for_status() | |
data = response.json() | |
pages = [] | |
if "query" in data and "categorymembers" in data["query"]: | |
for member in data["query"]["categorymembers"]: | |
if member["ns"] == 0 and "title" in member: | |
pages.append(member["title"]) | |
return pages | |
except requests.exceptions.RequestException as e: | |
logging.error(f"Request error fetching category {category}: {e}") | |
return [] | |
except Exception as e: | |
logging.error(f"Error processing category {category} members: {e}") | |
return [] | |
def _process_wiki_data(self): | |
logging.info("Background processing: Starting data collection...") | |
processed_count = 0 | |
logging.info("Processing crucial pages...") | |
for page in CRUCIAL_PAGES: | |
cur = self.db_conn.execute("SELECT last_fetched FROM wiki_data WHERE title = ?", (page,)) | |
result = cur.fetchone() | |
if result and time.time() - result[0] < REFRESH_INTERVAL: | |
processed_count += 1 | |
continue | |
if self._process_page(page): | |
processed_count += 1 | |
logging.info("Processing category pages...") | |
crawled_pages_from_categories = set() | |
for category_type, categories in WIKI_CATEGORIES.items(): | |
for category in categories: | |
try: | |
pages = self._fetch_category_pages(category) | |
for page in pages: | |
if page in CRUCIAL_PAGES: continue | |
if page in crawled_pages_from_categories: continue | |
crawled_pages_from_categories.add(page) | |
cur = self.db_conn.execute("SELECT last_fetched FROM wiki_data WHERE title = ?", (page,)) | |
result = cur.fetchone() | |
if result and time.time() - result[0] < REFRESH_INTERVAL: | |
processed_count += 1 | |
continue | |
if self._process_page(page): | |
processed_count += 1 | |
except Exception as e: | |
logging.error(f"Error processing category {category}: {e}") | |
self.initial_processing_done.set() | |
logging.info(f"Initial data processing finished. Processed {processed_count} pages.") | |
logging.info(f"Vector collection count after initial processing: {self.chroma_collection.count()}") | |
while True: | |
time.sleep(REFRESH_INTERVAL) | |
logging.info("Starting periodic refresh cycle...") | |
cur = self.db_conn.execute("SELECT title FROM wiki_data ORDER BY last_fetched ASC LIMIT 200") | |
pages_to_refresh = [row[0] for row in cur.fetchall()] | |
logging.info(f"Refreshing {len(pages_to_refresh)} pages.") | |
for page in pages_to_refresh: | |
self._process_page(page) | |
logging.info("Periodic refresh cycle finished.") | |
def _process_page(self, title): | |
with self.data_lock: | |
if title in self.processing_pages: | |
return False | |
cur = self.db_conn.execute("SELECT last_fetched FROM wiki_data WHERE title = ?", (title,)) | |
result = cur.fetchone() | |
if result and time.time() - result[0] < REFRESH_INTERVAL: | |
return False | |
self.processing_pages.add(title) | |
try: | |
content, links, category = self._fetch_wiki_page(title) | |
if not content: | |
return False | |
with self.data_lock: | |
self.db_conn.execute( | |
"INSERT OR REPLACE INTO wiki_data VALUES (?, ?, ?, ?, ?)", | |
(title, content, category, time.time(), ','.join(links)) | |
) | |
self.db_conn.commit() | |
chunks = self._chunk_text(content, title) | |
if chunks: | |
try: | |
embeddings = self.embedder.encode(chunks, convert_to_tensor=False).tolist() | |
ids = [f"{title}::{i}" for i in range(len(chunks))] | |
metadatas = [{"source": title, "category": category} for _ in chunks] | |
try: | |
old_ids = self.chroma_collection.get(where={"source": title}, include=[])["ids"] | |
if old_ids: | |
self.chroma_collection.delete(ids=old_ids) | |
except Exception as delete_e: | |
logging.warning(f"Could not delete old chunks for {title} from ChromaDB: {delete_e}") | |
self.chroma_collection.upsert( | |
ids=ids, | |
embeddings=embeddings, | |
documents=chunks, | |
metadatas=metadatas | |
) | |
except Exception as chroma_e: | |
logging.error(f"Error adding/updating chunks for {title} in ChromaDB: {chroma_e}") | |
return False | |
if links: | |
for link in links[:10]: | |
threading.Thread(target=self._process_page, args=(link,), daemon=True).start() | |
return True | |
except Exception as e: | |
logging.error(f"Caught unexpected error during processing of {title}: {e}") | |
return False | |
finally: | |
with self.data_lock: | |
if title in self.processing_pages: | |
self.processing_pages.remove(title) | |
def _chunk_text(self, text, title): | |
sentences = re.split(r'(?<=[.!?])\s+', text) | |
chunks, current_sentences = [], [] | |
current_chunk_tokens = 0 | |
if len(sentences) < 2 and len(text.split()) < 50: | |
return [] | |
for i, sentence in enumerate(sentences): | |
sentence_tokens = len(self.tokenizer.encode(sentence, add_special_tokens=False)) | |
new_token_count = current_chunk_tokens + sentence_tokens + (1 if current_chunk_tokens > 0 else 0) | |
if new_token_count > CHUNK_SIZE_TOKENS and current_sentences: | |
chunk_text = " ".join(current_sentences).strip() | |
if chunk_text: | |
chunks.append(chunk_text) | |
overlap_sentences = current_sentences[-CHUNK_OVERLAP:] if len(current_sentences) > CHUNK_OVERLAP else current_sentences | |
current_sentences = overlap_sentences | |
current_chunk_tokens = sum(len(self.tokenizer.encode(s, add_special_tokens=False)) for s in current_sentences) | |
current_sentences.append(sentence) | |
current_chunk_tokens += sentence_tokens + (1 if current_chunk_tokens > 0 else 0) | |
if current_sentences: | |
chunk_text = " ".join(current_sentences).strip() | |
if chunk_text: | |
chunks.append(chunk_text) | |
return [chunk for chunk in chunks if len(chunk.split()) > 20] | |
def _interpret_query(self, query): | |
if not self.initial_processing_done.is_set() or len(query.split()) > 3 and not (query.lower().startswith("and ") or query.lower().startswith("what about ")): | |
return query | |
try: | |
prompt = f""" | |
Based on this conversation history: | |
{self._format_history()} | |
The user asked a question: "{query}" | |
Please interpret this as a complete, standalone question about One Piece, incorporating context from the history if necessary. Ensure the reformulated question is clear and specific, even if the original query was vague or a follow-up. | |
Only provide the complete reformulated question and nothing else. | |
""" | |
interpretation_response = self.generator( | |
prompt, | |
max_new_tokens=50, | |
temperature=0.5, | |
do_sample=True, | |
repetition_penalty=1.1 | |
)[0]["generated_text"] | |
if "Only provide the complete reformulated question and nothing else:" in interpretation_response: | |
interpreted_query = interpretation_response.split("Only provide the complete reformulated question and nothing else:")[-1].strip() | |
else: | |
lines = interpretation_response.split('\n') | |
interpreted_query = lines[-1].strip() if lines else interpretation_response.strip() | |
interpreted_query = interpreted_query.replace('"', '').strip() | |
return interpreted_query | |
except Exception as e: | |
logging.error(f"Error interpreting query '{query}': {e}. Using original query.") | |
return query | |
def _find_relevant_chunks(self, query): | |
if not self.initial_processing_done.is_set(): | |
self.initial_processing_done.wait(timeout=5) | |
interpreted_query = self._interpret_query(query) | |
keywords_to_add = [] | |
lower_query = interpreted_query.lower() | |
if "joy boy" in lower_query or "nika" in lower_query: | |
keywords_to_add.extend(["Hito Hito no Mi Model Nika", "Sun God Nika"]) | |
if "blackbeard" in lower_query and "devil fruit" in lower_query: | |
keywords_to_add.extend(["multiple devil fruits", "Yami Yami no Mi", "Gura Gura no Mi"]) | |
if "gorosei" in lower_query or "im" in lower_query: | |
keywords_to_add.extend(["Five Elders", "Empty Throne", "World Government"]) | |
if "void century" in lower_query: | |
keywords_to_add.extend(["Poneglyphs", "Ancient Kingdom", "Ohara"]) | |
if keywords_to_add: | |
interpreted_query_with_keywords = interpreted_query + " " + " ".join(keywords_to_add) | |
else: | |
interpreted_query_with_keywords = interpreted_query | |
try: | |
query_embedding = self.embedder.encode(interpreted_query_with_keywords).tolist() | |
results = self.chroma_collection.query( | |
query_embeddings=[query_embedding], | |
n_results=MAX_CONTEXT_CHUNKS, | |
include=["documents", "metadatas", "distances"] | |
) | |
except Exception as e: | |
logging.error(f"Error querying ChromaDB: {e}") | |
return [], [] | |
chunks = [] | |
sources = set() | |
if results and results["documents"]: | |
for i, doc in enumerate(results["documents"][0]): | |
distance = results["distances"][0][i] | |
similarity = 1 - distance | |
if similarity >= SIMILARITY_THRESHOLD: | |
chunks.append(doc) | |
sources.add(results["metadatas"][0][i]["source"]) | |
return chunks, list(sources) | |
def _format_history(self): | |
if not self.conversation_history: | |
return "No recent conversation history." | |
history = "Recent conversation history:\n" | |
for i, (q, a) in enumerate(self.conversation_history): | |
history += f"Turn {i+1}:\nUser: {q}\nAssistant: {a}\n" | |
return history | |
def answer_question(self, question: str): | |
logging.info(f"Received question: '{question}'") | |
if not self.initial_processing_done.is_set(): | |
logging.warning("Initial data processing not yet complete. Waiting up to 60s...") | |
if not self.initial_processing_done.wait(timeout=60): | |
logging.error("Initial data processing timed out. Cannot answer reliably.") | |
return "The knowledge base is still loading. Please try again in a few minutes." | |
else: | |
logging.info("Initial data processing finished while waiting.") | |
chunks, sources = self._find_relevant_chunks(question) | |
if not chunks: | |
fallback_prompt = f"""You are an expert on the One Piece manga and anime. The user asked: "{question}". However, no relevant specific information was found in your knowledge base. Provide a general, helpful answer based on your broad understanding of One Piece, or state that you don't have specific information on this topic. Do not invent facts. | |
IMPORTANT: Start immediately with your answer.""" | |
try: | |
response = self.generator(fallback_prompt, max_new_tokens=200, temperature=0.7, do_sample=True)[0]["generated_text"].strip() | |
response = re.sub(r'^.*?IMPORTANT: Start immediately with your answer\.', '', response, flags=re.DOTALL).strip() | |
answer = response | |
self.conversation_history.append((question, answer)) | |
return answer | |
except Exception as e: | |
logging.error(f"Error generating fallback response: {e}") | |
answer = "I couldn't find specific information about that in my knowledge base." | |
self.conversation_history.append((question, answer)) | |
return answer | |
prompt = f"""You are an expert on the One Piece manga and anime. Answer the following question based *only* on the provided context and your knowledge of One Piece lore. | |
{self._format_history()} | |
Context information: | |
{chr(10).join(chunks)} | |
Question: {question} | |
Provide a detailed, accurate answer based on the context above. If the context doesn't contain enough information to fully answer, use your general One Piece knowledge but prioritize information from the context. Explain connections between characters and events clearly. Structure your answer logically. | |
IMPORTANT: Your answer must be directly useful and not include phrases like "based on the context" or "answer the question". Start immediately with your answer. Ensure your answer is cohesive and well-formatted. | |
""" | |
try: | |
response = self.generator(prompt)[0]["generated_text"] | |
answer_match = re.search(r'IMPORTANT:.*?Start immediately with your answer\.(.*)', response, re.DOTALL) | |
if answer_match: | |
answer = answer_match.group(1).strip() | |
else: | |
answer_parts = response.split("Question: " + question) | |
if len(answer_parts) > 1: | |
answer = answer_parts[-1].strip() | |
else: | |
answer = response.strip() | |
answer = re.sub(r'^(.*?)(?:IMPORTANT:|Based on this conversation history:|Context information:|Question:)', '', answer, flags=re.DOTALL | re.IGNORECASE).strip() | |
answer = re.sub(r'\s*Sources:\s*.*$', '', answer, flags=re.DOTALL) | |
except Exception as e: | |
logging.error(f"Error generating response from LLM: {e}") | |
answer = "Sorry, I encountered an error while generating the response." | |
self.conversation_history.append((question, answer)) | |
if sources: | |
clean_sources = [s.replace('_', ' ') for s in sources] | |
sources_list = list(clean_sources)[:5] | |
sources_str = ", ".join(sources_list) | |
return f"{answer}\n\nSources: {sources_str}" | |
return answer | |
# --- FastAPI Application Setup --- | |
app = FastAPI() | |
class QuestionRequest(BaseModel): | |
question: str | |
# --- Global chatbot instance with error handling --- | |
try: | |
chatbot = OnePieceChatbot() | |
logging.info("OnePieceChatbot instance created.") | |
except Exception as e: | |
chatbot = None | |
logging.critical(f"Failed to initialize OnePieceChatbot: {e}") | |
async def read_root(): | |
return {"message": "One Piece Chatbot API is running. Send a POST request to /ask with your question."} | |
async def health_check(): | |
if chatbot is None or not hasattr(chatbot, 'generator') or chatbot.generator is None: | |
return {"status": "error", "message": "LLM model not loaded"}, 500 | |
if not chatbot.initial_processing_done.is_set(): | |
return {"status": "warning", "message": "Initial data processing still in progress. Some answers may be limited."}, 200 | |
try: | |
count = chatbot.chroma_collection.count() | |
if count == 0: | |
return {"status": "warning", "message": "Knowledge base is empty after initialization. Data fetching might have failed."}, 200 | |
return {"status": "ok", "message": "Chatbot is ready.", "knowledge_base_size": count}, 200 | |
except Exception as e: | |
logging.error(f"Health check failed during ChromaDB count: {e}") | |
return {"status": "warning", "message": f"Health check encountered an issue: {e}"}, 200 | |
async def ask_question_endpoint(request: QuestionRequest): | |
if chatbot is None: | |
return {"answer": "Sorry, the chatbot failed to initialize."}, 500 | |
question = request.question | |
if not question or not question.strip(): | |
return {"answer": "Please provide a question."} | |
try: | |
answer = await run_in_threadpool(chatbot.answer_question, question) | |
return {"answer": answer} | |
except Exception as e: | |
logging.error(f"Error processing question '{question}': {e}") | |
return {"answer": "Sorry, an internal error occurred while processing your question."}, 500 | |
# To run locally: | |
# uvicorn app:app --host 0.0.0.0 --port 7860 | |