rag_hydro / app.py
baderanas's picture
return contexts
67e19f4 verified
import os
from fastapi import FastAPI, HTTPException
from langchain.prompts import PromptTemplate
from pydantic import BaseModel
from typing import Optional
from dotenv import load_dotenv
from embeddings.embeddings import generate_embeddings
from elastic.retrieval import search_certification_chunks
from prompting.rewrite_question import classify_certification, initialize_llms, process_query
load_dotenv()
app = FastAPI(
title="Hydrogen Certification RAG System",
description="API for querying hydrogen certification documents using RAG",
version="0.1.0"
)
# Initialize LLMs and Elasticsearch client
llms = initialize_llms()
# Request models
class QueryRequest(BaseModel):
query: str
llm = initialize_llms()["rewrite_llm"]
# Endpoints
@app.post("/query")
async def handle_query(request: QueryRequest):
"""
Process a query through the full RAG pipeline:
1. Classify certification (if not provided)
2. Optimize query based on specificity
3. Search relevant chunks
"""
try:
# Step 1: Determine certification
query = request.query
certification = classify_certification(request.query, llms["rewrite_llm"])
if "no certification mentioned" in certification :
raise HTTPException(
status_code=400,
detail="No certification specified in query and none provided"
)
# Step 2: Process query
processed_query = process_query(request.query, llms)
question_vector = generate_embeddings(processed_query)
# Step 3: Search
results = search_certification_chunks(
index_name="certif_index",
certification_name=certification,
text_query=processed_query,
vector_query=question_vector,
)
results_ = search_certification_chunks(
index_name="certification_index",
certification_name=certification,
text_query=processed_query,
vector_query=question_vector,
)
results_list = [result["text"] for result in results]
results_list_ = [result["text"] for result in results_]
results_merged = ". ".join([result["text"] for result in results])
results_merged_ = ". ".join([result["text"] for result in results_])
template = """
You are an AI assistant tasked with providing answers based on the given context about a specific hydrogen certification.
Provide a clear, concise response that directly addresses the question without unnecessary information.
Question: {question}
Certification: {certification}
Context: {context}
Answer:
"""
prompt = PromptTemplate(
input_variables=["question", "certification", "context"],
template=template
)
chain = prompt | llm
answer = chain.invoke({"question": processed_query, "certification": certification, "context": results_merged}).content
answer_ = chain.invoke({"question": processed_query, "certification": certification, "context": results_merged_}).content
return {
"certification": certification,
"certif_index": answer,
"certification_index": answer_,
"context_certif": results_list,
"context_certifications": results_list_
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/certifications", response_model=list[str])
async def list_certifications():
"""List all available certifications"""
try:
certs_dir = "docs/processed"
return [f for f in os.listdir(certs_dir) if os.path.isdir(os.path.join(certs_dir, f))]
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)