Spaces:
Build error
Build error
| from typing import Any, Dict, List | |
| import os | |
| import glob | |
| import shutil | |
| import subprocess | |
| import redis | |
| import torch | |
| import concurrent.futures | |
| import json | |
| from fastapi import FastAPI, HTTPException, UploadFile, WebSocket, WebSocketDisconnect | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel | |
| # langchain | |
| from langchain.chains import RetrievalQA | |
| from langchain.embeddings import HuggingFaceInstructEmbeddings | |
| from langchain.callbacks.base import BaseCallbackHandler | |
| from langchain.schema import LLMResult | |
| from langchain.vectorstores import Chroma | |
| from prompt_template_utils import get_prompt_template | |
| from load_models import load_model | |
| from constants import CHROMA_SETTINGS, EMBEDDING_MODEL_NAME, PERSIST_DIRECTORY, MODEL_ID, MODEL_BASENAME, PATH_NAME_SOURCE_DIRECTORY, SHOW_SOURCES | |
| class Predict(BaseModel): | |
| prompt: str | |
| class Delete(BaseModel): | |
| filename: str | |
| if torch.backends.mps.is_available(): | |
| DEVICE_TYPE = "mps" | |
| elif torch.cuda.is_available(): | |
| DEVICE_TYPE = "cuda" | |
| else: | |
| DEVICE_TYPE = "cpu" | |
| EMBEDDINGS = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={"device": DEVICE_TYPE}) | |
| DB = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=EMBEDDINGS, client_settings=CHROMA_SETTINGS) | |
| RETRIEVER = DB.as_retriever() | |
| redisClient = redis.Redis(host='localhost', port=6379, db=0) | |
| class MyCustomSyncHandler(BaseCallbackHandler): | |
| def __init__(self): | |
| self.message = '' | |
| def on_llm_new_token(self, token: str, **kwargs) -> Any: | |
| self.message += token | |
| redisClient.publish(f'{kwargs["tags"][0]}', self.message) | |
| def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any: | |
| self.message = "end" | |
| redisClient.publish(f'{kwargs["tags"][0]}', self.message) | |
| def on_llm_error( | |
| self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | |
| ) -> Any: | |
| self.message = "end" | |
| redisClient.publish(f'{kwargs["tags"][0]}', self.message) | |
| def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any: | |
| self.message = "end" | |
| redisClient.publish(f'{kwargs["tags"][0]}', self.message) | |
| LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True) | |
| prompt, memory = get_prompt_template(promptTemplate_type="llama", history=True) | |
| QA = RetrievalQA.from_chain_type( | |
| llm=LLM, | |
| chain_type="stuff", | |
| retriever=RETRIEVER, | |
| return_source_documents=SHOW_SOURCES, | |
| chain_type_kwargs={ | |
| "prompt": prompt, | |
| "memory": memory | |
| }, | |
| ) | |
| app = FastAPI(title="homepage-app") | |
| api_app = FastAPI(title="api app") | |
| app.mount("/api", api_app, name="api") | |
| app.mount("/", StaticFiles(directory="static",html = True), name="static") | |
| def run_ingest_route(): | |
| global DB | |
| global RETRIEVER | |
| global QA | |
| try: | |
| if os.path.exists(PERSIST_DIRECTORY): | |
| try: | |
| shutil.rmtree(PERSIST_DIRECTORY) | |
| except OSError as e: | |
| raise HTTPException(status_code=500, detail=f"Error: {e.filename} - {e.strerror}.") | |
| else: | |
| raise HTTPException(status_code=500, detail="The directory does not exist") | |
| run_langest_commands = ["python", "ingest.py"] | |
| if DEVICE_TYPE == "cpu": | |
| run_langest_commands.append("--device_type") | |
| run_langest_commands.append(DEVICE_TYPE) | |
| result = subprocess.run(run_langest_commands, capture_output=True) | |
| if result.returncode != 0: | |
| raise HTTPException(status_code=400, detail="Script execution failed: {}") | |
| # load the vectorstore | |
| DB = Chroma( | |
| persist_directory=PERSIST_DIRECTORY, | |
| embedding_function=EMBEDDINGS, | |
| client_settings=CHROMA_SETTINGS, | |
| ) | |
| RETRIEVER = DB.as_retriever() | |
| QA = RetrievalQA.from_chain_type( | |
| llm=LLM, | |
| chain_type="stuff", | |
| retriever=RETRIEVER, | |
| return_source_documents=SHOW_SOURCES, | |
| chain_type_kwargs={ | |
| "prompt": prompt, | |
| "memory": memory | |
| }, | |
| ) | |
| return {"response": "The training was successfully completed"} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error occurred: {str(e)}") | |
| def get_files(): | |
| upload_dir = os.path.join(os.getcwd(), PATH_NAME_SOURCE_DIRECTORY) | |
| files = glob.glob(os.path.join(upload_dir, '*')) | |
| return {"directory": upload_dir, "files": files} | |
| def delete_source_route(data: Delete): | |
| filename = data.filename | |
| path_source_documents = os.path.join(os.getcwd(), PATH_NAME_SOURCE_DIRECTORY) | |
| file_to_delete = f"{path_source_documents}/{filename}" | |
| if os.path.exists(file_to_delete): | |
| try: | |
| os.remove(file_to_delete) | |
| print(f"{file_to_delete} has been deleted.") | |
| return {"message": f"{file_to_delete} has been deleted."} | |
| except OSError as e: | |
| raise HTTPException(status_code=400, detail=print(f"error: {e}.")) | |
| else: | |
| raise HTTPException(status_code=400, detail=print(f"The file {file_to_delete} does not exist.")) | |
| async def predict(data: Predict): | |
| global QA | |
| user_prompt = data.prompt | |
| if user_prompt: | |
| res = QA(user_prompt) | |
| answer, docs = res["result"], res["source_documents"] | |
| prompt_response_dict = { | |
| "Prompt": user_prompt, | |
| "Answer": answer, | |
| } | |
| prompt_response_dict["Sources"] = [] | |
| for document in docs: | |
| prompt_response_dict["Sources"].append( | |
| (os.path.basename(str(document.metadata["source"])), str(document.page_content)) | |
| ) | |
| return {"response": prompt_response_dict} | |
| else: | |
| raise HTTPException(status_code=400, detail="Prompt Incorrect") | |
| async def create_upload_file(file: UploadFile): | |
| # Get the file size (in bytes) | |
| file.file.seek(0, 2) | |
| file_size = file.file.tell() | |
| # move the cursor back to the beginning | |
| await file.seek(0) | |
| if file_size > 10 * 1024 * 1024: | |
| # more than 10 MB | |
| raise HTTPException(status_code=400, detail="File too large") | |
| content_type = file.content_type | |
| if content_type not in [ | |
| "text/plain", | |
| "text/markdown", | |
| "text/x-markdown", | |
| "text/csv", | |
| "application/msword", | |
| "application/pdf", | |
| "application/vnd.ms-excel", | |
| "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", | |
| "application/vnd.openxmlformats-officedocument.wordprocessingml.document", | |
| "text/x-python", | |
| "application/x-python-code"]: | |
| raise HTTPException(status_code=400, detail="Invalid file type") | |
| upload_dir = os.path.join(os.getcwd(), PATH_NAME_SOURCE_DIRECTORY) | |
| if not os.path.exists(upload_dir): | |
| os.makedirs(upload_dir) | |
| dest = os.path.join(upload_dir, file.filename) | |
| with open(dest, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| return {"filename": file.filename} | |
| async def websocket_endpoint(websocket: WebSocket, client_id: int): | |
| global QA | |
| await websocket.accept() | |
| try: | |
| while True: | |
| prompt = await websocket.receive_text() | |
| pubsub = redisClient.pubsub() | |
| pubsub.subscribe(f'{client_id}') | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| executor.submit(QA(inputs=prompt, return_only_outputs=True, callbacks=[MyCustomSyncHandler()], tags=f'{client_id}', include_run_info=True)) | |
| i = 0 | |
| for item in pubsub.listen(): | |
| i+=1 | |
| print(item["type"] == "message") | |
| if(i > 10 ): pubsub.unsubscribe({client_id}) | |
| # await websocket.send_text(f'{item}') | |
| except WebSocketDisconnect: | |
| print('disconnect') | |
| except RuntimeError as error: | |
| print(error) | |