File size: 7,019 Bytes
fd7d17f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from typing import Dict, List, Any
import pandas as pd

class FilterManager:
    """Manages filtering logic for all leaderboard types"""
    
    def __init__(self, data_loader):
        self.data_loader = data_loader
        self.valid_tasks = {
            'NUBES', 'NorSynthClinical-NER', 'MEDIQA 2023-sum-A', 'Medication extraction', 
            'IMCS-V2-DAC', 'Cantemist-Coding', 'IFMIR-NER', 'EHRQA-QA', 'Ex4CDS', 'MedDG', 
            'MTS-Temporal', 'CHIP-MDCFNPC', 'n2c2 2014-Diabetes', 'MIMIC-III Outcome.LoS', 
            'n2c2 2014-Hypertension', 'RuCCoN', 'CARES-ICD10 Chapter', 'RuDReC-NER', 'MIMIC-IV DiReCT.Dis', 
            'n2c2 2014-Medication', 'iCorpus', 'Brateca-Hospitalization', 'n2c2 2010-Assertion', 
            'NorSynthClinical-PHI', 'IFMIR - NER&factuality', 'JP-STS', 'NorSynthClinical-RE', 
            'n2c2 2010-Concept', 'BARR2', 'IMCS-V2-NER', 'IMCS-V2-MRG', 'cMedQA', 'MedSTS', 
            'BRONCO150-NER&Status', 'n2c2 2018-ADE&medication', 'CLISTER', 'ClinicalNotes-UPMC', 
            'PPTS', 'CLIP', 'IMCS-V2-SR', 'EHRQA-Sub department', 'BrainMRI-AIS', 'Brateca-Mortality', 
            'meddocan', 'CHIP-CDEE', 'CAS-evidence', 'MEDIQA 2019-RQE', 'Cantemis-Norm', 'MEDIQA 2023-sum-B', 
            'CHIP-CTC', 'C-EMRS', 'CARES ICD10 Block', 'Cantemis-NER', 'CLINpt-NER', 'MEDIQA 2023-chat-A', 
            'n2c2 2014-De-identification', 'n2c2 2014-Hyperlipidemia', 'EHRQA-Primary department', 
            'ADE-Drug dosage', 'IFMIR-Incident type', 'MIMIC-III Outcome.Mortality', 'n2c2 2006-De-identification', 
            'CAS-label', 'MIMIC-IV CDM', 'CodiEsp-ICD-10-CM', 'n2c2 2010-Relation', 'CARES-ICD10 Subblock', 
            'MIE', 'HealthCareMagic-100k', 'ADE-Identification', 'MIMIC-IV DiReCT.PDD', 'ADE-Extraction', 
            'DialMed', 'GOUT-CC-Consensus', 'GraSSCo PHI', 'RuMedNLI', 'RuMedDaNet', 'CBLUE-CDN', 'icliniq-10k', 
            'CARDIO-DE', 'CARES-Area', 'DiSMed-NER', 'CodiEsp-ICD-10-PCS', 'MedNLI', 'MTS', 'MIMIC-IV BHC', 
            'n2c2 2014-CAD'
        }
        
        # Initialize filter states for each leaderboard type
        self.filter_states = {
            'zero_shot': self._create_empty_filter_state(),
            'few_shot': self._create_empty_filter_state(),
            'cot': self._create_empty_filter_state()
        }
    
    def _create_empty_filter_state(self) -> Dict[str, List]:
        """Create an empty filter state"""
        return {
            "Language": [],
            "Task Type": [],
            "Clinical Context": [],
            "Data Access": [],
            "Applications": [],
            "Clinical Stage": []
        }
    
    def get_filtered_columns(self, filter_selections: Dict[str, List]) -> List[str]:
        """
        Given an array of selected filters, return a list of all
        the columns that match the criteria.
        """
        valid_columns = []
        for task in self.data_loader.task_information:
            task_info = self.data_loader.task_information[task]
            
            # Flag to keep track of whether this task is valid
            is_valid = True

            # Iterate through each attribute of the task
            for attribute in task_info:
                # If the filter is empty
                if not filter_selections[attribute]:
                    continue

                value = task_info[attribute]

                # Handle edge case for multiple categories
                if "," in value:
                    all_categories = value.split(", ")
                    flag = False
                    for category in all_categories:
                        if category in filter_selections[attribute]:
                            flag = True
                            break
                    
                    if flag:  # one category matches
                        is_valid = True
                    else:  # none of the categories matched
                        is_valid = False

                # Handle Brazilian Edge Case
                elif (value == 'Portuguese\n(Brazilian)') and ('Portuguese' in filter_selections[attribute]):
                    is_valid = True
                    break
            
                elif value not in filter_selections[attribute]:
                    is_valid = False

            if task in self.valid_tasks and is_valid:
                valid_columns.append(task)

        return valid_columns

    def is_empty(self, filter_selections: Dict[str, List]) -> bool:
        """Check if there are no selected filters"""
        return all(not value for value in filter_selections.values())

    def update_average_performance(self, leaderboard_type: str, selected_columns: List[str]) -> Dict[str, float]:
        """
        Calculate updated average performance based on selected columns
        """
        updated_average_performance = {}
        leaderboard_json = self.data_loader.get_leaderboard_json(leaderboard_type)
        
        for i in range(self.data_loader.n_models):
            performance = 0
            num_tasks = 0
            
            for task in selected_columns:
                if task in leaderboard_json:
                    num_tasks += 1
                    performance += float(leaderboard_json[task][str(i)])

            if num_tasks == 0:
                num_tasks = 1
            
            updated_average_performance[f"{i}"] = float(round(performance / num_tasks, 2))

        return updated_average_performance

    def apply_filter(self, leaderboard_type: str, filter_type: str, filter_values: List[str]) -> pd.DataFrame:
        """
        Apply a filter to a specific leaderboard type and return updated dataframe
        """
        # Update the filter state
        self.filter_states[leaderboard_type][filter_type] = filter_values
        
        # Get the dataframe
        df = self.data_loader.get_dataframe(leaderboard_type).copy()
        
        # If no filters are applied, reset to original performance
        if self.is_empty(self.filter_states[leaderboard_type]):
            df["Average Performance"] = self.data_loader.get_original_performance(leaderboard_type)
            return df

        # Get filtered columns
        filtered_cols = self.get_filtered_columns(self.filter_states[leaderboard_type])
        
        # Update average performance
        updated_performance = self.update_average_performance(leaderboard_type, filtered_cols)
        
        # Convert dictionary keys to integers to match the DataFrame index
        updated_performance_int = {int(k): v for k, v in updated_performance.items()}
        
        # Map the values to the 'Average Performance' column based on index
        df["Average Performance"] = df.index.map(updated_performance_int)
        
        # Return dataframe with filtered columns
        base_columns = ['T', 'Model', 'Model: Domain', 'Model: Accessibility', 'Model: Size Range', 'Size (B)', 'Average Performance']
        return df[base_columns + filtered_cols]