acmc commited on
Commit
65d4803
·
verified ·
1 Parent(s): 6a4d4fc

Upload visualize.py

Browse files
Files changed (1) hide show
  1. visualize.py +208 -0
visualize.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import datasets
3
+ import difflib
4
+ import transformers
5
+ import torch
6
+ import logging
7
+
8
+
9
+ tokenizer = transformers.AutoTokenizer.from_pretrained("google/flan-t5-base")
10
+
11
+ dataset = (
12
+ datasets.load_dataset(
13
+ "shroom-semeval25/hallucinated_answer_generated_dataset",
14
+ split="test",
15
+ )
16
+ .take(10000)
17
+ .to_pandas()
18
+ .sort_values("question")
19
+ )
20
+
21
+ # Show columns in this order: question, correct_answer_generated, hallucinated_answer_generated, everything else
22
+ dataset = dataset[
23
+ ["question", "correct_answer_generated", "hallucinated_answer_generated"]
24
+ + [
25
+ col
26
+ for col in dataset.columns
27
+ if col
28
+ not in ["question", "correct_answer_generated", "hallucinated_answer_generated"]
29
+ ]
30
+ ]
31
+
32
+
33
+ def show_hallucinations(evt: gr.SelectData):
34
+ selected_row = evt.index[0]
35
+ element = dataset.iloc[selected_row]
36
+ original_text = element["correct_answer_generated"]
37
+ hallucinated_text = element["hallucinated_answer_generated"]
38
+ # tokenize both texts
39
+ original_tokens = tokenizer(
40
+ original_text, return_offsets_mapping=True, add_special_tokens=False
41
+ )
42
+ hallucinated_tokens = tokenizer(
43
+ hallucinated_text, return_offsets_mapping=True, add_special_tokens=False
44
+ )
45
+ # Find the tokens that are different. We have two lists of numbers, we need to find the differences (mind the order)
46
+ diff = difflib.SequenceMatcher(
47
+ None,
48
+ original_tokens["input_ids"],
49
+ hallucinated_tokens["input_ids"],
50
+ ).get_opcodes()
51
+ entities = []
52
+ # Follows this structure:
53
+ # {
54
+ # "entity": "+" or "-",
55
+ # "start": 0,
56
+ # "end": 0,
57
+ # }
58
+ for tag, i1, i2, j1, j2 in diff:
59
+ try:
60
+ if tag == "equal":
61
+ continue
62
+ # Anything that is not equal is a hallucination
63
+
64
+ start_char = hallucinated_tokens["offset_mapping"][j1][0]
65
+ end_char = hallucinated_tokens["offset_mapping"][j2 - 1][1] + 1
66
+ entity = {
67
+ "entity": "hal",
68
+ "start": start_char,
69
+ "end": end_char,
70
+ }
71
+ # entity_2 = {
72
+ # "entity": "-",
73
+ # "start": start,
74
+ # "end": end,
75
+ # }
76
+ entities.append(entity)
77
+ # entities.append(entity_2)
78
+ except IndexError as e:
79
+ gr.Error(f"There was an error in the tokenization process: {e}")
80
+
81
+ return [
82
+ {
83
+ "calculated_diffs": diff,
84
+ "tokenized_original": original_tokens,
85
+ "tokenized_hallucinated": hallucinated_tokens,
86
+ **element.to_dict(),
87
+ },
88
+ element["correct_answer_generated"],
89
+ {
90
+ "text": hallucinated_text,
91
+ "entities": entities,
92
+ },
93
+ ]
94
+
95
+
96
+ prediction_model = transformers.AutoModelForTokenClassification.from_pretrained(
97
+ "shroom-semeval25/cogumelo-hallucinations-detector-roberta-base"
98
+ )
99
+ prediction_tokenizer = transformers.AutoTokenizer.from_pretrained(
100
+ "shroom-semeval25/cogumelo-hallucinations-detector-roberta-base"
101
+ )
102
+
103
+
104
+ def predict_hallucinations(evt: gr.SelectData):
105
+ """The model will return 0 if it's not a hallucination, 1 if it is the beginning of a hallucination, and 2 if it's the continuation of a hallucination"""
106
+ selected_row = evt.index[0]
107
+ element = dataset.iloc[selected_row]
108
+ hallucinated_text = element["hallucinated_answer_generated"]
109
+ hallucinated_tokens = prediction_tokenizer(
110
+ hallucinated_text,
111
+ return_offsets_mapping=True,
112
+ add_special_tokens=True,
113
+ return_tensors="pt",
114
+ )
115
+
116
+ inputs = {
117
+ "input_ids": hallucinated_tokens["input_ids"],
118
+ "attention_mask": hallucinated_tokens["attention_mask"],
119
+ }
120
+ with torch.no_grad():
121
+ outputs = prediction_model(**inputs)
122
+ # Get the highest value for each token
123
+ predictions = outputs.logits.argmax(dim=-1).squeeze(0).tolist()
124
+ entities = []
125
+ current_entity = None
126
+ for i, prediction in enumerate(predictions):
127
+ if prediction == 0:
128
+ if current_entity is not None:
129
+ entities.append(current_entity)
130
+ current_entity = None
131
+ continue
132
+ if prediction == 1:
133
+ if current_entity is not None:
134
+ entities.append(current_entity)
135
+ current_entity = {
136
+ "entity": "hal",
137
+ "start": hallucinated_tokens["offset_mapping"][0][i][0],
138
+ "end": hallucinated_tokens["offset_mapping"][0][i][1] + 1,
139
+ }
140
+ if prediction == 2:
141
+ if current_entity is None:
142
+ current_entity = {
143
+ "entity": "hal",
144
+ "start": hallucinated_tokens["offset_mapping"][0][i][0],
145
+ "end": hallucinated_tokens["offset_mapping"][0][i][1] + 1,
146
+ }
147
+ else:
148
+ current_entity["end"] = (
149
+ hallucinated_tokens["offset_mapping"][0][i][1] + 1
150
+ )
151
+ if current_entity is not None:
152
+ entities.append(current_entity)
153
+ return {
154
+ "text": hallucinated_text,
155
+ "entities": entities,
156
+ }
157
+
158
+
159
+ def update_selection(evt: gr.SelectData):
160
+ # Run the two functions
161
+ json_example, original_text, highlighted_text = show_hallucinations(evt)
162
+ try:
163
+ highlighted_text_predicted = predict_hallucinations(evt)
164
+ except Exception as e:
165
+ logging.exception(f"An error occurred: {e}")
166
+ gr.Error(f"An error occurred: {e}")
167
+ highlighted_text_predicted = {"text": "", "entities": []}
168
+ return json_example, original_text, highlighted_text, highlighted_text_predicted
169
+
170
+
171
+ with gr.Blocks(title="Hallucinations Explorer") as demo:
172
+ # A selectable dataframe with the dataset
173
+ # print(dataset)
174
+ gr.Markdown(
175
+ """# Cogumelo
176
+
177
+ _SHROOM '25: Detection of Hallucinated Content_
178
+
179
+ ⚠️ These rows are part of the **test set** of the dataset, not the entire dataset (the model has therefore not seen them)"""
180
+ )
181
+
182
+ df = gr.Dataframe(dataset)
183
+
184
+ original_text = gr.Textbox(label="Original Text")
185
+ highlighted_text = gr.HighlightedText(
186
+ label="Real Hallucinations (ground truth)",
187
+ color_map={"+": "red", "-": "blue", "hal": "red"},
188
+ combine_adjacent=True,
189
+ )
190
+ highlighted_text_predicted = gr.HighlightedText(
191
+ label="Predicted Hallucinations",
192
+ color_map={"+": "red", "-": "blue", "hal": "red"},
193
+ combine_adjacent=True,
194
+ )
195
+ json_example = gr.JSON()
196
+
197
+ df.select(
198
+ update_selection,
199
+ inputs=[],
200
+ outputs=[
201
+ json_example,
202
+ original_text,
203
+ highlighted_text,
204
+ highlighted_text_predicted,
205
+ ],
206
+ )
207
+
208
+ demo.launch(show_error=True)