Spaces:
Runtime error
Runtime error
File size: 4,519 Bytes
5f625b7 91bb6b8 5f625b7 91bb6b8 5f625b7 91bb6b8 5f625b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
from haystack.document_stores.faiss import FAISSDocumentStore
from haystack.nodes.retriever import EmbeddingRetriever
from haystack.nodes.ranker import BaseRanker
from haystack.pipelines import Pipeline
from haystack.document_stores.base import BaseDocumentStore
from haystack.schema import Document
from typing import Optional, List
import gradio as gr
import numpy as np
import requests
import os
RETRIEVER_URL = os.getenv("RETRIEVER_URL")
RANKER_URL = os.getenv("RANKER_URL")
HF_TOKEN = os.getenv("HF_TOKEN")
class Retriever(EmbeddingRetriever):
def __init__(
self,
document_store: Optional[BaseDocumentStore] = None,
top_k: int = 10,
batch_size: int = 32,
scale_score: bool = True,
):
self.document_store = document_store
self.top_k = top_k
self.batch_size = batch_size
self.scale_score = scale_score
def embed_queries(self, queries: List[str]) -> np.ndarray:
response = requests.post(
RETRIEVER_URL,
json={"queries": queries, "inputs": ""},
headers={"Authorization": f"Bearer {HF_TOKEN}"},
)
arrays = np.array(response.json())
return arrays
def embed_documents(self, documents: List[Document]) -> np.ndarray:
response = requests.post(
RETRIEVER_URL,
json={"documents": [d.to_dict() for d in documents], "inputs": ""},
headers={"Authorization": f"Bearer {HF_TOKEN}"},
)
arrays = np.array(response.json())
return arrays
class Ranker(BaseRanker):
def predict(
self, query: str, documents: List[Document], top_k: Optional[int] = None
) -> List[Document]:
documents = [d.to_dict() for d in documents]
for doc in documents:
doc["embedding"] = doc["embedding"].tolist()
response = requests.post(
RANKER_URL,
json={
"query": query,
"documents": documents,
"top_k": top_k,
"inputs": "",
},
headers={"Authorization": f"Bearer {HF_TOKEN}"},
).json()
if "error" in response:
raise Exception(response["error"])
return [Document.from_dict(d) for d in response]
def predict_batch(
self,
queries: List[str],
documents: List[List[Document]],
batch_size: Optional[int] = None,
top_k: Optional[int] = None,
) -> List[List[Document]]:
documents = [[d.to_dict() for d in docs] for docs in documents]
for docs in documents:
for doc in docs:
doc["embedding"] = doc["embedding"].tolist()
response = requests.post(
RANKER_URL,
json={
"queries": queries,
"documents": documents,
"batch_size": batch_size,
"top_k": top_k,
"inputs": "",
},
).json()
if "error" in response:
raise Exception(response["error"])
return [[Document.from_dict(d) for d in docs] for docs in response]
TOP_K = 2
BATCH_SIZE = 16
EXAMPLES = [
"There is a blue house on Oxford Street.",
"Paris is the capital of France.",
"The Eiffel Tower is in Paris.",
"The Louvre is in Paris.",
"London is the capital of England.",
"Cairo is the capital of Egypt.",
"The pyramids are in Egypt.",
"The Sphinx is in Egypt.",
]
if os.path.exists("faiss_document_store.db"):
os.remove("faiss_document_store.db")
document_store = FAISSDocumentStore(embedding_dim=384, return_embedding=True)
document_store.write_documents(
[Document(content=d, id=i) for i, d in enumerate(EXAMPLES)]
)
retriever = Retriever(document_store=document_store, top_k=TOP_K, batch_size=BATCH_SIZE)
document_store.update_embeddings(retriever=retriever)
ranker = Ranker()
pipe = Pipeline()
pipe.add_node(component=retriever, name="Retriever", inputs=["Query"])
pipe.add_node(component=ranker, name="Ranker", inputs=["Retriever"])
def run(query: str) -> dict:
output = pipe.run(query=query)
return (
f"Closest document(s): {[output['documents'][i].content for i in range(TOP_K)]}"
)
# warm up
run("What is the capital of France?")
gr.Interface(
fn=run,
inputs="text",
outputs="text",
title="Pipeline",
examples=["What is the capital of France?"],
description="A pipeline for retrieving and ranking documents.",
).launch()
|