File size: 7,019 Bytes
fd7d17f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
from typing import Dict, List, Any
import pandas as pd
class FilterManager:
"""Manages filtering logic for all leaderboard types"""
def __init__(self, data_loader):
self.data_loader = data_loader
self.valid_tasks = {
'NUBES', 'NorSynthClinical-NER', 'MEDIQA 2023-sum-A', 'Medication extraction',
'IMCS-V2-DAC', 'Cantemist-Coding', 'IFMIR-NER', 'EHRQA-QA', 'Ex4CDS', 'MedDG',
'MTS-Temporal', 'CHIP-MDCFNPC', 'n2c2 2014-Diabetes', 'MIMIC-III Outcome.LoS',
'n2c2 2014-Hypertension', 'RuCCoN', 'CARES-ICD10 Chapter', 'RuDReC-NER', 'MIMIC-IV DiReCT.Dis',
'n2c2 2014-Medication', 'iCorpus', 'Brateca-Hospitalization', 'n2c2 2010-Assertion',
'NorSynthClinical-PHI', 'IFMIR - NER&factuality', 'JP-STS', 'NorSynthClinical-RE',
'n2c2 2010-Concept', 'BARR2', 'IMCS-V2-NER', 'IMCS-V2-MRG', 'cMedQA', 'MedSTS',
'BRONCO150-NER&Status', 'n2c2 2018-ADE&medication', 'CLISTER', 'ClinicalNotes-UPMC',
'PPTS', 'CLIP', 'IMCS-V2-SR', 'EHRQA-Sub department', 'BrainMRI-AIS', 'Brateca-Mortality',
'meddocan', 'CHIP-CDEE', 'CAS-evidence', 'MEDIQA 2019-RQE', 'Cantemis-Norm', 'MEDIQA 2023-sum-B',
'CHIP-CTC', 'C-EMRS', 'CARES ICD10 Block', 'Cantemis-NER', 'CLINpt-NER', 'MEDIQA 2023-chat-A',
'n2c2 2014-De-identification', 'n2c2 2014-Hyperlipidemia', 'EHRQA-Primary department',
'ADE-Drug dosage', 'IFMIR-Incident type', 'MIMIC-III Outcome.Mortality', 'n2c2 2006-De-identification',
'CAS-label', 'MIMIC-IV CDM', 'CodiEsp-ICD-10-CM', 'n2c2 2010-Relation', 'CARES-ICD10 Subblock',
'MIE', 'HealthCareMagic-100k', 'ADE-Identification', 'MIMIC-IV DiReCT.PDD', 'ADE-Extraction',
'DialMed', 'GOUT-CC-Consensus', 'GraSSCo PHI', 'RuMedNLI', 'RuMedDaNet', 'CBLUE-CDN', 'icliniq-10k',
'CARDIO-DE', 'CARES-Area', 'DiSMed-NER', 'CodiEsp-ICD-10-PCS', 'MedNLI', 'MTS', 'MIMIC-IV BHC',
'n2c2 2014-CAD'
}
# Initialize filter states for each leaderboard type
self.filter_states = {
'zero_shot': self._create_empty_filter_state(),
'few_shot': self._create_empty_filter_state(),
'cot': self._create_empty_filter_state()
}
def _create_empty_filter_state(self) -> Dict[str, List]:
"""Create an empty filter state"""
return {
"Language": [],
"Task Type": [],
"Clinical Context": [],
"Data Access": [],
"Applications": [],
"Clinical Stage": []
}
def get_filtered_columns(self, filter_selections: Dict[str, List]) -> List[str]:
"""
Given an array of selected filters, return a list of all
the columns that match the criteria.
"""
valid_columns = []
for task in self.data_loader.task_information:
task_info = self.data_loader.task_information[task]
# Flag to keep track of whether this task is valid
is_valid = True
# Iterate through each attribute of the task
for attribute in task_info:
# If the filter is empty
if not filter_selections[attribute]:
continue
value = task_info[attribute]
# Handle edge case for multiple categories
if "," in value:
all_categories = value.split(", ")
flag = False
for category in all_categories:
if category in filter_selections[attribute]:
flag = True
break
if flag: # one category matches
is_valid = True
else: # none of the categories matched
is_valid = False
# Handle Brazilian Edge Case
elif (value == 'Portuguese\n(Brazilian)') and ('Portuguese' in filter_selections[attribute]):
is_valid = True
break
elif value not in filter_selections[attribute]:
is_valid = False
if task in self.valid_tasks and is_valid:
valid_columns.append(task)
return valid_columns
def is_empty(self, filter_selections: Dict[str, List]) -> bool:
"""Check if there are no selected filters"""
return all(not value for value in filter_selections.values())
def update_average_performance(self, leaderboard_type: str, selected_columns: List[str]) -> Dict[str, float]:
"""
Calculate updated average performance based on selected columns
"""
updated_average_performance = {}
leaderboard_json = self.data_loader.get_leaderboard_json(leaderboard_type)
for i in range(self.data_loader.n_models):
performance = 0
num_tasks = 0
for task in selected_columns:
if task in leaderboard_json:
num_tasks += 1
performance += float(leaderboard_json[task][str(i)])
if num_tasks == 0:
num_tasks = 1
updated_average_performance[f"{i}"] = float(round(performance / num_tasks, 2))
return updated_average_performance
def apply_filter(self, leaderboard_type: str, filter_type: str, filter_values: List[str]) -> pd.DataFrame:
"""
Apply a filter to a specific leaderboard type and return updated dataframe
"""
# Update the filter state
self.filter_states[leaderboard_type][filter_type] = filter_values
# Get the dataframe
df = self.data_loader.get_dataframe(leaderboard_type).copy()
# If no filters are applied, reset to original performance
if self.is_empty(self.filter_states[leaderboard_type]):
df["Average Performance"] = self.data_loader.get_original_performance(leaderboard_type)
return df
# Get filtered columns
filtered_cols = self.get_filtered_columns(self.filter_states[leaderboard_type])
# Update average performance
updated_performance = self.update_average_performance(leaderboard_type, filtered_cols)
# Convert dictionary keys to integers to match the DataFrame index
updated_performance_int = {int(k): v for k, v in updated_performance.items()}
# Map the values to the 'Average Performance' column based on index
df["Average Performance"] = df.index.map(updated_performance_int)
# Return dataframe with filtered columns
base_columns = ['T', 'Model', 'Model: Domain', 'Model: Accessibility', 'Model: Size Range', 'Size (B)', 'Average Performance']
return df[base_columns + filtered_cols] |