File size: 7,149 Bytes
4cbe4e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import os
from dotenv import load_dotenv
from langchain.prompts import PromptTemplate
from langchain_groq import ChatGroq
from typing import Literal

# Load environment variables
load_dotenv()

# Initialize LLMs
def initialize_llms():
    """Initialize and return the LLM instances"""
    groq_api_key = os.getenv("GROQ_API_KEY")
    
    return {
        "rewrite_llm": ChatGroq(
            temperature=0.1,
            model="llama-3.3-70b-versatile",
            api_key=groq_api_key
        ),
        "step_back_llm": ChatGroq(
            temperature=0,
            model="Gemma2-9B-IT",
            api_key=groq_api_key
        )
    }

# Certification classification
def classify_certification(
    query: str, 
    llm: ChatGroq,
    certs_dir: str = "docs/processed"
) -> str:
    """
    Classify which certification a query is referring to.
    Returns certification name or 'no certification mentioned'.
    """
    available_certs = "2BSvs, CertifHy - National Green Certificate (NGC), CertifHy - RFNBO, Certified_Hydrogen_Producer, GH2_Standard, Green_Hydrogen_Certification, ISCC CORSIA, ISCC EU (International Sustainability & Carbon Certification), ISCC PLUS, ISO_19880_Hydrogen_Quality, REDcert-EU, RSB, Scottish Quality Farm Assured Combinable Crops (SQC), TUV Rheinland H2.21, UK RTFO_regulation"
    
    template = """
    You are an AI assistant classifying user queries based on the certification they are asking for in a RAG system.
    Classify the given query into one of the following certifications:
    - {available_certifications}
    
    Don't need any explanation, just return the name of the certification.
    
    Use the exact name of the certification as it appears in the directory.
    If the query refers to multiple certifications, return the most relevant one.

    If the query doesn't mention any certification, respond with "no certification mentioned".

    Original query: {original_query}

    Classification:
    """
    
    prompt = PromptTemplate(
        input_variables=["original_query", "available_certifications"],
        template=template
    )
    
    chain = prompt | llm
    response = chain.invoke({
        "original_query": query,
        "available_certifications": available_certs
    }).content.strip()
    
    return response

# Query specificity classification
def classify_query_specificity(
    query: str, 
    llm: ChatGroq
) -> Literal["specific", "general", "too narrow"]:
    """
    Classify query specificity.
    Returns one of: 'specific', 'general', or 'too narrow'.
    """
    template = """
    You are an AI assistant classifying user queries based on their specificity for a RAG system.
    Classify the given query into one of:
    - "specific" → If it asks for exact values, certifications, or well-defined facts.
    - "general" → If it is broad and needs refinement for better retrieval.
    - "too narrow" → If it is very specific and might need broader context.
    DO NOT output explanations, only return one of: "specific", "general", or "too narrow".

    Original query: {original_query}

    Classification:
    """
    
    prompt = PromptTemplate(
        input_variables=["original_query"],
        template=template
    )
    
    chain = prompt | llm
    response = chain.invoke({"original_query": query}).content.strip().lower()
    return response.split("\n")[0].strip()  # type: ignore

# Query refinement
def refine_query(
    query: str, 
    llm: ChatGroq
) -> str:
    """Rewrite a query to be clearer and more detailed while keeping the original intent"""
    template = """
    You are an AI assistant that improves queries for retrieving precise certification and compliance data.
    Rewrite the query to be clearer while keeping the intent unchanged.

    Original query: {original_query}

    Refined query:
    """
    
    prompt = PromptTemplate(
        input_variables=["original_query"],
        template=template
    )
    
    chain = prompt | llm
    return chain.invoke({"original_query": query}).content

# Step-back query generation
def generate_step_back_query(
    query: str, 
    llm: ChatGroq
) -> str:
    """Generate a broader step-back query to retrieve relevant background information"""
    template = """
    You are an AI assistant generating broader queries to improve retrieval context.
    Given the original query, generate a more general step-back query to retrieve relevant background information.

    Original query: {original_query}

    Step-back query:
    """
    
    prompt = PromptTemplate(
        input_variables=["original_query"],
        template=template
    )
    
    chain = prompt | llm
    return chain.invoke({"original_query": query}).content

# Main query processing pipeline
def process_query(
    original_query: str,
    llms: dict
) -> str:
    """
    Process a query through the full pipeline:
    1. Classify specificity
    2. Apply appropriate refinement
    """
    specificity = classify_query_specificity(original_query, llms["rewrite_llm"])
    
    if specificity == "specific":
        return refine_query(original_query, llms["rewrite_llm"])
    elif specificity == "general":
        return refine_query(original_query, llms["rewrite_llm"])
    elif specificity == "too narrow":
        return generate_step_back_query(original_query, llms["step_back_llm"])
    return original_query

# Test setup
def test_hydrogen_certification_functions():
    # Initialize LLMs
    llms = initialize_llms()
    
    # Create a test directory with hydrogen certifications
    test_certs_dir = "docs/processed"
    os.makedirs(test_certs_dir, exist_ok=True)
    
    # Create some dummy certification folders
    hydrogen_certifications = [
        "GH2_Standard", 
        "Certified_Hydrogen_Producer",
        "Green_Hydrogen_Certification",
        "ISO_19880_Hydrogen_Quality"
    ]
    
    for cert in hydrogen_certifications:
        os.makedirs(os.path.join(test_certs_dir, cert), exist_ok=True)
    
    # Test queries
    test_queries = [
        ("What are the purity requirements in GH2 Standard?", "specific"),
        ("How does hydrogen certification work?", "general"),
        ("What's the exact ppm of CO2 allowed in ISO_19880_Hydrogen_Quality section 4.2?", "too narrow"),
        ("What safety protocols exist for hydrogen storage?", "general")
    ]
    
    print("=== Testing Certification Classification ===")
    for query, _ in test_queries:
        cert = classify_certification(query, llms["rewrite_llm"], test_certs_dir)
        print(f"Query: {query}\nClassification: {cert}\n")
    
    print("\n=== Testing Specificity Classification ===")
    for query, expected_type in test_queries:
        specificity = classify_query_specificity(query, llms["rewrite_llm"])
        print(f"Query: {query}\nExpected: {expected_type}, Got: {specificity}\n")
    
    print("\n=== Testing Full Query Processing ===")
    for query, _ in test_queries:
        processed = process_query(query, llms)
        print(f"Original: {query}\nProcessed: {processed}\n")

# Run the tests
if __name__ == "__main__":
    test_hydrogen_certification_functions()