# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import ast import base64 import itertools import json import os from abc import ABC, abstractmethod from io import BytesIO from typing import Any, List, Union # from https://docs.python.org/3/howto/descriptor.html#validator-class # For usage of hidden flag see the ModelParams class in apis/utils/model_params.py class Validator(ABC): # set name is called when the validator is created as class variable # name is the name of the variable in the owner class, so here we create the name for the backing variable def __set_name__(self, owner, name): self.private_name = "_" + name def __get__(self, obj, objtype=None): return getattr(obj, self.private_name, self.default) def __set__(self, obj, value): value = self.validate(value) setattr(obj, self.private_name, value) @abstractmethod def validate(self, value): pass def json(self): pass class MultipleOf(Validator): def __init__(self, default: int, multiple_of: int, type_cast=None, hidden=False, tooltip=None): if type(multiple_of) is not int: raise ValueError(f"Expected {multiple_of!r} to be an int") self.multiple_of = multiple_of self.default = default self.type_cast = type_cast # For usage of hidden flag see the ModelParams class in apis/utils/model_params.py # if a parameter is hidden then probe() can't expose the param # and the param can't be set anymore self.hidden = hidden self.tooltip = tooltip def validate(self, value): if self.type_cast: try: value = self.type_cast(value) except ValueError: raise ValueError(f"Expected {value!r} to be castable to {self.type_cast!r}") if value % self.multiple_of != 0: raise ValueError(f"Expected {value!r} to be a multiple of {self.multiple_of!r}") return value def get_range_iterator(self): return itertools.count(0, self.multiple_of) def __repr__(self) -> str: return f"MultipleOf({self.private_name=} {self.multiple_of=} {self.hidden=})" def json(self): return { "type": MultipleOf.__name__, "default": self.default, "multiple_of": self.multiple_of, "tooltip": self.tooltip, } class OneOf(Validator): def __init__(self, default, options, type_cast=None, hidden=False, tooltip=None): self.options = set(options) self.default = default self.type_cast = type_cast # Cast the value to this type before checking if it's in options self.tooltip = tooltip self.hidden = hidden def validate(self, value): if self.type_cast: try: value = self.type_cast(value) except ValueError: raise ValueError(f"Expected {value!r} to be castable to {self.type_cast!r}") if value not in self.options: raise ValueError(f"Expected {value!r} to be one of {self.options!r}") return value def get_range_iterator(self): return self.options def __repr__(self) -> str: return f"OneOf({self.private_name=} {self.options=} {self.hidden=})" def json(self): return { "type": OneOf.__name__, "default": self.default, "values": list(self.options), "tooltip": self.tooltip, } class HumanAttributes(Validator): def __init__(self, default, hidden=False, tooltip=None): self.default = default self.hidden = hidden self.tooltip = tooltip # hard code the options for now # we extend this to init parameter as needed valid_attributes = { "emotion": ["angry", "contemptful", "disgusted", "fearful", "happy", "neutral", "sad", "surprised"], "race": ["asian", "indian", "black", "white", "middle eastern", "latino hispanic"], "gender": ["male", "female"], "age group": [ "young", "teen", "adult early twenties", "adult late twenties", "adult early thirties", "adult late thirties", "adult middle aged", "older adult", ], } def get_range_iterator(self): # create a list of all possible combinations l1 = self.valid_attributes["emotion"] l2 = self.valid_attributes["race"] l3 = self.valid_attributes["gender"] l4 = self.valid_attributes["age group"] all_combinations = list(itertools.product(l1, l2, l3, l4)) return iter(all_combinations) def validate(self, value): human_attributes = value.lower() if human_attributes not in ["none", "random"]: # In this case, we need for custom attribute string attr_string = human_attributes for attr_key in ["emotion", "race", "gender", "age group"]: attr_detected = False for attr_label in self.valid_attributes[attr_key]: if attr_string.startswith(attr_label): attr_string = attr_string[len(attr_label) + 1 :] # noqa: E203 attr_detected = True break if attr_detected is False: raise ValueError(f"Expected {value!r} to be one of {self.valid_attributes!r}") return value def __repr__(self) -> str: return f"HumanAttributes({self.private_name=} {self.hidden=})" def json(self): return { "type": HumanAttributes.__name__, "default": self.default, "values": self.valid_attributes, "tooltip": self.tooltip, } class Bool(Validator): def __init__(self, default, hidden=False, tooltip=None): self.default = default self.hidden = hidden self.tooltip = tooltip def validate(self, value): if isinstance(value, int): value = value != 0 elif isinstance(value, str): value = value.lower() if value in ["true", "1"]: value = True elif value in ["false", "0"]: value = False else: raise ValueError(f"Expected {value!r} to be one of ['True', 'False', '1', '0']") elif not isinstance(value, bool): raise TypeError(f"Expected {value!r} to be an bool") return value def get_range_iterator(self): return [True, False] def __repr__(self) -> str: return f"Bool({self.private_name=} {self.default=} {self.hidden=})" def json(self): return { "type": bool.__name__, "default": self.default, "tooltip": self.tooltip, } class Int(Validator): def __init__(self, default, min=None, max=None, step=1, hidden=False, tooltip=None): self.min = min self.max = max self.default = default self.step = step self.hidden = hidden self.tooltip = tooltip def validate(self, value): if isinstance(value, str): value = int(value) elif not isinstance(value, int): raise TypeError(f"Expected {value!r} to be an int") if self.min is not None and value < self.min: raise ValueError(f"Expected {value!r} to be at least {self.min!r}") if self.max is not None and value > self.max: raise ValueError(f"Expected {value!r} to be no more than {self.max!r}") return value def get_range_iterator(self): iter_min = self.min if self.min is not None else self.default iter_max = self.max if self.max is not None else self.default return itertools.takewhile(lambda x: x <= iter_max, itertools.count(iter_min, self.step)) def __repr__(self) -> str: return f"Int({self.private_name=} {self.default=}, {self.min=}, {self.max=} {self.hidden=})" def json(self): return { "type": int.__name__, "default": self.default, "min": self.min, "max": self.max, "step": self.step, "tooltip": self.tooltip, } class Float(Validator): def __init__(self, default=0.0, min=None, max=None, step=0.5, hidden=False, tooltip=None): self.min = min self.max = max self.default = default self.step = step self.hidden = hidden self.tooltip = tooltip def validate(self, value): if isinstance(value, str) or isinstance(value, int): value = float(value) elif not isinstance(value, float): raise TypeError(f"Expected {value!r} to be float") if self.min is not None and value < self.min: raise ValueError(f"Expected {value!r} to be at least {self.min!r}") if self.max is not None and value > self.max: raise ValueError(f"Expected {value!r} to be no more than {self.max!r}") return value def get_range_iterator(self): iter_min = self.min if self.min is not None else self.default iter_max = self.max if self.max is not None else self.default return itertools.takewhile(lambda x: x <= iter_max, itertools.count(iter_min, self.step)) def __repr__(self) -> str: return f"Float({self.private_name=} {self.default=}, {self.min=}, {self.max=} {self.hidden=})" def json(self): return { "type": float.__name__, "default": self.default, "min": self.min, "max": self.max, "step": self.step, "tooltip": self.tooltip, } class String(Validator): def __init__(self, default="", min=None, max=None, predicate=None, hidden=False, tooltip=None): self.min = min self.max = max self.predicate = predicate self.default = default self.hidden = hidden self.tooltip = tooltip def validate(self, value): if not isinstance(value, str): raise TypeError(f"Expected {value!r} to be an str") if self.min is not None and len(value) < self.min: raise ValueError(f"Expected {value!r} to be no smaller than {self.min!r}") if self.max is not None and len(value) > self.max: raise ValueError(f"Expected {value!r} to be no bigger than {self.max!r}") if self.predicate is not None and not self.predicate(value): raise ValueError(f"Expected {self.predicate} to be true for {value!r}") return value def get_range_iterator(self): return iter([self.default]) def __repr__(self) -> str: return f"String({self.private_name=} {self.default=}, {self.min=}, {self.max=} {self.hidden=})" def json(self): return { "type": str.__name__, "default": self.default, "tooltip": self.tooltip, } class Path(Validator): def __init__(self, default="", hidden=False, tooltip=None): self.default = default self.hidden = hidden self.tooltip = tooltip def validate(self, value): if not isinstance(value, str): raise TypeError(f"Expected {value!r} to be an str") if not os.path.exists(value): raise ValueError(f"Expected {value!r} to be a valid path") return value def get_range_iterator(self): return iter([self.default]) def __repr__(self) -> str: return f"String({self.private_name=} {self.default=}, {self.hidden=})" class InputImage(Validator): def __init__(self, default="", hidden=False, tooltip=None): self.default = default self.hidden = hidden self.tooltip = tooltip valid_formats = { "JPEG": ["jpeg", "jpg"], "JPEG2000": ["jp2"], "PNG": ["png"], "GIF": ["gif"], "BMP": ["bmp"], } valid_extensions = {vi: k for k, v in valid_formats.items() for vi in v} def validate(self, value): _, ext = os.path.splitext(value).lower() image_format = InputImage.valid_extensions[ext] if not isinstance(value, str): raise TypeError(f"Expected {value!r} to be an str") if not os.path.exists(value): raise ValueError(f"Expected {value!r} to be a valid path") return value def get_range_iterator(self): return iter([self.default]) def __repr__(self) -> str: return f"String({self.private_name=} {self.default=} {self.hidden=})" def json(self): return { "type": InputImage.__name__, "default": self.default, "values": self.valid_formats, "tooltip": self.tooltip, } class MeshFormat(Validator): """ Validator class for mesh formats. Valid inputs are either: - single valid format such as "glb", "obj" - or a list of valid formats such as "[obj, ply, usdz]" """ valid_formats = {"glb", "usdz", "obj", "ply"} def __init__(self, default="glb", hidden=False, tooltip=None): self.default = default self.hidden = hidden self.tooltip = tooltip def validate(self, value: str) -> Union[str, List[str]]: try: # Attempt to parse the input as a Python list if value.startswith("[") and value.endswith("]"): formats = ast.literal_eval(value) if not all(fmt in MeshFormat.valid_formats for fmt in formats): raise ValueError(f"Each item must be one of {MeshFormat.valid_formats}") return formats elif value in MeshFormat.valid_formats: return value else: raise ValueError(f"Expected {value!r} to be one of {MeshFormat.valid_formats} or a list of them") except (SyntaxError, ValueError) as e: # Handle case where the input is neither a valid single format nor a list of valid formats raise ValueError(f"Invalid format specification: {value}. Error: {str(e)}") def __repr__(self) -> str: return f"MeshFormat(default={self.default}, hidden={self.hidden})" def json(self): return { "type": MeshFormat.__name__, "default": self.default, "values": self.valid_formats, "tooltip": self.tooltip, } class JsonDict(Validator): """ JSON stringified version of a python dict. Example: '{"ema_customization_iter.pt": "ema_customization_iter.pt"}' """ def __init__(self, default="", hidden=False): self.default = default self.hidden = hidden def validate(self, value): if not value: return {} try: dict = json.loads(value) return dict except json.JSONDecodeError as e: raise ValueError(f"Expected {value!r} to be json stringified dict. Error: {str(e)}") def __repr__(self) -> str: return f"Dict({self.default=} {self.hidden=})" class BytesIOType(Validator): """ Validator class for BytesIO. Valid inputs are either: - bytes - objects of class BytesIO - str which can be successfully decoded into BytesIO """ def __init__(self, default=None, hidden=False, tooltip=None): self.default = default self.hidden = hidden self.tooltip = tooltip def validate(self, value: Any) -> BytesIO: if isinstance(value, str): try: # Decode the Base64 string decoded_bytes = base64.b64decode(value) # Create a BytesIO stream from the decoded bytes return BytesIO(decoded_bytes) except (base64.binascii.Error, ValueError) as e: raise ValueError(f"Invalid Base64 encoded string: {e}") elif isinstance(value, bytes): return BytesIO(value) elif isinstance(value, BytesIO): return value else: raise TypeError(f"Expected {value!r} to be a Base64 encoded string, bytes, or BytesIO") def __repr__(self) -> str: return f"BytesIOValidator({self.default=}, {self.hidden=})" def json(self): return { "type": BytesIO.__name__, "default": self.default, "tooltip": self.tooltip, }