arxiv-rag-mvp / retrieval.py
donb-hf's picture
initial commit
8c3a73e
raw
history blame
897 Bytes
# File: retrieval.py
from langchain_qdrant import Qdrant
from langchain_groq import ChatGroq
from langchain_openai import OpenAIEmbeddings
from langchain.chains import RetrievalQA
from config import *
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
llm = ChatGroq(model="llama3-70b-4096", temperature=0.3)
def rag_query(query: str) -> str:
qdrant = Qdrant.from_existing_collection(
embedding=embeddings,
collection_name=COLLECTION_NAME,
url=QDRANT_API_URL,
api_key=QDRANT_API_KEY,
prefer_grpc=True,
)
retriever = qdrant.as_retriever(search_kwargs={"k": 5})
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
return_source_documents=True
)
result = qa_chain({"query": query})
return result["result"]