import dsp import tqdm import threading import pandas as pd try: from IPython.display import display as ipython_display, HTML except ImportError: ipython_display = print HTML = lambda x: x from concurrent.futures import ThreadPoolExecutor, as_completed from dsp.utils import EM from dsp.evaluation.utils import * """ TODO: Counting failures and having a max_failure count. When that is exceeded (also just at the end), we print the number of failures, the first N examples that failed, and the first N exceptions raised. """ class Evaluate: def __init__(self, *, devset, metric=None, num_threads=1, display_progress=False, display_table=False, display=True, max_errors=5, return_outputs=False): self.devset = devset self.metric = metric self.num_threads = num_threads self.display_progress = display_progress self.display_table = display_table self.display = display self.max_errors = max_errors self.error_count = 0 self.error_lock = threading.Lock() self.return_outputs = return_outputs def _execute_single_thread(self, wrapped_program, devset, display_progress): ncorrect = 0 ntotal = 0 reordered_devset = [] pbar = tqdm.tqdm(total=len(devset), dynamic_ncols=True, disable=not display_progress) for idx, arg in devset: example_idx, example, prediction, score = wrapped_program(idx, arg) reordered_devset.append((example_idx, example, prediction, score)) ncorrect += score ntotal += 1 self._update_progress(pbar, ncorrect, ntotal) pbar.close() return reordered_devset, ncorrect, ntotal def _execute_multi_thread(self, wrapped_program, devset, num_threads, display_progress): ncorrect = 0 ntotal = 0 reordered_devset = [] with ThreadPoolExecutor(max_workers=num_threads) as executor: futures = {executor.submit(wrapped_program, idx, arg) for idx, arg in devset} pbar = tqdm.tqdm(total=len(devset), dynamic_ncols=True, disable=not display_progress) for future in as_completed(futures): example_idx, example, prediction, score = future.result() reordered_devset.append((example_idx, example, prediction, score)) ncorrect += score ntotal += 1 self._update_progress(pbar, ncorrect, ntotal) pbar.close() return reordered_devset, ncorrect, ntotal def _update_progress(self, pbar, ncorrect, ntotal): pbar.set_description(f"Average Metric: {ncorrect} / {ntotal} ({round(100 * ncorrect / ntotal, 1)})") pbar.update() def __call__(self, program, metric=None, devset=None, num_threads=None, display_progress=None, display_table=None, display=None, return_all_scores=False, return_outputs=False): metric = metric if metric is not None else self.metric devset = devset if devset is not None else self.devset num_threads = num_threads if num_threads is not None else self.num_threads display_progress = display_progress if display_progress is not None else self.display_progress display_table = display_table if display_table is not None else self.display_table display = self.display if display is None else display display_progress = display_progress and display display_table = display_table if display else False return_outputs = return_outputs if return_outputs is not False else self.return_outputs results = [] def wrapped_program(example_idx, example): # NOTE: TODO: Won't work if threads create threads! creating_new_thread = threading.get_ident() not in dsp.settings.stack_by_thread if creating_new_thread: dsp.settings.stack_by_thread[threading.get_ident()] = list(dsp.settings.main_stack) # print(threading.get_ident(), dsp.settings.stack_by_thread[threading.get_ident()]) # print(type(example), example) try: prediction = program(**example.inputs()) score = metric(example, prediction) # FIXME: TODO: What's the right order? Maybe force name-based kwargs! # increment assert and suggest failures to program's attributes if hasattr(program, '_assert_failures'): program._assert_failures += dsp.settings.assert_failures if hasattr(program, '_suggest_failures'): program._suggest_failures += dsp.settings.suggest_failures return example_idx, example, prediction, score except Exception as e: with self.error_lock: self.error_count += 1 current_error_count = self.error_count if current_error_count >= self.max_errors: raise e print(f"Error for example in dev set: \t\t {e}") return example_idx, example, dict(), 0.0 finally: if creating_new_thread: del dsp.settings.stack_by_thread[threading.get_ident()] devset = list(enumerate(devset)) if num_threads == 1: reordered_devset, ncorrect, ntotal = self._execute_single_thread(wrapped_program, devset, display_progress) else: reordered_devset, ncorrect, ntotal = self._execute_multi_thread(wrapped_program, devset, num_threads, display_progress) if return_outputs: # Handle the return_outputs logic results = [(example, prediction, score) for _, example, prediction, score in reordered_devset] if display: print(f"Average Metric: {ncorrect} / {ntotal} ({round(100 * ncorrect / ntotal, 1)}%)") predicted_devset = sorted(reordered_devset) # data = [{**example, **prediction, 'correct': score} for example, prediction, score in zip(reordered_devset, preds, scores)] data = [merge_dicts(example, prediction) | {'correct': score} for _, example, prediction, score in predicted_devset] df = pd.DataFrame(data) # Truncate every cell in the DataFrame df = df.applymap(truncate_cell) # Rename the 'correct' column to the name of the metric function metric_name = metric.__name__ df.rename(columns={'correct': metric_name}, inplace=True) if display_table: if isinstance(display_table, int): df_to_display = df.head(display_table).copy() truncated_rows = len(df) - display_table else: df_to_display = df.copy() truncated_rows = 0 styled_df = configure_dataframe_display(df_to_display, metric_name) ipython_display(styled_df) if truncated_rows > 0: # Simplified message about the truncated rows message = f"""