Spaces:
Sleeping
Sleeping
Update pipeline.py
Browse files- pipeline.py +21 -20
pipeline.py
CHANGED
@@ -13,7 +13,7 @@ from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMMod
|
|
13 |
from pydantic import BaseModel, ValidationError, validator
|
14 |
from mistralai import Mistral
|
15 |
from langchain.prompts import PromptTemplate
|
16 |
-
|
17 |
# Import chains and tools
|
18 |
from classification_chain import get_classification_chain
|
19 |
from cleaner_chain import get_cleaner_chain
|
@@ -25,6 +25,13 @@ from prompts import classification_prompt, refusal_prompt, tailor_prompt
|
|
25 |
mistral_api_key = os.environ.get("MISTRAL_API_KEY")
|
26 |
client = Mistral(api_key=mistral_api_key)
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
# Load spaCy model for NER and download it if not already installed
|
29 |
def install_spacy_model():
|
30 |
try:
|
@@ -131,25 +138,19 @@ def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
|
|
131 |
return vectorstore
|
132 |
|
133 |
# Function to build RAG chain
|
134 |
-
def build_rag_chain(
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
llm=gemini_as_llm,
|
148 |
-
chain_type="stuff",
|
149 |
-
retriever=retriever,
|
150 |
-
return_source_documents=True
|
151 |
-
)
|
152 |
-
return rag_chain
|
153 |
|
154 |
# Function to perform web search using DuckDuckGo
|
155 |
def do_web_search(query: str) -> str:
|
|
|
13 |
from pydantic import BaseModel, ValidationError, validator
|
14 |
from mistralai import Mistral
|
15 |
from langchain.prompts import PromptTemplate
|
16 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
17 |
# Import chains and tools
|
18 |
from classification_chain import get_classification_chain
|
19 |
from cleaner_chain import get_cleaner_chain
|
|
|
25 |
mistral_api_key = os.environ.get("MISTRAL_API_KEY")
|
26 |
client = Mistral(api_key=mistral_api_key)
|
27 |
|
28 |
+
gemini_llm = ChatGoogleGenerativeAI(
|
29 |
+
model="gemini-1.5-pro",
|
30 |
+
temperature=0.5,
|
31 |
+
max_retries=2,
|
32 |
+
google_api_key=os.environ.get("GEMINI_API_KEY"),
|
33 |
+
# Additional parameters or safety_settings can be added here if needed
|
34 |
+
)
|
35 |
# Load spaCy model for NER and download it if not already installed
|
36 |
def install_spacy_model():
|
37 |
try:
|
|
|
138 |
return vectorstore
|
139 |
|
140 |
# Function to build RAG chain
|
141 |
+
def build_rag_chain(vectorstore: FAISS) -> RetrievalQA:
|
142 |
+
"""Build RAG chain using the Gemini LLM directly without a custom class."""
|
143 |
+
try:
|
144 |
+
retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3})
|
145 |
+
chain = RetrievalQA.from_chain_type(
|
146 |
+
llm=gemini_llm, # Directly use the ChatGoogleGenerativeAI instance
|
147 |
+
chain_type="stuff",
|
148 |
+
retriever=retriever,
|
149 |
+
return_source_documents=True
|
150 |
+
)
|
151 |
+
return chain
|
152 |
+
except Exception as e:
|
153 |
+
raise RuntimeError(f"Error building RAG chain: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
|
155 |
# Function to perform web search using DuckDuckGo
|
156 |
def do_web_search(query: str) -> str:
|