|
import os |
|
from dotenv import load_dotenv |
|
from langchain.prompts import PromptTemplate |
|
from langchain_groq import ChatGroq |
|
from typing import Literal |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
def initialize_llms(): |
|
"""Initialize and return the LLM instances""" |
|
groq_api_key = os.getenv("GROQ_API_KEY") |
|
|
|
return { |
|
"rewrite_llm": ChatGroq( |
|
temperature=0.1, |
|
model="llama-3.3-70b-versatile", |
|
api_key=groq_api_key |
|
), |
|
"step_back_llm": ChatGroq( |
|
temperature=0, |
|
model="Gemma2-9B-IT", |
|
api_key=groq_api_key |
|
) |
|
} |
|
|
|
|
|
def classify_certification( |
|
query: str, |
|
llm: ChatGroq, |
|
certs_dir: str = "docs/processed" |
|
) -> str: |
|
""" |
|
Classify which certification a query is referring to. |
|
Returns certification name or 'no certification mentioned'. |
|
""" |
|
available_certs = "2BSvs, CertifHy - National Green Certificate (NGC), CertifHy - RFNBO, Certified_Hydrogen_Producer, GH2_Standard, Green_Hydrogen_Certification, ISCC CORSIA, ISCC EU (International Sustainability & Carbon Certification), ISCC PLUS, ISO_19880_Hydrogen_Quality, REDcert-EU, RSB, Scottish Quality Farm Assured Combinable Crops (SQC), TUV Rheinland H2.21, UK RTFO_regulation" |
|
|
|
template = """ |
|
You are an AI assistant classifying user queries based on the certification they are asking for in a RAG system. |
|
Classify the given query into one of the following certifications: |
|
- {available_certifications} |
|
|
|
Don't need any explanation, just return the name of the certification. |
|
|
|
Use the exact name of the certification as it appears in the directory. |
|
If the query refers to multiple certifications, return the most relevant one. |
|
|
|
If the query doesn't mention any certification, respond with "no certification mentioned". |
|
|
|
Original query: {original_query} |
|
|
|
Classification: |
|
""" |
|
|
|
prompt = PromptTemplate( |
|
input_variables=["original_query", "available_certifications"], |
|
template=template |
|
) |
|
|
|
chain = prompt | llm |
|
response = chain.invoke({ |
|
"original_query": query, |
|
"available_certifications": available_certs |
|
}).content.strip() |
|
|
|
return response |
|
|
|
|
|
def classify_query_specificity( |
|
query: str, |
|
llm: ChatGroq |
|
) -> Literal["specific", "general", "too narrow"]: |
|
""" |
|
Classify query specificity. |
|
Returns one of: 'specific', 'general', or 'too narrow'. |
|
""" |
|
template = """ |
|
You are an AI assistant classifying user queries based on their specificity for a RAG system. |
|
Classify the given query into one of: |
|
- "specific" → If it asks for exact values, certifications, or well-defined facts. |
|
- "general" → If it is broad and needs refinement for better retrieval. |
|
- "too narrow" → If it is very specific and might need broader context. |
|
DO NOT output explanations, only return one of: "specific", "general", or "too narrow". |
|
|
|
Original query: {original_query} |
|
|
|
Classification: |
|
""" |
|
|
|
prompt = PromptTemplate( |
|
input_variables=["original_query"], |
|
template=template |
|
) |
|
|
|
chain = prompt | llm |
|
response = chain.invoke({"original_query": query}).content.strip().lower() |
|
return response.split("\n")[0].strip() |
|
|
|
|
|
def refine_query( |
|
query: str, |
|
llm: ChatGroq |
|
) -> str: |
|
"""Rewrite a query to be clearer and more detailed while keeping the original intent""" |
|
template = """ |
|
You are an AI assistant that improves queries for retrieving precise certification and compliance data. |
|
Rewrite the query to be clearer while keeping the intent unchanged. |
|
|
|
Original query: {original_query} |
|
|
|
Refined query: |
|
""" |
|
|
|
prompt = PromptTemplate( |
|
input_variables=["original_query"], |
|
template=template |
|
) |
|
|
|
chain = prompt | llm |
|
return chain.invoke({"original_query": query}).content |
|
|
|
|
|
def generate_step_back_query( |
|
query: str, |
|
llm: ChatGroq |
|
) -> str: |
|
"""Generate a broader step-back query to retrieve relevant background information""" |
|
template = """ |
|
You are an AI assistant generating broader queries to improve retrieval context. |
|
Given the original query, generate a more general step-back query to retrieve relevant background information. |
|
|
|
Original query: {original_query} |
|
|
|
Step-back query: |
|
""" |
|
|
|
prompt = PromptTemplate( |
|
input_variables=["original_query"], |
|
template=template |
|
) |
|
|
|
chain = prompt | llm |
|
return chain.invoke({"original_query": query}).content |
|
|
|
|
|
def process_query( |
|
original_query: str, |
|
llms: dict |
|
) -> str: |
|
""" |
|
Process a query through the full pipeline: |
|
1. Classify specificity |
|
2. Apply appropriate refinement |
|
""" |
|
specificity = classify_query_specificity(original_query, llms["rewrite_llm"]) |
|
|
|
if specificity == "specific": |
|
return refine_query(original_query, llms["rewrite_llm"]) |
|
elif specificity == "general": |
|
return refine_query(original_query, llms["rewrite_llm"]) |
|
elif specificity == "too narrow": |
|
return generate_step_back_query(original_query, llms["step_back_llm"]) |
|
return original_query |
|
|
|
|
|
def test_hydrogen_certification_functions(): |
|
|
|
llms = initialize_llms() |
|
|
|
|
|
test_certs_dir = "docs/processed" |
|
os.makedirs(test_certs_dir, exist_ok=True) |
|
|
|
|
|
hydrogen_certifications = [ |
|
"GH2_Standard", |
|
"Certified_Hydrogen_Producer", |
|
"Green_Hydrogen_Certification", |
|
"ISO_19880_Hydrogen_Quality" |
|
] |
|
|
|
for cert in hydrogen_certifications: |
|
os.makedirs(os.path.join(test_certs_dir, cert), exist_ok=True) |
|
|
|
|
|
test_queries = [ |
|
("What are the purity requirements in GH2 Standard?", "specific"), |
|
("How does hydrogen certification work?", "general"), |
|
("What's the exact ppm of CO2 allowed in ISO_19880_Hydrogen_Quality section 4.2?", "too narrow"), |
|
("What safety protocols exist for hydrogen storage?", "general") |
|
] |
|
|
|
print("=== Testing Certification Classification ===") |
|
for query, _ in test_queries: |
|
cert = classify_certification(query, llms["rewrite_llm"], test_certs_dir) |
|
print(f"Query: {query}\nClassification: {cert}\n") |
|
|
|
print("\n=== Testing Specificity Classification ===") |
|
for query, expected_type in test_queries: |
|
specificity = classify_query_specificity(query, llms["rewrite_llm"]) |
|
print(f"Query: {query}\nExpected: {expected_type}, Got: {specificity}\n") |
|
|
|
print("\n=== Testing Full Query Processing ===") |
|
for query, _ in test_queries: |
|
processed = process_query(query, llms) |
|
print(f"Original: {query}\nProcessed: {processed}\n") |
|
|
|
|
|
if __name__ == "__main__": |
|
test_hydrogen_certification_functions() |