anselp commited on
Commit
7d62703
·
verified ·
1 Parent(s): 7d55365

Upload dipromats_evaluation.py

Browse files
Files changed (1) hide show
  1. dipromats_evaluation.py +298 -0
dipromats_evaluation.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import json
3
+ import numpy as np
4
+ import warnings
5
+ # Suprimir SettingWithCopyWarning
6
+ warnings.simplefilter(action='ignore', category=pd.errors.SettingWithCopyWarning)
7
+
8
+
9
+ gold_json_path='./gold_test.json'
10
+
11
+
12
+
13
+
14
+
15
+ import pandas as pd
16
+ import json
17
+ import numpy as np
18
+ import warnings
19
+ # Suprimir SettingWithCopyWarning
20
+ warnings.simplefilter(action='ignore', category=pd.errors.SettingWithCopyWarning)
21
+
22
+ import json
23
+
24
+ # Leer un archivo JSON
25
+
26
+ gold_json_path='./gold_test.json'
27
+ def evaluate_results(lang, file_path):
28
+ def load_gold():
29
+ df=pd.read_json(gold_json_path)
30
+ return df
31
+ def load_to_be_evaluated_set(file_path):
32
+ with open(file_path, 'r') as file:
33
+ data=json.load(file)
34
+ dft=pd.DataFrame(data)
35
+ return dft
36
+ def normalize_labels(df):
37
+ # Define a function that checks if each narrative is present and assigns "yes" or "no"
38
+ def convert_narratives(row):
39
+ country_code = row['country'][:2].upper() # Get the country code ('RU', 'CH', etc.)
40
+ narratives = row['narratives'] # List of narratives for that row
41
+
42
+ # For each N1 to N6, check if it appears in the list of narratives
43
+ for i in range(1, 7):
44
+ narrative_code = f"{country_code}{i}"
45
+ row[f"N{i}"] = 'yes' if narrative_code in narratives else 'no'
46
+ return row
47
+ # Apply the function to each row of the DataFrame
48
+ data = df.apply(convert_narratives, axis=1)
49
+ # Drop the original 'narratives' column if no longer needed
50
+ data.drop(columns=['narratives', 'tweet_id'], inplace=True)
51
+ return data
52
+ def get_gold_lists_for_evaluation(gold_list, test_list):
53
+ gold_strict=[]
54
+ gold_lenient=[]
55
+ for i in range(0,6):
56
+ g=gold_list[i]
57
+ t=test_list[i]
58
+ g = 1 if g == 'yes' else 2 if g == 'no' else g
59
+ t = 1 if t == 'yes' else 2 if t == 'no' else t
60
+ if g==t:
61
+ gold_strict.append(g)
62
+ gold_lenient.append(g)
63
+ elif g!=t:
64
+ if g in [2, 1]:
65
+ gold_strict.append(g)
66
+ gold_lenient.append(g)
67
+ else:
68
+ gold_strict.append(2)
69
+ gold_lenient.append(t)
70
+ return gold_strict, gold_lenient
71
+ def gen_dic(lang):
72
+ narratives_list=['CH1', 'CH2', 'CH3', 'CH4', 'CH5', 'CH6', 'CH_micro', 'RU1', 'RU2', 'RU3', 'RU4', 'RU5', 'RU6', 'RU_micro', 'EU1', 'EU2', 'EU3', 'EU4', 'EU5', 'EU6', 'EU_micro', 'US1', 'US2', 'US3', 'US4', 'US5', 'US6', 'US_micro']
73
+ countries_dic={'China':'CH', 'Russia':'RU', 'EU':'EU', 'USA':'US'}
74
+ dic = {}
75
+ dic[lang] = {}
76
+ for ev in ['strict', 'lenient']:
77
+ if ev not in dic[lang]:
78
+ dic[lang][ev] = {}
79
+ for narr in narratives_list:
80
+ dic[lang][ev][narr] = {'scores': {'precision': 0., 'recall': 0., 'f1-score': 0.}, 'raw_data': []}
81
+
82
+ for code in countries_dic.values():
83
+ dic[lang][ev][f'{code}_micro'] = {'scores': {'precision': 0., 'recall': 0., 'f1-score': 0}, 'raw_data': []}
84
+
85
+ dic[lang][ev]['micro'] = {'scores': {'precision': 0., 'recall': 0., 'f1-score': 0}, 'raw_data': []}
86
+ return dic
87
+ def convert_labels(values):
88
+ return np.array([
89
+ [1 if v == 'yes' else 2 if v == 'no' else 3 for v in row]
90
+ for row in values
91
+ ])
92
+ def convert_floats(dic):
93
+ for key, value in dic.items():
94
+ if isinstance(value, np.float64):
95
+ dic[key] = float(value)
96
+ elif isinstance(value, dict): # If the value is another dictionary, apply recursion
97
+ convert_floats(value)
98
+ elif isinstance(value, list): # If the value is a list, convert individual elements
99
+ dic[key] = [float(v) if isinstance(v, np.float64) else v for v in value]
100
+ dic=gen_dic(lang)
101
+ countries_dic={'China':'CH', 'Russia':'RU', 'EU':'EU', 'USA':'US'}
102
+ cols=[f'N{i}' for i in range(1,7)]
103
+ df_gold=load_gold()
104
+ df_gold.drop_duplicates(subset=['id', 'lang'], keep='last', inplace=True)
105
+ df=df_gold[df_gold['lang']==lang]
106
+ df.reset_index(inplace=True, drop=True)
107
+ df_test=load_to_be_evaluated_set(file_path)
108
+ df_test=normalize_labels(df_test)
109
+ df_test.drop_duplicates(subset=['id', 'language'], keep='last', inplace=True)
110
+ df_test.reset_index(inplace=True, drop=True)
111
+
112
+
113
+ df_strict=df.copy()
114
+ df_lenient=df.copy()
115
+ for i in range(len(df)):
116
+ lang=df['lang'].iloc[i]
117
+ id=df['id'].iloc[i]
118
+ gold_values=df[cols].iloc[i].values
119
+ dft=df_test[(df_test['language']==lang) & (df_test['id']==id)]
120
+
121
+ test_values=dft[cols].iloc[0].values
122
+ df_strict.loc[i, cols], df_lenient.loc[i, cols]=get_gold_lists_for_evaluation(gold_values, test_values)
123
+
124
+ countries=['China', 'Russia', 'EU', 'USA']
125
+
126
+ df_lang=df[(df['lang']==lang)]
127
+ df_test_lang=df_test[(df_test['language']==lang)]
128
+ df_strict_lang=df_strict[df_strict['lang']==lang]
129
+ df_lenient_lang=df_lenient[df_lenient['lang']==lang]
130
+ #F1 per narrative
131
+ for country in countries:
132
+ df_dup_t=df[(df['country']==country) & (df['lang']==lang)]
133
+ df_strict_t=df_strict_lang[df_strict_lang['country']==country]
134
+ df_lenient_t=df_lenient_lang[df_lenient_lang['country']==country]
135
+ dft=df_test_lang[(df_test_lang['country']==country)]
136
+ real_strict=[]
137
+ real_lenient=[]
138
+ real=[]
139
+ pred=[]
140
+ for i in range(len(df_strict_t)):
141
+ id=df_strict_t['id'].iloc[i]
142
+ dft2=dft[dft['id']==id]
143
+ if len(dft2)!=0:
144
+ real_strict.append(df_strict_t[cols].iloc[i].values)
145
+ real_lenient.append(df_lenient_t[cols].iloc[i].values)
146
+ pred.append(dft2[cols].iloc[0].values)
147
+ real.append(df_dup_t[df_dup_t['id']==id][cols].iloc[0].values)
148
+ real_strict=np.array(real_strict)
149
+ real_lenient=np.array(real_lenient)
150
+
151
+ real = convert_labels(real)
152
+ pred = convert_labels(pred)
153
+
154
+ for i in range(0, 6):
155
+ raw_matrix = np.zeros((2, 3), dtype=int) # 2 filas (pred), 3 columnas (real)
156
+ pred_options = [1, 2] # 1 -> 'yes', 2 -> 'no'
157
+ real_options = [1, 3, 2] # 1
158
+ p=pred[:,i]
159
+ r=real[:,i]
160
+ for p, r in zip(p, r):
161
+ pred_index = pred_options.index(p)
162
+ real_index = real_options.index(r)
163
+ raw_matrix[pred_index, real_index] += 1
164
+ tp=raw_matrix[0,0]
165
+ yl=raw_matrix[0,1]
166
+ fp=raw_matrix[0,2]
167
+ fn=raw_matrix[1,0]
168
+ nl=raw_matrix[1,1]
169
+ tn=raw_matrix[1,2]
170
+ dic[lang]['lenient'][f'{countries_dic[country]}{i+1}']['raw_data']=raw_matrix.tolist()
171
+ precision=(tp+yl)/(tp+yl+fp) if (tp+yl+fp)!=0 else 0
172
+ recall=(tp+yl)/(tp+fn+yl) if (tp+fn+yl)!=0 else 0
173
+ dic[lang]['lenient'][f'{countries_dic[country]}{i+1}']['scores']['precision']=precision
174
+ dic[lang]['lenient'][f'{countries_dic[country]}{i+1}']['scores']['recall']=recall
175
+ dic[lang]['lenient'][f'{countries_dic[country]}{i+1}']['scores']['f1-score']=(2*precision*recall)/(precision+recall) if (precision+recall)!=0 else 0
176
+ dic[lang]['strict'][f'{countries_dic[country]}{i+1}']['raw_data']=raw_matrix.tolist()
177
+ precision=tp/(tp+fp+yl) if (tp+fp+yl)!=0 else 0
178
+ recall=tp/(tp+fn) if (tp+fn)!=0 else 0
179
+ dic[lang]['strict'][f'{countries_dic[country]}{i+1}']['scores']['precision']=precision
180
+ dic[lang]['strict'][f'{countries_dic[country]}{i+1}']['scores']['recall']=recall
181
+ dic[lang]['strict'][f'{countries_dic[country]}{i+1}']['scores']['f1-score']=(2*precision*recall)/(precision+recall) if (precision+recall)!=0 else 0
182
+
183
+ #F1 Micro
184
+ real_strict=[]
185
+ real_lenient=[]
186
+ pred=[]
187
+ not_match=[]
188
+ real=[]
189
+ for i in range(len(df_lang)):
190
+ id=df_lang['id'].iloc[i]
191
+ dft=df_test_lang[df_test_lang['id']==id][cols]
192
+ if len(dft)!=0:
193
+ real_strict.extend(df_strict_lang[cols].iloc[i].values)
194
+ real_lenient.extend(df_strict_lang[cols].iloc[i].values)
195
+ pred.extend(df_test_lang[df_test_lang['id']==id][cols].iloc[0].values)
196
+ real.extend(df_lang[df_lang['id']==id][cols].iloc[0].values)
197
+ else:
198
+ not_match.append(id)
199
+
200
+ real = convert_labels([real])[0]
201
+ pred = convert_labels([pred])[0]
202
+ raw_matrix=np.zeros((2,3), dtype=int)
203
+ pred_options = [1, 2] # 1 -> 'yes', 2 -> 'no'
204
+ real_options = [1, 3, 2] # 1
205
+ raw_matrix = np.zeros((2, 3), dtype=int)
206
+ for p, r in zip(pred, real):
207
+ pred_index = pred_options.index(p)
208
+ real_index = real_options.index(r)
209
+ raw_matrix[pred_index, real_index] += 1
210
+ tp=raw_matrix[0,0]
211
+ yl=raw_matrix[0,1]
212
+ fp=raw_matrix[0,2]
213
+ fn=raw_matrix[1,0]
214
+ nl=raw_matrix[1,1]
215
+ tn=raw_matrix[1,2]
216
+ dic[lang]['lenient']['micro']['raw_data']=raw_matrix.tolist()
217
+ precision=(tp+yl)/(tp+yl+fp) if (tp+yl+fp)!=0 else 0
218
+ recall=(tp+yl)/(tp+fn+yl) if (tp+fn+yl)!=0 else 0
219
+ dic[lang]['lenient']['micro']['scores']['precision']=precision
220
+ dic[lang]['lenient']['micro']['scores']['recall']=recall
221
+ dic[lang]['lenient']['micro']['scores']['f1-score']=(2*precision*recall)/(precision+recall) if (precision+recall)!=0 else 0
222
+ dic[lang]['strict']['micro']['raw_data']=raw_matrix.tolist()
223
+ precision=tp/(tp+fp+yl) if (tp+yl+fp)!=0 else 0
224
+ recall=tp/(tp+fn) if (tp+fn)!=0 else 0
225
+ dic[lang]['strict']['micro']['scores']['precision']=precision
226
+ dic[lang]['strict']['micro']['scores']['recall']=recall
227
+ dic[lang]['strict']['micro']['scores']['f1-score']=(2*precision*recall)/(precision+recall) if (precision+recall)!=0 else 0
228
+
229
+ #Micro-Countries
230
+ for country in countries_dic.values():
231
+ raw_matrix = np.sum([np.array(dic[f'{lang}']['strict'][f'{country}{i}']['raw_data']) for i in range(1, 7)], axis=0)
232
+ tp=raw_matrix[0,0]
233
+ yl=raw_matrix[0,1]
234
+ fp=raw_matrix[0,2]
235
+ fn=raw_matrix[1,0]
236
+ nl=raw_matrix[1,1]
237
+ tn=raw_matrix[1,2]
238
+ precision=(tp+yl)/(tp+yl+fp) if (tp+yl+fp)!=0 else 0
239
+ recall=(tp+yl)/(tp+fn+yl) if (tp+fn+yl)!=0 else 0
240
+ dic[lang]['lenient'][f'{country}_micro']['scores']['precision']=precision
241
+ dic[lang]['lenient'][f'{country}_micro']['scores']['recall']=recall
242
+ dic[lang]['lenient'][f'{country}_micro']['scores']['f1-score']=(2*precision*recall)/(precision+recall) if (precision+recall)!=0 else 0
243
+ dic[lang]['lenient'][f'{country}_micro']['raw_data']=raw_matrix.tolist()
244
+ precision=tp/(tp+fp+yl) if (tp+yl+fp)!=0 else 0
245
+ recall=tp/(tp+fn) if (tp+fn)!=0 else 0
246
+ dic[lang]['strict'][f'{country}_micro']['scores']['precision']=precision
247
+ dic[lang]['strict'][f'{country}_micro']['scores']['recall']=recall
248
+ dic[lang]['strict'][f'{country}_micro']['scores']['f1-score']=(2*precision*recall)/(precision+recall) if (precision+recall)!=0 else 0
249
+ dic[lang]['strict'][f'{country}_micro']['raw_data']=raw_matrix.tolist()
250
+
251
+ convert_floats(dic[lang])
252
+
253
+ return dic[lang]
254
+
255
+
256
+
257
+
258
+
259
+ """
260
+ strict
261
+ narrative_country (e.g. CH1)
262
+ scores
263
+ precision
264
+ recall
265
+ f1-score
266
+ raw_data
267
+ country_micro (e.g. CH_micro)
268
+ scores
269
+ precision
270
+ recall
271
+ f1-score
272
+ raw_data
273
+ micro (global micro)
274
+ scores
275
+ precision
276
+ recall
277
+ f1-score
278
+ raw_data
279
+
280
+ lenient
281
+ narrative_country (e.g. CH1)
282
+ scores
283
+ precision
284
+ recall
285
+ f1-score
286
+ raw_data
287
+ country_micro (e.g. CH_micro)
288
+ scores
289
+ precision
290
+ recall
291
+ f1-score
292
+ raw_data
293
+ micro (global micro)
294
+ scores
295
+ precision
296
+ recall
297
+ f1-score
298
+ raw_data"""