Spaces:
Runtime error
Runtime error
File size: 5,285 Bytes
5f625b7 91bb6b8 5f625b7 699d13a 5f625b7 699d13a 2baebe5 699d13a 6407a06 699d13a 6407a06 2baebe5 6407a06 699d13a 2baebe5 699d13a 6407a06 2baebe5 5f625b7 699d13a 5f625b7 91bb6b8 5f625b7 6407a06 91bb6b8 5f625b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
from haystack.document_stores.faiss import FAISSDocumentStore
from haystack.nodes.retriever import EmbeddingRetriever
from haystack.nodes.ranker import BaseRanker
from haystack.pipelines import Pipeline
from haystack.document_stores.base import BaseDocumentStore
from haystack.schema import Document
from typing import Optional, List
import gradio as gr
import numpy as np
import requests
import os
RETRIEVER_URL = os.getenv("RETRIEVER_URL")
RANKER_URL = os.getenv("RANKER_URL")
HF_TOKEN = os.getenv("HF_TOKEN")
class Retriever(EmbeddingRetriever):
def __init__(
self,
document_store: Optional[BaseDocumentStore] = None,
top_k: int = 10,
batch_size: int = 32,
scale_score: bool = True,
):
self.document_store = document_store
self.top_k = top_k
self.batch_size = batch_size
self.scale_score = scale_score
def embed_queries(self, queries: List[str]) -> np.ndarray:
response = requests.post(
RETRIEVER_URL,
json={"queries": queries, "inputs": ""},
headers={"Authorization": f"Bearer {HF_TOKEN}"},
)
arrays = np.array(response.json())
return arrays
def embed_documents(self, documents: List[Document]) -> np.ndarray:
response = requests.post(
RETRIEVER_URL,
json={"documents": [d.to_dict() for d in documents], "inputs": ""},
headers={"Authorization": f"Bearer {HF_TOKEN}"},
)
arrays = np.array(response.json())
return arrays
class Ranker(BaseRanker):
def predict(
self, query: str, documents: List[Document], top_k: Optional[int] = None
) -> List[Document]:
documents = [d.to_dict() for d in documents]
for doc in documents:
doc["embedding"] = doc["embedding"].tolist()
response = requests.post(
RANKER_URL,
json={
"query": query,
"documents": documents,
"top_k": top_k,
"inputs": "",
},
headers={"Authorization": f"Bearer {HF_TOKEN}"},
).json()
if "error" in response:
raise Exception(response["error"])
return [Document.from_dict(d) for d in response]
def predict_batch(
self,
queries: List[str],
documents: List[List[Document]],
batch_size: Optional[int] = None,
top_k: Optional[int] = None,
) -> List[List[Document]]:
documents = [[d.to_dict() for d in docs] for docs in documents]
for docs in documents:
for doc in docs:
doc["embedding"] = doc["embedding"].tolist()
response = requests.post(
RANKER_URL,
json={
"queries": queries,
"documents": documents,
"batch_size": batch_size,
"top_k": top_k,
"inputs": "",
},
).json()
if "error" in response:
raise Exception(response["error"])
return [[Document.from_dict(d) for d in docs] for docs in response]
TOP_K = 2
BATCH_SIZE = 16
EXAMPLES = [
"There is a blue house on Oxford Street.",
"Paris is the capital of France.",
"The Eiffel Tower is in Paris.",
"The Louvre is in Paris.",
"London is the capital of England.",
"Cairo is the capital of Egypt.",
"The pyramids are in Egypt.",
"The Sphinx is in Egypt.",
]
if (
os.path.exists("/data/faiss_document_store.db")
and os.path.exists("/data/faiss_index.json")
and os.path.exists("/data/faiss_index")
):
document_store = FAISSDocumentStore.load("./data/faiss_index")
retriever = Retriever(
document_store=document_store, top_k=TOP_K, batch_size=BATCH_SIZE
)
document_store.update_embeddings(retriever=retriever)
document_store.save(index_path="./data/faiss_index")
else:
try:
os.remove("/data/faiss_index")
os.remove("/data/faiss_index.json")
os.remove("/data/faiss_document_store.db")
except FileNotFoundError:
pass
document_store = FAISSDocumentStore(
sql_url="sqlite:////data/faiss_document_store.db",
return_embedding=True,
embedding_dim=384,
)
document_store.write_documents(
[Document(content=d, id=i) for i, d in enumerate(EXAMPLES)]
)
retriever = Retriever(
document_store=document_store, top_k=TOP_K, batch_size=BATCH_SIZE
)
document_store.update_embeddings(retriever=retriever)
document_store.save(index_path="/data/faiss_index")
ranker = Ranker()
pipe = Pipeline()
pipe.add_node(component=retriever, name="Retriever", inputs=["Query"])
pipe.add_node(component=ranker, name="Ranker", inputs=["Retriever"])
def run(query: str) -> dict:
output = pipe.run(query=query)
closest_documents = [d.content for d in output["documents"]]
return f"Closest ({TOP_K}) document(s): {closest_documents}"
run("What is the capital of France?")
print("Warmed up successfully!")
gr.Interface(
fn=run,
inputs="text",
outputs="text",
title="Pipeline",
examples=["What is the capital of France?"],
description="A pipeline for retrieving and ranking documents.",
).launch()
|