Phoenix21 commited on
Commit
c06a9ab
·
verified ·
1 Parent(s): 0b20500

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +21 -20
pipeline.py CHANGED
@@ -13,7 +13,7 @@ from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMMod
13
  from pydantic import BaseModel, ValidationError, validator
14
  from mistralai import Mistral
15
  from langchain.prompts import PromptTemplate
16
-
17
  # Import chains and tools
18
  from classification_chain import get_classification_chain
19
  from cleaner_chain import get_cleaner_chain
@@ -25,6 +25,13 @@ from prompts import classification_prompt, refusal_prompt, tailor_prompt
25
  mistral_api_key = os.environ.get("MISTRAL_API_KEY")
26
  client = Mistral(api_key=mistral_api_key)
27
 
 
 
 
 
 
 
 
28
  # Load spaCy model for NER and download it if not already installed
29
  def install_spacy_model():
30
  try:
@@ -131,25 +138,19 @@ def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
131
  return vectorstore
132
 
133
  # Function to build RAG chain
134
- def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
135
- class GeminiLangChainLLM(LLM):
136
- def _call(self, prompt: str, stop: Optional[list] = None, **kwargs) -> str:
137
- messages = [{"role": "user", "content": prompt}]
138
- return llm_model(messages, stop_sequences=stop)
139
-
140
- @property
141
- def _llm_type(self) -> str:
142
- return "custom_gemini"
143
-
144
- retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3})
145
- gemini_as_llm = GeminiLangChainLLM()
146
- rag_chain = RetrievalQA.from_chain_type(
147
- llm=gemini_as_llm,
148
- chain_type="stuff",
149
- retriever=retriever,
150
- return_source_documents=True
151
- )
152
- return rag_chain
153
 
154
  # Function to perform web search using DuckDuckGo
155
  def do_web_search(query: str) -> str:
 
13
  from pydantic import BaseModel, ValidationError, validator
14
  from mistralai import Mistral
15
  from langchain.prompts import PromptTemplate
16
+ from langchain_google_genai import ChatGoogleGenerativeAI
17
  # Import chains and tools
18
  from classification_chain import get_classification_chain
19
  from cleaner_chain import get_cleaner_chain
 
25
  mistral_api_key = os.environ.get("MISTRAL_API_KEY")
26
  client = Mistral(api_key=mistral_api_key)
27
 
28
+ gemini_llm = ChatGoogleGenerativeAI(
29
+ model="gemini-1.5-pro",
30
+ temperature=0.5,
31
+ max_retries=2,
32
+ google_api_key=os.environ.get("GEMINI_API_KEY"),
33
+ # Additional parameters or safety_settings can be added here if needed
34
+ )
35
  # Load spaCy model for NER and download it if not already installed
36
  def install_spacy_model():
37
  try:
 
138
  return vectorstore
139
 
140
  # Function to build RAG chain
141
+ def build_rag_chain(vectorstore: FAISS) -> RetrievalQA:
142
+ """Build RAG chain using the Gemini LLM directly without a custom class."""
143
+ try:
144
+ retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3})
145
+ chain = RetrievalQA.from_chain_type(
146
+ llm=gemini_llm, # Directly use the ChatGoogleGenerativeAI instance
147
+ chain_type="stuff",
148
+ retriever=retriever,
149
+ return_source_documents=True
150
+ )
151
+ return chain
152
+ except Exception as e:
153
+ raise RuntimeError(f"Error building RAG chain: {str(e)}")
 
 
 
 
 
 
154
 
155
  # Function to perform web search using DuckDuckGo
156
  def do_web_search(query: str) -> str: