jayebaku commited on
Commit
47df43c
·
verified ·
1 Parent(s): 77e3da1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -10
app.py CHANGED
@@ -35,14 +35,6 @@ def load_and_classify_csv(file, text_field, event_model):
35
  not_related = gr.CheckboxGroup(choices=df[df["model_label"]=="none"][text_field].to_list())
36
 
37
  return flood_related, fire_related, not_related, model_confidence, len(df[text_field].to_list()), df
38
-
39
- def qa_process(selections):
40
- selected_texts = selections
41
-
42
- analysis_results = [f"Word Count: {len(text.split())}" for text in selected_texts]
43
-
44
- result_df = pd.DataFrame({"Selected Text": selected_texts, "Analysis": analysis_results})
45
- return result_df
46
 
47
  def calculate_accuracy(flood_selections, fire_selections, none_selections, num_posts, text_field, data_df):
48
  posts = data_df[text_field].to_list()
@@ -94,6 +86,42 @@ def add_query(to_add, history):
94
  if to_add not in history:
95
  history.append(to_add)
96
  return gr.CheckboxGroup(choices=history), history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  with gr.Blocks() as demo:
99
  event_models = ["jayebaku/distilbert-base-multilingual-cased-crexdata-relevance-classifier"]
@@ -209,7 +237,9 @@ with gr.Blocks() as demo:
209
 
210
 
211
  addqry_button.click(add_query, inputs=[query_inp, queries_state], outputs=[selected_queries, queries_state])
212
- qa_button.click(qa_process, inputs=selected_queries, outputs=analysis_output)
213
-
 
 
214
 
215
  demo.launch()
 
35
  not_related = gr.CheckboxGroup(choices=df[df["model_label"]=="none"][text_field].to_list())
36
 
37
  return flood_related, fire_related, not_related, model_confidence, len(df[text_field].to_list()), df
 
 
 
 
 
 
 
 
38
 
39
  def calculate_accuracy(flood_selections, fire_selections, none_selections, num_posts, text_field, data_df):
40
  posts = data_df[text_field].to_list()
 
86
  if to_add not in history:
87
  history.append(to_add)
88
  return gr.CheckboxGroup(choices=history), history
89
+
90
+ def qa_process(selected_queries, qa_llm_model, aggregator,
91
+ batch_size, topk, text_field, data_df):
92
+
93
+ emb_model = 'multi-qa-mpnet-base-dot-v1'
94
+ contexts = []
95
+
96
+ queries_df = pd.DataFrame({'id':[j for j in range(len(selected_queries))],'query': selected_queries})
97
+
98
+ tweets_df = data_df[[text_field]]
99
+ tweets_df.reset_index(inplace=True)
100
+ tweets_df.rename(columns={"index": "order"},inplace=True)
101
+
102
+ gr.Info("Loading GENRA pipeline....")
103
+ genra = GenraPipeline(qa_llm_model, emb_model, aggregator, contexts)
104
+ gr.Info("Waiting for data...")
105
+ batches = [tweets_df[i:i+batch_size] for i in range(0,len(tweets_df),batch_size)]
106
+
107
+ genra_answers = []
108
+ summarize_batch = True
109
+ for batch_number, tweets in enumerate(batches):
110
+ gr.Info(f"Populating index for batch {batch_number}")
111
+ genra.qa_indexer.index_dataframe(tweets)
112
+ gr.Info(f"Performing retrieval for batch {batch_number}")
113
+ genra.retrieval(batch_number, queries_df, topk, summarize_batch)
114
+
115
+ gr.Info("Processed all batches!")
116
+ # result ------ genra.answers_store
117
+
118
+ summary = genra.summarize_history(queries_df)
119
+
120
+ analysis_results = [f"Word Count: {len(text.split())}" for text in selected_queries]
121
+
122
+ result_df = pd.DataFrame({"Selected Text": selected_queries, "Analysis": analysis_results})
123
+ return result_df, summary
124
+
125
 
126
  with gr.Blocks() as demo:
127
  event_models = ["jayebaku/distilbert-base-multilingual-cased-crexdata-relevance-classifier"]
 
237
 
238
 
239
  addqry_button.click(add_query, inputs=[query_inp, queries_state], outputs=[selected_queries, queries_state])
240
+ qa_button.click(qa_process,
241
+ inputs=[selected_queries, qa_llm_model, aggregator, batch_size, topk, text_field, data],
242
+ outputs=[analysis_output, ])
243
+
244
 
245
  demo.launch()