Spaces:
Runtime error
Runtime error
import weaviate | |
import langchain | |
import gradio as gr | |
from langchain.embeddings import CohereEmbeddings | |
from langchain.document_loaders import UnstructuredFileLoader, PyPDFLoader | |
from langchain.vectorstores import Qdrant | |
import os | |
import urllib.request | |
import ssl | |
import mimetypes | |
from dotenv import load_dotenv | |
# Load environment variables | |
load_dotenv() | |
openai_api_key = os.getenv('OPENAI') | |
cohere_api_key = os.getenv('COHERE') | |
weaviate_api_key = os.getenv('WEAVIATE') | |
weaviate_url = os.getenv('WEAVIATE_URL') | |
# Weaviate connection | |
auth_config = weaviate.auth.AuthApiKey(api_key=weaviate_api_key) | |
client = weaviate.client(url=weaviate_url, auth_client_secret=auth_config, | |
additional_headers={"X-Cohere-Api-Key": cohere_api_key}) | |
# Initialize vectorstore | |
vectorstore = Weaviate(client, index_name="Articles", text_key="text") | |
vectorstore._query_attrs = ["text", "title", "url", "views", "lang", "_additional {distance}"] | |
vectorstore.embedding = CohereEmbeddings(model="embed-multilingual-v2.0", cohere_api_key=cohere_api_key) | |
# Initialize OpenAI and RetrievalQA | |
llm = OpenAI(temperature=0, openai_api_key=openai_api_key) | |
qa = RetrievalQA.from_chain_type(llm, retriever=vectorstore.as_retriever()) | |
def embed_pdf(file, collection_name): | |
# Save the uploaded file | |
filename = file.name | |
file_path = os.path.join('./', filename) | |
with open(file_path, 'wb') as f: | |
f.write(file.read()) | |
# Checking filetype for document parsing | |
mime_type = mimetypes.guess_type(file_path)[0] | |
loader = UnstructuredFileLoader(file_path) | |
docs = loader.load() | |
# Generate embeddings | |
embeddings = CohereEmbeddings(model="embed-multilingual-v2.0", cohere_api_key=cohere_api_key) | |
# Store documents in vectorstore (Qdrant) | |
for doc in docs: | |
embedding = embeddings.embed([doc['text']]) | |
vectorstore_document = { | |
"text": doc['text'], | |
"embedding": embedding | |
} | |
collection_name = request.json.get("collection_name") | |
file_url = request.json.get("file_url") | |
# Download the file | |
folder_path = f'./' | |
os.makedirs(folder_path, exist_ok=True) | |
filename = file_url.split('/')[-1] | |
file_path = os.path.join(folder_path, filename) | |
ssl._create_default_https_context = ssl._create_unverified_context | |
urllib.request.urlretrieve(file_url, file_path) | |
# Check filetype for document parsing | |
mime_type = mimetypes.guess_type(file_path)[0] | |
loader = UnstructuredFileLoader(file_path) | |
docs = loader.load() | |
# Generate embeddings | |
embeddings = CohereEmbeddings(model="embed-multilingual-v2.0", cohere_api_key=cohere_api_key) | |
# Store documents in Weaviate | |
for doc in docs: | |
embedding = embeddings.embed([doc['text']]) | |
weaviate_document = { | |
"text": doc['text'], | |
"embedding": embedding | |
} | |
client.data_object.create(data_object=weaviate_document, class_name=collection_name) | |
os.remove(file_path) | |
return {"message": f"Documents embedded in Weaviate collection '{collection_name}'"} | |
# Initialize Cohere client | |
co = cohere.Client(api_key=cohere_api_key) | |
def retrieve_info(): | |
query = request.json.get("query") | |
llm = OpenAI(temperature=0, openai_api_key=openai_api_key) | |
qa = RetrievalQA.from_chain_type(llm, retriever=vectorstore.as_retriever()) | |
# Retrieve initial results | |
initial_results = qa({"query": query}) | |
# Assuming initial_results are in the desired format, extract the top 25 documents | |
# Adjust this part according to the actual format of your initial_results | |
top_docs = initial_results[:25] # Adjust this if your result format is different | |
# Rerank the top 25 results | |
reranked_results = co.rerank(query=query, documents=top_docs, top_n=3, model='rerank-english-v2.0') | |
# Format the reranked results | |
formatted_results = [] | |
for idx, r in enumerate(reranked_results): | |
formatted_result = { | |
"Document Rank": idx + 1, | |
"Document Index": r.index, | |
"Document": r.document['text'], | |
"Relevance Score": f"{r.relevance_score:.2f}" | |
} | |
formatted_results.append(formatted_result) | |
return {"results": result} | |
# Gradio interface | |
iface = gr.Interface( | |
fn=retrieve_info, | |
inputs=[ | |
gr.inputs.Textbox(label="Query"), | |
gr.inputs.File(label="PDF File", type="file", optional=True) | |
], | |
outputs="text", | |
allow_flagging="never" | |
) | |
# Embed PDF function | |
iface.add_endpoint( | |
fn=embed_pdf, | |
inputs=[ | |
gr.inputs.File(label="PDF File", type="file"), | |
gr.inputs.Textbox(label="Collection Name") | |
], | |
outputs="text" | |
) | |
iface.launch() | |