khalednabawi11 commited on
Commit
87179d2
·
verified ·
1 Parent(s): 6ac1aed

Create app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +442 -0
app/main.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import torch
2
+ # from fastapi import FastAPI, Request, HTTPException, status
3
+ # import uvicorn
4
+ # from pydantic import BaseModel, Field
5
+ # from langchain.chains import RetrievalQA
6
+ # from langchain_huggingface import HuggingFacePipeline
7
+ # from langchain.vectorstores import Qdrant
8
+ # from langchain.embeddings import HuggingFaceEmbeddings
9
+ # from transformers import pipeline
10
+ # from qdrant_client import QdrantClient
11
+ # from llama_cpp import Llama
12
+ # from langchain_huggingface import HuggingFacePipeline
13
+ # from langdetect import detect
14
+ # from contextlib import asynccontextmanager
15
+ # import logging
16
+ # from langchain.callbacks.manager import CallbackManager
17
+ # from langchain.callbacks.base import BaseCallbackHandler
18
+ # import asyncio
19
+ # from contextlib import asynccontextmanager
20
+ # import logging
21
+ # from huggingface_hub import hf_hub_download
22
+ # from langchain.llms import LlamaCpp
23
+
24
+
25
+ # # === CONFIGURATION === #
26
+ # MODEL_NAME = "FreedomIntelligence/Apollo-7B"
27
+ # EMBEDDING_MODEL = "Omartificial-Intelligence-Space/GATE-AraBert-v1"
28
+ # QDRANT_URL = "https://12efeef2-9f10-4402-9deb-f070977ddfc8.eu-central-1-0.aws.cloud.qdrant.io:6333"
29
+ # QDRANT_API_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.Jb39rYQW2rSE9RdXrjdzKY6T1RF44XjdQzCvzFkjat4"
30
+ # COLLECTION_NAME = "arabic_rag_collection"
31
+
32
+ # # === INIT APP === #
33
+ # # Add this line to enable debug logging
34
+ # logging.basicConfig(level=logging.DEBUG)
35
+
36
+ # app = FastAPI()
37
+
38
+ # # === LOAD MODEL === #
39
+ # # model, tokenizer = FastLanguageModel.from_pretrained(
40
+ # # model_name=MODEL_NAME,
41
+ # # max_seq_length=2048,
42
+ # # dtype=torch.float16,
43
+ # # load_in_4bit=True
44
+ # # )
45
+
46
+ # # from transformers import AutoTokenizer, AutoModelForCausalLM
47
+
48
+ # # tokenizer = AutoTokenizer.from_pretrained("FreedomIntelligence/Apollo-7B")
49
+ # # model = AutoModelForCausalLM.from_pretrained("FreedomIntelligence/Apollo-7B")
50
+
51
+
52
+ # # llm = Llama.from_pretrained(
53
+ # # repo_id="FreedomIntelligence/Apollo-7B-GGUF",
54
+ # # filename="Apollo-7B-q8_0.gguf",
55
+ # # )
56
+
57
+ # # model = Llama.from_pretrained(
58
+ # # repo_id="FreedomIntelligence/Apollo-7B-GGUF",
59
+ # # filename="Apollo-7B.Q4_K_S.gguf", # Choose the correct quantization level
60
+ # # n_ctx=1024, # Adjust context length as per your use case
61
+ # # n_threads=4, # Adjust the number of threads based on your environment
62
+ # # chat_format="llama-2" # Or None depending on the model
63
+ # # )
64
+
65
+ # # # Define the HuggingFacePipeline to work with the model
66
+ # # llm_pipeline = pipeline(
67
+ # # model=model,
68
+ # # task="text-generation",
69
+ # # max_new_tokens=1024,
70
+ # # temperature=0.3
71
+ # # )
72
+
73
+
74
+ # model_path = hf_hub_download(
75
+ # repo_id="FreedomIntelligence/Apollo-7B-GGUF",
76
+ # filename="Apollo-7B.Q4_K_S.gguf",
77
+ # local_dir="./models",
78
+ # local_dir_use_symlinks=False
79
+ # )
80
+ # # https://huggingface.co/FreedomIntelligence/Apollo-7B-GGUF/blob/main/Apollo-7B.Q4_K_S.gguf
81
+
82
+ # llm = LlamaCpp(
83
+ # model_path=model_path,
84
+ # temperature=0.3,
85
+ # max_tokens=200,
86
+ # n_ctx=1024,
87
+ # top_p=0.9,
88
+ # top_k=40,
89
+ # n_threads=1,
90
+ # n_batch=1,
91
+ # low_vram=True,
92
+ # f16_kv=True,
93
+ # verbose=True
94
+ # )
95
+
96
+ # # Wrap it in HuggingFacePipeline
97
+ # # hf_llm = HuggingFacePipeline(pipeline=llm)
98
+
99
+ # # === EMBEDDINGS AND VECTOR STORE === #
100
+ # embedding = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
101
+ # qdrant_client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
102
+ # qdrant_vectorstore = Qdrant(
103
+ # client=qdrant_client,
104
+ # collection_name=COLLECTION_NAME,
105
+ # embeddings=embedding,
106
+ # )
107
+
108
+ # retriever = qdrant_vectorstore.as_retriever(search_kwargs={"k": 3})
109
+ # qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever, chain_type="stuff")
110
+
111
+ # # llm_pipeline = pipeline(
112
+ # # model=model,
113
+ # # tokenizer=tokenizer,
114
+ # # task="text-generation",
115
+ # # max_new_tokens=1024,
116
+ # # temperature=0.3,
117
+ # # )
118
+
119
+ # # llm = HuggingFacePipeline(pipeline=llm_pipeline)
120
+
121
+ # # # === EMBEDDINGS + VECTORSTORE === #
122
+ # # embedding = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
123
+ # # qdrant_client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
124
+ # # qdrant_vectorstore = Qdrant(
125
+ # # client=qdrant_client,
126
+ # # collection_name=COLLECTION_NAME,
127
+ # # embeddings=embedding,
128
+ # # )
129
+
130
+ # # retriever = qdrant_vectorstore.as_retriever(search_kwargs={"k": 3})
131
+ # # qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever, chain_type="stuff")
132
+
133
+ # def generate_prompt(question):
134
+ # lang = detect(question)
135
+ # if lang == "ar":
136
+ # return f"""أجب على السؤال الطبي التالي بلغة عربية فصحى، بإجابة دقيقة ومفصلة. إذا لم تجد معلومات كافية في السياق، استخدم معرفتك الطبية السابقة.
137
+ # وتأكد من ان:
138
+ # - عدم تكرار أي نقطة أو عبارة أو كلمة
139
+ # - وضوح وسلاسة كل نقطة
140
+ # - تجنب الحشو والعبارات الزائدة-
141
+
142
+ # السؤال: {question}
143
+ # الإجابة:
144
+ # """
145
+
146
+ # else:
147
+ # return f"""Answer the following medical question in clear English with a detailed, non-redundant response. Do not repeat ideas, phrases, or restate the question in the answer. If the context lacks relevant information, rely on your prior medical knowledge. If the answer involves multiple points, list them in concise and distinct bullet points:
148
+ # Question: {question}
149
+ # Answer:"""
150
+
151
+
152
+
153
+ # # === API INPUT/OUTPUT === #
154
+ # class Query(BaseModel):
155
+ # question: str = Field(..., example="ما هي اسباب تساقط الشعر ؟", min_length=3)
156
+
157
+
158
+ # # Setup logging
159
+ # logging.basicConfig(level=logging.DEBUG)
160
+ # logger = logging.getLogger(__name__)
161
+
162
+ # # Create startup and shutdown events
163
+ # @asynccontextmanager
164
+ # async def lifespan(app: FastAPI):
165
+ # # Startup: Initialize QA chain and other resources
166
+ # global qa_chain
167
+ # try:
168
+ # # ...existing qa_chain initialization code...
169
+ # logger.info("Successfully initialized QA chain")
170
+ # yield
171
+ # except Exception as e:
172
+ # logger.error(f"Failed to initialize QA chain: {e}")
173
+ # raise
174
+ # finally:
175
+ # # Cleanup
176
+ # if 'qa_chain' in globals():
177
+ # del qa_chain
178
+ # if 'qdrant_client' in globals():
179
+ # await qdrant_client.close()
180
+ # logger.info("Cleanup completed")
181
+
182
+ # # Update FastAPI initialization
183
+ # app = FastAPI(lifespan=lifespan)
184
+
185
+ # @app.get("/")
186
+ # async def root():
187
+ # return {"message": "API is running!"}
188
+
189
+ # # the ask endpoint
190
+ # @app.post("/ask")
191
+ # async def ask(query: Query):
192
+ # try:
193
+ # logger.debug(f"Processing question: {query.question}")
194
+ # prompt = generate_prompt(query.question)
195
+
196
+ # # Create callback with longer timeout
197
+ # timeout_callback = TimeoutCallback(timeout_seconds=60)
198
+
199
+ # # Add timeout to prevent hanging
200
+ # import asyncio
201
+
202
+ # try:
203
+ # answer = await asyncio.wait_for(
204
+ # qa_chain.run(prompt, callbacks=[timeout_callback]),
205
+ # timeout=60 # seconds
206
+ # )
207
+ # except asyncio.TimeoutError:
208
+ # raise TimeoutError("LLM chain processing timed out")
209
+
210
+ # logger.debug(f"Raw answer from qa_chain: {answer} ({type(answer)})")
211
+
212
+ # if not answer:
213
+ # raise ValueError("Empty answer returned from qa_chain")
214
+
215
+ # if not isinstance(answer, str):
216
+ # answer = str(answer) # Fallback to string for serialization
217
+
218
+ # return {
219
+ # "status": "success",
220
+ # "response": answer,
221
+ # "language": detect(query.question)
222
+ # }
223
+
224
+ # except TimeoutError as te:
225
+ # logger.error("Request timed out", exc_info=True)
226
+ # raise HTTPException(
227
+ # status_code=status.HTTP_504_GATEWAY_TIMEOUT,
228
+ # detail={
229
+ # "status": "error",
230
+ # "message": "Request timed out",
231
+ # "error": str(te)
232
+ # }
233
+ # )
234
+ # except Exception as e:
235
+ # logger.error(f"Error processing request: {str(e)}", exc_info=True)
236
+ # raise HTTPException(
237
+ # status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
238
+ # detail={
239
+ # "status": "error",
240
+ # "message": "Failed to process question",
241
+ # "error": str(e)
242
+ # }
243
+ # )
244
+
245
+ # # Add TimeoutCallback
246
+ # class TimeoutCallback(BaseCallbackHandler):
247
+ # def __init__(self, timeout_seconds: int = 60): # Increased default timeout
248
+ # super().__init__()
249
+ # self.timeout_seconds = timeout_seconds
250
+ # self.start_time = None
251
+
252
+ # async def on_llm_start(self, *args, **kwargs):
253
+ # self.start_time = asyncio.get_event_loop().time()
254
+
255
+ # async def on_llm_new_token(self, *args, **kwargs):
256
+ # if asyncio.get_event_loop().time() - self.start_time > self.timeout_seconds:
257
+ # raise TimeoutError("LLM processing timeout")
258
+
259
+ # if __name__ == "__main__":
260
+ # import signal
261
+
262
+ # def handle_exit(signum, frame):
263
+ # print("Shutting down gracefully...")
264
+ # exit(0)
265
+
266
+ # signal.signal(signal.SIGINT, handle_exit)
267
+ # uvicorn.run(app, host="0.0.0.0", port=8000)
268
+
269
+
270
+ import torch
271
+ import asyncio
272
+ import logging
273
+ import signal
274
+ import uvicorn
275
+
276
+ from fastapi import FastAPI, Request, HTTPException, status
277
+ from pydantic import BaseModel, Field
278
+ from langdetect import detect
279
+ from langchain.chains import RetrievalQA
280
+ from langchain.vectorstores import Qdrant
281
+ from langchain.embeddings import HuggingFaceEmbeddings
282
+ from langchain.llms import LlamaCpp
283
+ from langchain.callbacks.base import BaseCallbackHandler
284
+ from qdrant_client import QdrantClient
285
+ from huggingface_hub import hf_hub_download
286
+ from contextlib import asynccontextmanager
287
+
288
+ # === CONFIGURATION === #
289
+ MODEL_NAME = "FreedomIntelligence/Apollo-7B"
290
+ MODEL_FILE = "Apollo-7B.Q4_K_S.gguf"
291
+ EMBEDDING_MODEL = "Omartificial-Intelligence-Space/GATE-AraBert-v1"
292
+ QDRANT_URL = "https://12efeef2-9f10-4402-9deb-f070977ddfc8.eu-central-1-0.aws.cloud.qdrant.io:6333"
293
+ QDRANT_API_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.Jb39rYQW2rSE9RdXrjdzKY6T1RF44XjdQzCvzFkjat4"
294
+ COLLECTION_NAME = "arabic_rag_collection"
295
+
296
+ # === LOGGING === #
297
+ logging.basicConfig(level=logging.DEBUG)
298
+ logger = logging.getLogger(__name__)
299
+
300
+ # === INITIALIZATION === #
301
+ app = FastAPI()
302
+
303
+ class Query(BaseModel):
304
+ question: str = Field(..., example="ما هي اسباب تساقط الشعر ؟", min_length=3)
305
+
306
+ class TimeoutCallback(BaseCallbackHandler):
307
+ def __init__(self, timeout_seconds: int = 60):
308
+ self.timeout_seconds = timeout_seconds
309
+ self.start_time = None
310
+
311
+ async def on_llm_start(self, *args, **kwargs):
312
+ self.start_time = asyncio.get_event_loop().time()
313
+
314
+ async def on_llm_new_token(self, *args, **kwargs):
315
+ if asyncio.get_event_loop().time() - self.start_time > self.timeout_seconds:
316
+ raise TimeoutError("LLM processing timeout")
317
+
318
+
319
+ # === LIFESPAN STARTUP/SHUTDOWN === #
320
+ @asynccontextmanager
321
+ async def lifespan(app: FastAPI):
322
+ global qa_chain, qdrant_client
323
+
324
+ try:
325
+ logger.info("Initializing model and vector store...")
326
+
327
+ # Load LLM model
328
+ model_path = hf_hub_download(
329
+ repo_id="FreedomIntelligence/Apollo-7B-GGUF",
330
+ filename=MODEL_FILE,
331
+ local_dir="./models",
332
+ local_dir_use_symlinks=False
333
+ )
334
+ llm = LlamaCpp(
335
+ model_path=model_path,
336
+ temperature=0.3,
337
+ max_tokens=200,
338
+ n_ctx=1024,
339
+ top_p=0.9,
340
+ top_k=40,
341
+ n_threads=1,
342
+ n_batch=1,
343
+ low_vram=True,
344
+ f16_kv=True,
345
+ verbose=True
346
+ )
347
+
348
+ # Setup embeddings and Qdrant
349
+ embedding = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
350
+ qdrant_client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
351
+ qdrant_vectorstore = Qdrant(
352
+ client=qdrant_client,
353
+ collection_name=COLLECTION_NAME,
354
+ embeddings=embedding,
355
+ )
356
+
357
+ retriever = qdrant_vectorstore.as_retriever(search_kwargs={"k": 3})
358
+ qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever, chain_type="stuff")
359
+
360
+ logger.info("Model and vector store initialized successfully.")
361
+ yield
362
+
363
+ except Exception as e:
364
+ logger.error(f"Initialization error: {e}")
365
+ raise
366
+
367
+ finally:
368
+ if 'qdrant_client' in globals():
369
+ await qdrant_client.close()
370
+ logger.info("Shutdown complete.")
371
+
372
+ app = FastAPI(lifespan=lifespan)
373
+
374
+ # === PROMPT GENERATOR === #
375
+ def generate_prompt(question: str) -> str:
376
+ lang = detect(question)
377
+ if lang == "ar":
378
+ return (
379
+ "أجب على السؤال الطبي التالي بلغة عربية فصحى، بإجابة دقيقة ومفصلة. إذا لم تجد معلومات كافية في السياق، استخدم معرفتك الطبية السابقة. \n"
380
+ "- عدم تكرار أي نقطة أو عبارة أو كلمة\n"
381
+ "- وضوح وسلاسة كل نقطة\n"
382
+ "- تجنب الحشو والعبارات الزائدة\n"
383
+ f"\nالسؤال: {question}\nالإجابة:"
384
+ )
385
+ else:
386
+ return (
387
+ "Answer the following medical question in clear English with a detailed, non-redundant response. "
388
+ "Do not repeat ideas, phrases, or restate the question in the answer. If the context lacks relevant "
389
+ "information, rely on your prior medical knowledge. If the answer involves multiple points, list them "
390
+ "in concise and distinct bullet points:\n"
391
+ f"Question: {question}\nAnswer:"
392
+ )
393
+
394
+ # === ROUTES === #
395
+ @app.get("/")
396
+ async def root():
397
+ return {"message": "Medical QA API is running!"}
398
+
399
+ @app.post("/ask")
400
+ async def ask(query: Query):
401
+ try:
402
+ logger.debug(f"Received question: {query.question}")
403
+ prompt = generate_prompt(query.question)
404
+ timeout_callback = TimeoutCallback(timeout_seconds=60)
405
+
406
+ answer = await asyncio.wait_for(
407
+ qa_chain.run(prompt, callbacks=[timeout_callback]),
408
+ timeout=60
409
+ )
410
+
411
+ if not answer:
412
+ raise ValueError("Empty answer returned from model")
413
+
414
+ return {
415
+ "status": "success",
416
+ "response": str(answer),
417
+ "language": detect(query.question)
418
+ }
419
+
420
+ except TimeoutError as te:
421
+ logger.error("Request timed out", exc_info=True)
422
+ raise HTTPException(
423
+ status_code=status.HTTP_504_GATEWAY_TIMEOUT,
424
+ detail={"status": "error", "message": "Request timed out", "error": str(te)}
425
+ )
426
+
427
+ except Exception as e:
428
+ logger.error(f"Unexpected error: {e}", exc_info=True)
429
+ raise HTTPException(
430
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
431
+ detail={"status": "error", "message": "Internal server error", "error": str(e)}
432
+ )
433
+
434
+ # === ENTRYPOINT === #
435
+ if __name__ == "__main__":
436
+ def handle_exit(signum, frame):
437
+ print("Shutting down gracefully...")
438
+ exit(0)
439
+
440
+ signal.signal(signal.SIGINT, handle_exit)
441
+ import uvicorn
442
+ uvicorn.run(app, host="0.0.0.0", port=8000)