Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import csv | |
| import datetime | |
| import io | |
| import json | |
| import os | |
| import uuid | |
| from abc import ABC, abstractmethod | |
| from pathlib import Path | |
| from typing import TYPE_CHECKING, Any, List | |
| import fo_utils as fou | |
| import gradio as gr | |
| from gradio import utils | |
| if TYPE_CHECKING: | |
| from gradio.components import IOComponent | |
| def _get_dataset_features_info(is_new, components): | |
| """ | |
| Takes in a list of components and returns a dataset features info | |
| Parameters: | |
| is_new: boolean, whether the dataset is new or not | |
| components: list of components | |
| Returns: | |
| infos: a dictionary of the dataset features | |
| file_preview_types: dictionary mapping of gradio components to appropriate string. | |
| header: list of header strings | |
| """ | |
| infos = {"flagged": {"features": {}}} | |
| # File previews for certain input and output types | |
| file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"} | |
| headers = [] | |
| # Generate the headers and dataset_infos | |
| if is_new: | |
| for component in components: | |
| headers.append(component.label) | |
| infos["flagged"]["features"][component.label] = { | |
| "dtype": "string", | |
| "_type": "Value", | |
| } | |
| if isinstance(component, tuple(file_preview_types)): | |
| headers.append(component.label + " file") | |
| for _component, _type in file_preview_types.items(): | |
| if isinstance(component, _component): | |
| infos["flagged"]["features"][ | |
| (component.label or "") + " file" | |
| ] = {"_type": _type} | |
| break | |
| headers.append("flag") | |
| infos["flagged"]["features"]["flag"] = { | |
| "dtype": "string", | |
| "_type": "Value", | |
| } | |
| return infos, file_preview_types, headers | |
| class FlaggingCallback(ABC): | |
| """ | |
| An abstract class for defining the methods that any FlaggingCallback should have. | |
| """ | |
| def setup(self, components: List[IOComponent], flagging_dir: str): | |
| """ | |
| This method should be overridden and ensure that everything is set up correctly for flag(). | |
| This method gets called once at the beginning of the Interface.launch() method. | |
| Parameters: | |
| components: Set of components that will provide flagged data. | |
| flagging_dir: A string, typically containing the path to the directory where the flagging file should be storied (provided as an argument to Interface.__init__()). | |
| """ | |
| pass | |
| def flag( | |
| self, | |
| flag_data: List[Any], | |
| flag_option: str | None = None, | |
| flag_index: int | None = None, | |
| username: str | None = None, | |
| ) -> int: | |
| """ | |
| This method should be overridden by the FlaggingCallback subclass and may contain optional additional arguments. | |
| This gets called every time the <flag> button is pressed. | |
| Parameters: | |
| interface: The Interface object that is being used to launch the flagging interface. | |
| flag_data: The data to be flagged. | |
| flag_option (optional): In the case that flagging_options are provided, the flag option that is being used. | |
| flag_index (optional): The index of the sample that is being flagged. | |
| username (optional): The username of the user that is flagging the data, if logged in. | |
| Returns: | |
| (int) The total number of samples that have been flagged. | |
| """ | |
| pass | |
| class SimpleCSVLogger(FlaggingCallback): | |
| """ | |
| A simplified implementation of the FlaggingCallback abstract class | |
| provided for illustrative purposes. Each flagged sample (both the input and output data) | |
| is logged to a CSV file on the machine running the gradio app. | |
| Example: | |
| import gradio as gr | |
| def image_classifier(inp): | |
| return {'cat': 0.3, 'dog': 0.7} | |
| demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label", | |
| flagging_callback=SimpleCSVLogger()) | |
| """ | |
| def __init__(self): | |
| pass | |
| def setup(self, components: List[IOComponent], flagging_dir: str | Path): | |
| self.components = components | |
| self.flagging_dir = flagging_dir | |
| os.makedirs(flagging_dir, exist_ok=True) | |
| def flag( | |
| self, | |
| flag_data: List[Any], | |
| flag_option: str | None = None, | |
| flag_index: int | None = None, | |
| username: str | None = None, | |
| ) -> int: | |
| flagging_dir = self.flagging_dir | |
| log_filepath = Path(flagging_dir) / "log.csv" | |
| csv_data = [] | |
| for component, sample in zip(self.components, flag_data): | |
| save_dir = Path(flagging_dir) | |
| # / utils.strip_invalid_filename_characters( | |
| # component.label or "" | |
| # ) | |
| csv_data.append( | |
| component.deserialize( | |
| sample, | |
| save_dir, | |
| None, | |
| ) | |
| ) | |
| with open(log_filepath, "a", newline="") as csvfile: | |
| writer = csv.writer(csvfile) | |
| writer.writerow(utils.sanitize_list_for_csv(csv_data)) | |
| with open(log_filepath, "r") as csvfile: | |
| line_count = len([None for row in csv.reader(csvfile)]) - 1 | |
| # if flag_option == "Bad": | |
| # #get the image path | |
| # image_path = csv_data | |
| # #get the image name | |
| # print(image_path) | |
| # fou.upload_image_to_cvat(image_path[0]) | |
| return line_count | |
| class FlagMethod: | |
| """ | |
| Helper class that contains the flagging button option and callback | |
| """ | |
| def __init__(self, flagging_callback: FlaggingCallback, flag_option=None): | |
| self.flagging_callback = flagging_callback | |
| self.flag_option = flag_option | |
| self.__name__ = "Flag" | |
| def __call__(self, *flag_data): | |
| self.flagging_callback.flag(list(flag_data), flag_option=self.flag_option) | |