| # classification_chain.py | |
| import os | |
| from langchain.chains import LLMChain | |
| from langchain_groq import ChatGroq | |
| # We'll import the classification_prompt from prompts.py | |
| from prompts import classification_prompt | |
| def get_classification_chain() -> LLMChain: | |
| """ | |
| Builds the classification chain (LLMChain) using ChatGroq and the classification prompt. | |
| """ | |
| # Initialize the ChatGroq model (Gemma2-9b-It) with your GROQ_API_KEY | |
| chat_groq_model = ChatGroq( | |
| model="Gemma2-9b-It", | |
| groq_api_key=os.environ["GROQ_API_KEY"] # must be set in environment | |
| ) | |
| # Build an LLMChain | |
| classification_chain = LLMChain( | |
| llm=chat_groq_model, | |
| prompt=classification_prompt | |
| ) | |
| return classification_chain | |