Spaces:
Configuration error
Configuration error
import os | |
import json | |
import torch | |
import folder_paths | |
from .conf import dit_conf | |
from .loader import load_dit | |
class DitCheckpointLoader: | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"ckpt_name": (folder_paths.get_filename_list("checkpoints"),), | |
"model": (list(dit_conf.keys()),), | |
"image_size": ([256, 512],), | |
# "num_classes": ("INT", {"default": 1000, "min": 0,}), | |
} | |
} | |
RETURN_TYPES = ("MODEL",) | |
RETURN_NAMES = ("model",) | |
FUNCTION = "load_checkpoint" | |
CATEGORY = "ExtraModels/DiT" | |
TITLE = "DitCheckpointLoader" | |
def load_checkpoint(self, ckpt_name, model, image_size): | |
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) | |
model_conf = dit_conf[model] | |
model_conf["unet_config"]["input_size"] = image_size // 8 | |
# model_conf["unet_config"]["num_classes"] = num_classes | |
dit = load_dit( | |
model_path = ckpt_path, | |
model_conf = model_conf, | |
) | |
return (dit,) | |
# todo: this needs frontend code to display properly | |
def get_label_data(label_file="labels/imagenet1000.json"): | |
label_path = os.path.join( | |
os.path.dirname(os.path.realpath(__file__)), | |
label_file, | |
) | |
label_data = {0: "None"} | |
with open(label_path, "r") as f: | |
label_data = json.loads(f.read()) | |
return label_data | |
label_data = get_label_data() | |
class DiTCondLabelSelect: | |
def INPUT_TYPES(s): | |
global label_data | |
return { | |
"required": { | |
"model" : ("MODEL",), | |
"label_name": (list(label_data.values()),), | |
} | |
} | |
RETURN_TYPES = ("CONDITIONING",) | |
RETURN_NAMES = ("class",) | |
FUNCTION = "cond_label" | |
CATEGORY = "ExtraModels/DiT" | |
TITLE = "DiTCondLabelSelect" | |
def cond_label(self, model, label_name): | |
global label_data | |
class_labels = [int(k) for k,v in label_data.items() if v == label_name] | |
y = torch.tensor([[class_labels[0]]]).to(torch.int) | |
return ([[y, {}]], ) | |
class DiTCondLabelEmpty: | |
def INPUT_TYPES(s): | |
global label_data | |
return { | |
"required": { | |
"model" : ("MODEL",), | |
} | |
} | |
RETURN_TYPES = ("CONDITIONING",) | |
RETURN_NAMES = ("empty",) | |
FUNCTION = "cond_empty" | |
CATEGORY = "ExtraModels/DiT" | |
TITLE = "DiTCondLabelEmpty" | |
def cond_empty(self, model): | |
# [ID of last class + 1] == [num_classes] | |
y_null = model.model.model_config.unet_config["num_classes"] | |
y = torch.tensor([[y_null]]).to(torch.int) | |
return ([[y, {}]], ) | |
NODE_CLASS_MAPPINGS = { | |
"DitCheckpointLoader" : DitCheckpointLoader, | |
"DiTCondLabelSelect" : DiTCondLabelSelect, | |
"DiTCondLabelEmpty" : DiTCondLabelEmpty, | |
} | |