πwπ
Browse files- README.md +1 -1
- app.py +161 -0
- requirements.txt +6 -0
README.md
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
---
|
| 2 |
title: RAG
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: yellow
|
| 5 |
colorTo: red
|
| 6 |
sdk: gradio
|
|
|
|
| 1 |
---
|
| 2 |
title: RAG
|
| 3 |
+
emoji: πwπ
|
| 4 |
colorFrom: yellow
|
| 5 |
colorTo: red
|
| 6 |
sdk: gradio
|
app.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from datasets import load_dataset
|
| 3 |
+
from sentence_transformers import SentenceTransformer
|
| 4 |
+
from sentence_transformers.quantization import quantize_embeddings
|
| 5 |
+
import faiss
|
| 6 |
+
from usearch.index import Index
|
| 7 |
+
import os
|
| 8 |
+
import spaces
|
| 9 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
| 10 |
+
import torch
|
| 11 |
+
from threading import Thread
|
| 12 |
+
|
| 13 |
+
token = os.environ["HF_TOKEN"]
|
| 14 |
+
model = AutoModelForCausalLM.from_pretrained("google/gemma-7b-it",
|
| 15 |
+
# torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
| 16 |
+
torch_dtype=torch.float16,
|
| 17 |
+
token=token)
|
| 18 |
+
tok = AutoTokenizer.from_pretrained("google/gemma-7b-it",token=token)
|
| 19 |
+
device = torch.device('cuda')
|
| 20 |
+
model = model.to(device)
|
| 21 |
+
|
| 22 |
+
# Load titles and texts
|
| 23 |
+
title_text_dataset = load_dataset(
|
| 24 |
+
"mixedbread-ai/wikipedia-data-en-2023-11", split="train", num_proc=4
|
| 25 |
+
).select_columns(["title", "text"])
|
| 26 |
+
|
| 27 |
+
# Load the int8 and binary indices. Int8 is loaded as a view to save memory, as we never actually perform search with it.
|
| 28 |
+
int8_view = Index.restore("wikipedia_int8_usearch_50m.index", view=True)
|
| 29 |
+
binary_index: faiss.IndexBinaryFlat = faiss.read_index_binary(
|
| 30 |
+
"wikipedia_ubinary_faiss_50m.index"
|
| 31 |
+
)
|
| 32 |
+
binary_ivf: faiss.IndexBinaryIVF = faiss.read_index_binary(
|
| 33 |
+
"wikipedia_ubinary_ivf_faiss_50m.index"
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# Load the SentenceTransformer model for embedding the queries
|
| 37 |
+
model = SentenceTransformer(
|
| 38 |
+
"mixedbread-ai/mxbai-embed-large-v1",
|
| 39 |
+
prompts={
|
| 40 |
+
"retrieval": "Represent this sentence for searching relevant passages: ",
|
| 41 |
+
},
|
| 42 |
+
default_prompt_name="retrieval",
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def search(
|
| 47 |
+
query, top_k: int = 10, rescore_multiplier: int = 1, use_approx: bool = False
|
| 48 |
+
):
|
| 49 |
+
# 1. Embed the query as float32
|
| 50 |
+
query_embedding = model.encode(query)
|
| 51 |
+
|
| 52 |
+
# 2. Quantize the query to ubinary
|
| 53 |
+
query_embedding_ubinary = quantize_embeddings(
|
| 54 |
+
query_embedding.reshape(1, -1), "ubinary"
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# 3. Search the binary index (either exact or approximate)
|
| 58 |
+
index = binary_ivf if use_approx else binary_index
|
| 59 |
+
_scores, binary_ids = index.search(
|
| 60 |
+
query_embedding_ubinary, top_k * rescore_multiplier
|
| 61 |
+
)
|
| 62 |
+
binary_ids = binary_ids[0]
|
| 63 |
+
|
| 64 |
+
# 4. Load the corresponding int8 embeddings
|
| 65 |
+
int8_embeddings = int8_view[binary_ids].astype(int)
|
| 66 |
+
|
| 67 |
+
# 5. Rescore the top_k * rescore_multiplier using the float32 query embedding and the int8 document embeddings
|
| 68 |
+
scores = query_embedding @ int8_embeddings.T
|
| 69 |
+
|
| 70 |
+
# 6. Sort the scores and return the top_k
|
| 71 |
+
indices = scores.argsort()[::-1][:top_k]
|
| 72 |
+
top_k_indices = binary_ids[indices]
|
| 73 |
+
top_k_scores = scores[indices]
|
| 74 |
+
top_k_titles, top_k_texts = zip(
|
| 75 |
+
*[
|
| 76 |
+
(title_text_dataset[idx]["title"], title_text_dataset[idx]["text"])
|
| 77 |
+
for idx in top_k_indices.tolist()
|
| 78 |
+
]
|
| 79 |
+
)
|
| 80 |
+
df = {
|
| 81 |
+
"Score": [round(value, 2) for value in top_k_scores],
|
| 82 |
+
"Title": top_k_titles,
|
| 83 |
+
"Text": top_k_texts,
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
return df
|
| 87 |
+
|
| 88 |
+
def prepare_prompt(query, df):
|
| 89 |
+
prompt = f"Query: {query}\nContinue to answer the query by using the Search Results:\n"
|
| 90 |
+
for data in df :
|
| 91 |
+
title = data["Title"]
|
| 92 |
+
text = data["Text"]
|
| 93 |
+
prompt+=f"Title: {title}, Text: {text}\n"
|
| 94 |
+
return prompt
|
| 95 |
+
|
| 96 |
+
@spaces.GPU
|
| 97 |
+
def talk(message, history):
|
| 98 |
+
df = search(message)
|
| 99 |
+
message = prepare_prompt(message,df)
|
| 100 |
+
resources = "\nRESOURCES:\n"
|
| 101 |
+
for title in df["Title"][:3] :
|
| 102 |
+
resources+=f"[{title}](https://huggingface.co/spaces/not-lain/RAG), "
|
| 103 |
+
chat = []
|
| 104 |
+
for item in history:
|
| 105 |
+
chat.append({"role": "user", "content": item[0]})
|
| 106 |
+
if item[1] is not None:
|
| 107 |
+
cleaned_past = item[1].split("\nRESOURCES:\n")[0]
|
| 108 |
+
chat.append({"role": "assistant", "content": cleaned_past})
|
| 109 |
+
chat.append({"role": "user", "content": message})
|
| 110 |
+
messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
|
| 111 |
+
# Tokenize the messages string
|
| 112 |
+
model_inputs = tok([messages], return_tensors="pt").to(device)
|
| 113 |
+
streamer = TextIteratorStreamer(
|
| 114 |
+
tok, timeout=10., skip_prompt=True, skip_special_tokens=True)
|
| 115 |
+
generate_kwargs = dict(
|
| 116 |
+
model_inputs,
|
| 117 |
+
streamer=streamer,
|
| 118 |
+
max_new_tokens=1024,
|
| 119 |
+
do_sample=True,
|
| 120 |
+
top_p=0.95,
|
| 121 |
+
top_k=1000,
|
| 122 |
+
temperature=0.75,
|
| 123 |
+
num_beams=1,
|
| 124 |
+
)
|
| 125 |
+
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
| 126 |
+
t.start()
|
| 127 |
+
|
| 128 |
+
# Initialize an empty string to store the generated text
|
| 129 |
+
partial_text = ""
|
| 130 |
+
for new_text in streamer:
|
| 131 |
+
partial_text += new_text
|
| 132 |
+
yield partial_text
|
| 133 |
+
partial_text+= resources
|
| 134 |
+
yield partial_text
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
TITLE = "RAG"
|
| 141 |
+
|
| 142 |
+
DESCRIPTION = """
|
| 143 |
+
## Resources used to build this project
|
| 144 |
+
* https://huggingface.co/learn/cookbook/rag_with_hugging_face_gemma_mongodb
|
| 145 |
+
* https://huggingface.co/spaces/sentence-transformers/quantized-retrieval
|
| 146 |
+
## Retrival paramaters
|
| 147 |
+
```python
|
| 148 |
+
top_k: int = 10, rescore_multiplier: int = 1, use_approx: bool = False
|
| 149 |
+
```
|
| 150 |
+
## Models
|
| 151 |
+
the models used in this space are :
|
| 152 |
+
* google/gemma-7b-it
|
| 153 |
+
* mixedbread-ai/wikipedia-data-en-2023-11
|
| 154 |
+
"""
|
| 155 |
+
|
| 156 |
+
demo = gr.ChatInterface(fn=talk,
|
| 157 |
+
chatbot=gr.Chatbot(show_label=True, show_share_button=True, show_copy_button=True, likeable=True, layout="bubble", bubble_full_width=False),
|
| 158 |
+
theme="Soft",
|
| 159 |
+
examples=[["Write me a poem about Machine Learning."]],
|
| 160 |
+
title="Text Streaming")
|
| 161 |
+
demo.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spaces
|
| 2 |
+
torch==2.2.0
|
| 3 |
+
git+https://github.com/huggingface/transformers/
|
| 4 |
+
git+https://github.com/tomaarsen/sentence-transformers@feat/quantization
|
| 5 |
+
usearch
|
| 6 |
+
faiss-cpu
|