File size: 3,996 Bytes
4cbe4e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df95e65
4cbe4e9
 
 
 
 
 
 
 
67e19f4
 
 
 
4cbe4e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67e19f4
 
4cbe4e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
from fastapi import FastAPI, HTTPException
from langchain.prompts import PromptTemplate
from pydantic import BaseModel
from typing import Optional
from dotenv import load_dotenv

from embeddings.embeddings import generate_embeddings
from elastic.retrieval import search_certification_chunks
from prompting.rewrite_question import classify_certification, initialize_llms, process_query

load_dotenv()

app = FastAPI(
    title="Hydrogen Certification RAG System",
    description="API for querying hydrogen certification documents using RAG",
    version="0.1.0"
)

# Initialize LLMs and Elasticsearch client
llms = initialize_llms()

# Request models
class QueryRequest(BaseModel):
    query: str


    
llm = initialize_llms()["rewrite_llm"]
    

# Endpoints
@app.post("/query")
async def handle_query(request: QueryRequest):
    """
    Process a query through the full RAG pipeline:
    1. Classify certification (if not provided)
    2. Optimize query based on specificity
    3. Search relevant chunks
    """
    try:
        # Step 1: Determine certification
        query = request.query
        certification = classify_certification(request.query, llms["rewrite_llm"])
        if "no certification mentioned" in certification :
                raise HTTPException(
                    status_code=400,
                    detail="No certification specified in query and none provided"
                )

        # Step 2: Process query
        processed_query = process_query(request.query, llms)
        question_vector = generate_embeddings(processed_query)

        # Step 3: Search
        results = search_certification_chunks(
            index_name="certif_index",
            certification_name=certification,
            text_query=processed_query,
            vector_query=question_vector,
        )
        
        results_ = search_certification_chunks(
            index_name="certification_index",
            certification_name=certification,
            text_query=processed_query,
            vector_query=question_vector,
        )

        results_list = [result["text"] for result in results]
        results_list_ = [result["text"] for result in results_]
        
        
        results_merged = ". ".join([result["text"] for result in results])
        results_merged_ = ". ".join([result["text"] for result in results_])
        
        template = """
        You are an AI assistant tasked with providing answers based on the given context about a specific hydrogen certification. 

        Provide a clear, concise response that directly addresses the question without unnecessary information.

        Question: {question}
        Certification: {certification}
        Context: {context}

        Answer:
        """
        prompt = PromptTemplate(
        input_variables=["question", "certification", "context"],
        template=template
        )
    
        chain = prompt | llm
        answer = chain.invoke({"question": processed_query, "certification": certification, "context": results_merged}).content
        answer_ = chain.invoke({"question": processed_query, "certification": certification, "context": results_merged_}).content
        

        return {
            "certification": certification,
            "certif_index": answer,
            "certification_index": answer_,
            "context_certif": results_list,
            "context_certifications": results_list_
        }

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/certifications", response_model=list[str])
async def list_certifications():
    """List all available certifications"""
    try:
        certs_dir = "docs/processed"
        return [f for f in os.listdir(certs_dir) if os.path.isdir(os.path.join(certs_dir, f))]
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)