animesounds's picture
Update main.py
4081001 verified
import os
import re
import requests
import sqlite3
import threading
import time
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from sentence_transformers import SentenceTransformer
from tenacity import retry, stop_after_attempt, wait_exponential
import mwparserfromhell
import logging
import chromadb
from collections import deque
from huggingface_hub import login
from fastapi import FastAPI
from pydantic import BaseModel
from fastapi.concurrency import run_in_threadpool
# --- Configuration ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Use /tmp for all persistent storage (always writable)
CACHE_DIR = "/tmp/one_piece_cache"
DB_PATH = os.path.join(CACHE_DIR, "one_piece_data.db")
CHROMA_DB_PATH = os.path.join(CACHE_DIR, "chroma_db")
LLM_MODEL = "google/gemma-2-2b-it"
EMBED_MODEL = "sentence-transformers/paraphrase-MiniLM-L6-v2"
HF_TOKEN = os.environ.get("HF_TOKEN")
WIKI_CATEGORIES = {
"Characters": ["Straw_Hat_Pirates", "Marines", "Yonko", "Seven_Warlords", "Worst_Generation"],
"Devil_Fruits": ["Paramecia", "Zoan", "Logia"],
"Locations": ["Islands", "Seas", "Grand_Line", "New_World"],
"Story": ["Story_Arcs", "Sagas", "Events"],
"Organizations": ["Pirates", "Crews", "Marines", "World_Government"],
"Concepts": ["Haki", "Void_Century", "Ancient_Weapons", "Will_of_D"]
}
CRUCIAL_PAGES = [
"Monkey_D._Luffy", "Straw_Hat_Pirates", "One_Piece_(Manga)", "Eiichiro_Oda",
"Devil_Fruit", "Haki", "Void_Century", "Gol_D._Roger", "Marines", "Yonko",
"World_Government", "Grand_Line", "New_World", "One_Piece", "Will_of_D",
"Poneglyphs", "Ancient_Weapons", "Roger_Pirates", "God_Valley_Incident",
"Joy_Boy", "Sun_God_Nika", "Laugh_Tale", "Rocks_Pirates", "Revolutionary_Army",
"Hito_Hito_no_Mi,_Model:_Nika", "Gomu_Gomu_no_Mi", "Five_Elders", "Im",
"Marshall_D._Teach", "Blackbeard_Pirates", "Gura_Gura_no_Mi", "Yami_Yami_no_Mi"
]
CHUNK_SIZE_TOKENS = 300
CHUNK_OVERLAP = 2
MAX_CONTEXT_CHUNKS = 10
SIMILARITY_THRESHOLD = 0.35
REFRESH_INTERVAL = 7 * 24 * 3600
CONVERSATION_HISTORY_LENGTH = 6
class OnePieceChatbot:
def __init__(self):
os.makedirs(CACHE_DIR, exist_ok=True)
if HF_TOKEN:
try:
login(token=HF_TOKEN)
logging.info("Successfully logged into Hugging Face.")
except Exception as e:
logging.warning(f"Hugging Face login failed: {e}. Proceeding without explicit login.")
self.db_conn = self._init_db()
self.chroma_client, self.chroma_collection = self._init_chroma()
self.data_lock = threading.Lock()
self.processing_pages = set()
self.initial_processing_done = threading.Event()
try:
self.embedder = SentenceTransformer(EMBED_MODEL)
logging.info(f"Loaded SentenceTransformer model: {EMBED_MODEL}")
self.tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL)
logging.info(f"Loaded Tokenizer: {LLM_MODEL}")
self.model = AutoModelForCausalLM.from_pretrained(
LLM_MODEL,
device_map="auto",
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
)
logging.info(f"Loaded LLM Model: {LLM_MODEL}")
self.generator = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
max_new_tokens=500,
temperature=0.2,
do_sample=True,
repetition_penalty=1.2
)
logging.info("Initialized text generation pipeline.")
except Exception as e:
logging.error(f"Failed to load models: {e}")
raise SystemExit("Failed to load models, exiting.") from e
self.conversation_history = deque(maxlen=CONVERSATION_HISTORY_LENGTH)
logging.info("Starting background data processing thread...")
thread = threading.Thread(target=self._process_wiki_data, daemon=True)
thread.start()
logging.info("Background data processing thread started.")
def _init_db(self):
conn = sqlite3.connect(DB_PATH, check_same_thread=False)
conn.execute("""
CREATE TABLE IF NOT EXISTS wiki_data (
title TEXT PRIMARY KEY,
content TEXT,
category TEXT,
last_fetched REAL,
page_links TEXT
)
""")
conn.execute("CREATE INDEX IF NOT EXISTS idx_category ON wiki_data (category)")
conn.commit()
logging.info(f"SQLite database initialized at {DB_PATH}")
return conn
def _init_chroma(self):
client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
collection_name = "one_piece_knowledge"
try:
collection = client.get_collection(name=collection_name)
logging.info(f"Connected to existing ChromaDB collection: {collection_name}")
except Exception:
collection = client.create_collection(
name=collection_name,
metadata={"hnsw:space": "cosine"}
)
logging.info(f"Created new ChromaDB collection: {collection_name}")
logging.info(f"ChromaDB initialized at {CHROMA_DB_PATH}")
return client, collection
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10))
def _fetch_wiki_page(self, title):
url = f"https://onepiece.fandom.com/api.php?action=parse&page={title}&format=json&prop=wikitext|categories"
try:
response = requests.get(url, timeout=15)
response.raise_for_status()
data = response.json()
if "parse" not in data:
logging.warning(f"Could not parse wiki data for {title}")
return None, [], None
wikitext = data["parse"]["wikitext"]["*"]
parsed = mwparserfromhell.parse(wikitext)
for node in parsed.ifilter_templates():
template_name = str(node.name).strip().lower()
if template_name.startswith('infobox') or 'sidebar' in template_name:
try:
parsed.remove(node)
except ValueError:
pass
links = []
for link in parsed.ifilter_wikilinks():
link_title = str(link.title).split("#")[0].strip()
if ":" not in link_title and len(link_title) > 1 and not link_title.startswith(('File:', 'Category:', 'Template:')):
links.append(link_title)
category = "Other"
if "categories" in data["parse"]:
categories = [cat["*"] for cat in data["parse"]["categories"]]
for cat_type, cat_list in WIKI_CATEGORIES.items():
if any(cat.replace(' ', '_') in [c.replace(' ', '_') for c in categories] for cat in cat_list):
category = cat_type
break
text = parsed.strip_code().strip()
text = re.sub(r'https?://\S+', '', text)
text = re.sub(r'\[\[[^\]]+\]\]', '', text)
text = re.sub(r'\s+', ' ', text).strip()
text = re.sub(r'\n{2,}', '\n\n', text)
return text, links, category
except requests.exceptions.RequestException as e:
logging.error(f"Request error fetching {title}: {e}")
raise
except Exception as e:
logging.error(f"Error processing wiki data for {title}: {e}")
return None, [], None
def _fetch_category_pages(self, category):
url = f"https://onepiece.fandom.com/api.php?action=query&list=categorymembers&cmtitle=Category:{category}&cmlimit=500&format=json"
try:
response = requests.get(url, timeout=20)
response.raise_for_status()
data = response.json()
pages = []
if "query" in data and "categorymembers" in data["query"]:
for member in data["query"]["categorymembers"]:
if member["ns"] == 0 and "title" in member:
pages.append(member["title"])
return pages
except requests.exceptions.RequestException as e:
logging.error(f"Request error fetching category {category}: {e}")
return []
except Exception as e:
logging.error(f"Error processing category {category} members: {e}")
return []
def _process_wiki_data(self):
logging.info("Background processing: Starting data collection...")
processed_count = 0
logging.info("Processing crucial pages...")
for page in CRUCIAL_PAGES:
cur = self.db_conn.execute("SELECT last_fetched FROM wiki_data WHERE title = ?", (page,))
result = cur.fetchone()
if result and time.time() - result[0] < REFRESH_INTERVAL:
processed_count += 1
continue
if self._process_page(page):
processed_count += 1
logging.info("Processing category pages...")
crawled_pages_from_categories = set()
for category_type, categories in WIKI_CATEGORIES.items():
for category in categories:
try:
pages = self._fetch_category_pages(category)
for page in pages:
if page in CRUCIAL_PAGES: continue
if page in crawled_pages_from_categories: continue
crawled_pages_from_categories.add(page)
cur = self.db_conn.execute("SELECT last_fetched FROM wiki_data WHERE title = ?", (page,))
result = cur.fetchone()
if result and time.time() - result[0] < REFRESH_INTERVAL:
processed_count += 1
continue
if self._process_page(page):
processed_count += 1
except Exception as e:
logging.error(f"Error processing category {category}: {e}")
self.initial_processing_done.set()
logging.info(f"Initial data processing finished. Processed {processed_count} pages.")
logging.info(f"Vector collection count after initial processing: {self.chroma_collection.count()}")
while True:
time.sleep(REFRESH_INTERVAL)
logging.info("Starting periodic refresh cycle...")
cur = self.db_conn.execute("SELECT title FROM wiki_data ORDER BY last_fetched ASC LIMIT 200")
pages_to_refresh = [row[0] for row in cur.fetchall()]
logging.info(f"Refreshing {len(pages_to_refresh)} pages.")
for page in pages_to_refresh:
self._process_page(page)
logging.info("Periodic refresh cycle finished.")
def _process_page(self, title):
with self.data_lock:
if title in self.processing_pages:
return False
cur = self.db_conn.execute("SELECT last_fetched FROM wiki_data WHERE title = ?", (title,))
result = cur.fetchone()
if result and time.time() - result[0] < REFRESH_INTERVAL:
return False
self.processing_pages.add(title)
try:
content, links, category = self._fetch_wiki_page(title)
if not content:
return False
with self.data_lock:
self.db_conn.execute(
"INSERT OR REPLACE INTO wiki_data VALUES (?, ?, ?, ?, ?)",
(title, content, category, time.time(), ','.join(links))
)
self.db_conn.commit()
chunks = self._chunk_text(content, title)
if chunks:
try:
embeddings = self.embedder.encode(chunks, convert_to_tensor=False).tolist()
ids = [f"{title}::{i}" for i in range(len(chunks))]
metadatas = [{"source": title, "category": category} for _ in chunks]
try:
old_ids = self.chroma_collection.get(where={"source": title}, include=[])["ids"]
if old_ids:
self.chroma_collection.delete(ids=old_ids)
except Exception as delete_e:
logging.warning(f"Could not delete old chunks for {title} from ChromaDB: {delete_e}")
self.chroma_collection.upsert(
ids=ids,
embeddings=embeddings,
documents=chunks,
metadatas=metadatas
)
except Exception as chroma_e:
logging.error(f"Error adding/updating chunks for {title} in ChromaDB: {chroma_e}")
return False
if links:
for link in links[:10]:
threading.Thread(target=self._process_page, args=(link,), daemon=True).start()
return True
except Exception as e:
logging.error(f"Caught unexpected error during processing of {title}: {e}")
return False
finally:
with self.data_lock:
if title in self.processing_pages:
self.processing_pages.remove(title)
def _chunk_text(self, text, title):
sentences = re.split(r'(?<=[.!?])\s+', text)
chunks, current_sentences = [], []
current_chunk_tokens = 0
if len(sentences) < 2 and len(text.split()) < 50:
return []
for i, sentence in enumerate(sentences):
sentence_tokens = len(self.tokenizer.encode(sentence, add_special_tokens=False))
new_token_count = current_chunk_tokens + sentence_tokens + (1 if current_chunk_tokens > 0 else 0)
if new_token_count > CHUNK_SIZE_TOKENS and current_sentences:
chunk_text = " ".join(current_sentences).strip()
if chunk_text:
chunks.append(chunk_text)
overlap_sentences = current_sentences[-CHUNK_OVERLAP:] if len(current_sentences) > CHUNK_OVERLAP else current_sentences
current_sentences = overlap_sentences
current_chunk_tokens = sum(len(self.tokenizer.encode(s, add_special_tokens=False)) for s in current_sentences)
current_sentences.append(sentence)
current_chunk_tokens += sentence_tokens + (1 if current_chunk_tokens > 0 else 0)
if current_sentences:
chunk_text = " ".join(current_sentences).strip()
if chunk_text:
chunks.append(chunk_text)
return [chunk for chunk in chunks if len(chunk.split()) > 20]
def _interpret_query(self, query):
if not self.initial_processing_done.is_set() or len(query.split()) > 3 and not (query.lower().startswith("and ") or query.lower().startswith("what about ")):
return query
try:
prompt = f"""
Based on this conversation history:
{self._format_history()}
The user asked a question: "{query}"
Please interpret this as a complete, standalone question about One Piece, incorporating context from the history if necessary. Ensure the reformulated question is clear and specific, even if the original query was vague or a follow-up.
Only provide the complete reformulated question and nothing else.
"""
interpretation_response = self.generator(
prompt,
max_new_tokens=50,
temperature=0.5,
do_sample=True,
repetition_penalty=1.1
)[0]["generated_text"]
if "Only provide the complete reformulated question and nothing else:" in interpretation_response:
interpreted_query = interpretation_response.split("Only provide the complete reformulated question and nothing else:")[-1].strip()
else:
lines = interpretation_response.split('\n')
interpreted_query = lines[-1].strip() if lines else interpretation_response.strip()
interpreted_query = interpreted_query.replace('"', '').strip()
return interpreted_query
except Exception as e:
logging.error(f"Error interpreting query '{query}': {e}. Using original query.")
return query
def _find_relevant_chunks(self, query):
if not self.initial_processing_done.is_set():
self.initial_processing_done.wait(timeout=5)
interpreted_query = self._interpret_query(query)
keywords_to_add = []
lower_query = interpreted_query.lower()
if "joy boy" in lower_query or "nika" in lower_query:
keywords_to_add.extend(["Hito Hito no Mi Model Nika", "Sun God Nika"])
if "blackbeard" in lower_query and "devil fruit" in lower_query:
keywords_to_add.extend(["multiple devil fruits", "Yami Yami no Mi", "Gura Gura no Mi"])
if "gorosei" in lower_query or "im" in lower_query:
keywords_to_add.extend(["Five Elders", "Empty Throne", "World Government"])
if "void century" in lower_query:
keywords_to_add.extend(["Poneglyphs", "Ancient Kingdom", "Ohara"])
if keywords_to_add:
interpreted_query_with_keywords = interpreted_query + " " + " ".join(keywords_to_add)
else:
interpreted_query_with_keywords = interpreted_query
try:
query_embedding = self.embedder.encode(interpreted_query_with_keywords).tolist()
results = self.chroma_collection.query(
query_embeddings=[query_embedding],
n_results=MAX_CONTEXT_CHUNKS,
include=["documents", "metadatas", "distances"]
)
except Exception as e:
logging.error(f"Error querying ChromaDB: {e}")
return [], []
chunks = []
sources = set()
if results and results["documents"]:
for i, doc in enumerate(results["documents"][0]):
distance = results["distances"][0][i]
similarity = 1 - distance
if similarity >= SIMILARITY_THRESHOLD:
chunks.append(doc)
sources.add(results["metadatas"][0][i]["source"])
return chunks, list(sources)
def _format_history(self):
if not self.conversation_history:
return "No recent conversation history."
history = "Recent conversation history:\n"
for i, (q, a) in enumerate(self.conversation_history):
history += f"Turn {i+1}:\nUser: {q}\nAssistant: {a}\n"
return history
def answer_question(self, question: str):
logging.info(f"Received question: '{question}'")
if not self.initial_processing_done.is_set():
logging.warning("Initial data processing not yet complete. Waiting up to 60s...")
if not self.initial_processing_done.wait(timeout=60):
logging.error("Initial data processing timed out. Cannot answer reliably.")
return "The knowledge base is still loading. Please try again in a few minutes."
else:
logging.info("Initial data processing finished while waiting.")
chunks, sources = self._find_relevant_chunks(question)
if not chunks:
fallback_prompt = f"""You are an expert on the One Piece manga and anime. The user asked: "{question}". However, no relevant specific information was found in your knowledge base. Provide a general, helpful answer based on your broad understanding of One Piece, or state that you don't have specific information on this topic. Do not invent facts.
IMPORTANT: Start immediately with your answer."""
try:
response = self.generator(fallback_prompt, max_new_tokens=200, temperature=0.7, do_sample=True)[0]["generated_text"].strip()
response = re.sub(r'^.*?IMPORTANT: Start immediately with your answer\.', '', response, flags=re.DOTALL).strip()
answer = response
self.conversation_history.append((question, answer))
return answer
except Exception as e:
logging.error(f"Error generating fallback response: {e}")
answer = "I couldn't find specific information about that in my knowledge base."
self.conversation_history.append((question, answer))
return answer
prompt = f"""You are an expert on the One Piece manga and anime. Answer the following question based *only* on the provided context and your knowledge of One Piece lore.
{self._format_history()}
Context information:
{chr(10).join(chunks)}
Question: {question}
Provide a detailed, accurate answer based on the context above. If the context doesn't contain enough information to fully answer, use your general One Piece knowledge but prioritize information from the context. Explain connections between characters and events clearly. Structure your answer logically.
IMPORTANT: Your answer must be directly useful and not include phrases like "based on the context" or "answer the question". Start immediately with your answer. Ensure your answer is cohesive and well-formatted.
"""
try:
response = self.generator(prompt)[0]["generated_text"]
answer_match = re.search(r'IMPORTANT:.*?Start immediately with your answer\.(.*)', response, re.DOTALL)
if answer_match:
answer = answer_match.group(1).strip()
else:
answer_parts = response.split("Question: " + question)
if len(answer_parts) > 1:
answer = answer_parts[-1].strip()
else:
answer = response.strip()
answer = re.sub(r'^(.*?)(?:IMPORTANT:|Based on this conversation history:|Context information:|Question:)', '', answer, flags=re.DOTALL | re.IGNORECASE).strip()
answer = re.sub(r'\s*Sources:\s*.*$', '', answer, flags=re.DOTALL)
except Exception as e:
logging.error(f"Error generating response from LLM: {e}")
answer = "Sorry, I encountered an error while generating the response."
self.conversation_history.append((question, answer))
if sources:
clean_sources = [s.replace('_', ' ') for s in sources]
sources_list = list(clean_sources)[:5]
sources_str = ", ".join(sources_list)
return f"{answer}\n\nSources: {sources_str}"
return answer
# --- FastAPI Application Setup ---
app = FastAPI()
class QuestionRequest(BaseModel):
question: str
# --- Global chatbot instance with error handling ---
try:
chatbot = OnePieceChatbot()
logging.info("OnePieceChatbot instance created.")
except Exception as e:
chatbot = None
logging.critical(f"Failed to initialize OnePieceChatbot: {e}")
@app.get("/")
async def read_root():
return {"message": "One Piece Chatbot API is running. Send a POST request to /ask with your question."}
@app.get("/health")
async def health_check():
if chatbot is None or not hasattr(chatbot, 'generator') or chatbot.generator is None:
return {"status": "error", "message": "LLM model not loaded"}, 500
if not chatbot.initial_processing_done.is_set():
return {"status": "warning", "message": "Initial data processing still in progress. Some answers may be limited."}, 200
try:
count = chatbot.chroma_collection.count()
if count == 0:
return {"status": "warning", "message": "Knowledge base is empty after initialization. Data fetching might have failed."}, 200
return {"status": "ok", "message": "Chatbot is ready.", "knowledge_base_size": count}, 200
except Exception as e:
logging.error(f"Health check failed during ChromaDB count: {e}")
return {"status": "warning", "message": f"Health check encountered an issue: {e}"}, 200
@app.post("/ask")
async def ask_question_endpoint(request: QuestionRequest):
if chatbot is None:
return {"answer": "Sorry, the chatbot failed to initialize."}, 500
question = request.question
if not question or not question.strip():
return {"answer": "Please provide a question."}
try:
answer = await run_in_threadpool(chatbot.answer_question, question)
return {"answer": answer}
except Exception as e:
logging.error(f"Error processing question '{question}': {e}")
return {"answer": "Sorry, an internal error occurred while processing your question."}, 500
# To run locally:
# uvicorn app:app --host 0.0.0.0 --port 7860