import os import json from typing import List, Dict, Any, Optional from datetime import datetime from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends, Query from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse from buffalo_rag.scraper.scraper import BuffaloScraper from buffalo_rag.embeddings.chunker import DocumentChunker from buffalo_rag.vector_store.db import VectorStore from buffalo_rag.model.rag import BuffaloRAG # Initialize FastAPI app app = FastAPI( title="BuffaloRAG API", description="API for BuffaloRAG - AI Assistant for International Students at University at Buffalo", version="1.0.0" ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Initialize components vector_store = VectorStore() rag = BuffaloRAG(vector_store=vector_store) # Pydantic models class QueryRequest(BaseModel): query: str k: int = 5 categories: Optional[List[str]] = None class QueryResponse(BaseModel): query: str response: str sources: List[Dict[str, Any]] timestamp: str class ScrapeRequest(BaseModel): seed_url: str = "https://www.buffalo.edu/international-student-services.html" max_pages: int = 100 class ScrapeResponse(BaseModel): status: str message: str # Background tasks def run_scraper(seed_url: str, max_pages: int): """Run the web scraper in the background.""" scraper = BuffaloScraper(seed_url=seed_url) scraper.scrape(max_pages=max_pages) # After scraping, update the embeddings and index chunker = DocumentChunker() chunks = chunker.create_chunks() chunker.create_embeddings(chunks) # Reload the vector store global vector_store vector_store = VectorStore() # Update the RAG model global rag rag = BuffaloRAG(vector_store=vector_store) def refresh_index(): """Refresh the vector index in the background.""" chunker = DocumentChunker() chunks = chunker.create_chunks() chunker.create_embeddings(chunks) # Reload the vector store global vector_store vector_store = VectorStore() # Update the RAG model global rag rag = BuffaloRAG(vector_store=vector_store) # Setup static files directory static_dir = os.path.join(os.path.dirname(__file__), "static") os.makedirs(static_dir, exist_ok=True) # Add this after creating the FastAPI app app.mount("/static", StaticFiles(directory=static_dir), name="static") # API endpoints @app.post("/api/ask", response_model=QueryResponse) async def ask(request: QueryRequest): """Ask a question to the RAG system.""" try: response = rag.answer( query=request.query, k=request.k, filter_categories=request.categories ) # Add timestamp response['timestamp'] = datetime.now().isoformat() # Log the query for analytics with open("data/query_log.jsonl", "a") as f: f.write(json.dumps(response) + "\n") return response except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/scrape", response_model=ScrapeResponse) async def scrape(request: ScrapeRequest, background_tasks: BackgroundTasks): """Trigger web scraping.""" try: background_tasks.add_task(run_scraper, request.seed_url, request.max_pages) return { "status": "success", "message": f"Started scraping from {request.seed_url} (max {request.max_pages} pages)" } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/refresh-index", response_model=ScrapeResponse) async def refresh(background_tasks: BackgroundTasks): """Refresh the vector index.""" try: background_tasks.add_task(refresh_index) return { "status": "success", "message": "Started refreshing the vector index" } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # Add a route to serve the React app @app.get("/", include_in_schema=False) async def serve_frontend(): return FileResponse(os.path.join(static_dir, "index.html")) @app.get("/{path:path}", include_in_schema=False) async def serve_frontend_paths(path: str): # First check if the file exists in static directory file_path = os.path.join(static_dir, path) if os.path.isfile(file_path): return FileResponse(file_path) # Otherwise, return index.html for client-side routing return FileResponse(os.path.join(static_dir, "index.html")) # Run the API server if __name__ == "__main__": import uvicorn uvicorn.run("buffalo_rag.api.main:app", host="localhost", port=8000, reload=True)