Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import dataclasses | |
| import enum | |
| import functools | |
| import json | |
| import os | |
| import re | |
| import types | |
| from typing import Callable | |
| import einops | |
| import imageio | |
| import numpy as np | |
| import torch.utils.data | |
| import torchvision | |
| import tqdm | |
| from config import CONFIG | |
| from utils import load_pickle_or_build_object_and_save | |
| class Source(enum.Enum): | |
| generated = "generated" | |
| extracted = "extracted" | |
| class ChartType(enum.Enum): | |
| dot = "dot" | |
| horizontal_bar = "horizontal_bar" | |
| vertical_bar = "vertical_bar" | |
| line = "line" | |
| scatter = "scatter" | |
| class PlotBoundingBox: | |
| height: int | |
| width: int | |
| x0: int | |
| y0: int | |
| def get_bounds(self): | |
| xs = [self.x0, self.x0 + self.width, self.x0 + self.width, self.x0, self.x0] | |
| ys = [self.y0, self.y0, self.y0 + self.height, self.y0 + self.height, self.y0] | |
| return xs, ys | |
| class DataPoint: | |
| x: float or str | |
| y: float or str | |
| class TextRole(enum.Enum): | |
| axis_title = "axis_title" | |
| chart_title = "chart_title" | |
| legend_label = "legend_label" | |
| tick_grouping = "tick_grouping" | |
| tick_label = "tick_label" | |
| other = "other" | |
| class Polygon: | |
| x0: int | |
| x1: int | |
| x2: int | |
| x3: int | |
| y0: int | |
| y1: int | |
| y2: int | |
| y3: int | |
| def get_bounds(self): | |
| xs = [ | |
| self.x0, | |
| self.x1, | |
| self.x2, | |
| self.x3, | |
| self.x0, | |
| ] | |
| ys = [ | |
| self.y0, | |
| self.y1, | |
| self.y2, | |
| self.y3, | |
| self.y0, | |
| ] | |
| return xs, ys | |
| class Text: | |
| id: int | |
| polygon: Polygon | |
| role: TextRole | |
| text: str | |
| def __post_init__(self): | |
| self.polygon = Polygon(**self.polygon) | |
| self.role = TextRole(self.role) | |
| class ValuesType(enum.Enum): | |
| categorical = "categorical" | |
| numerical = "numerical" | |
| class Tick: | |
| id: int | |
| x: int | |
| y: int | |
| class TickType(enum.Enum): | |
| markers = "markers" | |
| separators = "separators" | |
| class Axis: | |
| values_type: ValuesType | |
| tick_type: TickType | |
| ticks: list[Tick] | |
| def __post_init__(self): | |
| self.values_type = ValuesType(self.values_type) | |
| self.tick_type = TickType(self.tick_type) | |
| self.ticks = [ | |
| Tick(id=kw["id"], x=kw["tick_pt"]["x"], y=kw["tick_pt"]["y"]) | |
| for kw in self.ticks | |
| ] | |
| def get_bounds(self): | |
| min_x = min(tick.x for tick in self.ticks) | |
| max_x = max(tick.x for tick in self.ticks) | |
| min_y = min(tick.y for tick in self.ticks) | |
| max_y = max(tick.y for tick in self.ticks) | |
| xs = [min_x, max_x, max_x, min_x, min_x] | |
| ys = [min_y, min_y, max_y, max_y, min_y] | |
| return xs, ys | |
| def convert_dashes_to_underscores_in_key_names(dictionary): | |
| return {k.replace("-", "_"): v for k, v in dictionary.items()} | |
| class Axes: | |
| x_axis: Axis | |
| y_axis: Axis | |
| def __post_init__(self): | |
| self.x_axis = Axis(**convert_dashes_to_underscores_in_key_names(self.x_axis)) | |
| self.y_axis = Axis(**convert_dashes_to_underscores_in_key_names(self.y_axis)) | |
| def preprocess_numerical_value(value): | |
| value = float(value) | |
| value = 0 if np.isnan(value) else value | |
| return value | |
| def preprocess_value(value, value_type: ValuesType): | |
| if value_type == ValuesType.numerical: | |
| return preprocess_numerical_value(value) | |
| else: | |
| return str(value) | |
| class Annotation: | |
| source: Source | |
| chart_type: ChartType | |
| plot_bb: PlotBoundingBox | |
| text: list[Text] | |
| axes: Axes | |
| data_series: list[DataPoint] | |
| def __post_init__(self): | |
| self.source = Source(self.source) | |
| self.chart_type = ChartType(self.chart_type) | |
| self.plot_bb = PlotBoundingBox(**self.plot_bb) | |
| self.text = [Text(**kw) for kw in self.text] | |
| self.axes = Axes(**convert_dashes_to_underscores_in_key_names(self.axes)) | |
| self.data_series = [DataPoint(**kw) for kw in self.data_series] | |
| for i in range(len(self.data_series)): | |
| self.data_series[i].x = preprocess_value( | |
| self.data_series[i].x, self.axes.x_axis.values_type | |
| ) | |
| self.data_series[i].y = preprocess_value( | |
| self.data_series[i].y, self.axes.y_axis.values_type | |
| ) | |
| def from_dict_with_dashes(kwargs): | |
| return Annotation(**convert_dashes_to_underscores_in_key_names(kwargs)) | |
| def from_image_index(image_index: int): | |
| image_id = load_train_image_ids()[image_index] | |
| return Annotation.from_dict_with_dashes(load_image_annotation(image_id)) | |
| def get_text_by_role(self, text_role: TextRole) -> list[Text]: | |
| return [t for t in self.text if t.role == text_role] | |
| class AnnotatedImage: | |
| id: str | |
| image: np.ndarray | |
| annotation: Annotation | |
| def from_image_id(image_id: str): | |
| return AnnotatedImage( | |
| id=image_id, | |
| image=load_image(image_id), | |
| annotation=Annotation.from_dict_with_dashes( | |
| load_image_annotation(image_id) | |
| ), | |
| ) | |
| def from_image_index(image_index: int): | |
| return AnnotatedImage.from_image_id(load_train_image_ids()[image_index]) | |
| def generate_annotated_images(): | |
| for image_id in tqdm.autonotebook.tqdm( | |
| load_train_image_ids(), "Iterating over annotated images" | |
| ): | |
| yield AnnotatedImage.from_image_id(image_id) | |
| def load_train_image_ids() -> list[str]: | |
| train_image_ids = [i.replace(".jpg", "") for i in os.listdir("data/train/images")] | |
| return train_image_ids[: 1000 if CONFIG.debug else None] | |
| def load_test_image_ids() -> list[str]: | |
| return [i.replace(".jpg", "") for i in os.listdir("data/test/images")] | |
| def load_image_annotation(image_id: str) -> dict: | |
| return json.load(open(f"data/train/annotations/{image_id}.json")) | |
| def load_image(image_id: str) -> np.ndarray: | |
| return imageio.v3.imread(open(f"data/train/images/{image_id}.jpg", "rb")) | |
| class DataItem: | |
| image: torch.FloatTensor | |
| target_string: str | |
| data_index: int | |
| def __post_init__(self): | |
| shape = einops.parse_shape(self.image, "channel height width") | |
| assert shape["channel"] == 3, "Image is expected to have 3 channels." | |
| def split_train_indices_by_source(): | |
| extracted_image_indices = [] | |
| generated_image_indices = [] | |
| for i, annotated_image in enumerate(generate_annotated_images()): | |
| if annotated_image.annotation.source == Source.extracted: | |
| extracted_image_indices.append(i) | |
| else: | |
| generated_image_indices.append(i) | |
| return extracted_image_indices, generated_image_indices | |
| def get_train_val_split_indices(val_fraction=0.1, seed=42): | |
| np.random.seed(seed) | |
| val_size = int(len(load_train_image_ids()) * val_fraction) | |
| extracted_image_indices, generated_image_indices = split_train_indices_by_source() | |
| extracted_image_indices = np.random.permutation(extracted_image_indices) | |
| generated_image_indices = np.random.permutation(generated_image_indices) | |
| val_indices = extracted_image_indices[:val_size] | |
| n_generated_images_in_val = val_size - len(val_indices) | |
| val_indices = np.concatenate( | |
| [val_indices, generated_image_indices[:n_generated_images_in_val]] | |
| ) | |
| train_indices = generated_image_indices[n_generated_images_in_val:] | |
| assert len(set(train_indices) | set(val_indices)) == len(load_train_image_ids()) | |
| assert len(val_indices) == val_size | |
| assert len(set(train_indices) & set(val_indices)) == 0 | |
| return train_indices, val_indices | |
| def to_token_str(value: str or enum.Enum): | |
| string = value.name if isinstance(value, enum.Enum) else value | |
| if re.fullmatch("<.*>", string): | |
| return string | |
| else: | |
| return f"<{string}>" | |
| def get_extra_tokens() -> types.SimpleNamespace: | |
| token_ns = types.SimpleNamespace() | |
| token_ns.benetech_prompt = to_token_str("benetech_prompt") | |
| token_ns.benetech_prompt_end = to_token_str("/benetech_prompt") | |
| token_ns.x_start = to_token_str("x_start") | |
| token_ns.y_start = to_token_str("y_start") | |
| token_ns.value_separator = to_token_str(";") | |
| for chart_type in ChartType: | |
| setattr(token_ns, chart_type.name, to_token_str(chart_type)) | |
| for values_type in ValuesType: | |
| setattr(token_ns, values_type.name, to_token_str(values_type)) | |
| return token_ns | |
| def convert_number_to_scientific_string(value: int or float) -> str: | |
| return f"{value:.{CONFIG.float_scientific_notation_string_precision}e}" | |
| def convert_axis_data_to_string( | |
| axis_data: list[str or float], values_type: ValuesType | |
| ) -> str: | |
| formatted_axis_data = [] | |
| for value in axis_data: | |
| if values_type == ValuesType.numerical: | |
| value = convert_number_to_scientific_string(value) | |
| formatted_axis_data.append(value) | |
| return get_extra_tokens().value_separator.join(formatted_axis_data) | |
| def convert_string_to_axis_data(string, values_type: ValuesType): | |
| data = string.split(get_extra_tokens().value_separator) | |
| if values_type == ValuesType.numerical: | |
| data = [float(i.replace(" ", "")) for i in data] | |
| return data | |
| class BenetechOutput: | |
| chart_type: ChartType | |
| x_values_type: ValuesType | |
| y_values_type: ValuesType | |
| x_data: list[str or float] | |
| y_data: list[str or float] | |
| def __post_init__(self): | |
| self.chart_type = ChartType(self.chart_type) | |
| self.x_values_type = ValuesType(self.x_values_type) | |
| self.y_values_type = ValuesType(self.y_values_type) | |
| assert isinstance(self.x_data, list) | |
| assert isinstance(self.y_data, list) | |
| def get_main_characteristics(self): | |
| return ( | |
| self.chart_type, | |
| self.x_values_type, | |
| self.y_values_type, | |
| len(self.x_data), | |
| len(self.y_data), | |
| ) | |
| def from_annotation(annotation: Annotation): | |
| return BenetechOutput( | |
| chart_type=annotation.chart_type, | |
| x_values_type=annotation.axes.x_axis.values_type, | |
| y_values_type=annotation.axes.y_axis.values_type, | |
| x_data=[dp.x for dp in annotation.data_series], | |
| y_data=[dp.y for dp in annotation.data_series], | |
| ) | |
| def to_string(self): | |
| return self.format_strings( | |
| chart_type=self.chart_type, | |
| x_values_type=self.x_values_type, | |
| y_values_type=self.y_values_type, | |
| x_data=convert_axis_data_to_string(self.x_data, self.x_values_type), | |
| y_data=convert_axis_data_to_string(self.y_data, self.y_values_type), | |
| ) | |
| def format_strings(*, chart_type, x_values_type, y_values_type, x_data, y_data): | |
| chart_type = to_token_str(chart_type) | |
| x_values_type = to_token_str(x_values_type) | |
| y_values_type = to_token_str(y_values_type) | |
| token_ns = get_extra_tokens() | |
| return ( | |
| f"{token_ns.benetech_prompt}{chart_type}" | |
| f"{token_ns.x_start}{x_values_type}{x_data}" | |
| f"{token_ns.y_start}{y_values_type}{y_data}" | |
| f"{token_ns.benetech_prompt_end}" | |
| ) | |
| def get_string_pattern(): | |
| field_names = [field.name for field in dataclasses.fields(BenetechOutput)] | |
| pattern = BenetechOutput.format_strings( | |
| **{field_name: f"(?P<{field_name}>.*?)" for field_name in field_names} | |
| ) | |
| return pattern | |
| def does_string_match_expected_pattern(string): | |
| try: | |
| BenetechOutput.from_string(string) | |
| return True | |
| except: | |
| return False | |
| def from_string(string): | |
| fullmatch = re.fullmatch(BenetechOutput.get_string_pattern(), string) | |
| benetech_kwargs = fullmatch.groupdict() | |
| benetech_kwargs["chart_type"] = ChartType(benetech_kwargs["chart_type"]) | |
| benetech_kwargs["x_values_type"] = ValuesType(benetech_kwargs["x_values_type"]) | |
| benetech_kwargs["y_values_type"] = ValuesType(benetech_kwargs["y_values_type"]) | |
| benetech_kwargs["x_data"] = convert_string_to_axis_data( | |
| benetech_kwargs["x_data"], benetech_kwargs["x_values_type"] | |
| ) | |
| benetech_kwargs["y_data"] = convert_string_to_axis_data( | |
| benetech_kwargs["y_data"], benetech_kwargs["y_values_type"] | |
| ) | |
| return BenetechOutput(**benetech_kwargs) | |
| def get_annotation_ground_truth_str(annotation: Annotation): | |
| benetech_output = BenetechOutput( | |
| chart_type=annotation.chart_type, | |
| x_values_type=annotation.axes.x_axis.values_type, | |
| x_data=[dp.x for dp in annotation.data_series], | |
| y_values_type=annotation.axes.y_axis.values_type, | |
| y_data=[dp.y for dp in annotation.data_series], | |
| ) | |
| return benetech_output.to_string() | |
| def get_annotation_ground_truth_str_from_image_index(image_index: int) -> str: | |
| return get_annotation_ground_truth_str(Annotation.from_image_index(image_index)) | |
| class Dataset(torch.utils.data.Dataset): | |
| def __init__(self, indices: list[int]): | |
| super().__init__() | |
| self.indices = indices | |
| self.to_tensor = torchvision.transforms.ToTensor() | |
| def __len__(self): | |
| return len(self.indices) | |
| def __getitem__(self, idx: int) -> DataItem: | |
| data_index = self.indices[idx] | |
| annotated_image = AnnotatedImage.from_image_index(data_index) | |
| image = annotated_image.image | |
| image = self.to_tensor(image) | |
| target_string = get_annotation_ground_truth_str(annotated_image.annotation) | |
| return DataItem(image=image, target_string=target_string, data_index=data_index) | |
| def get_train_val_datasets(): | |
| train_indices, val_indices = load_pickle_or_build_object_and_save( | |
| CONFIG.train_val_indices_path, | |
| lambda: get_train_val_split_indices(CONFIG.val_fraction, CONFIG.seed), | |
| ) | |
| return Dataset(train_indices), Dataset(val_indices) | |
| def get_train_dataset(): | |
| return get_train_val_datasets()[0] | |
| def get_val_dataset(): | |
| return get_train_val_datasets()[1] | |
| class Batch: | |
| images: torch.FloatTensor | |
| labels: torch.IntTensor | |
| data_indices: list[int] | |
| def __post_init__(self): | |
| if CONFIG.debug: | |
| images_shape = einops.parse_shape(self.images, "batch channel height width") | |
| labels_shape = einops.parse_shape(self.labels, "batch label") | |
| assert images_shape["batch"] == labels_shape["batch"] | |
| assert len(self.data_indices) == images_shape["batch"] | |
| class Split(enum.Enum): | |
| train = "train" | |
| val = "val" | |
| BatchCollateFunction = Callable[[list[DataItem], Split], Batch] | |
| def build_dataloader(split: Split, batch_collate_function: BatchCollateFunction): | |
| return torch.utils.data.DataLoader( | |
| get_train_dataset() if split == Split.train else get_val_dataset(), | |
| batch_size=CONFIG.batch_size, | |
| shuffle=split == Split.train, | |
| num_workers=CONFIG.num_workers, | |
| collate_fn=functools.partial(batch_collate_function, split=split), | |
| ) | |