File size: 9,473 Bytes
f5776d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
import dsp
import tqdm
import threading
import pandas as pd

try:
    from IPython.display import display as ipython_display, HTML
except ImportError:
    ipython_display = print
    HTML = lambda x: x
from concurrent.futures import ThreadPoolExecutor, as_completed

from dsp.utils import EM
from dsp.evaluation.utils import *

"""
TODO: Counting failures and having a max_failure count. When that is exceeded (also just at the end),
we print the number of failures, the first N examples that failed, and the first N exceptions raised.
"""


class Evaluate:
    def __init__(self, *, devset, metric=None, num_threads=1, display_progress=False,
                 display_table=False, display=True, max_errors=5, return_outputs=False):
        self.devset = devset
        self.metric = metric
        self.num_threads = num_threads
        self.display_progress = display_progress
        self.display_table = display_table
        self.display = display
        self.max_errors = max_errors
        self.error_count = 0
        self.error_lock = threading.Lock()
        self.return_outputs = return_outputs

    def _execute_single_thread(self, wrapped_program, devset, display_progress):
        ncorrect = 0
        ntotal = 0
        reordered_devset = []
        
        pbar = tqdm.tqdm(total=len(devset), dynamic_ncols=True, disable=not display_progress)
        for idx, arg in devset:
            example_idx, example, prediction, score = wrapped_program(idx, arg)
            reordered_devset.append((example_idx, example, prediction, score))
            ncorrect += score
            ntotal += 1
            self._update_progress(pbar, ncorrect, ntotal)
        pbar.close()
        
        return reordered_devset, ncorrect, ntotal

    def _execute_multi_thread(self, wrapped_program, devset, num_threads, display_progress):
        ncorrect = 0
        ntotal = 0
        reordered_devset = []
        
        with ThreadPoolExecutor(max_workers=num_threads) as executor:
            futures = {executor.submit(wrapped_program, idx, arg) for idx, arg in devset}
            pbar = tqdm.tqdm(total=len(devset), dynamic_ncols=True, disable=not display_progress)

            for future in as_completed(futures):
                example_idx, example, prediction, score = future.result()
                reordered_devset.append((example_idx, example, prediction, score))
                ncorrect += score
                ntotal += 1
                self._update_progress(pbar, ncorrect, ntotal)
            pbar.close()

        return reordered_devset, ncorrect, ntotal

    def _update_progress(self, pbar, ncorrect, ntotal):
        pbar.set_description(f"Average Metric: {ncorrect} / {ntotal}  ({round(100 * ncorrect / ntotal, 1)})")
        pbar.update()

    def __call__(self, program, metric=None, devset=None, num_threads=None,
                 display_progress=None, display_table=None, display=None,
                 return_all_scores=False, return_outputs=False):
        metric = metric if metric is not None else self.metric
        devset = devset if devset is not None else self.devset
        num_threads = num_threads if num_threads is not None else self.num_threads
        display_progress = display_progress if display_progress is not None else self.display_progress
        display_table = display_table if display_table is not None else self.display_table

        display = self.display if display is None else display
        display_progress = display_progress and display
        display_table = display_table if display else False
        return_outputs = return_outputs if return_outputs is not False else self.return_outputs
        results = []
        
        def wrapped_program(example_idx, example):
            # NOTE: TODO: Won't work if threads create threads!
            creating_new_thread = threading.get_ident() not in dsp.settings.stack_by_thread
            if creating_new_thread:
                dsp.settings.stack_by_thread[threading.get_ident()] = list(dsp.settings.main_stack)
                # print(threading.get_ident(), dsp.settings.stack_by_thread[threading.get_ident()])

            # print(type(example), example)

            try:
                prediction = program(**example.inputs())
                score = metric(example, prediction)  # FIXME: TODO: What's the right order? Maybe force name-based kwargs!
                
                # increment assert and suggest failures to program's attributes
                if hasattr(program, '_assert_failures'):
                    program._assert_failures += dsp.settings.assert_failures
                if hasattr(program, '_suggest_failures'):
                    program._suggest_failures += dsp.settings.suggest_failures
                
                return example_idx, example, prediction, score
            except Exception as e:
                with self.error_lock:
                    self.error_count += 1
                    current_error_count = self.error_count
                if current_error_count >= self.max_errors:
                    raise e
                print(f"Error for example in dev set: \t\t {e}")
                return example_idx, example, dict(), 0.0
            finally:
                if creating_new_thread:
                    del dsp.settings.stack_by_thread[threading.get_ident()]

        devset = list(enumerate(devset))

        if num_threads == 1:
            reordered_devset, ncorrect, ntotal = self._execute_single_thread(wrapped_program, devset, display_progress)
        else:
            reordered_devset, ncorrect, ntotal = self._execute_multi_thread(wrapped_program, devset, num_threads, display_progress)
        if return_outputs:  # Handle the return_outputs logic
            results = [(example, prediction, score) for _, example, prediction, score in reordered_devset]

        if display:
            print(f"Average Metric: {ncorrect} / {ntotal}  ({round(100 * ncorrect / ntotal, 1)}%)")

        predicted_devset = sorted(reordered_devset)

        # data = [{**example, **prediction, 'correct': score} for example, prediction, score in zip(reordered_devset, preds, scores)]
        data = [merge_dicts(example, prediction) | {'correct': score} for _, example, prediction, score in predicted_devset]

        df = pd.DataFrame(data)

        # Truncate every cell in the DataFrame
        df = df.applymap(truncate_cell)

        # Rename the 'correct' column to the name of the metric function
        metric_name = metric.__name__
        df.rename(columns={'correct': metric_name}, inplace=True)

        if display_table:
            if isinstance(display_table, int):
                df_to_display = df.head(display_table).copy()
                truncated_rows = len(df) - display_table
            else:
                df_to_display = df.copy()
                truncated_rows = 0

            styled_df = configure_dataframe_display(df_to_display, metric_name)
            
            ipython_display(styled_df)

            if truncated_rows > 0:
                # Simplified message about the truncated rows
                message = f"""
                <div style='
                    text-align: center; 
                    font-size: 16px; 
                    font-weight: bold; 
                    color: #555; 
                    margin: 10px 0;'>
                    ... {truncated_rows} more rows not displayed ...
                </div>
                """
                ipython_display(HTML(message))
                
        if return_all_scores and return_outputs:
            return round(100 * ncorrect / ntotal, 2), results
        elif return_all_scores:
            return round(100 * ncorrect / ntotal, 2), [score for *_, score in reordered_devset]
        elif return_outputs:
            return round(100 * ncorrect / ntotal, 2), results

        return round(100 * ncorrect / ntotal, 2)


def merge_dicts(d1, d2):
    merged = {}
    for k, v in d1.items():
        if k in d2:
            merged[f"example_{k}"] = v
        else:
            merged[k] = v

    for k, v in d2.items():
        if k in d1:
            merged[f"pred_{k}"] = v
        else:
            merged[k] = v

    return merged


def truncate_cell(content):
    """Truncate content of a cell to 25 words."""
    words = str(content).split()
    if len(words) > 25:
        return ' '.join(words[:25]) + '...'
    return content

def configure_dataframe_display(df, metric_name):
    """Set various pandas display options for DataFrame."""
    pd.options.display.max_colwidth = None
    pd.set_option('display.max_colwidth', 20)  # Adjust the number as needed
    pd.set_option('display.width', 400)  # Adjust

    # df[metric_name] = df[metric_name].apply(lambda x: f'✔️ [{x}]' if x is True else f'❌ [{x}]')
    df.loc[:, metric_name] = df[metric_name].apply(lambda x: f'✔️ [{x}]' if x is True else f'{x}')

    # Return styled DataFrame
    return df.style.set_table_styles([
        {'selector': 'th', 'props': [('text-align', 'left')]},
        {'selector': 'td', 'props': [('text-align', 'left')]}
    ]).set_properties(**{
        'text-align': 'left',
        'white-space': 'pre-wrap',
        'word-wrap': 'break-word',
        'max-width': '400px'
    })

# FIXME: TODO: The merge_dicts stuff above is way too quick and dirty.
# TODO: the display_table can't handle False but can handle 0! Not sure how it works with True exactly, probably fails too.