Refactor ISCO_Hierarchical_Accuracy class to use weighted hierarchy dictionary
Browse files- isco_hierarchical_accuracy.py +42 -25
isco_hierarchical_accuracy.py
CHANGED
|
@@ -114,15 +114,14 @@ class ISCO_Hierarchical_Accuracy(evaluate.Metric):
|
|
| 114 |
|
| 115 |
def create_hierarchy_dict(self, file: str) -> dict:
|
| 116 |
"""
|
| 117 |
-
Creates a dictionary where keys are nodes and values are
|
| 118 |
-
|
| 119 |
-
A csv file with the ISCO-08 structure can be downloaded from the International Labour Organization (ILO) at [https://www.ilo.org/ilostat-files/ISCO/newdocs-08-2021/ISCO-08/ISCO-08 EN.csv](https://www.ilo.org/ilostat-files/ISCO/newdocs-08-2021/ISCO-08/ISCO-08%20EN.csv)
|
| 120 |
|
| 121 |
Args:
|
| 122 |
- file: A string representing the path to the CSV file containing the 4-digit ISCO-08 codes. It can be a local path or a web URL.
|
| 123 |
|
| 124 |
Returns:
|
| 125 |
-
- A dictionary where keys are ISCO-08 unit codes and values are
|
| 126 |
"""
|
| 127 |
|
| 128 |
try:
|
|
@@ -146,7 +145,12 @@ class ISCO_Hierarchical_Accuracy(evaluate.Metric):
|
|
| 146 |
minor_code = unit_code[0:3]
|
| 147 |
sub_major_code = unit_code[0:2]
|
| 148 |
major_code = unit_code[0]
|
| 149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
return isco_hierarchy
|
| 152 |
|
|
@@ -192,40 +196,53 @@ class ISCO_Hierarchical_Accuracy(evaluate.Metric):
|
|
| 192 |
self,
|
| 193 |
reference_codes: List[str],
|
| 194 |
predicted_codes: List[str],
|
| 195 |
-
hierarchy: Dict[str,
|
| 196 |
) -> Tuple[float, float]:
|
| 197 |
"""
|
| 198 |
Calculates the hierarchical precision and recall given the reference codes, predicted codes, and hierarchy definition.
|
| 199 |
|
| 200 |
Args:
|
| 201 |
-
|
| 202 |
predicted_codes (List[str]): The list of predicted codes.
|
| 203 |
hierarchy (Dict[str, Set[str]]): The hierarchy definition where keys are nodes and values are sets of parent nodes.
|
| 204 |
|
| 205 |
Returns:
|
| 206 |
Tuple[float, float]: A tuple containing the hierarchical precision and recall floating point values.
|
| 207 |
"""
|
| 208 |
-
|
| 209 |
-
extended_real = set()
|
| 210 |
-
for code in reference_codes:
|
| 211 |
-
extended_real.add(code)
|
| 212 |
-
extended_real.update(hierarchy.get(code, set()))
|
| 213 |
|
| 214 |
-
|
| 215 |
-
for code in
|
| 216 |
-
|
| 217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
|
| 219 |
-
|
| 220 |
-
correct_predictions = extended_real.intersection(extended_predicted)
|
| 221 |
|
| 222 |
-
#
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
|
| 230 |
return hP, hR
|
| 231 |
|
|
|
|
| 114 |
|
| 115 |
def create_hierarchy_dict(self, file: str) -> dict:
|
| 116 |
"""
|
| 117 |
+
Creates a dictionary where keys are nodes and values are dictionaries of their parent nodes with distance as weights,
|
| 118 |
+
representing the group level hierarchy of the ISCO-08 structure.
|
|
|
|
| 119 |
|
| 120 |
Args:
|
| 121 |
- file: A string representing the path to the CSV file containing the 4-digit ISCO-08 codes. It can be a local path or a web URL.
|
| 122 |
|
| 123 |
Returns:
|
| 124 |
+
- A dictionary where keys are ISCO-08 unit codes and values are dictionaries of their parent codes with distances.
|
| 125 |
"""
|
| 126 |
|
| 127 |
try:
|
|
|
|
| 145 |
minor_code = unit_code[0:3]
|
| 146 |
sub_major_code = unit_code[0:2]
|
| 147 |
major_code = unit_code[0]
|
| 148 |
+
|
| 149 |
+
# Assign weights, higher for closer ancestors
|
| 150 |
+
weights = {minor_code: 0.75, sub_major_code: 0.5, major_code: 0.25}
|
| 151 |
+
|
| 152 |
+
# Store ancestors with their weights
|
| 153 |
+
isco_hierarchy[unit_code] = weights
|
| 154 |
|
| 155 |
return isco_hierarchy
|
| 156 |
|
|
|
|
| 196 |
self,
|
| 197 |
reference_codes: List[str],
|
| 198 |
predicted_codes: List[str],
|
| 199 |
+
hierarchy: Dict[str, Dict[str, float]],
|
| 200 |
) -> Tuple[float, float]:
|
| 201 |
"""
|
| 202 |
Calculates the hierarchical precision and recall given the reference codes, predicted codes, and hierarchy definition.
|
| 203 |
|
| 204 |
Args:
|
| 205 |
+
reference_codes (List[str]): The list of reference codes.
|
| 206 |
predicted_codes (List[str]): The list of predicted codes.
|
| 207 |
hierarchy (Dict[str, Set[str]]): The hierarchy definition where keys are nodes and values are sets of parent nodes.
|
| 208 |
|
| 209 |
Returns:
|
| 210 |
Tuple[float, float]: A tuple containing the hierarchical precision and recall floating point values.
|
| 211 |
"""
|
| 212 |
+
extended_real = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
|
| 214 |
+
# Extend the sets of reference codes with their ancestors
|
| 215 |
+
for code in reference_codes:
|
| 216 |
+
weight = 1.0 # Full weight for exact match
|
| 217 |
+
extended_real[code] = weight
|
| 218 |
+
for ancestor, ancestor_weight in hierarchy.get(code, {}).items():
|
| 219 |
+
extended_real[ancestor] = max(
|
| 220 |
+
extended_real.get(ancestor, 0), ancestor_weight
|
| 221 |
+
)
|
| 222 |
|
| 223 |
+
extended_predicted = {}
|
|
|
|
| 224 |
|
| 225 |
+
# Extend the sets of predicted codes with their ancestors
|
| 226 |
+
for code in predicted_codes:
|
| 227 |
+
weight = 1.0
|
| 228 |
+
extended_predicted[code] = weight
|
| 229 |
+
for ancestor, ancestor_weight in hierarchy.get(code, {}).items():
|
| 230 |
+
extended_predicted[ancestor] = max(
|
| 231 |
+
extended_predicted.get(ancestor, 0), ancestor_weight
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
# Calculate weighted correct predictions
|
| 235 |
+
correct_weights = 0
|
| 236 |
+
for code, weight in extended_predicted.items():
|
| 237 |
+
if code in extended_real:
|
| 238 |
+
correct_weights += min(weight, extended_real[code])
|
| 239 |
+
|
| 240 |
+
total_predicted_weights = sum(extended_predicted.values())
|
| 241 |
+
total_real_weights = sum(extended_real.values())
|
| 242 |
+
|
| 243 |
+
# Calculate hierarchical precision and recall using weighted sums
|
| 244 |
+
hP = correct_weights / total_predicted_weights if total_predicted_weights else 0
|
| 245 |
+
hR = correct_weights / total_real_weights if total_real_weights else 0
|
| 246 |
|
| 247 |
return hP, hR
|
| 248 |
|