|
from fastapi import FastAPI, HTTPException |
|
from fastapi.responses import HTMLResponse |
|
from fastapi.staticfiles import StaticFiles |
|
from pydantic import BaseModel |
|
from typing import List, Optional, Dict, Any |
|
import uvicorn |
|
import logging |
|
import time |
|
import os |
|
import asyncio |
|
from contextlib import asynccontextmanager |
|
from pathlib import Path |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
rag_system = None |
|
system_loading = False |
|
system_load_error = None |
|
|
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
|
|
global rag_system, system_loading, system_load_error |
|
logger.info("Starting Text-to-SQL RAG API with CodeLlama for HF Spaces...") |
|
|
|
|
|
system_loading = True |
|
system_load_error = None |
|
|
|
try: |
|
|
|
from rag_system import VectorStore, SQLRetriever, PromptEngine, SQLGenerator, DataProcessor |
|
|
|
|
|
logger.info("Initializing RAG system components...") |
|
|
|
|
|
logger.info("Initializing vector store...") |
|
vector_store = VectorStore() |
|
|
|
|
|
logger.info("Initializing SQL retriever...") |
|
sql_retriever = SQLRetriever(vector_store) |
|
|
|
|
|
logger.info("Initializing prompt engine...") |
|
prompt_engine = PromptEngine() |
|
|
|
|
|
logger.info("Initializing SQL generator with CodeLlama...") |
|
sql_generator = SQLGenerator(sql_retriever, prompt_engine) |
|
|
|
|
|
logger.info("Initializing data processor...") |
|
data_processor = DataProcessor() |
|
|
|
|
|
rag_system = { |
|
"vector_store": vector_store, |
|
"sql_retriever": sql_retriever, |
|
"prompt_engine": prompt_engine, |
|
"sql_generator": sql_generator, |
|
"data_processor": data_processor |
|
} |
|
|
|
|
|
logger.info("Loading sample data...") |
|
await load_or_create_sample_data(data_processor, vector_store) |
|
|
|
logger.info("All RAG system components initialized successfully!") |
|
|
|
except Exception as e: |
|
logger.error(f"Failed to initialize RAG system: {str(e)}") |
|
system_load_error = str(e) |
|
finally: |
|
system_loading = False |
|
|
|
yield |
|
|
|
logger.info("Shutting down Text-to-SQL RAG API...") |
|
|
|
async def load_or_create_sample_data(data_processor, vector_store): |
|
"""Load existing data or create sample dataset.""" |
|
try: |
|
|
|
examples = data_processor.load_processed_data() |
|
|
|
if examples: |
|
logger.info(f"Loaded {len(examples)} existing examples") |
|
|
|
vector_store.add_examples(examples) |
|
else: |
|
|
|
logger.info("Creating sample dataset...") |
|
sample_data = data_processor.create_sample_dataset() |
|
vector_store.add_examples(sample_data) |
|
logger.info(f"Added {len(sample_data)} sample examples to vector store") |
|
|
|
except Exception as e: |
|
logger.warning(f"Could not load sample data: {e}") |
|
|
|
try: |
|
sample_data = data_processor.create_sample_dataset() |
|
vector_store.add_examples(sample_data) |
|
logger.info(f"Added {len(sample_data)} sample examples to vector store") |
|
except Exception as e2: |
|
logger.error(f"Failed to create sample data: {e2}") |
|
|
|
|
|
app = FastAPI( |
|
title="Text-to-SQL RAG API with CodeLlama", |
|
description="Advanced API for converting natural language questions to SQL queries using RAG and CodeLlama", |
|
version="2.0.0", |
|
lifespan=lifespan |
|
) |
|
|
|
|
|
class SQLRequest(BaseModel): |
|
question: str |
|
table_headers: List[str] |
|
|
|
class SQLResponse(BaseModel): |
|
question: str |
|
table_headers: List[str] |
|
sql_query: str |
|
model_used: str |
|
processing_time: float |
|
retrieved_examples: List[Dict[str, Any]] |
|
status: str |
|
|
|
class BatchRequest(BaseModel): |
|
queries: List[SQLRequest] |
|
|
|
class BatchResponse(BaseModel): |
|
results: List[SQLResponse] |
|
total_queries: int |
|
successful_queries: int |
|
|
|
class HealthResponse(BaseModel): |
|
status: str |
|
system_loaded: bool |
|
system_loading: bool |
|
system_error: Optional[str] = None |
|
model_info: Optional[Dict[str, Any]] = None |
|
timestamp: float |
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
async def root(): |
|
"""Serve the main HTML interface""" |
|
try: |
|
with open("index.html", "r", encoding="utf-8") as f: |
|
return HTMLResponse(content=f.read()) |
|
except FileNotFoundError: |
|
return HTMLResponse(content=""" |
|
<html> |
|
<body> |
|
<h1>Text-to-SQL RAG API with CodeLlama</h1> |
|
<p>Advanced SQL generation using RAG and CodeLlama models</p> |
|
<p>index.html not found. Please ensure the file exists in the same directory.</p> |
|
</body> |
|
</html> |
|
""") |
|
|
|
@app.get("/api", response_model=dict) |
|
async def api_info(): |
|
"""API information endpoint""" |
|
return { |
|
"message": "Text-to-SQL RAG API with CodeLlama", |
|
"version": "2.0.0", |
|
"features": [ |
|
"RAG-enhanced SQL generation", |
|
"CodeLlama as primary model", |
|
"Vector-based example retrieval", |
|
"Advanced prompt engineering" |
|
], |
|
"endpoints": { |
|
"/": "GET - Web interface", |
|
"/api": "GET - API information", |
|
"/predict": "POST - Generate SQL from single question", |
|
"/batch": "POST - Generate SQL from multiple questions", |
|
"/health": "GET - Health check", |
|
"/docs": "GET - API documentation" |
|
} |
|
} |
|
|
|
@app.get("/health", response_model=HealthResponse) |
|
async def health_check(): |
|
"""Health check endpoint""" |
|
global rag_system, system_loading, system_load_error |
|
|
|
model_info = None |
|
if rag_system and "sql_generator" in rag_system: |
|
try: |
|
model_info = rag_system["sql_generator"].get_model_info() |
|
except Exception as e: |
|
logger.warning(f"Could not get model info: {e}") |
|
|
|
return HealthResponse( |
|
status="healthy" if rag_system and not system_loading else "unhealthy", |
|
system_loaded=rag_system is not None, |
|
system_loading=system_loading, |
|
system_error=system_load_error, |
|
model_info=model_info, |
|
timestamp=time.time() |
|
) |
|
|
|
@app.post("/predict", response_model=SQLResponse) |
|
async def predict_sql(request: SQLRequest): |
|
""" |
|
Generate SQL query from a natural language question using RAG and CodeLlama |
|
|
|
Args: |
|
request: SQLRequest containing question and table headers |
|
|
|
Returns: |
|
SQLResponse with generated SQL query and metadata |
|
""" |
|
global rag_system, system_loading, system_load_error |
|
|
|
if system_loading: |
|
raise HTTPException(status_code=503, detail="System is still loading, please try again in a few minutes") |
|
|
|
if rag_system is None: |
|
error_msg = system_load_error or "RAG system not loaded" |
|
raise HTTPException(status_code=503, detail=f"System not available: {error_msg}") |
|
|
|
start_time = time.time() |
|
|
|
try: |
|
|
|
result = rag_system["sql_generator"].generate_sql( |
|
question=request.question, |
|
table_headers=request.table_headers |
|
) |
|
|
|
processing_time = time.time() - start_time |
|
|
|
return SQLResponse( |
|
question=request.question, |
|
table_headers=request.table_headers, |
|
sql_query=result["sql_query"], |
|
model_used=result["model_used"], |
|
processing_time=processing_time, |
|
retrieved_examples=result["retrieved_examples"], |
|
status=result["status"] |
|
) |
|
|
|
except Exception as e: |
|
logger.error(f"Error generating SQL: {str(e)}") |
|
raise HTTPException(status_code=500, detail=f"Error generating SQL: {str(e)}") |
|
|
|
@app.post("/batch", response_model=BatchResponse) |
|
async def batch_predict(request: BatchRequest): |
|
""" |
|
Generate SQL queries from multiple questions using RAG and CodeLlama |
|
|
|
Args: |
|
request: BatchRequest containing list of questions and table headers |
|
|
|
Returns: |
|
BatchResponse with generated SQL queries |
|
""" |
|
global rag_system, system_loading, system_load_error |
|
|
|
if system_loading: |
|
raise HTTPException(status_code=503, detail="System is still loading, please try again in a few minutes") |
|
|
|
if rag_system is None: |
|
error_msg = system_load_error or "RAG system not loaded" |
|
raise HTTPException(status_code=503, detail=f"System not available: {error_msg}") |
|
|
|
start_time = time.time() |
|
|
|
try: |
|
results = [] |
|
successful_count = 0 |
|
|
|
for query in request.queries: |
|
try: |
|
result = rag_system["sql_generator"].generate_sql( |
|
question=query.question, |
|
table_headers=query.table_headers |
|
) |
|
|
|
sql_response = SQLResponse( |
|
question=query.question, |
|
table_headers=query.table_headers, |
|
sql_query=result["sql_query"], |
|
model_used=result["model_used"], |
|
processing_time=result["processing_time"], |
|
retrieved_examples=result["retrieved_examples"], |
|
status=result["status"] |
|
) |
|
|
|
results.append(sql_response) |
|
if result["status"] == "success": |
|
successful_count += 1 |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing query '{query.question}': {str(e)}") |
|
|
|
error_response = SQLResponse( |
|
question=query.question, |
|
table_headers=query.table_headers, |
|
sql_query="", |
|
model_used="none", |
|
processing_time=0.0, |
|
retrieved_examples=[], |
|
status="error" |
|
) |
|
results.append(error_response) |
|
|
|
total_time = time.time() - start_time |
|
|
|
return BatchResponse( |
|
results=results, |
|
total_queries=len(request.queries), |
|
successful_queries=successful_count |
|
) |
|
|
|
except Exception as e: |
|
logger.error(f"Error in batch processing: {str(e)}") |
|
raise HTTPException(status_code=500, detail=f"Error in batch processing: {str(e)}") |
|
|
|
if __name__ == "__main__": |
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|