khalednabawi11's picture
Update app/main.py
e16ea6c verified
# import torch
# from fastapi import FastAPI, Request, HTTPException, status
# import uvicorn
# from pydantic import BaseModel, Field
# from langchain.chains import RetrievalQA
# from langchain_huggingface import HuggingFacePipeline
# from langchain.vectorstores import Qdrant
# from langchain.embeddings import HuggingFaceEmbeddings
# from transformers import pipeline
# from qdrant_client import QdrantClient
# from llama_cpp import Llama
# from langchain_huggingface import HuggingFacePipeline
# from langdetect import detect
# from contextlib import asynccontextmanager
# import logging
# from langchain.callbacks.manager import CallbackManager
# from langchain.callbacks.base import BaseCallbackHandler
# import asyncio
# from contextlib import asynccontextmanager
# import logging
# from huggingface_hub import hf_hub_download
# from langchain.llms import LlamaCpp
# # === CONFIGURATION === #
# MODEL_NAME = "FreedomIntelligence/Apollo-7B"
# EMBEDDING_MODEL = "Omartificial-Intelligence-Space/GATE-AraBert-v1"
# QDRANT_URL = "https://12efeef2-9f10-4402-9deb-f070977ddfc8.eu-central-1-0.aws.cloud.qdrant.io:6333"
# QDRANT_API_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.Jb39rYQW2rSE9RdXrjdzKY6T1RF44XjdQzCvzFkjat4"
# COLLECTION_NAME = "arabic_rag_collection"
# # === INIT APP === #
# # Add this line to enable debug logging
# logging.basicConfig(level=logging.DEBUG)
# app = FastAPI()
# # === LOAD MODEL === #
# # model, tokenizer = FastLanguageModel.from_pretrained(
# # model_name=MODEL_NAME,
# # max_seq_length=2048,
# # dtype=torch.float16,
# # load_in_4bit=True
# # )
# # from transformers import AutoTokenizer, AutoModelForCausalLM
# # tokenizer = AutoTokenizer.from_pretrained("FreedomIntelligence/Apollo-7B")
# # model = AutoModelForCausalLM.from_pretrained("FreedomIntelligence/Apollo-7B")
# # llm = Llama.from_pretrained(
# # repo_id="FreedomIntelligence/Apollo-7B-GGUF",
# # filename="Apollo-7B-q8_0.gguf",
# # )
# # model = Llama.from_pretrained(
# # repo_id="FreedomIntelligence/Apollo-7B-GGUF",
# # filename="Apollo-7B.Q4_K_S.gguf", # Choose the correct quantization level
# # n_ctx=1024, # Adjust context length as per your use case
# # n_threads=4, # Adjust the number of threads based on your environment
# # chat_format="llama-2" # Or None depending on the model
# # )
# # # Define the HuggingFacePipeline to work with the model
# # llm_pipeline = pipeline(
# # model=model,
# # task="text-generation",
# # max_new_tokens=1024,
# # temperature=0.3
# # )
# model_path = hf_hub_download(
# repo_id="FreedomIntelligence/Apollo-7B-GGUF",
# filename="Apollo-7B.Q4_K_S.gguf",
# local_dir="./models",
# local_dir_use_symlinks=False
# )
# # https://huggingface.co/FreedomIntelligence/Apollo-7B-GGUF/blob/main/Apollo-7B.Q4_K_S.gguf
# llm = LlamaCpp(
# model_path=model_path,
# temperature=0.3,
# max_tokens=200,
# n_ctx=1024,
# top_p=0.9,
# top_k=40,
# n_threads=1,
# n_batch=1,
# low_vram=True,
# f16_kv=True,
# verbose=True
# )
# # Wrap it in HuggingFacePipeline
# # hf_llm = HuggingFacePipeline(pipeline=llm)
# # === EMBEDDINGS AND VECTOR STORE === #
# embedding = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
# qdrant_client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
# qdrant_vectorstore = Qdrant(
# client=qdrant_client,
# collection_name=COLLECTION_NAME,
# embeddings=embedding,
# )
# retriever = qdrant_vectorstore.as_retriever(search_kwargs={"k": 3})
# qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever, chain_type="stuff")
# # llm_pipeline = pipeline(
# # model=model,
# # tokenizer=tokenizer,
# # task="text-generation",
# # max_new_tokens=1024,
# # temperature=0.3,
# # )
# # llm = HuggingFacePipeline(pipeline=llm_pipeline)
# # # === EMBEDDINGS + VECTORSTORE === #
# # embedding = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
# # qdrant_client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
# # qdrant_vectorstore = Qdrant(
# # client=qdrant_client,
# # collection_name=COLLECTION_NAME,
# # embeddings=embedding,
# # )
# # retriever = qdrant_vectorstore.as_retriever(search_kwargs={"k": 3})
# # qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever, chain_type="stuff")
# def generate_prompt(question):
# lang = detect(question)
# if lang == "ar":
# return f"""أجب على السؤال الطبي التالي بلغة عربية فصحى، بإجابة دقيقة ومفصلة. إذا لم تجد معلومات كافية في السياق، استخدم معرفتك الطبية السابقة.
# وتأكد من ان:
# - عدم تكرار أي نقطة أو عبارة أو كلمة
# - وضوح وسلاسة كل نقطة
# - تجنب الحشو والعبارات الزائدة-
# السؤال: {question}
# الإجابة:
# """
# else:
# return f"""Answer the following medical question in clear English with a detailed, non-redundant response. Do not repeat ideas, phrases, or restate the question in the answer. If the context lacks relevant information, rely on your prior medical knowledge. If the answer involves multiple points, list them in concise and distinct bullet points:
# Question: {question}
# Answer:"""
# # === API INPUT/OUTPUT === #
# class Query(BaseModel):
# question: str = Field(..., example="ما هي اسباب تساقط الشعر ؟", min_length=3)
# # Setup logging
# logging.basicConfig(level=logging.DEBUG)
# logger = logging.getLogger(__name__)
# # Create startup and shutdown events
# @asynccontextmanager
# async def lifespan(app: FastAPI):
# # Startup: Initialize QA chain and other resources
# global qa_chain
# try:
# # ...existing qa_chain initialization code...
# logger.info("Successfully initialized QA chain")
# yield
# except Exception as e:
# logger.error(f"Failed to initialize QA chain: {e}")
# raise
# finally:
# # Cleanup
# if 'qa_chain' in globals():
# del qa_chain
# if 'qdrant_client' in globals():
# await qdrant_client.close()
# logger.info("Cleanup completed")
# # Update FastAPI initialization
# app = FastAPI(lifespan=lifespan)
# @app.get("/")
# async def root():
# return {"message": "API is running!"}
# # the ask endpoint
# @app.post("/ask")
# async def ask(query: Query):
# try:
# logger.debug(f"Processing question: {query.question}")
# prompt = generate_prompt(query.question)
# # Create callback with longer timeout
# timeout_callback = TimeoutCallback(timeout_seconds=60)
# # Add timeout to prevent hanging
# import asyncio
# try:
# answer = await asyncio.wait_for(
# qa_chain.run(prompt, callbacks=[timeout_callback]),
# timeout=60 # seconds
# )
# except asyncio.TimeoutError:
# raise TimeoutError("LLM chain processing timed out")
# logger.debug(f"Raw answer from qa_chain: {answer} ({type(answer)})")
# if not answer:
# raise ValueError("Empty answer returned from qa_chain")
# if not isinstance(answer, str):
# answer = str(answer) # Fallback to string for serialization
# return {
# "status": "success",
# "response": answer,
# "language": detect(query.question)
# }
# except TimeoutError as te:
# logger.error("Request timed out", exc_info=True)
# raise HTTPException(
# status_code=status.HTTP_504_GATEWAY_TIMEOUT,
# detail={
# "status": "error",
# "message": "Request timed out",
# "error": str(te)
# }
# )
# except Exception as e:
# logger.error(f"Error processing request: {str(e)}", exc_info=True)
# raise HTTPException(
# status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
# detail={
# "status": "error",
# "message": "Failed to process question",
# "error": str(e)
# }
# )
# # Add TimeoutCallback
# class TimeoutCallback(BaseCallbackHandler):
# def __init__(self, timeout_seconds: int = 60): # Increased default timeout
# super().__init__()
# self.timeout_seconds = timeout_seconds
# self.start_time = None
# async def on_llm_start(self, *args, **kwargs):
# self.start_time = asyncio.get_event_loop().time()
# async def on_llm_new_token(self, *args, **kwargs):
# if asyncio.get_event_loop().time() - self.start_time > self.timeout_seconds:
# raise TimeoutError("LLM processing timeout")
# if __name__ == "__main__":
# import signal
# def handle_exit(signum, frame):
# print("Shutting down gracefully...")
# exit(0)
# signal.signal(signal.SIGINT, handle_exit)
# uvicorn.run(app, host="0.0.0.0", port=8000)
import torch
import asyncio
import logging
import signal
import uvicorn
import os
from fastapi import FastAPI, Request, HTTPException, status
from pydantic import BaseModel, Field
from langdetect import detect
from langchain.chains import RetrievalQA
from langchain.vectorstores import Qdrant
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms import LlamaCpp
from langchain.callbacks.base import BaseCallbackHandler
from qdrant_client import QdrantClient
from huggingface_hub import hf_hub_download
from contextlib import asynccontextmanager
# === CONFIGURATION === #
from llama_cpp import Llama
# REPO_ID = "FreedomIntelligence/Apollo-7B-GGUF"
# MODEL_NAME = "FreedomIntelligence/Apollo-7B"
# MODEL_FILE = "Apollo-7B.Q4_K_S.gguf"
REPO_ID = "RichardErkhov/FreedomIntelligence_-_Apollo-2B-gguf"
MODEL_NAME = "FreedomIntelligence/Apollo-2B"
MODEL_FILE = "Apollo-2B.IQ4_XS.gguf"
EMBEDDING_MODEL = "Omartificial-Intelligence-Space/GATE-AraBert-v1"
COLLECTION_NAME = "arabic_rag_collection"
QDRANT_URL = os.getenv("QDRANT_URL", "https://12efeef2-9f10-4402-9deb-f070977ddfc8.eu-central-1-0.aws.cloud.qdrant.io:6333")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.Jb39rYQW2rSE9RdXrjdzKY6T1RF44XjdQzCvzFkjat4")
# === LOGGING === #
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
# === INITIALIZATION === #
app = FastAPI()
class Query(BaseModel):
question: str = Field(..., example="ما هي اسباب تساقط الشعر ؟", min_length=3)
class TimeoutCallback(BaseCallbackHandler):
def __init__(self, timeout_seconds: int = 60):
self.timeout_seconds = timeout_seconds
self.start_time = None
async def on_llm_start(self, *args, **kwargs):
self.start_time = asyncio.get_event_loop().time()
async def on_llm_new_token(self, *args, **kwargs):
if asyncio.get_event_loop().time() - self.start_time > self.timeout_seconds:
raise TimeoutError("LLM processing timeout")
# === LIFESPAN STARTUP/SHUTDOWN === #
@asynccontextmanager
async def lifespan(app: FastAPI):
global qa_chain, qdrant_client
try:
logger.info("Initializing model and vector store...")
# Load LLM model
model_path = hf_hub_download(
repo_id=REPO_ID,
filename=MODEL_FILE,
local_dir="./models",
local_dir_use_symlinks=False
)
llm = LlamaCpp(
model_path=model_path,
temperature=0.3,
max_tokens=200,
n_ctx=1024,
top_p=0.9,
top_k=40,
n_threads=1,
n_batch=1,
low_vram=True,
f16_kv=True,
verbose=True
)
# Setup embeddings and Qdrant
embedding = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
qdrant_client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
qdrant_vectorstore = Qdrant(
client=qdrant_client,
collection_name=COLLECTION_NAME,
embeddings=embedding,
)
retriever = qdrant_vectorstore.as_retriever(search_kwargs={"k": 3})
# combine_docs_chain = RefineDocumentsChain.from_llm(llm=llm)
# qa_chain = RetrievalQA(combine_documents_chain=combine_docs_chain, retriever=retriever)
qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever, chain_type="stuff")
logger.info("Model and vector store initialized successfully.")
yield
except Exception as e:
logger.error(f"Initialization error: {e}")
raise
finally:
if 'qdrant_client' in globals():
await qdrant_client.close()
logger.info("Shutdown complete.")
app = FastAPI(lifespan=lifespan)
# === PROMPT GENERATOR === #
def generate_prompt(question: str) -> str:
lang = detect(question)
if lang == "ar":
return (
"أجب على السؤال الطبي التالي بلغة عربية فصحى، بإجابة دقيقة ومفصلة. إذا لم تجد معلومات كافية في السياق، استخدم معرفتك الطبية السابقة. \n"
"- عدم تكرار أي نقطة أو عبارة أو كلمة\n"
"- وضوح وسلاسة كل نقطة\n"
"- تجنب الحشو والعبارات الزائدة\n"
f"\nالسؤال: {question}\nالإجابة:"
)
else:
return (
"Answer the following medical question in clear English with a detailed, non-redundant response. "
"Do not repeat ideas, phrases, or restate the question in the answer. If the context lacks relevant "
"information, rely on your prior medical knowledge. If the answer involves multiple points, list them "
"in concise and distinct bullet points:\n"
f"Question: {question}\nAnswer:"
)
# === ROUTES === #
@app.get("/")
async def root():
return {"message": "Medical QA API is running!"}
@app.post("/ask")
async def ask(query: Query):
try:
logger.debug(f"Received question: {query.question}")
prompt = generate_prompt(query.question)
timeout_callback = TimeoutCallback(timeout_seconds=60)
loop = asyncio.get_event_loop()
answer = await asyncio.wait_for(
# qa_chain.run(prompt, callbacks=[timeout_callback]),
loop.run_in_executor(None, qa_chain.run, prompt),
timeout=360
)
if not answer:
raise ValueError("Empty answer returned from model")
if 'Answer:' in answer:
response_text = answer.split('Answer:')[-1].strip()
elif 'الإجابة:' in answer:
response_text = answer.split('الإجابة:')[-1].strip()
else:
response_text = answer.strip()
return {
"status": "success",
"response": response_text,
"language": detect(query.question)
}
except TimeoutError as te:
logger.error("Request timed out", exc_info=True)
raise HTTPException(
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
detail={"status": "error", "message": "Request timed out", "error": str(te)}
)
except Exception as e:
logger.error(f"Unexpected error: {e}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"status": "error", "message": "Internal server error", "error": str(e)}
)
# === ENTRYPOINT === #
if __name__ == "__main__":
def handle_exit(signum, frame):
print("Shutting down gracefully...")
exit(0)
signal.signal(signal.SIGINT, handle_exit)
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)