Spaces:
Running
on
Zero
Running
on
Zero
Upload genra_incremental.py
Browse files- genra_incremental.py +253 -0
genra_incremental.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
|
5 |
+
# Import the PyFLAGR modules for rank aggregation
|
6 |
+
import pyflagr.Linear as Linear
|
7 |
+
import pyflagr.Majoritarian as Majoritarian
|
8 |
+
|
9 |
+
from operator import itemgetter
|
10 |
+
|
11 |
+
from haystack import Document
|
12 |
+
# from haystack.pipeline import Pipeline
|
13 |
+
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
14 |
+
from haystack.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder
|
15 |
+
from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever
|
16 |
+
|
17 |
+
from sentence_transformers import SentenceTransformer
|
18 |
+
|
19 |
+
from tqdm import tqdm
|
20 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
21 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel, pipeline
|
22 |
+
|
23 |
+
|
24 |
+
class LLMGenerator:
|
25 |
+
def __init__(self, llm_model, tokenizer, llm_name):
|
26 |
+
self.llm_model = llm_model
|
27 |
+
self.tokenizer = tokenizer
|
28 |
+
self.llm_name = llm_name
|
29 |
+
|
30 |
+
def generate_answer(self, texts, query, mode='validate'):
|
31 |
+
template_texts =""
|
32 |
+
for i, text in enumerate(texts):
|
33 |
+
template_texts += f'{i+1}. {text} \n'
|
34 |
+
|
35 |
+
if mode == 'validate':
|
36 |
+
conversation = [ {'role': 'user', 'content': f'Given the following query: "{query}"? \nIs the following document relevant to answer this query?\n{template_texts} \nResponse: Yes / No'} ]
|
37 |
+
elif mode == 'summarize':
|
38 |
+
conversation = [ {'role': 'user', 'content': f'For the following query and documents, try to answer the given query based on the documents.\nQuery: {query} \nDocuments: {template_texts}.'} ]
|
39 |
+
elif mode == 'h_summarize':
|
40 |
+
conversation = [ {'role': 'user', 'content': f'The documents below describe a developing disaster event. Based on these documents, write a brief summary in the form of a paragraph, highlighting the most crucial information. \nDocuments: {template_texts}'} ]
|
41 |
+
|
42 |
+
prompt = self.tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
|
43 |
+
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.llm_model.device)
|
44 |
+
outputs = self.llm_model.generate(**inputs, use_cache=True, max_length=4096,do_sample=True,temperature=0.7,top_p=0.95,top_k=10,repetition_penalty=1.1)
|
45 |
+
output_text = self.tokenizer.decode(outputs[0])
|
46 |
+
if self.llm_name == "solar":
|
47 |
+
assistant_respond = output_text.split("Assistant:")[1]
|
48 |
+
elif self.llm_name == "phi3mini":
|
49 |
+
assistant_respond = output_text.split("<|assistant|>")[1]
|
50 |
+
assistant_respond = assistant_respond[:-7]
|
51 |
+
else:
|
52 |
+
assistant_respond = output_text.split("[/INST]")[1]
|
53 |
+
if mode == 'validate':
|
54 |
+
if 'Yes' in assistant_respond:
|
55 |
+
return True
|
56 |
+
else:
|
57 |
+
return False
|
58 |
+
elif mode == 'summarize':
|
59 |
+
return assistant_respond
|
60 |
+
elif mode == 'h_summarize':
|
61 |
+
return assistant_respond
|
62 |
+
|
63 |
+
|
64 |
+
class QAIndexer:
|
65 |
+
def __init__(self, index_type, emb_model):
|
66 |
+
self.document_embedder = SentenceTransformersDocumentEmbedder(model=emb_model)
|
67 |
+
self.document_embedder.warm_up()
|
68 |
+
if index_type == 'in_memory':
|
69 |
+
self.document_store = InMemoryDocumentStore(embedding_similarity_function="cosine")
|
70 |
+
|
71 |
+
|
72 |
+
def index(self, docs_to_index):
|
73 |
+
documents_with_embeddings = self.document_embedder.run(docs_to_index)
|
74 |
+
self.document_store.write_documents(documents_with_embeddings['documents'])
|
75 |
+
|
76 |
+
|
77 |
+
def index_dataframe(self, data):
|
78 |
+
docs_to_index = self.read_dataframe(data)
|
79 |
+
self.index(docs_to_index)
|
80 |
+
|
81 |
+
def index_stream(self, stream_data):
|
82 |
+
docs_to_index = self.read_stream(stream_data)
|
83 |
+
self.index(docs_to_index)
|
84 |
+
|
85 |
+
def read_dataframe(self, data):
|
86 |
+
# Convert Dataframe to list of dicts for DocumentStore
|
87 |
+
docs_to_index = [Document(content=row['text'],id=str(row['order'])) for idx, row in data.iterrows()]
|
88 |
+
return docs_to_index
|
89 |
+
|
90 |
+
def read_stream(self, stream_data):
|
91 |
+
# stream consist of single docs for now
|
92 |
+
docs_to_index = [Document(content=doc['text'],id=str(doc['id'])) for doc in [stream_data]]
|
93 |
+
return docs_to_index
|
94 |
+
|
95 |
+
|
96 |
+
class QARetriever:
|
97 |
+
def __init__(self, document_store):
|
98 |
+
self.retriever = InMemoryEmbeddingRetriever(document_store=document_store)
|
99 |
+
|
100 |
+
def retrieve(self, query, topk):
|
101 |
+
retrieval_results = self.retriever.run(query_embedding=query, top_k=topk)
|
102 |
+
documents = [x.to_dict() for x in retrieval_results["documents"]]
|
103 |
+
return documents
|
104 |
+
|
105 |
+
|
106 |
+
def rank_aggregation(aggregator, lists, k):
|
107 |
+
if aggregator == 'linear':
|
108 |
+
csum = Linear.CombSUM(norm='score')
|
109 |
+
df_out, df_eval = csum.aggregate(input_file=lists)
|
110 |
+
elif aggregator == 'outrank':
|
111 |
+
outrank = Majoritarian.OutrankingApproach(eval_pts=7)
|
112 |
+
df_out, df_eval = outrank.aggregate(input_file=lists)
|
113 |
+
|
114 |
+
df_out['query_ids'] = df_out.index
|
115 |
+
queries = list(df_out['query_ids'].unique())
|
116 |
+
results = []
|
117 |
+
for query in queries:
|
118 |
+
df_query = df_out[df_out['query_ids'] == query][:k]
|
119 |
+
rank = 0
|
120 |
+
for index, r_q in df_query.iterrows():
|
121 |
+
rank += 1
|
122 |
+
doc_id = r_q['Voter']
|
123 |
+
score_doc = r_q['Score']
|
124 |
+
results.append({'qid': query, 'docid': doc_id, 'rank': rank, 'score': score_doc})
|
125 |
+
return results
|
126 |
+
|
127 |
+
|
128 |
+
class GenraPipeline:
|
129 |
+
def __init__(self, llm_name, emb_model, aggregator, contexts):
|
130 |
+
self.qa_indexer = QAIndexer('in_memory', emb_model)
|
131 |
+
self.qa_retriever = QARetriever(self.qa_indexer.document_store)
|
132 |
+
self.encoder = SentenceTransformer(emb_model)
|
133 |
+
self.contexts = contexts
|
134 |
+
self.aggregator = aggregator
|
135 |
+
self.answers_store = {}
|
136 |
+
if llm_name == 'solar':
|
137 |
+
self.tokenizer = AutoTokenizer.from_pretrained("Upstage/SOLAR-10.7B-Instruct-v1.0", use_fast=True)
|
138 |
+
self.llm_model = AutoModelForCausalLM.from_pretrained(
|
139 |
+
"Upstage/SOLAR-10.7B-Instruct-v1.0",
|
140 |
+
device_map="auto", #device_map="cuda"
|
141 |
+
#torch_dtype=torch.float16,
|
142 |
+
)
|
143 |
+
self.llm_generator = LLMGenerator(self.llm_model, self.tokenizer, llm_name)
|
144 |
+
elif llm_name == 'mistral':
|
145 |
+
self.tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", use_fast=True)
|
146 |
+
self.llm_model = AutoModelForCausalLM.from_pretrained(
|
147 |
+
"mistralai/Mistral-7B-Instruct-v0.2",
|
148 |
+
device_map="auto", #device_map="cuda"
|
149 |
+
#torch_dtype=torch.float16,
|
150 |
+
)
|
151 |
+
self.llm_generator = LLMGenerator(self.llm_model, self.tokenizer, llm_name)
|
152 |
+
elif llm_name == 'phi3mini':
|
153 |
+
self.tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct", use_fast=True)
|
154 |
+
self.llm_model = AutoModelForCausalLM.from_pretrained(
|
155 |
+
"microsoft/Phi-3-mini-128k-instruct",
|
156 |
+
device_map="auto",
|
157 |
+
torch_dtype="auto",
|
158 |
+
trust_remote_code=True,
|
159 |
+
)
|
160 |
+
self.llm_generator = LLMGenerator(self.llm_model, self.tokenizer, llm_name)
|
161 |
+
|
162 |
+
def retrieval(self, batch_number, queries, topk, summarize_results=True):
|
163 |
+
for qid,question in tqdm(queries[['id','query']].values):
|
164 |
+
if len(self.contexts)<1:
|
165 |
+
self.contexts.append(question)
|
166 |
+
all_emb_c = []
|
167 |
+
for c in self.contexts:
|
168 |
+
c_emb = self.encoder.encode([c], convert_to_numpy=True)[0]
|
169 |
+
all_emb_c.append(np.array(c_emb))
|
170 |
+
all_emb_c = np.array(all_emb_c)
|
171 |
+
avg_emb_c = np.mean(all_emb_c, axis=0)
|
172 |
+
avg_emb_c = avg_emb_c.reshape((1, len(avg_emb_c)))
|
173 |
+
# we want a list of floats for haystack retrievers
|
174 |
+
hits = self.qa_retriever.retrieve(avg_emb_c[0].tolist(), 20) # topk or more?
|
175 |
+
hyde_texts = []
|
176 |
+
candidate_texts = []
|
177 |
+
hit_count = 0
|
178 |
+
while len(candidate_texts) < 5:
|
179 |
+
if hit_count < len(hits):
|
180 |
+
json_doc = hits[hit_count]
|
181 |
+
doc_text = json_doc['content']
|
182 |
+
if self.llm_generator.generate_answer([doc_text], question, mode='validate'):
|
183 |
+
candidate_texts.append(doc_text) #candidate_texts.append(doc_text[0])
|
184 |
+
hit_count += 1
|
185 |
+
else:
|
186 |
+
break
|
187 |
+
if len(candidate_texts)<1:
|
188 |
+
# no unswerable result
|
189 |
+
results = []
|
190 |
+
else:
|
191 |
+
all_emb_c = []
|
192 |
+
all_hits = []
|
193 |
+
for i, c in enumerate(candidate_texts):
|
194 |
+
c_emb = self.encoder.encode([c], convert_to_numpy=True)[0]
|
195 |
+
c_emb = c_emb.reshape((1, len(c_emb)))
|
196 |
+
c_hits = self.qa_retriever.retrieve(c_emb[0].tolist(), topk) # changed to len(candidates)+1
|
197 |
+
rank=0
|
198 |
+
for hit in c_hits: # get each ranking with pyflagr format
|
199 |
+
rank += 1
|
200 |
+
# penalize score wrt hit counts (the smaller the better!)
|
201 |
+
all_hits.append({'qid': qid, 'voter':i, 'docid': hit['id'], 'rank': rank, 'score': hit['score']})
|
202 |
+
# write pyglagr aggregation files
|
203 |
+
tempfile = 'temp_rankings_file'
|
204 |
+
with open(tempfile, 'w') as f:
|
205 |
+
for res in all_hits:
|
206 |
+
f.write(f"{res['qid']},V{res['voter']},{res['docid']},{res['score']},test\n")
|
207 |
+
# run aggregation
|
208 |
+
results = rank_aggregation(self.aggregator, tempfile, topk)
|
209 |
+
|
210 |
+
# enhance each result with doc info
|
211 |
+
for res in results:
|
212 |
+
res['document'] = self.qa_indexer.document_store.filter_documents(filters={'id':str(res['docid'])})[0].content
|
213 |
+
|
214 |
+
if summarize_results:
|
215 |
+
summary = self.summarize_results(question, results, candidate_texts)
|
216 |
+
|
217 |
+
self.store_results(batch_number, question, results, summary)
|
218 |
+
|
219 |
+
|
220 |
+
def store_results(self, batch_number, question, results, summary):
|
221 |
+
if results:
|
222 |
+
tweets = [t['document'] for t in results]
|
223 |
+
if question in self.answers_store:
|
224 |
+
self.answers_store[question].append({'batch_number':batch_number, 'tweets':tweets, 'summary':summary})
|
225 |
+
else:
|
226 |
+
self.answers_store[question] = [{'batch_number':batch_number, 'tweets':tweets, 'summary':summary}]
|
227 |
+
|
228 |
+
|
229 |
+
|
230 |
+
def summarize_results(self, question, results, candidate_texts):
|
231 |
+
if results:
|
232 |
+
texts = [t['document'] for t in results] #+ candidate_texts
|
233 |
+
summary = self.llm_generator.generate_answer(texts, question, mode='summarize')
|
234 |
+
else:
|
235 |
+
summary = "N/A"
|
236 |
+
return summary
|
237 |
+
|
238 |
+
def summarize_history(self, queries):
|
239 |
+
h_per_q = []
|
240 |
+
for qid,question in tqdm(queries[['id','query']].values):
|
241 |
+
if question in self.answers_store:
|
242 |
+
q_history = self.answers_store[question]
|
243 |
+
q_hist_docs = self.order_history(q_history)
|
244 |
+
h_per_q.extend(q_hist_docs)
|
245 |
+
historical_summary = self.llm_generator.generate_answer(h_per_q, question, mode='h_summarize')
|
246 |
+
return historical_summary
|
247 |
+
|
248 |
+
def order_history(self, query_history):
|
249 |
+
ordered_history = sorted(query_history, key=itemgetter('batch_number'))
|
250 |
+
ordered_docs = [hist['summary'] for hist in ordered_history]
|
251 |
+
return ordered_docs
|
252 |
+
|
253 |
+
|