# import torch | |
# from fastapi import FastAPI, Request, HTTPException, status | |
# import uvicorn | |
# from pydantic import BaseModel, Field | |
# from langchain.chains import RetrievalQA | |
# from langchain_huggingface import HuggingFacePipeline | |
# from langchain.vectorstores import Qdrant | |
# from langchain.embeddings import HuggingFaceEmbeddings | |
# from transformers import pipeline | |
# from qdrant_client import QdrantClient | |
# from llama_cpp import Llama | |
# from langchain_huggingface import HuggingFacePipeline | |
# from langdetect import detect | |
# from contextlib import asynccontextmanager | |
# import logging | |
# from langchain.callbacks.manager import CallbackManager | |
# from langchain.callbacks.base import BaseCallbackHandler | |
# import asyncio | |
# from contextlib import asynccontextmanager | |
# import logging | |
# from huggingface_hub import hf_hub_download | |
# from langchain.llms import LlamaCpp | |
# # === CONFIGURATION === # | |
# MODEL_NAME = "FreedomIntelligence/Apollo-7B" | |
# EMBEDDING_MODEL = "Omartificial-Intelligence-Space/GATE-AraBert-v1" | |
# QDRANT_URL = "https://12efeef2-9f10-4402-9deb-f070977ddfc8.eu-central-1-0.aws.cloud.qdrant.io:6333" | |
# QDRANT_API_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.Jb39rYQW2rSE9RdXrjdzKY6T1RF44XjdQzCvzFkjat4" | |
# COLLECTION_NAME = "arabic_rag_collection" | |
# # === INIT APP === # | |
# # Add this line to enable debug logging | |
# logging.basicConfig(level=logging.DEBUG) | |
# app = FastAPI() | |
# # === LOAD MODEL === # | |
# # model, tokenizer = FastLanguageModel.from_pretrained( | |
# # model_name=MODEL_NAME, | |
# # max_seq_length=2048, | |
# # dtype=torch.float16, | |
# # load_in_4bit=True | |
# # ) | |
# # from transformers import AutoTokenizer, AutoModelForCausalLM | |
# # tokenizer = AutoTokenizer.from_pretrained("FreedomIntelligence/Apollo-7B") | |
# # model = AutoModelForCausalLM.from_pretrained("FreedomIntelligence/Apollo-7B") | |
# # llm = Llama.from_pretrained( | |
# # repo_id="FreedomIntelligence/Apollo-7B-GGUF", | |
# # filename="Apollo-7B-q8_0.gguf", | |
# # ) | |
# # model = Llama.from_pretrained( | |
# # repo_id="FreedomIntelligence/Apollo-7B-GGUF", | |
# # filename="Apollo-7B.Q4_K_S.gguf", # Choose the correct quantization level | |
# # n_ctx=1024, # Adjust context length as per your use case | |
# # n_threads=4, # Adjust the number of threads based on your environment | |
# # chat_format="llama-2" # Or None depending on the model | |
# # ) | |
# # # Define the HuggingFacePipeline to work with the model | |
# # llm_pipeline = pipeline( | |
# # model=model, | |
# # task="text-generation", | |
# # max_new_tokens=1024, | |
# # temperature=0.3 | |
# # ) | |
# model_path = hf_hub_download( | |
# repo_id="FreedomIntelligence/Apollo-7B-GGUF", | |
# filename="Apollo-7B.Q4_K_S.gguf", | |
# local_dir="./models", | |
# local_dir_use_symlinks=False | |
# ) | |
# # https://huggingface.co/FreedomIntelligence/Apollo-7B-GGUF/blob/main/Apollo-7B.Q4_K_S.gguf | |
# llm = LlamaCpp( | |
# model_path=model_path, | |
# temperature=0.3, | |
# max_tokens=200, | |
# n_ctx=1024, | |
# top_p=0.9, | |
# top_k=40, | |
# n_threads=1, | |
# n_batch=1, | |
# low_vram=True, | |
# f16_kv=True, | |
# verbose=True | |
# ) | |
# # Wrap it in HuggingFacePipeline | |
# # hf_llm = HuggingFacePipeline(pipeline=llm) | |
# # === EMBEDDINGS AND VECTOR STORE === # | |
# embedding = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL) | |
# qdrant_client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY) | |
# qdrant_vectorstore = Qdrant( | |
# client=qdrant_client, | |
# collection_name=COLLECTION_NAME, | |
# embeddings=embedding, | |
# ) | |
# retriever = qdrant_vectorstore.as_retriever(search_kwargs={"k": 3}) | |
# qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever, chain_type="stuff") | |
# # llm_pipeline = pipeline( | |
# # model=model, | |
# # tokenizer=tokenizer, | |
# # task="text-generation", | |
# # max_new_tokens=1024, | |
# # temperature=0.3, | |
# # ) | |
# # llm = HuggingFacePipeline(pipeline=llm_pipeline) | |
# # # === EMBEDDINGS + VECTORSTORE === # | |
# # embedding = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL) | |
# # qdrant_client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY) | |
# # qdrant_vectorstore = Qdrant( | |
# # client=qdrant_client, | |
# # collection_name=COLLECTION_NAME, | |
# # embeddings=embedding, | |
# # ) | |
# # retriever = qdrant_vectorstore.as_retriever(search_kwargs={"k": 3}) | |
# # qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever, chain_type="stuff") | |
# def generate_prompt(question): | |
# lang = detect(question) | |
# if lang == "ar": | |
# return f"""أجب على السؤال الطبي التالي بلغة عربية فصحى، بإجابة دقيقة ومفصلة. إذا لم تجد معلومات كافية في السياق، استخدم معرفتك الطبية السابقة. | |
# وتأكد من ان: | |
# - عدم تكرار أي نقطة أو عبارة أو كلمة | |
# - وضوح وسلاسة كل نقطة | |
# - تجنب الحشو والعبارات الزائدة- | |
# السؤال: {question} | |
# الإجابة: | |
# """ | |
# else: | |
# return f"""Answer the following medical question in clear English with a detailed, non-redundant response. Do not repeat ideas, phrases, or restate the question in the answer. If the context lacks relevant information, rely on your prior medical knowledge. If the answer involves multiple points, list them in concise and distinct bullet points: | |
# Question: {question} | |
# Answer:""" | |
# # === API INPUT/OUTPUT === # | |
# class Query(BaseModel): | |
# question: str = Field(..., example="ما هي اسباب تساقط الشعر ؟", min_length=3) | |
# # Setup logging | |
# logging.basicConfig(level=logging.DEBUG) | |
# logger = logging.getLogger(__name__) | |
# # Create startup and shutdown events | |
# @asynccontextmanager | |
# async def lifespan(app: FastAPI): | |
# # Startup: Initialize QA chain and other resources | |
# global qa_chain | |
# try: | |
# # ...existing qa_chain initialization code... | |
# logger.info("Successfully initialized QA chain") | |
# yield | |
# except Exception as e: | |
# logger.error(f"Failed to initialize QA chain: {e}") | |
# raise | |
# finally: | |
# # Cleanup | |
# if 'qa_chain' in globals(): | |
# del qa_chain | |
# if 'qdrant_client' in globals(): | |
# await qdrant_client.close() | |
# logger.info("Cleanup completed") | |
# # Update FastAPI initialization | |
# app = FastAPI(lifespan=lifespan) | |
# @app.get("/") | |
# async def root(): | |
# return {"message": "API is running!"} | |
# # the ask endpoint | |
# @app.post("/ask") | |
# async def ask(query: Query): | |
# try: | |
# logger.debug(f"Processing question: {query.question}") | |
# prompt = generate_prompt(query.question) | |
# # Create callback with longer timeout | |
# timeout_callback = TimeoutCallback(timeout_seconds=60) | |
# # Add timeout to prevent hanging | |
# import asyncio | |
# try: | |
# answer = await asyncio.wait_for( | |
# qa_chain.run(prompt, callbacks=[timeout_callback]), | |
# timeout=60 # seconds | |
# ) | |
# except asyncio.TimeoutError: | |
# raise TimeoutError("LLM chain processing timed out") | |
# logger.debug(f"Raw answer from qa_chain: {answer} ({type(answer)})") | |
# if not answer: | |
# raise ValueError("Empty answer returned from qa_chain") | |
# if not isinstance(answer, str): | |
# answer = str(answer) # Fallback to string for serialization | |
# return { | |
# "status": "success", | |
# "response": answer, | |
# "language": detect(query.question) | |
# } | |
# except TimeoutError as te: | |
# logger.error("Request timed out", exc_info=True) | |
# raise HTTPException( | |
# status_code=status.HTTP_504_GATEWAY_TIMEOUT, | |
# detail={ | |
# "status": "error", | |
# "message": "Request timed out", | |
# "error": str(te) | |
# } | |
# ) | |
# except Exception as e: | |
# logger.error(f"Error processing request: {str(e)}", exc_info=True) | |
# raise HTTPException( | |
# status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
# detail={ | |
# "status": "error", | |
# "message": "Failed to process question", | |
# "error": str(e) | |
# } | |
# ) | |
# # Add TimeoutCallback | |
# class TimeoutCallback(BaseCallbackHandler): | |
# def __init__(self, timeout_seconds: int = 60): # Increased default timeout | |
# super().__init__() | |
# self.timeout_seconds = timeout_seconds | |
# self.start_time = None | |
# async def on_llm_start(self, *args, **kwargs): | |
# self.start_time = asyncio.get_event_loop().time() | |
# async def on_llm_new_token(self, *args, **kwargs): | |
# if asyncio.get_event_loop().time() - self.start_time > self.timeout_seconds: | |
# raise TimeoutError("LLM processing timeout") | |
# if __name__ == "__main__": | |
# import signal | |
# def handle_exit(signum, frame): | |
# print("Shutting down gracefully...") | |
# exit(0) | |
# signal.signal(signal.SIGINT, handle_exit) | |
# uvicorn.run(app, host="0.0.0.0", port=8000) | |
import torch | |
import asyncio | |
import logging | |
import signal | |
import uvicorn | |
import os | |
from fastapi import FastAPI, Request, HTTPException, status | |
from pydantic import BaseModel, Field | |
from langdetect import detect | |
from langchain.chains import RetrievalQA | |
from langchain.vectorstores import Qdrant | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.llms import LlamaCpp | |
from langchain.callbacks.base import BaseCallbackHandler | |
from qdrant_client import QdrantClient | |
from huggingface_hub import hf_hub_download | |
from contextlib import asynccontextmanager | |
# === CONFIGURATION === # | |
from llama_cpp import Llama | |
# REPO_ID = "FreedomIntelligence/Apollo-7B-GGUF" | |
# MODEL_NAME = "FreedomIntelligence/Apollo-7B" | |
# MODEL_FILE = "Apollo-7B.Q4_K_S.gguf" | |
REPO_ID = "RichardErkhov/FreedomIntelligence_-_Apollo-2B-gguf" | |
MODEL_NAME = "FreedomIntelligence/Apollo-2B" | |
MODEL_FILE = "Apollo-2B.IQ4_XS.gguf" | |
EMBEDDING_MODEL = "Omartificial-Intelligence-Space/GATE-AraBert-v1" | |
COLLECTION_NAME = "arabic_rag_collection" | |
QDRANT_URL = os.getenv("QDRANT_URL", "https://12efeef2-9f10-4402-9deb-f070977ddfc8.eu-central-1-0.aws.cloud.qdrant.io:6333") | |
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.Jb39rYQW2rSE9RdXrjdzKY6T1RF44XjdQzCvzFkjat4") | |
# === LOGGING === # | |
logging.basicConfig(level=logging.DEBUG) | |
logger = logging.getLogger(__name__) | |
# === INITIALIZATION === # | |
app = FastAPI() | |
class Query(BaseModel): | |
question: str = Field(..., example="ما هي اسباب تساقط الشعر ؟", min_length=3) | |
class TimeoutCallback(BaseCallbackHandler): | |
def __init__(self, timeout_seconds: int = 60): | |
self.timeout_seconds = timeout_seconds | |
self.start_time = None | |
async def on_llm_start(self, *args, **kwargs): | |
self.start_time = asyncio.get_event_loop().time() | |
async def on_llm_new_token(self, *args, **kwargs): | |
if asyncio.get_event_loop().time() - self.start_time > self.timeout_seconds: | |
raise TimeoutError("LLM processing timeout") | |
# === LIFESPAN STARTUP/SHUTDOWN === # | |
async def lifespan(app: FastAPI): | |
global qa_chain, qdrant_client | |
try: | |
logger.info("Initializing model and vector store...") | |
# Load LLM model | |
model_path = hf_hub_download( | |
repo_id=REPO_ID, | |
filename=MODEL_FILE, | |
local_dir="./models", | |
local_dir_use_symlinks=False | |
) | |
llm = LlamaCpp( | |
model_path=model_path, | |
temperature=0.3, | |
max_tokens=200, | |
n_ctx=1024, | |
top_p=0.9, | |
top_k=40, | |
n_threads=1, | |
n_batch=1, | |
low_vram=True, | |
f16_kv=True, | |
verbose=True | |
) | |
# Setup embeddings and Qdrant | |
embedding = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL) | |
qdrant_client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY) | |
qdrant_vectorstore = Qdrant( | |
client=qdrant_client, | |
collection_name=COLLECTION_NAME, | |
embeddings=embedding, | |
) | |
retriever = qdrant_vectorstore.as_retriever(search_kwargs={"k": 3}) | |
# combine_docs_chain = RefineDocumentsChain.from_llm(llm=llm) | |
# qa_chain = RetrievalQA(combine_documents_chain=combine_docs_chain, retriever=retriever) | |
qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever, chain_type="stuff") | |
logger.info("Model and vector store initialized successfully.") | |
yield | |
except Exception as e: | |
logger.error(f"Initialization error: {e}") | |
raise | |
finally: | |
if 'qdrant_client' in globals(): | |
await qdrant_client.close() | |
logger.info("Shutdown complete.") | |
app = FastAPI(lifespan=lifespan) | |
# === PROMPT GENERATOR === # | |
def generate_prompt(question: str) -> str: | |
lang = detect(question) | |
if lang == "ar": | |
return ( | |
"أجب على السؤال الطبي التالي بلغة عربية فصحى، بإجابة دقيقة ومفصلة. إذا لم تجد معلومات كافية في السياق، استخدم معرفتك الطبية السابقة. \n" | |
"- عدم تكرار أي نقطة أو عبارة أو كلمة\n" | |
"- وضوح وسلاسة كل نقطة\n" | |
"- تجنب الحشو والعبارات الزائدة\n" | |
f"\nالسؤال: {question}\nالإجابة:" | |
) | |
else: | |
return ( | |
"Answer the following medical question in clear English with a detailed, non-redundant response. " | |
"Do not repeat ideas, phrases, or restate the question in the answer. If the context lacks relevant " | |
"information, rely on your prior medical knowledge. If the answer involves multiple points, list them " | |
"in concise and distinct bullet points:\n" | |
f"Question: {question}\nAnswer:" | |
) | |
# === ROUTES === # | |
async def root(): | |
return {"message": "Medical QA API is running!"} | |
async def ask(query: Query): | |
try: | |
logger.debug(f"Received question: {query.question}") | |
prompt = generate_prompt(query.question) | |
timeout_callback = TimeoutCallback(timeout_seconds=60) | |
loop = asyncio.get_event_loop() | |
answer = await asyncio.wait_for( | |
# qa_chain.run(prompt, callbacks=[timeout_callback]), | |
loop.run_in_executor(None, qa_chain.run, prompt), | |
timeout=360 | |
) | |
if not answer: | |
raise ValueError("Empty answer returned from model") | |
if 'Answer:' in answer: | |
response_text = answer.split('Answer:')[-1].strip() | |
elif 'الإجابة:' in answer: | |
response_text = answer.split('الإجابة:')[-1].strip() | |
else: | |
response_text = answer.strip() | |
return { | |
"status": "success", | |
"response": response_text, | |
"language": detect(query.question) | |
} | |
except TimeoutError as te: | |
logger.error("Request timed out", exc_info=True) | |
raise HTTPException( | |
status_code=status.HTTP_504_GATEWAY_TIMEOUT, | |
detail={"status": "error", "message": "Request timed out", "error": str(te)} | |
) | |
except Exception as e: | |
logger.error(f"Unexpected error: {e}", exc_info=True) | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
detail={"status": "error", "message": "Internal server error", "error": str(e)} | |
) | |
# === ENTRYPOINT === # | |
if __name__ == "__main__": | |
def handle_exit(signum, frame): | |
print("Shutting down gracefully...") | |
exit(0) | |
signal.signal(signal.SIGINT, handle_exit) | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |