anselp commited on
Commit
c65e197
·
verified ·
1 Parent(s): a97354e

Upload dipromats_evaluation_v2.py

Browse files
Files changed (1) hide show
  1. dipromats_evaluation_v2.py +285 -0
dipromats_evaluation_v2.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import json
3
+ import numpy as np
4
+ import warnings
5
+ # Suprimir SettingWithCopyWarning
6
+ warnings.simplefilter(action='ignore', category=pd.errors.SettingWithCopyWarning)
7
+ GOLD_PATH='/home/jfraile/Programs/DIPROMATS_2024/evaluation_script/gold_test.json'
8
+ file_path='/home/jfraile/Programs/DIPROMATS_2024/evaluation_script/test_en.json'
9
+ lang='en'
10
+
11
+ with open(GOLD_PATH, 'r') as g:
12
+ gold = json.load(g)
13
+
14
+ with open(file_path, 'r') as f:
15
+ test = json.load(f)
16
+
17
+
18
+
19
+ def evaluate_results(lang, gold, test):
20
+ def normalize_labels(df):
21
+ # Define a function that checks if each narrative is present and assigns "yes" or "no"
22
+ def convert_narratives(row):
23
+ country_code = row['country'][:2].upper() # Get the country code ('RU', 'CH', etc.)
24
+ narratives = row['narratives'] # List of narratives for that row
25
+
26
+ # For each N1 to N6, check if it appears in the list of narratives
27
+ for i in range(1, 7):
28
+ narrative_code = f"{country_code}{i}"
29
+ row[f"N{i}"] = 'yes' if narrative_code in narratives else 'no'
30
+ return row
31
+ # Apply the function to each row of the DataFrame
32
+ data = df.apply(convert_narratives, axis=1)
33
+ # Drop the original 'narratives' column if no longer needed
34
+ data.drop(columns=['narratives', 'tweet_id'], inplace=True)
35
+ return data
36
+ def get_gold_lists_for_evaluation(gold_list, test_list):
37
+ gold_strict=[]
38
+ gold_lenient=[]
39
+ for i in range(0,6):
40
+ g=gold_list[i]
41
+ t=test_list[i]
42
+ g = 1 if g == 'yes' else 2 if g == 'no' else g
43
+ t = 1 if t == 'yes' else 2 if t == 'no' else t
44
+ if g==t:
45
+ gold_strict.append(g)
46
+ gold_lenient.append(g)
47
+ elif g!=t:
48
+ if g in [2, 1]:
49
+ gold_strict.append(g)
50
+ gold_lenient.append(g)
51
+ else:
52
+ gold_strict.append(2)
53
+ gold_lenient.append(t)
54
+ return gold_strict, gold_lenient
55
+ def gen_dic(lang):
56
+ narratives_list=['CH1', 'CH2', 'CH3', 'CH4', 'CH5', 'CH6', 'CH_micro', 'RU1', 'RU2', 'RU3', 'RU4', 'RU5', 'RU6', 'RU_micro', 'EU1', 'EU2', 'EU3', 'EU4', 'EU5', 'EU6', 'EU_micro', 'US1', 'US2', 'US3', 'US4', 'US5', 'US6', 'US_micro']
57
+ countries_dic={'China':'CH', 'Russia':'RU', 'EU':'EU', 'USA':'US'}
58
+ dic = {}
59
+ dic[lang] = {}
60
+ for ev in ['strict', 'lenient']:
61
+ if ev not in dic[lang]:
62
+ dic[lang][ev] = {}
63
+ for narr in narratives_list:
64
+ dic[lang][ev][narr] = {'scores': {'precision': 0., 'recall': 0., 'f1-score': 0.}, 'raw_data': []}
65
+
66
+ for code in countries_dic.values():
67
+ dic[lang][ev][f'{code}_micro'] = {'scores': {'precision': 0., 'recall': 0., 'f1-score': 0}, 'raw_data': []}
68
+
69
+ dic[lang][ev]['micro'] = {'scores': {'precision': 0., 'recall': 0., 'f1-score': 0}, 'raw_data': []}
70
+ return dic
71
+ def convert_labels(values):
72
+ return np.array([
73
+ [1 if v == 'yes' else 2 if v == 'no' else 3 for v in row]
74
+ for row in values
75
+ ])
76
+ def convert_floats(dic):
77
+ for key, value in dic.items():
78
+ if isinstance(value, np.float64):
79
+ dic[key] = float(value)
80
+ elif isinstance(value, dict): # If the value is another dictionary, apply recursion
81
+ convert_floats(value)
82
+ elif isinstance(value, list): # If the value is a list, convert individual elements
83
+ dic[key] = [float(v) if isinstance(v, np.float64) else v for v in value]
84
+ dic=gen_dic(lang)
85
+ countries_dic={'China':'CH', 'Russia':'RU', 'EU':'EU', 'USA':'US'}
86
+ cols=[f'N{i}' for i in range(1,7)]
87
+
88
+ df_gold=pd.DataFrame(gold)
89
+ df_gold["country"] = df_gold["country"].replace("European Union", "EU")
90
+ df_gold.drop_duplicates(subset=['id', 'lang'], keep='last', inplace=True)
91
+ df=df_gold[df_gold['lang']==lang]
92
+ df.reset_index(inplace=True, drop=True)
93
+
94
+ df_test=pd.DataFrame(test)
95
+ df_test["country"] = df_test["country"].replace("European Union", "EU")
96
+ df_test=normalize_labels(df_test)
97
+ df_test.drop_duplicates(subset=['id', 'language'], keep='last', inplace=True)
98
+ df_test.reset_index(inplace=True, drop=True)
99
+
100
+ df_strict=df.copy()
101
+ df_lenient=df.copy()
102
+ for i in range(len(df)):
103
+ lang=df['lang'].iloc[i]
104
+ id=df['id'].iloc[i]
105
+ gold_values=df[cols].iloc[i].values
106
+ dft=df_test[(df_test['language']==lang) & (df_test['id']==id)]
107
+
108
+ test_values=dft[cols].iloc[0].values
109
+ df_strict.loc[i, cols], df_lenient.loc[i, cols]=get_gold_lists_for_evaluation(gold_values, test_values)
110
+
111
+ countries=['China', 'Russia', 'EU', 'USA']
112
+
113
+ df_lang=df[(df['lang']==lang)]
114
+ df_test_lang=df_test[(df_test['language']==lang)]
115
+ df_strict_lang=df_strict[df_strict['lang']==lang]
116
+ df_lenient_lang=df_lenient[df_lenient['lang']==lang]
117
+ #F1 per narrative
118
+ for country in countries:
119
+ df_dup_t=df[(df['country']==country) & (df['lang']==lang)]
120
+ df_strict_t=df_strict_lang[df_strict_lang['country']==country]
121
+ df_lenient_t=df_lenient_lang[df_lenient_lang['country']==country]
122
+ dft=df_test_lang[(df_test_lang['country']==country)]
123
+ real_strict=[]
124
+ real_lenient=[]
125
+ real=[]
126
+ pred=[]
127
+ for i in range(len(df_strict_t)):
128
+ id=df_strict_t['id'].iloc[i]
129
+ dft2=dft[dft['id']==id]
130
+ if len(dft2)!=0:
131
+ real_strict.append(df_strict_t[cols].iloc[i].values)
132
+ real_lenient.append(df_lenient_t[cols].iloc[i].values)
133
+ pred.append(dft2[cols].iloc[0].values)
134
+ real.append(df_dup_t[df_dup_t['id']==id][cols].iloc[0].values)
135
+ real_strict=np.array(real_strict)
136
+ real_lenient=np.array(real_lenient)
137
+
138
+ real = convert_labels(real)
139
+ pred = convert_labels(pred)
140
+
141
+ for i in range(0, 6):
142
+ raw_matrix = np.zeros((2, 3), dtype=int) # 2 filas (pred), 3 columnas (real)
143
+ pred_options = [1, 2] # 1 -> 'yes', 2 -> 'no'
144
+ real_options = [1, 3, 2] # 1
145
+ p=pred[:,i]
146
+ r=real[:,i]
147
+ for p, r in zip(p, r):
148
+ pred_index = pred_options.index(p)
149
+ real_index = real_options.index(r)
150
+ raw_matrix[pred_index, real_index] += 1
151
+ tp=raw_matrix[0,0]
152
+ yl=raw_matrix[0,1]
153
+ fp=raw_matrix[0,2]
154
+ fn=raw_matrix[1,0]
155
+ nl=raw_matrix[1,1]
156
+ tn=raw_matrix[1,2]
157
+ dic[lang]['lenient'][f'{countries_dic[country]}{i+1}']['raw_data']=raw_matrix.tolist()
158
+ precision=(tp+yl)/(tp+yl+fp) if (tp+yl+fp)!=0 else 0
159
+ recall=(tp+yl)/(tp+fn+yl) if (tp+fn+yl)!=0 else 0
160
+ dic[lang]['lenient'][f'{countries_dic[country]}{i+1}']['scores']['precision']=precision
161
+ dic[lang]['lenient'][f'{countries_dic[country]}{i+1}']['scores']['recall']=recall
162
+ dic[lang]['lenient'][f'{countries_dic[country]}{i+1}']['scores']['f1-score']=(2*precision*recall)/(precision+recall) if (precision+recall)!=0 else 0
163
+ dic[lang]['strict'][f'{countries_dic[country]}{i+1}']['raw_data']=raw_matrix.tolist()
164
+ precision=tp/(tp+fp+yl) if (tp+fp+yl)!=0 else 0
165
+ recall=tp/(tp+fn) if (tp+fn)!=0 else 0
166
+ dic[lang]['strict'][f'{countries_dic[country]}{i+1}']['scores']['precision']=precision
167
+ dic[lang]['strict'][f'{countries_dic[country]}{i+1}']['scores']['recall']=recall
168
+ dic[lang]['strict'][f'{countries_dic[country]}{i+1}']['scores']['f1-score']=(2*precision*recall)/(precision+recall) if (precision+recall)!=0 else 0
169
+
170
+ #F1 Micro
171
+ real_strict=[]
172
+ real_lenient=[]
173
+ pred=[]
174
+ not_match=[]
175
+ real=[]
176
+ for i in range(len(df_lang)):
177
+ id=df_lang['id'].iloc[i]
178
+ dft=df_test_lang[df_test_lang['id']==id][cols]
179
+ if len(dft)!=0:
180
+ real_strict.extend(df_strict_lang[cols].iloc[i].values)
181
+ real_lenient.extend(df_strict_lang[cols].iloc[i].values)
182
+ pred.extend(df_test_lang[df_test_lang['id']==id][cols].iloc[0].values)
183
+ real.extend(df_lang[df_lang['id']==id][cols].iloc[0].values)
184
+ else:
185
+ not_match.append(id)
186
+
187
+ real = convert_labels([real])[0]
188
+ pred = convert_labels([pred])[0]
189
+ raw_matrix=np.zeros((2,3), dtype=int)
190
+ pred_options = [1, 2] # 1 -> 'yes', 2 -> 'no'
191
+ real_options = [1, 3, 2] # 1
192
+ raw_matrix = np.zeros((2, 3), dtype=int)
193
+ for p, r in zip(pred, real):
194
+ pred_index = pred_options.index(p)
195
+ real_index = real_options.index(r)
196
+ raw_matrix[pred_index, real_index] += 1
197
+ tp=raw_matrix[0,0]
198
+ yl=raw_matrix[0,1]
199
+ fp=raw_matrix[0,2]
200
+ fn=raw_matrix[1,0]
201
+ nl=raw_matrix[1,1]
202
+ tn=raw_matrix[1,2]
203
+ dic[lang]['lenient']['micro']['raw_data']=raw_matrix.tolist()
204
+ precision=(tp+yl)/(tp+yl+fp) if (tp+yl+fp)!=0 else 0
205
+ recall=(tp+yl)/(tp+fn+yl) if (tp+fn+yl)!=0 else 0
206
+ dic[lang]['lenient']['micro']['scores']['precision']=precision
207
+ dic[lang]['lenient']['micro']['scores']['recall']=recall
208
+ dic[lang]['lenient']['micro']['scores']['f1-score']=(2*precision*recall)/(precision+recall) if (precision+recall)!=0 else 0
209
+ dic[lang]['strict']['micro']['raw_data']=raw_matrix.tolist()
210
+ precision=tp/(tp+fp+yl) if (tp+yl+fp)!=0 else 0
211
+ recall=tp/(tp+fn) if (tp+fn)!=0 else 0
212
+ dic[lang]['strict']['micro']['scores']['precision']=precision
213
+ dic[lang]['strict']['micro']['scores']['recall']=recall
214
+ dic[lang]['strict']['micro']['scores']['f1-score']=(2*precision*recall)/(precision+recall) if (precision+recall)!=0 else 0
215
+
216
+ #Micro-Countries
217
+ for country in countries_dic.values():
218
+ raw_matrix = np.sum([np.array(dic[f'{lang}']['strict'][f'{country}{i}']['raw_data']) for i in range(1, 7)], axis=0)
219
+ tp=raw_matrix[0,0]
220
+ yl=raw_matrix[0,1]
221
+ fp=raw_matrix[0,2]
222
+ fn=raw_matrix[1,0]
223
+ nl=raw_matrix[1,1]
224
+ tn=raw_matrix[1,2]
225
+ precision=(tp+yl)/(tp+yl+fp) if (tp+yl+fp)!=0 else 0
226
+ recall=(tp+yl)/(tp+fn+yl) if (tp+fn+yl)!=0 else 0
227
+ dic[lang]['lenient'][f'{country}_micro']['scores']['precision']=precision
228
+ dic[lang]['lenient'][f'{country}_micro']['scores']['recall']=recall
229
+ dic[lang]['lenient'][f'{country}_micro']['scores']['f1-score']=(2*precision*recall)/(precision+recall) if (precision+recall)!=0 else 0
230
+ dic[lang]['lenient'][f'{country}_micro']['raw_data']=raw_matrix.tolist()
231
+ precision=tp/(tp+fp+yl) if (tp+yl+fp)!=0 else 0
232
+ recall=tp/(tp+fn) if (tp+fn)!=0 else 0
233
+ dic[lang]['strict'][f'{country}_micro']['scores']['precision']=precision
234
+ dic[lang]['strict'][f'{country}_micro']['scores']['recall']=recall
235
+ dic[lang]['strict'][f'{country}_micro']['scores']['f1-score']=(2*precision*recall)/(precision+recall) if (precision+recall)!=0 else 0
236
+ dic[lang]['strict'][f'{country}_micro']['raw_data']=raw_matrix.tolist()
237
+
238
+ convert_floats(dic[lang])
239
+
240
+ return dic[lang]
241
+
242
+ results = evaluate_results(lang, gold, test)
243
+ print(results)
244
+
245
+
246
+ """
247
+ strict
248
+ narrative_country (e.g. CH1)
249
+ scores
250
+ precision
251
+ recall
252
+ f1-score
253
+ raw_data
254
+ country_micro (e.g. CH_micro)
255
+ scores
256
+ precision
257
+ recall
258
+ f1-score
259
+ raw_data
260
+ micro (global micro)
261
+ scores
262
+ precision
263
+ recall
264
+ f1-score
265
+ raw_data
266
+
267
+ lenient
268
+ narrative_country (e.g. CH1)
269
+ scores
270
+ precision
271
+ recall
272
+ f1-score
273
+ raw_data
274
+ country_micro (e.g. CH_micro)
275
+ scores
276
+ precision
277
+ recall
278
+ f1-score
279
+ raw_data
280
+ micro (global micro)
281
+ scores
282
+ precision
283
+ recall
284
+ f1-score
285
+ raw_data"""