jayebaku commited on
Commit
a5f1bbb
·
verified ·
1 Parent(s): 78765c2

Upload genra_incremental.py

Browse files
Files changed (1) hide show
  1. genra_incremental.py +253 -0
genra_incremental.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import pandas as pd
4
+
5
+ # Import the PyFLAGR modules for rank aggregation
6
+ import pyflagr.Linear as Linear
7
+ import pyflagr.Majoritarian as Majoritarian
8
+
9
+ from operator import itemgetter
10
+
11
+ from haystack import Document
12
+ # from haystack.pipeline import Pipeline
13
+ from haystack.document_stores.in_memory import InMemoryDocumentStore
14
+ from haystack.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder
15
+ from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever
16
+
17
+ from sentence_transformers import SentenceTransformer
18
+
19
+ from tqdm import tqdm
20
+ from sklearn.metrics.pairwise import cosine_similarity
21
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel, pipeline
22
+
23
+
24
+ class LLMGenerator:
25
+ def __init__(self, llm_model, tokenizer, llm_name):
26
+ self.llm_model = llm_model
27
+ self.tokenizer = tokenizer
28
+ self.llm_name = llm_name
29
+
30
+ def generate_answer(self, texts, query, mode='validate'):
31
+ template_texts =""
32
+ for i, text in enumerate(texts):
33
+ template_texts += f'{i+1}. {text} \n'
34
+
35
+ if mode == 'validate':
36
+ conversation = [ {'role': 'user', 'content': f'Given the following query: "{query}"? \nIs the following document relevant to answer this query?\n{template_texts} \nResponse: Yes / No'} ]
37
+ elif mode == 'summarize':
38
+ conversation = [ {'role': 'user', 'content': f'For the following query and documents, try to answer the given query based on the documents.\nQuery: {query} \nDocuments: {template_texts}.'} ]
39
+ elif mode == 'h_summarize':
40
+ conversation = [ {'role': 'user', 'content': f'The documents below describe a developing disaster event. Based on these documents, write a brief summary in the form of a paragraph, highlighting the most crucial information. \nDocuments: {template_texts}'} ]
41
+
42
+ prompt = self.tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
43
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.llm_model.device)
44
+ outputs = self.llm_model.generate(**inputs, use_cache=True, max_length=4096,do_sample=True,temperature=0.7,top_p=0.95,top_k=10,repetition_penalty=1.1)
45
+ output_text = self.tokenizer.decode(outputs[0])
46
+ if self.llm_name == "solar":
47
+ assistant_respond = output_text.split("Assistant:")[1]
48
+ elif self.llm_name == "phi3mini":
49
+ assistant_respond = output_text.split("<|assistant|>")[1]
50
+ assistant_respond = assistant_respond[:-7]
51
+ else:
52
+ assistant_respond = output_text.split("[/INST]")[1]
53
+ if mode == 'validate':
54
+ if 'Yes' in assistant_respond:
55
+ return True
56
+ else:
57
+ return False
58
+ elif mode == 'summarize':
59
+ return assistant_respond
60
+ elif mode == 'h_summarize':
61
+ return assistant_respond
62
+
63
+
64
+ class QAIndexer:
65
+ def __init__(self, index_type, emb_model):
66
+ self.document_embedder = SentenceTransformersDocumentEmbedder(model=emb_model)
67
+ self.document_embedder.warm_up()
68
+ if index_type == 'in_memory':
69
+ self.document_store = InMemoryDocumentStore(embedding_similarity_function="cosine")
70
+
71
+
72
+ def index(self, docs_to_index):
73
+ documents_with_embeddings = self.document_embedder.run(docs_to_index)
74
+ self.document_store.write_documents(documents_with_embeddings['documents'])
75
+
76
+
77
+ def index_dataframe(self, data):
78
+ docs_to_index = self.read_dataframe(data)
79
+ self.index(docs_to_index)
80
+
81
+ def index_stream(self, stream_data):
82
+ docs_to_index = self.read_stream(stream_data)
83
+ self.index(docs_to_index)
84
+
85
+ def read_dataframe(self, data):
86
+ # Convert Dataframe to list of dicts for DocumentStore
87
+ docs_to_index = [Document(content=row['text'],id=str(row['order'])) for idx, row in data.iterrows()]
88
+ return docs_to_index
89
+
90
+ def read_stream(self, stream_data):
91
+ # stream consist of single docs for now
92
+ docs_to_index = [Document(content=doc['text'],id=str(doc['id'])) for doc in [stream_data]]
93
+ return docs_to_index
94
+
95
+
96
+ class QARetriever:
97
+ def __init__(self, document_store):
98
+ self.retriever = InMemoryEmbeddingRetriever(document_store=document_store)
99
+
100
+ def retrieve(self, query, topk):
101
+ retrieval_results = self.retriever.run(query_embedding=query, top_k=topk)
102
+ documents = [x.to_dict() for x in retrieval_results["documents"]]
103
+ return documents
104
+
105
+
106
+ def rank_aggregation(aggregator, lists, k):
107
+ if aggregator == 'linear':
108
+ csum = Linear.CombSUM(norm='score')
109
+ df_out, df_eval = csum.aggregate(input_file=lists)
110
+ elif aggregator == 'outrank':
111
+ outrank = Majoritarian.OutrankingApproach(eval_pts=7)
112
+ df_out, df_eval = outrank.aggregate(input_file=lists)
113
+
114
+ df_out['query_ids'] = df_out.index
115
+ queries = list(df_out['query_ids'].unique())
116
+ results = []
117
+ for query in queries:
118
+ df_query = df_out[df_out['query_ids'] == query][:k]
119
+ rank = 0
120
+ for index, r_q in df_query.iterrows():
121
+ rank += 1
122
+ doc_id = r_q['Voter']
123
+ score_doc = r_q['Score']
124
+ results.append({'qid': query, 'docid': doc_id, 'rank': rank, 'score': score_doc})
125
+ return results
126
+
127
+
128
+ class GenraPipeline:
129
+ def __init__(self, llm_name, emb_model, aggregator, contexts):
130
+ self.qa_indexer = QAIndexer('in_memory', emb_model)
131
+ self.qa_retriever = QARetriever(self.qa_indexer.document_store)
132
+ self.encoder = SentenceTransformer(emb_model)
133
+ self.contexts = contexts
134
+ self.aggregator = aggregator
135
+ self.answers_store = {}
136
+ if llm_name == 'solar':
137
+ self.tokenizer = AutoTokenizer.from_pretrained("Upstage/SOLAR-10.7B-Instruct-v1.0", use_fast=True)
138
+ self.llm_model = AutoModelForCausalLM.from_pretrained(
139
+ "Upstage/SOLAR-10.7B-Instruct-v1.0",
140
+ device_map="auto", #device_map="cuda"
141
+ #torch_dtype=torch.float16,
142
+ )
143
+ self.llm_generator = LLMGenerator(self.llm_model, self.tokenizer, llm_name)
144
+ elif llm_name == 'mistral':
145
+ self.tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", use_fast=True)
146
+ self.llm_model = AutoModelForCausalLM.from_pretrained(
147
+ "mistralai/Mistral-7B-Instruct-v0.2",
148
+ device_map="auto", #device_map="cuda"
149
+ #torch_dtype=torch.float16,
150
+ )
151
+ self.llm_generator = LLMGenerator(self.llm_model, self.tokenizer, llm_name)
152
+ elif llm_name == 'phi3mini':
153
+ self.tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct", use_fast=True)
154
+ self.llm_model = AutoModelForCausalLM.from_pretrained(
155
+ "microsoft/Phi-3-mini-128k-instruct",
156
+ device_map="auto",
157
+ torch_dtype="auto",
158
+ trust_remote_code=True,
159
+ )
160
+ self.llm_generator = LLMGenerator(self.llm_model, self.tokenizer, llm_name)
161
+
162
+ def retrieval(self, batch_number, queries, topk, summarize_results=True):
163
+ for qid,question in tqdm(queries[['id','query']].values):
164
+ if len(self.contexts)<1:
165
+ self.contexts.append(question)
166
+ all_emb_c = []
167
+ for c in self.contexts:
168
+ c_emb = self.encoder.encode([c], convert_to_numpy=True)[0]
169
+ all_emb_c.append(np.array(c_emb))
170
+ all_emb_c = np.array(all_emb_c)
171
+ avg_emb_c = np.mean(all_emb_c, axis=0)
172
+ avg_emb_c = avg_emb_c.reshape((1, len(avg_emb_c)))
173
+ # we want a list of floats for haystack retrievers
174
+ hits = self.qa_retriever.retrieve(avg_emb_c[0].tolist(), 20) # topk or more?
175
+ hyde_texts = []
176
+ candidate_texts = []
177
+ hit_count = 0
178
+ while len(candidate_texts) < 5:
179
+ if hit_count < len(hits):
180
+ json_doc = hits[hit_count]
181
+ doc_text = json_doc['content']
182
+ if self.llm_generator.generate_answer([doc_text], question, mode='validate'):
183
+ candidate_texts.append(doc_text) #candidate_texts.append(doc_text[0])
184
+ hit_count += 1
185
+ else:
186
+ break
187
+ if len(candidate_texts)<1:
188
+ # no unswerable result
189
+ results = []
190
+ else:
191
+ all_emb_c = []
192
+ all_hits = []
193
+ for i, c in enumerate(candidate_texts):
194
+ c_emb = self.encoder.encode([c], convert_to_numpy=True)[0]
195
+ c_emb = c_emb.reshape((1, len(c_emb)))
196
+ c_hits = self.qa_retriever.retrieve(c_emb[0].tolist(), topk) # changed to len(candidates)+1
197
+ rank=0
198
+ for hit in c_hits: # get each ranking with pyflagr format
199
+ rank += 1
200
+ # penalize score wrt hit counts (the smaller the better!)
201
+ all_hits.append({'qid': qid, 'voter':i, 'docid': hit['id'], 'rank': rank, 'score': hit['score']})
202
+ # write pyglagr aggregation files
203
+ tempfile = 'temp_rankings_file'
204
+ with open(tempfile, 'w') as f:
205
+ for res in all_hits:
206
+ f.write(f"{res['qid']},V{res['voter']},{res['docid']},{res['score']},test\n")
207
+ # run aggregation
208
+ results = rank_aggregation(self.aggregator, tempfile, topk)
209
+
210
+ # enhance each result with doc info
211
+ for res in results:
212
+ res['document'] = self.qa_indexer.document_store.filter_documents(filters={'id':str(res['docid'])})[0].content
213
+
214
+ if summarize_results:
215
+ summary = self.summarize_results(question, results, candidate_texts)
216
+
217
+ self.store_results(batch_number, question, results, summary)
218
+
219
+
220
+ def store_results(self, batch_number, question, results, summary):
221
+ if results:
222
+ tweets = [t['document'] for t in results]
223
+ if question in self.answers_store:
224
+ self.answers_store[question].append({'batch_number':batch_number, 'tweets':tweets, 'summary':summary})
225
+ else:
226
+ self.answers_store[question] = [{'batch_number':batch_number, 'tweets':tweets, 'summary':summary}]
227
+
228
+
229
+
230
+ def summarize_results(self, question, results, candidate_texts):
231
+ if results:
232
+ texts = [t['document'] for t in results] #+ candidate_texts
233
+ summary = self.llm_generator.generate_answer(texts, question, mode='summarize')
234
+ else:
235
+ summary = "N/A"
236
+ return summary
237
+
238
+ def summarize_history(self, queries):
239
+ h_per_q = []
240
+ for qid,question in tqdm(queries[['id','query']].values):
241
+ if question in self.answers_store:
242
+ q_history = self.answers_store[question]
243
+ q_hist_docs = self.order_history(q_history)
244
+ h_per_q.extend(q_hist_docs)
245
+ historical_summary = self.llm_generator.generate_answer(h_per_q, question, mode='h_summarize')
246
+ return historical_summary
247
+
248
+ def order_history(self, query_history):
249
+ ordered_history = sorted(query_history, key=itemgetter('batch_number'))
250
+ ordered_docs = [hist['summary'] for hist in ordered_history]
251
+ return ordered_docs
252
+
253
+