reab5555 commited on
Commit
57e1f9e
·
verified ·
1 Parent(s): a755f9a

Update processing.py

Browse files
Files changed (1) hide show
  1. processing.py +10 -8
processing.py CHANGED
@@ -7,8 +7,9 @@ from config import openai_api_key
7
  from langchain.chains import RetrievalQA
8
  from langchain.prompts import PromptTemplate
9
  from langchain_core.runnables import RunnablePassthrough, RunnableLambda
10
- from typing import List
11
  from pydantic import Field
 
12
  import os
13
  import json
14
 
@@ -51,20 +52,21 @@ class CombinedRetriever(BaseRetriever):
51
  class Config:
52
  arbitrary_types_allowed = True
53
 
54
- def get_relevant_documents(self, query: str) -> List[Document]:
55
- return self.invoke(query)
56
-
57
- async def aget_relevant_documents(self, query: str) -> List[Document]:
58
  combined_docs = []
59
  for retriever in self.retrievers:
60
- docs = await retriever.aget_relevant_documents(query)
61
  combined_docs.extend(docs)
62
  return combined_docs
63
 
64
- def invoke(self, query: str) -> List[Document]:
 
 
65
  combined_docs = []
66
  for retriever in self.retrievers:
67
- docs = retriever.get_relevant_documents(query)
68
  combined_docs.extend(docs)
69
  return combined_docs
70
 
 
7
  from langchain.chains import RetrievalQA
8
  from langchain.prompts import PromptTemplate
9
  from langchain_core.runnables import RunnablePassthrough, RunnableLambda
10
+ from typing import List, Any, Optional
11
  from pydantic import Field
12
+ from langchain_core.callbacks import CallbackManagerForRetrieverRun
13
  import os
14
  import json
15
 
 
52
  class Config:
53
  arbitrary_types_allowed = True
54
 
55
+ def _get_relevant_documents(
56
+ self, query: str, *, run_manager: Optional[CallbackManagerForRetrieverRun] = None
57
+ ) -> List[Document]:
 
58
  combined_docs = []
59
  for retriever in self.retrievers:
60
+ docs = retriever.get_relevant_documents(query, run_manager=run_manager)
61
  combined_docs.extend(docs)
62
  return combined_docs
63
 
64
+ async def _aget_relevant_documents(
65
+ self, query: str, *, run_manager: Optional[CallbackManagerForRetrieverRun] = None
66
+ ) -> List[Document]:
67
  combined_docs = []
68
  for retriever in self.retrievers:
69
+ docs = await retriever.aget_relevant_documents(query, run_manager=run_manager)
70
  combined_docs.extend(docs)
71
  return combined_docs
72