import asyncio import logging from dotenv import load_dotenv from typing import Literal from jinja2 import Environment, TemplateNotFound import nltk import warnings import os from fastapi import Depends, FastAPI, BackgroundTasks, HTTPException, Request, Response from fastapi.staticfiles import StaticFiles import api.solutions from dependencies import get_llm_router, get_prompt_templates, init_dependencies import api.docs import api.requirements from api.docs import docx_to_txt from schemas import * from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse, StreamingResponse from litellm.router import Router load_dotenv() logging.basicConfig( level=logging.DEBUG if (os.environ.get( "DEBUG_LOG", "0") == "1") else logging.INFO, format='[%(asctime)s][%(levelname)s][%(filename)s:%(lineno)d]: %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) # Initialize global dependencies init_dependencies() # Download required packages for NLTK nltk.download('stopwords') nltk.download('punkt_tab') nltk.download('wordnet') warnings.filterwarnings("ignore") app = FastAPI(title="Requirements Extractor", docs_url="/apidocs") app.add_middleware(CORSMiddleware, allow_credentials=True, allow_headers=[ "*"], allow_methods=["*"], allow_origins=["*"]) app.include_router(api.docs.router, prefix="/docs") app.include_router(api.requirements.router, prefix="/requirements") app.include_router(api.solutions.router, prefix="/solutions") # INTERNAL ROUTE TO RETRIEVE PROMPT TEMPLATES FOR PRIVATE COMPUTE @app.get("/prompt/{task}", include_in_schema=True) async def retrieve_prompt(task: str, prompt_env: Environment = Depends(get_prompt_templates)): """Retrieves a prompt for client-side private inference""" try: logging.debug(f"Retrieving template for on device private task {task}.") prompt, filename, _ = prompt_env.loader.get_source( prompt_env, f"private/{task}.txt") return prompt except TemplateNotFound as _: return Response(content="", status_code=404) app.mount("/", StaticFiles(directory="static", html=True), name="static")