rag_hydro / prompting /rewrite_question.py
Anas Bader
redo
4cbe4e9
import os
from dotenv import load_dotenv
from langchain.prompts import PromptTemplate
from langchain_groq import ChatGroq
from typing import Literal
# Load environment variables
load_dotenv()
# Initialize LLMs
def initialize_llms():
"""Initialize and return the LLM instances"""
groq_api_key = os.getenv("GROQ_API_KEY")
return {
"rewrite_llm": ChatGroq(
temperature=0.1,
model="llama-3.3-70b-versatile",
api_key=groq_api_key
),
"step_back_llm": ChatGroq(
temperature=0,
model="Gemma2-9B-IT",
api_key=groq_api_key
)
}
# Certification classification
def classify_certification(
query: str,
llm: ChatGroq,
certs_dir: str = "docs/processed"
) -> str:
"""
Classify which certification a query is referring to.
Returns certification name or 'no certification mentioned'.
"""
available_certs = "2BSvs, CertifHy - National Green Certificate (NGC), CertifHy - RFNBO, Certified_Hydrogen_Producer, GH2_Standard, Green_Hydrogen_Certification, ISCC CORSIA, ISCC EU (International Sustainability & Carbon Certification), ISCC PLUS, ISO_19880_Hydrogen_Quality, REDcert-EU, RSB, Scottish Quality Farm Assured Combinable Crops (SQC), TUV Rheinland H2.21, UK RTFO_regulation"
template = """
You are an AI assistant classifying user queries based on the certification they are asking for in a RAG system.
Classify the given query into one of the following certifications:
- {available_certifications}
Don't need any explanation, just return the name of the certification.
Use the exact name of the certification as it appears in the directory.
If the query refers to multiple certifications, return the most relevant one.
If the query doesn't mention any certification, respond with "no certification mentioned".
Original query: {original_query}
Classification:
"""
prompt = PromptTemplate(
input_variables=["original_query", "available_certifications"],
template=template
)
chain = prompt | llm
response = chain.invoke({
"original_query": query,
"available_certifications": available_certs
}).content.strip()
return response
# Query specificity classification
def classify_query_specificity(
query: str,
llm: ChatGroq
) -> Literal["specific", "general", "too narrow"]:
"""
Classify query specificity.
Returns one of: 'specific', 'general', or 'too narrow'.
"""
template = """
You are an AI assistant classifying user queries based on their specificity for a RAG system.
Classify the given query into one of:
- "specific" → If it asks for exact values, certifications, or well-defined facts.
- "general" → If it is broad and needs refinement for better retrieval.
- "too narrow" → If it is very specific and might need broader context.
DO NOT output explanations, only return one of: "specific", "general", or "too narrow".
Original query: {original_query}
Classification:
"""
prompt = PromptTemplate(
input_variables=["original_query"],
template=template
)
chain = prompt | llm
response = chain.invoke({"original_query": query}).content.strip().lower()
return response.split("\n")[0].strip() # type: ignore
# Query refinement
def refine_query(
query: str,
llm: ChatGroq
) -> str:
"""Rewrite a query to be clearer and more detailed while keeping the original intent"""
template = """
You are an AI assistant that improves queries for retrieving precise certification and compliance data.
Rewrite the query to be clearer while keeping the intent unchanged.
Original query: {original_query}
Refined query:
"""
prompt = PromptTemplate(
input_variables=["original_query"],
template=template
)
chain = prompt | llm
return chain.invoke({"original_query": query}).content
# Step-back query generation
def generate_step_back_query(
query: str,
llm: ChatGroq
) -> str:
"""Generate a broader step-back query to retrieve relevant background information"""
template = """
You are an AI assistant generating broader queries to improve retrieval context.
Given the original query, generate a more general step-back query to retrieve relevant background information.
Original query: {original_query}
Step-back query:
"""
prompt = PromptTemplate(
input_variables=["original_query"],
template=template
)
chain = prompt | llm
return chain.invoke({"original_query": query}).content
# Main query processing pipeline
def process_query(
original_query: str,
llms: dict
) -> str:
"""
Process a query through the full pipeline:
1. Classify specificity
2. Apply appropriate refinement
"""
specificity = classify_query_specificity(original_query, llms["rewrite_llm"])
if specificity == "specific":
return refine_query(original_query, llms["rewrite_llm"])
elif specificity == "general":
return refine_query(original_query, llms["rewrite_llm"])
elif specificity == "too narrow":
return generate_step_back_query(original_query, llms["step_back_llm"])
return original_query
# Test setup
def test_hydrogen_certification_functions():
# Initialize LLMs
llms = initialize_llms()
# Create a test directory with hydrogen certifications
test_certs_dir = "docs/processed"
os.makedirs(test_certs_dir, exist_ok=True)
# Create some dummy certification folders
hydrogen_certifications = [
"GH2_Standard",
"Certified_Hydrogen_Producer",
"Green_Hydrogen_Certification",
"ISO_19880_Hydrogen_Quality"
]
for cert in hydrogen_certifications:
os.makedirs(os.path.join(test_certs_dir, cert), exist_ok=True)
# Test queries
test_queries = [
("What are the purity requirements in GH2 Standard?", "specific"),
("How does hydrogen certification work?", "general"),
("What's the exact ppm of CO2 allowed in ISO_19880_Hydrogen_Quality section 4.2?", "too narrow"),
("What safety protocols exist for hydrogen storage?", "general")
]
print("=== Testing Certification Classification ===")
for query, _ in test_queries:
cert = classify_certification(query, llms["rewrite_llm"], test_certs_dir)
print(f"Query: {query}\nClassification: {cert}\n")
print("\n=== Testing Specificity Classification ===")
for query, expected_type in test_queries:
specificity = classify_query_specificity(query, llms["rewrite_llm"])
print(f"Query: {query}\nExpected: {expected_type}, Got: {specificity}\n")
print("\n=== Testing Full Query Processing ===")
for query, _ in test_queries:
processed = process_query(query, llms)
print(f"Original: {query}\nProcessed: {processed}\n")
# Run the tests
if __name__ == "__main__":
test_hydrogen_certification_functions()