import os from dotenv import load_dotenv from langchain.prompts import PromptTemplate from langchain_groq import ChatGroq from typing import Literal # Load environment variables load_dotenv() # Initialize LLMs def initialize_llms(): """Initialize and return the LLM instances""" groq_api_key = os.getenv("GROQ_API_KEY") return { "rewrite_llm": ChatGroq( temperature=0.1, model="llama-3.3-70b-versatile", api_key=groq_api_key ), "step_back_llm": ChatGroq( temperature=0, model="Gemma2-9B-IT", api_key=groq_api_key ) } # Certification classification def classify_certification( query: str, llm: ChatGroq, certs_dir: str = "docs/processed" ) -> str: """ Classify which certification a query is referring to. Returns certification name or 'no certification mentioned'. """ available_certs = "2BSvs, CertifHy - National Green Certificate (NGC), CertifHy - RFNBO, Certified_Hydrogen_Producer, GH2_Standard, Green_Hydrogen_Certification, ISCC CORSIA, ISCC EU (International Sustainability & Carbon Certification), ISCC PLUS, ISO_19880_Hydrogen_Quality, REDcert-EU, RSB, Scottish Quality Farm Assured Combinable Crops (SQC), TUV Rheinland H2.21, UK RTFO_regulation" template = """ You are an AI assistant classifying user queries based on the certification they are asking for in a RAG system. Classify the given query into one of the following certifications: - {available_certifications} Don't need any explanation, just return the name of the certification. Use the exact name of the certification as it appears in the directory. If the query refers to multiple certifications, return the most relevant one. If the query doesn't mention any certification, respond with "no certification mentioned". Original query: {original_query} Classification: """ prompt = PromptTemplate( input_variables=["original_query", "available_certifications"], template=template ) chain = prompt | llm response = chain.invoke({ "original_query": query, "available_certifications": available_certs }).content.strip() return response # Query specificity classification def classify_query_specificity( query: str, llm: ChatGroq ) -> Literal["specific", "general", "too narrow"]: """ Classify query specificity. Returns one of: 'specific', 'general', or 'too narrow'. """ template = """ You are an AI assistant classifying user queries based on their specificity for a RAG system. Classify the given query into one of: - "specific" → If it asks for exact values, certifications, or well-defined facts. - "general" → If it is broad and needs refinement for better retrieval. - "too narrow" → If it is very specific and might need broader context. DO NOT output explanations, only return one of: "specific", "general", or "too narrow". Original query: {original_query} Classification: """ prompt = PromptTemplate( input_variables=["original_query"], template=template ) chain = prompt | llm response = chain.invoke({"original_query": query}).content.strip().lower() return response.split("\n")[0].strip() # type: ignore # Query refinement def refine_query( query: str, llm: ChatGroq ) -> str: """Rewrite a query to be clearer and more detailed while keeping the original intent""" template = """ You are an AI assistant that improves queries for retrieving precise certification and compliance data. Rewrite the query to be clearer while keeping the intent unchanged. Original query: {original_query} Refined query: """ prompt = PromptTemplate( input_variables=["original_query"], template=template ) chain = prompt | llm return chain.invoke({"original_query": query}).content # Step-back query generation def generate_step_back_query( query: str, llm: ChatGroq ) -> str: """Generate a broader step-back query to retrieve relevant background information""" template = """ You are an AI assistant generating broader queries to improve retrieval context. Given the original query, generate a more general step-back query to retrieve relevant background information. Original query: {original_query} Step-back query: """ prompt = PromptTemplate( input_variables=["original_query"], template=template ) chain = prompt | llm return chain.invoke({"original_query": query}).content # Main query processing pipeline def process_query( original_query: str, llms: dict ) -> str: """ Process a query through the full pipeline: 1. Classify specificity 2. Apply appropriate refinement """ specificity = classify_query_specificity(original_query, llms["rewrite_llm"]) if specificity == "specific": return refine_query(original_query, llms["rewrite_llm"]) elif specificity == "general": return refine_query(original_query, llms["rewrite_llm"]) elif specificity == "too narrow": return generate_step_back_query(original_query, llms["step_back_llm"]) return original_query # Test setup def test_hydrogen_certification_functions(): # Initialize LLMs llms = initialize_llms() # Create a test directory with hydrogen certifications test_certs_dir = "docs/processed" os.makedirs(test_certs_dir, exist_ok=True) # Create some dummy certification folders hydrogen_certifications = [ "GH2_Standard", "Certified_Hydrogen_Producer", "Green_Hydrogen_Certification", "ISO_19880_Hydrogen_Quality" ] for cert in hydrogen_certifications: os.makedirs(os.path.join(test_certs_dir, cert), exist_ok=True) # Test queries test_queries = [ ("What are the purity requirements in GH2 Standard?", "specific"), ("How does hydrogen certification work?", "general"), ("What's the exact ppm of CO2 allowed in ISO_19880_Hydrogen_Quality section 4.2?", "too narrow"), ("What safety protocols exist for hydrogen storage?", "general") ] print("=== Testing Certification Classification ===") for query, _ in test_queries: cert = classify_certification(query, llms["rewrite_llm"], test_certs_dir) print(f"Query: {query}\nClassification: {cert}\n") print("\n=== Testing Specificity Classification ===") for query, expected_type in test_queries: specificity = classify_query_specificity(query, llms["rewrite_llm"]) print(f"Query: {query}\nExpected: {expected_type}, Got: {specificity}\n") print("\n=== Testing Full Query Processing ===") for query, _ in test_queries: processed = process_query(query, llms) print(f"Original: {query}\nProcessed: {processed}\n") # Run the tests if __name__ == "__main__": test_hydrogen_certification_functions()