Spaces:
Runtime error
Runtime error
TheDavidYoungblood
commited on
Commit
·
fb75b53
1
Parent(s):
8e70e09
99 additions of files in the repo, 99 additions of files...
Browse files- FAISS-index.py +18 -20
- RAGbot.py +67 -36
- requirements.txt +9 -1
FAISS-index.py
CHANGED
|
@@ -1,27 +1,25 @@
|
|
| 1 |
from datasets import Dataset, load_from_disk
|
| 2 |
import faiss
|
| 3 |
import numpy as np
|
| 4 |
-
from transformers import
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
dataset.save_to_disk(dataset_path)
|
| 13 |
|
| 14 |
-
|
| 15 |
-
passages = dataset["text"]
|
| 16 |
-
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
|
| 17 |
-
model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq")
|
| 18 |
-
passage_embeddings = model.get_encoder()(
|
| 19 |
-
tokenizer(passages, return_tensors="pt", padding=True, truncation=True)
|
| 20 |
-
).last_hidden_state.mean(dim=1).detach().numpy()
|
| 21 |
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
index_path = "path/to/your/index"
|
| 27 |
-
faiss.write_index(index, index_path)
|
|
|
|
| 1 |
from datasets import Dataset, load_from_disk
|
| 2 |
import faiss
|
| 3 |
import numpy as np
|
| 4 |
+
from transformers import RagTokenizer, RagSequenceForGeneration
|
| 5 |
|
| 6 |
+
def create_and_save_faiss_index(dataset_path, index_path):
|
| 7 |
+
dataset = load_from_disk(dataset_path)
|
| 8 |
+
passages = dataset["text"]
|
| 9 |
+
|
| 10 |
+
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
|
| 11 |
+
model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq")
|
| 12 |
+
|
| 13 |
+
passage_embeddings = model.get_encoder()(
|
| 14 |
+
tokenizer(passages, return_tensors="pt", padding=True, truncation=True)
|
| 15 |
+
).last_hidden_state.mean(dim=1).detach().numpy()
|
| 16 |
|
| 17 |
+
index = faiss.IndexFlatL2(passage_embeddings.shape[1])
|
| 18 |
+
index.add(passage_embeddings)
|
|
|
|
| 19 |
|
| 20 |
+
faiss.write_index(index, index_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
+
if __name__ == "__main__":
|
| 23 |
+
dataset_path = "path/to/your/hf_dataset"
|
| 24 |
+
index_path = "path/to/your/hf_index"
|
| 25 |
+
create_and_save_faiss_index(dataset_path, index_path)
|
|
|
|
|
|
RAGbot.py
CHANGED
|
@@ -10,7 +10,12 @@ from langchain.document_loaders import PyPDFLoader
|
|
| 10 |
from langchain.prompts import PromptTemplate
|
| 11 |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
| 12 |
import spaces
|
| 13 |
-
from langchain_text_splitters import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
class RAGbot:
|
| 16 |
def __init__(self, config_path="config.yaml"):
|
|
@@ -20,7 +25,8 @@ class RAGbot:
|
|
| 20 |
self.prompt = None
|
| 21 |
self.documents = None
|
| 22 |
self.embeddings = None
|
| 23 |
-
self.
|
|
|
|
| 24 |
self.tokenizer = None
|
| 25 |
self.model = None
|
| 26 |
self.pipeline = None
|
|
@@ -38,22 +44,26 @@ class RAGbot:
|
|
| 38 |
self.model_embeddings = config["modelEmbeddings"]
|
| 39 |
self.auto_tokenizer = config["autoTokenizer"]
|
| 40 |
self.auto_model_for_causal_lm = config["autoModelForCausalLM"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
def load_embeddings(self):
|
| 43 |
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_embeddings)
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
chunk_size=self.chunk_size,
|
| 50 |
-
chunk_overlap=overlap,
|
| 51 |
-
length_function=len,
|
| 52 |
-
add_start_index=True,
|
| 53 |
-
)
|
| 54 |
-
docs = text_splitter.split_documents(self.documents)
|
| 55 |
-
self.vectordb = Chroma.from_documents(docs, self.embeddings)
|
| 56 |
-
print("Vector store created")
|
| 57 |
|
| 58 |
@spaces.GPU
|
| 59 |
def load_tokenizer(self):
|
|
@@ -67,20 +77,34 @@ class RAGbot:
|
|
| 67 |
model_kwargs={"torch_dtype": torch.bfloat16},
|
| 68 |
device="cuda",
|
| 69 |
)
|
| 70 |
-
print("Model pipeline loaded")
|
| 71 |
|
| 72 |
-
def get_organic_context(self, query):
|
| 73 |
-
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
self.current_context = context
|
| 76 |
-
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
@spaces.GPU
|
| 80 |
-
def create_organic_response(self, history, query):
|
| 81 |
-
self.get_organic_context(query)
|
| 82 |
messages = [
|
| 83 |
-
{"role": "system", "content": "
|
| 84 |
{"role": "user", "content": query},
|
| 85 |
]
|
| 86 |
|
|
@@ -97,17 +121,15 @@ class RAGbot:
|
|
| 97 |
temperature=temp,
|
| 98 |
top_p=0.9,
|
| 99 |
)
|
| 100 |
-
print(outputs)
|
| 101 |
return outputs[0]["generated_text"][len(prompt):]
|
| 102 |
|
| 103 |
def process_file(self, file):
|
| 104 |
self.documents = PyPDFLoader(file.name).load()
|
| 105 |
self.load_embeddings()
|
| 106 |
-
self.
|
| 107 |
-
self.create_organic_pipeline()
|
| 108 |
|
| 109 |
@spaces.GPU
|
| 110 |
-
def generate_response(self, history, query, file, chunk_size, chunk_overlap_percentage, model_temperature, max_chunks_in_context):
|
| 111 |
self.chunk_size = chunk_size
|
| 112 |
self.overlap_percentage = chunk_overlap_percentage
|
| 113 |
self.model_temperatue = model_temperature
|
|
@@ -115,19 +137,28 @@ class RAGbot:
|
|
| 115 |
|
| 116 |
if not query:
|
| 117 |
raise gr.Error(message='Submit a question')
|
| 118 |
-
if not file:
|
| 119 |
-
raise gr.Error(message='Upload a PDF')
|
| 120 |
-
if not self.processed:
|
| 121 |
-
self.process_file(file)
|
| 122 |
-
self.processed = True
|
| 123 |
|
| 124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
for char in result:
|
| 126 |
history[-1][-1] += char
|
| 127 |
return history, ""
|
| 128 |
|
| 129 |
def render_file(self, file, chunk_size, chunk_overlap_percentage, model_temperature, max_chunks_in_context):
|
| 130 |
-
print(chunk_size)
|
| 131 |
doc = fitz.open(file.name)
|
| 132 |
page = doc[self.page]
|
| 133 |
self.chunk_size = chunk_size
|
|
@@ -142,4 +173,4 @@ class RAGbot:
|
|
| 142 |
if not text:
|
| 143 |
raise gr.Error('Enter text')
|
| 144 |
history.append((text, ''))
|
| 145 |
-
return history
|
|
|
|
| 10 |
from langchain.prompts import PromptTemplate
|
| 11 |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
| 12 |
import spaces
|
| 13 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 14 |
+
from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType
|
| 15 |
+
from datasets import Dataset, load_from_disk
|
| 16 |
+
import faiss
|
| 17 |
+
import numpy as np
|
| 18 |
+
from pastebin_api import get_protected_content
|
| 19 |
|
| 20 |
class RAGbot:
|
| 21 |
def __init__(self, config_path="config.yaml"):
|
|
|
|
| 25 |
self.prompt = None
|
| 26 |
self.documents = None
|
| 27 |
self.embeddings = None
|
| 28 |
+
self.zilliz_vectordb = None
|
| 29 |
+
self.hf_vectordb = None
|
| 30 |
self.tokenizer = None
|
| 31 |
self.model = None
|
| 32 |
self.pipeline = None
|
|
|
|
| 44 |
self.model_embeddings = config["modelEmbeddings"]
|
| 45 |
self.auto_tokenizer = config["autoTokenizer"]
|
| 46 |
self.auto_model_for_causal_lm = config["autoModelForCausalLM"]
|
| 47 |
+
self.zilliz_config = config["zilliz"]
|
| 48 |
+
self.persona_paste_key = config["personaPasteKey"]
|
| 49 |
+
|
| 50 |
+
def connect_to_zilliz(self):
|
| 51 |
+
connections.connect(
|
| 52 |
+
host=self.zilliz_config["host"],
|
| 53 |
+
port=self.zilliz_config["port"],
|
| 54 |
+
user=self.zilliz_config["user"],
|
| 55 |
+
password=self.zilliz_config["password"],
|
| 56 |
+
secure=True
|
| 57 |
+
)
|
| 58 |
+
self.zilliz_vectordb = Collection(self.zilliz_config["collection"])
|
| 59 |
|
| 60 |
def load_embeddings(self):
|
| 61 |
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_embeddings)
|
| 62 |
+
|
| 63 |
+
def load_hf_vectordb(self, dataset_path, index_path):
|
| 64 |
+
dataset = load_from_disk(dataset_path)
|
| 65 |
+
index = faiss.read_index(index_path)
|
| 66 |
+
self.hf_vectordb = (dataset, index)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
@spaces.GPU
|
| 69 |
def load_tokenizer(self):
|
|
|
|
| 77 |
model_kwargs={"torch_dtype": torch.bfloat16},
|
| 78 |
device="cuda",
|
| 79 |
)
|
|
|
|
| 80 |
|
| 81 |
+
def get_organic_context(self, query, use_hf=False):
|
| 82 |
+
if use_hf:
|
| 83 |
+
dataset, index = self.hf_vectordb
|
| 84 |
+
D, I = index.search(np.array([self.embeddings.embed_query(query)]), self.max_chunks_in_context)
|
| 85 |
+
context = self.format_seperator.join([dataset[i] for i in I[0]])
|
| 86 |
+
else:
|
| 87 |
+
result = self.zilliz_vectordb.search(
|
| 88 |
+
data=[self.embeddings.embed_query(query)],
|
| 89 |
+
anns_field="embeddings",
|
| 90 |
+
param={"metric_type": "IP", "params": {"nprobe": 10}},
|
| 91 |
+
limit=self.max_chunks_in_context,
|
| 92 |
+
expr=None,
|
| 93 |
+
)
|
| 94 |
+
context = self.format_seperator.join([hit.entity.get('text') for hit in result[0]])
|
| 95 |
+
|
| 96 |
self.current_context = context
|
| 97 |
+
|
| 98 |
+
def load_persona_data(self):
|
| 99 |
+
persona_content = get_protected_content(self.persona_paste_key)
|
| 100 |
+
persona_data = yaml.safe_load(persona_content)
|
| 101 |
+
self.persona_text = persona_data["persona_text"]
|
| 102 |
|
| 103 |
@spaces.GPU
|
| 104 |
+
def create_organic_response(self, history, query, use_hf=False):
|
| 105 |
+
self.get_organic_context(query, use_hf=use_hf)
|
| 106 |
messages = [
|
| 107 |
+
{"role": "system", "content": f"Based on the given context, answer the user's question while maintaining the persona:\n{self.persona_text}\n\nContext:\n{self.current_context}"},
|
| 108 |
{"role": "user", "content": query},
|
| 109 |
]
|
| 110 |
|
|
|
|
| 121 |
temperature=temp,
|
| 122 |
top_p=0.9,
|
| 123 |
)
|
|
|
|
| 124 |
return outputs[0]["generated_text"][len(prompt):]
|
| 125 |
|
| 126 |
def process_file(self, file):
|
| 127 |
self.documents = PyPDFLoader(file.name).load()
|
| 128 |
self.load_embeddings()
|
| 129 |
+
self.connect_to_zilliz()
|
|
|
|
| 130 |
|
| 131 |
@spaces.GPU
|
| 132 |
+
def generate_response(self, history, query, file, chunk_size, chunk_overlap_percentage, model_temperature, max_chunks_in_context, use_hf_index=False, hf_dataset_path=None, hf_index_path=None):
|
| 133 |
self.chunk_size = chunk_size
|
| 134 |
self.overlap_percentage = chunk_overlap_percentage
|
| 135 |
self.model_temperatue = model_temperature
|
|
|
|
| 137 |
|
| 138 |
if not query:
|
| 139 |
raise gr.Error(message='Submit a question')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
+
if use_hf_index:
|
| 142 |
+
if not hf_dataset_path or not hf_index_path:
|
| 143 |
+
raise gr.Error(message='Provide HuggingFace dataset and index paths')
|
| 144 |
+
self.load_hf_vectordb(hf_dataset_path, hf_index_path)
|
| 145 |
+
result = self.create_organic_response(history="", query=query, use_hf=True)
|
| 146 |
+
else:
|
| 147 |
+
if not file:
|
| 148 |
+
raise gr.Error(message='Upload a PDF')
|
| 149 |
+
if not self.processed:
|
| 150 |
+
self.process_file(file)
|
| 151 |
+
self.processed = True
|
| 152 |
+
result = self.create_organic_response(history="", query=query)
|
| 153 |
+
|
| 154 |
+
self.load_persona_data()
|
| 155 |
+
result = f"{self.persona_text}\n\n{result}"
|
| 156 |
+
|
| 157 |
for char in result:
|
| 158 |
history[-1][-1] += char
|
| 159 |
return history, ""
|
| 160 |
|
| 161 |
def render_file(self, file, chunk_size, chunk_overlap_percentage, model_temperature, max_chunks_in_context):
|
|
|
|
| 162 |
doc = fitz.open(file.name)
|
| 163 |
page = doc[self.page]
|
| 164 |
self.chunk_size = chunk_size
|
|
|
|
| 173 |
if not text:
|
| 174 |
raise gr.Error('Enter text')
|
| 175 |
history.append((text, ''))
|
| 176 |
+
return history
|
requirements.txt
CHANGED
|
@@ -6,7 +6,15 @@ langchain-community
|
|
| 6 |
tqdm
|
| 7 |
accelerate
|
| 8 |
pypdf
|
|
|
|
| 9 |
protobuf>=3.20,<5
|
| 10 |
poetry
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
requests
|
| 12 |
-
|
|
|
|
| 6 |
tqdm
|
| 7 |
accelerate
|
| 8 |
pypdf
|
| 9 |
+
faiss-cpu
|
| 10 |
protobuf>=3.20,<5
|
| 11 |
poetry
|
| 12 |
+
pymilvus
|
| 13 |
+
chromadb
|
| 14 |
+
gradio
|
| 15 |
+
fitz
|
| 16 |
+
PyYAML
|
| 17 |
+
datasets
|
| 18 |
+
numpy
|
| 19 |
requests
|
| 20 |
+
python-dotenv
|