Spaces:
Running
Running
File size: 8,490 Bytes
10b392a cf4d43c 10b392a cf4d43c 10b392a cf4d43c 10b392a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
# src/query_service/api.py
from fastapi import FastAPI, HTTPException, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware # Import CORSMiddleware
from pydantic import BaseModel
from src.retrieval_handler.retriever import RetrievalHandler
from src.llm_integrator.llm import LLMIntegrator
from src.embedding_generator.embedder import EmbeddingGenerator
from src.vector_store_manager.chroma_manager import ChromaManager
import logging
from typing import Literal, Optional, Dict, Any, List # Import List
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
import shutil
import uuid
logger = logging.getLogger(__name__)
# Initialize core components (these should ideally be dependency injected in a larger app)
# For simplicity in this example, we initialize them globally.
embedding_generator: Optional[EmbeddingGenerator] = None
vector_store_manager: Optional[ChromaManager] = None
retrieval_handler: Optional[RetrievalHandler] = None
llm_integrator: Optional[LLMIntegrator] = None
try:
embedding_generator = EmbeddingGenerator()
vector_store_manager = ChromaManager(embedding_generator)
retrieval_handler = RetrievalHandler(embedding_generator, vector_store_manager)
llm_integrator = LLMIntegrator()
logger.info("Initialized core RAG components.")
except Exception as e:
logger.critical(f"Failed to initialize core RAG components: {e}")
# Depending on severity, you might want to exit or raise an error here
# For a production API, you might want to return a 500 error on relevant endpoints
# if components fail to initialize, rather than crashing the app startup.
app = FastAPI(
title="Insight AI API",
description="API for querying financial information.",
version="1.0.0"
)
# --- CORS Middleware ---
# Add CORSMiddleware to allow cross-origin requests from your frontend.
# For development, you can allow all origins (*).
# For production, you should restrict this to your frontend's specific origin(s).
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins. Change this to your frontend's URL in production.
allow_credentials=True,
allow_methods=["*"], # Allows all methods (GET, POST, OPTIONS, etc.)
allow_headers=["*"], # Allows all headers
)
# -----------------------
class Message(BaseModel):
role: Literal['user', 'assistant', 'system']
content: str
class QueryRequest(BaseModel):
query: str
chat_history: Optional[List[Message]] = []
filters: Optional[Dict[str, Any]] = None # Allow passing metadata filters
# Define interfaces matching the backend response structure
class SourceMetadata(BaseModel):
source: Optional[str] = None
ruling_date: Optional[str] = None
# Add other expected metadata fields here
# Example: topic: Optional[str] = None
class RetrievedSource(BaseModel):
content_snippet: str
metadata: Optional[SourceMetadata] = None
class QueryResponse(BaseModel):
answer: str
retrieved_sources: Optional[List[RetrievedSource]] = None
class TitleResponse(BaseModel):
title: str
class TitleRequest(BaseModel):
query: str
@app.post("/query", response_model=QueryResponse)
async def query_rulings(request: QueryRequest):
"""
Receives a user query and returns a generated answer based on retrieved rulings.
"""
logger.info(f"Received query: {request.query}")
if request.filters:
logger.info(f"Received filters: {request.filters}")
# Check if RAG components were initialized successfully
if not retrieval_handler or not llm_integrator:
logger.error("RAG components not initialized.")
raise HTTPException(status_code=500, detail="System components not ready.")
try:
# 1. Retrieve relevant documents based on the query and filters
# Pass filters if your RetrievalHandler/ChromaManager supports using them in search
# Current simple implementation in RetrievalHandler doesn't directly use filters in invoke,
# requires adjustment in RetrievalHandler.retrieve_documents if needed.
retrieved_docs = retrieval_handler.retrieve_documents(request.query, filters=request.filters)
if not retrieved_docs:
logger.warning("No relevant documents retrieved for query.")
return QueryResponse(answer="Could not find relevant rulings for your query.")
# Convert chat_history to appropriate LangChain message types
chat_history = []
logger.debug(f"Raw chat history input: {request.chat_history}")
for msg in request.chat_history:
logger.debug(f"Processing message - Role: {msg.role}, Content: {msg.content[:50]}...")
if msg.role == "user":
new_msg = HumanMessage(content=msg.content)
elif msg.role == "assistant":
new_msg = AIMessage(content=msg.content)
elif msg.role == "system":
new_msg = SystemMessage(content=msg.content)
else:
logger.warning(f"Invalid message role: {msg.role}. Skipping message.")
continue
logger.debug(f"Converted to: {type(new_msg).__name__}")
chat_history.append(new_msg)
logger.debug(f"Final chat history types: {[type(m).__name__ for m in chat_history]}")
# 2. Generate response using the LLM based on the query, retrieved context, and chat history
answer = llm_integrator.generate_response(request.query, retrieved_docs, chat_history)
# 3. Prepare retrieved source information for the response
retrieved_sources = []
for doc in retrieved_docs:
# Ensure the structure matches the RetrievedSource Pydantic model
source_metadata = SourceMetadata(**doc.metadata) if doc.metadata else None
retrieved_sources.append(RetrievedSource(
content_snippet=doc.page_content[:500] + "..." if len(doc.page_content) > 500 else doc.page_content, # Snippet of content
metadata=source_metadata # Include all metadata
))
logger.info("Successfully processed query and generated response.")
return QueryResponse(answer=answer, retrieved_sources=retrieved_sources)
except Exception as e:
logger.error(f"An error occurred during query processing: {e}")
# Provide a more informative but secure error message to the user.
raise HTTPException(status_code=500, detail="An internal error occurred while processing your query.")
@app.post("/generate-title", response_model=TitleResponse)
async def generate_chat_title(request: TitleRequest):
try:
title = llm_integrator.generate_chat_title(request.query)
return {"title": title}
except Exception as e:
logger.error(f"Title generation error: {e}")
return {"title": "New Chat"}
@app.post("/upload-docs")
async def upload_docs(files: list[UploadFile] = File(...)):
"""
Upload new documents and trigger ingestion for them.
"""
import os
from src.ingestion_orchestrator.orchestrator import IngestionOrchestrator
# Create a unique folder in /tmp
upload_id = str(uuid.uuid4())
upload_dir = f"/tmp/ingest_{upload_id}"
os.makedirs(upload_dir, exist_ok=True)
saved_files = []
for file in files:
file_path = os.path.join(upload_dir, file.filename)
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
saved_files.append(file.filename)
# Run the ingestion pipeline for the uploaded folder
try:
orchestrator = IngestionOrchestrator()
orchestrator.run_ingestion_pipeline(docs_folder=upload_dir)
logger.info(f"Ingested files: {saved_files}")
return {"status": "success", "files": saved_files}
except Exception as e:
logger.error(f"Ingestion failed: {e}")
raise HTTPException(status_code=500, detail="Ingestion failed.")
# You can add more endpoints here, e.g., /health for health checks
# @app.get("/health")
# async def health_check():
# # Check connectivity to ChromaDB, LLM service, etc.
# # This requires adding health check methods to your ChromaManager and LLMIntegrator
# chroma_status = vector_store_manager.check_health() if vector_store_manager else "uninitialized"
# llm_status = llm_integrator.check_health() if llm_integrator else "uninitialized"
# return {"chroma": chroma_status, "llm": llm_status}
|