ChatBotAgenticRAG1 / classification_chain.py
Phoenix21's picture
Update classification_chain.py
6c91d6e verified
raw
history blame contribute delete
810 Bytes
# classification_chain.py
import os
from langchain.chains import LLMChain
from langchain_groq import ChatGroq
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from prompts import classification_prompt
def get_classification_chain() -> LLMChain:
"""
Builds the classification chain (LLMChain) using ChatGroq and the classification prompt.
"""
# Initialize the ChatGroq model (Gemma2-9b-It) with your GROQ_API_KEY
chat_groq_model = ChatGroq(
model="Gemma2-9b-It",
groq_api_key=os.environ["GROQ_API_KEY"] # must be set in environment
)
output_parser=StrOutputParser()
# Build an LLMChain
classification_chain = classification_prompt|chat_groq_model|output_parser
return classification_chain