Spaces:
Build error
Build error
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import ast | |
import base64 | |
import itertools | |
import json | |
import os | |
from abc import ABC, abstractmethod | |
from io import BytesIO | |
from typing import Any, List, Union | |
# from https://docs.python.org/3/howto/descriptor.html#validator-class | |
# For usage of hidden flag see the ModelParams class in apis/utils/model_params.py | |
class Validator(ABC): | |
# set name is called when the validator is created as class variable | |
# name is the name of the variable in the owner class, so here we create the name for the backing variable | |
def __set_name__(self, owner, name): | |
self.private_name = "_" + name | |
def __get__(self, obj, objtype=None): | |
return getattr(obj, self.private_name, self.default) | |
def __set__(self, obj, value): | |
value = self.validate(value) | |
setattr(obj, self.private_name, value) | |
def validate(self, value): | |
pass | |
def json(self): | |
pass | |
class MultipleOf(Validator): | |
def __init__(self, default: int, multiple_of: int, type_cast=None, hidden=False, tooltip=None): | |
if type(multiple_of) is not int: | |
raise ValueError(f"Expected {multiple_of!r} to be an int") | |
self.multiple_of = multiple_of | |
self.default = default | |
self.type_cast = type_cast | |
# For usage of hidden flag see the ModelParams class in apis/utils/model_params.py | |
# if a parameter is hidden then probe() can't expose the param | |
# and the param can't be set anymore | |
self.hidden = hidden | |
self.tooltip = tooltip | |
def validate(self, value): | |
if self.type_cast: | |
try: | |
value = self.type_cast(value) | |
except ValueError: | |
raise ValueError(f"Expected {value!r} to be castable to {self.type_cast!r}") | |
if value % self.multiple_of != 0: | |
raise ValueError(f"Expected {value!r} to be a multiple of {self.multiple_of!r}") | |
return value | |
def get_range_iterator(self): | |
return itertools.count(0, self.multiple_of) | |
def __repr__(self) -> str: | |
return f"MultipleOf({self.private_name=} {self.multiple_of=} {self.hidden=})" | |
def json(self): | |
return { | |
"type": MultipleOf.__name__, | |
"default": self.default, | |
"multiple_of": self.multiple_of, | |
"tooltip": self.tooltip, | |
} | |
class OneOf(Validator): | |
def __init__(self, default, options, type_cast=None, hidden=False, tooltip=None): | |
self.options = set(options) | |
self.default = default | |
self.type_cast = type_cast # Cast the value to this type before checking if it's in options | |
self.tooltip = tooltip | |
self.hidden = hidden | |
def validate(self, value): | |
if self.type_cast: | |
try: | |
value = self.type_cast(value) | |
except ValueError: | |
raise ValueError(f"Expected {value!r} to be castable to {self.type_cast!r}") | |
if value not in self.options: | |
raise ValueError(f"Expected {value!r} to be one of {self.options!r}") | |
return value | |
def get_range_iterator(self): | |
return self.options | |
def __repr__(self) -> str: | |
return f"OneOf({self.private_name=} {self.options=} {self.hidden=})" | |
def json(self): | |
return { | |
"type": OneOf.__name__, | |
"default": self.default, | |
"values": list(self.options), | |
"tooltip": self.tooltip, | |
} | |
class HumanAttributes(Validator): | |
def __init__(self, default, hidden=False, tooltip=None): | |
self.default = default | |
self.hidden = hidden | |
self.tooltip = tooltip | |
# hard code the options for now | |
# we extend this to init parameter as needed | |
valid_attributes = { | |
"emotion": ["angry", "contemptful", "disgusted", "fearful", "happy", "neutral", "sad", "surprised"], | |
"race": ["asian", "indian", "black", "white", "middle eastern", "latino hispanic"], | |
"gender": ["male", "female"], | |
"age group": [ | |
"young", | |
"teen", | |
"adult early twenties", | |
"adult late twenties", | |
"adult early thirties", | |
"adult late thirties", | |
"adult middle aged", | |
"older adult", | |
], | |
} | |
def get_range_iterator(self): | |
# create a list of all possible combinations | |
l1 = self.valid_attributes["emotion"] | |
l2 = self.valid_attributes["race"] | |
l3 = self.valid_attributes["gender"] | |
l4 = self.valid_attributes["age group"] | |
all_combinations = list(itertools.product(l1, l2, l3, l4)) | |
return iter(all_combinations) | |
def validate(self, value): | |
human_attributes = value.lower() | |
if human_attributes not in ["none", "random"]: | |
# In this case, we need for custom attribute string | |
attr_string = human_attributes | |
for attr_key in ["emotion", "race", "gender", "age group"]: | |
attr_detected = False | |
for attr_label in self.valid_attributes[attr_key]: | |
if attr_string.startswith(attr_label): | |
attr_string = attr_string[len(attr_label) + 1 :] # noqa: E203 | |
attr_detected = True | |
break | |
if attr_detected is False: | |
raise ValueError(f"Expected {value!r} to be one of {self.valid_attributes!r}") | |
return value | |
def __repr__(self) -> str: | |
return f"HumanAttributes({self.private_name=} {self.hidden=})" | |
def json(self): | |
return { | |
"type": HumanAttributes.__name__, | |
"default": self.default, | |
"values": self.valid_attributes, | |
"tooltip": self.tooltip, | |
} | |
class Bool(Validator): | |
def __init__(self, default, hidden=False, tooltip=None): | |
self.default = default | |
self.hidden = hidden | |
self.tooltip = tooltip | |
def validate(self, value): | |
if isinstance(value, int): | |
value = value != 0 | |
elif isinstance(value, str): | |
value = value.lower() | |
if value in ["true", "1"]: | |
value = True | |
elif value in ["false", "0"]: | |
value = False | |
else: | |
raise ValueError(f"Expected {value!r} to be one of ['True', 'False', '1', '0']") | |
elif not isinstance(value, bool): | |
raise TypeError(f"Expected {value!r} to be an bool") | |
return value | |
def get_range_iterator(self): | |
return [True, False] | |
def __repr__(self) -> str: | |
return f"Bool({self.private_name=} {self.default=} {self.hidden=})" | |
def json(self): | |
return { | |
"type": bool.__name__, | |
"default": self.default, | |
"tooltip": self.tooltip, | |
} | |
class Int(Validator): | |
def __init__(self, default, min=None, max=None, step=1, hidden=False, tooltip=None): | |
self.min = min | |
self.max = max | |
self.default = default | |
self.step = step | |
self.hidden = hidden | |
self.tooltip = tooltip | |
def validate(self, value): | |
if isinstance(value, str): | |
value = int(value) | |
elif not isinstance(value, int): | |
raise TypeError(f"Expected {value!r} to be an int") | |
if self.min is not None and value < self.min: | |
raise ValueError(f"Expected {value!r} to be at least {self.min!r}") | |
if self.max is not None and value > self.max: | |
raise ValueError(f"Expected {value!r} to be no more than {self.max!r}") | |
return value | |
def get_range_iterator(self): | |
iter_min = self.min if self.min is not None else self.default | |
iter_max = self.max if self.max is not None else self.default | |
return itertools.takewhile(lambda x: x <= iter_max, itertools.count(iter_min, self.step)) | |
def __repr__(self) -> str: | |
return f"Int({self.private_name=} {self.default=}, {self.min=}, {self.max=} {self.hidden=})" | |
def json(self): | |
return { | |
"type": int.__name__, | |
"default": self.default, | |
"min": self.min, | |
"max": self.max, | |
"step": self.step, | |
"tooltip": self.tooltip, | |
} | |
class Float(Validator): | |
def __init__(self, default=0.0, min=None, max=None, step=0.5, hidden=False, tooltip=None): | |
self.min = min | |
self.max = max | |
self.default = default | |
self.step = step | |
self.hidden = hidden | |
self.tooltip = tooltip | |
def validate(self, value): | |
if isinstance(value, str) or isinstance(value, int): | |
value = float(value) | |
elif not isinstance(value, float): | |
raise TypeError(f"Expected {value!r} to be float") | |
if self.min is not None and value < self.min: | |
raise ValueError(f"Expected {value!r} to be at least {self.min!r}") | |
if self.max is not None and value > self.max: | |
raise ValueError(f"Expected {value!r} to be no more than {self.max!r}") | |
return value | |
def get_range_iterator(self): | |
iter_min = self.min if self.min is not None else self.default | |
iter_max = self.max if self.max is not None else self.default | |
return itertools.takewhile(lambda x: x <= iter_max, itertools.count(iter_min, self.step)) | |
def __repr__(self) -> str: | |
return f"Float({self.private_name=} {self.default=}, {self.min=}, {self.max=} {self.hidden=})" | |
def json(self): | |
return { | |
"type": float.__name__, | |
"default": self.default, | |
"min": self.min, | |
"max": self.max, | |
"step": self.step, | |
"tooltip": self.tooltip, | |
} | |
class String(Validator): | |
def __init__(self, default="", min=None, max=None, predicate=None, hidden=False, tooltip=None): | |
self.min = min | |
self.max = max | |
self.predicate = predicate | |
self.default = default | |
self.hidden = hidden | |
self.tooltip = tooltip | |
def validate(self, value): | |
if not isinstance(value, str): | |
raise TypeError(f"Expected {value!r} to be an str") | |
if self.min is not None and len(value) < self.min: | |
raise ValueError(f"Expected {value!r} to be no smaller than {self.min!r}") | |
if self.max is not None and len(value) > self.max: | |
raise ValueError(f"Expected {value!r} to be no bigger than {self.max!r}") | |
if self.predicate is not None and not self.predicate(value): | |
raise ValueError(f"Expected {self.predicate} to be true for {value!r}") | |
return value | |
def get_range_iterator(self): | |
return iter([self.default]) | |
def __repr__(self) -> str: | |
return f"String({self.private_name=} {self.default=}, {self.min=}, {self.max=} {self.hidden=})" | |
def json(self): | |
return { | |
"type": str.__name__, | |
"default": self.default, | |
"tooltip": self.tooltip, | |
} | |
class Path(Validator): | |
def __init__(self, default="", hidden=False, tooltip=None): | |
self.default = default | |
self.hidden = hidden | |
self.tooltip = tooltip | |
def validate(self, value): | |
if not isinstance(value, str): | |
raise TypeError(f"Expected {value!r} to be an str") | |
if not os.path.exists(value): | |
raise ValueError(f"Expected {value!r} to be a valid path") | |
return value | |
def get_range_iterator(self): | |
return iter([self.default]) | |
def __repr__(self) -> str: | |
return f"String({self.private_name=} {self.default=}, {self.hidden=})" | |
class InputImage(Validator): | |
def __init__(self, default="", hidden=False, tooltip=None): | |
self.default = default | |
self.hidden = hidden | |
self.tooltip = tooltip | |
valid_formats = { | |
"JPEG": ["jpeg", "jpg"], | |
"JPEG2000": ["jp2"], | |
"PNG": ["png"], | |
"GIF": ["gif"], | |
"BMP": ["bmp"], | |
} | |
valid_extensions = {vi: k for k, v in valid_formats.items() for vi in v} | |
def validate(self, value): | |
_, ext = os.path.splitext(value).lower() | |
image_format = InputImage.valid_extensions[ext] | |
if not isinstance(value, str): | |
raise TypeError(f"Expected {value!r} to be an str") | |
if not os.path.exists(value): | |
raise ValueError(f"Expected {value!r} to be a valid path") | |
return value | |
def get_range_iterator(self): | |
return iter([self.default]) | |
def __repr__(self) -> str: | |
return f"String({self.private_name=} {self.default=} {self.hidden=})" | |
def json(self): | |
return { | |
"type": InputImage.__name__, | |
"default": self.default, | |
"values": self.valid_formats, | |
"tooltip": self.tooltip, | |
} | |
class MeshFormat(Validator): | |
""" | |
Validator class for mesh formats. Valid inputs are either: | |
- single valid format such as "glb", "obj" | |
- or a list of valid formats such as "[obj, ply, usdz]" | |
""" | |
valid_formats = {"glb", "usdz", "obj", "ply"} | |
def __init__(self, default="glb", hidden=False, tooltip=None): | |
self.default = default | |
self.hidden = hidden | |
self.tooltip = tooltip | |
def validate(self, value: str) -> Union[str, List[str]]: | |
try: | |
# Attempt to parse the input as a Python list | |
if value.startswith("[") and value.endswith("]"): | |
formats = ast.literal_eval(value) | |
if not all(fmt in MeshFormat.valid_formats for fmt in formats): | |
raise ValueError(f"Each item must be one of {MeshFormat.valid_formats}") | |
return formats | |
elif value in MeshFormat.valid_formats: | |
return value | |
else: | |
raise ValueError(f"Expected {value!r} to be one of {MeshFormat.valid_formats} or a list of them") | |
except (SyntaxError, ValueError) as e: | |
# Handle case where the input is neither a valid single format nor a list of valid formats | |
raise ValueError(f"Invalid format specification: {value}. Error: {str(e)}") | |
def __repr__(self) -> str: | |
return f"MeshFormat(default={self.default}, hidden={self.hidden})" | |
def json(self): | |
return { | |
"type": MeshFormat.__name__, | |
"default": self.default, | |
"values": self.valid_formats, | |
"tooltip": self.tooltip, | |
} | |
class JsonDict(Validator): | |
""" | |
JSON stringified version of a python dict. | |
Example: '{"ema_customization_iter.pt": "ema_customization_iter.pt"}' | |
""" | |
def __init__(self, default="", hidden=False): | |
self.default = default | |
self.hidden = hidden | |
def validate(self, value): | |
if not value: | |
return {} | |
try: | |
dict = json.loads(value) | |
return dict | |
except json.JSONDecodeError as e: | |
raise ValueError(f"Expected {value!r} to be json stringified dict. Error: {str(e)}") | |
def __repr__(self) -> str: | |
return f"Dict({self.default=} {self.hidden=})" | |
class BytesIOType(Validator): | |
""" | |
Validator class for BytesIO. Valid inputs are either: | |
- bytes | |
- objects of class BytesIO | |
- str which can be successfully decoded into BytesIO | |
""" | |
def __init__(self, default=None, hidden=False, tooltip=None): | |
self.default = default | |
self.hidden = hidden | |
self.tooltip = tooltip | |
def validate(self, value: Any) -> BytesIO: | |
if isinstance(value, str): | |
try: | |
# Decode the Base64 string | |
decoded_bytes = base64.b64decode(value) | |
# Create a BytesIO stream from the decoded bytes | |
return BytesIO(decoded_bytes) | |
except (base64.binascii.Error, ValueError) as e: | |
raise ValueError(f"Invalid Base64 encoded string: {e}") | |
elif isinstance(value, bytes): | |
return BytesIO(value) | |
elif isinstance(value, BytesIO): | |
return value | |
else: | |
raise TypeError(f"Expected {value!r} to be a Base64 encoded string, bytes, or BytesIO") | |
def __repr__(self) -> str: | |
return f"BytesIOValidator({self.default=}, {self.hidden=})" | |
def json(self): | |
return { | |
"type": BytesIO.__name__, | |
"default": self.default, | |
"tooltip": self.tooltip, | |
} | |