Spaces:
Configuration error
Configuration error
File size: 2,489 Bytes
8866644 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import os
import json
import torch
import folder_paths
from .conf import dit_conf
from .loader import load_dit
class DitCheckpointLoader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"ckpt_name": (folder_paths.get_filename_list("checkpoints"),),
"model": (list(dit_conf.keys()),),
"image_size": ([256, 512],),
# "num_classes": ("INT", {"default": 1000, "min": 0,}),
}
}
RETURN_TYPES = ("MODEL",)
RETURN_NAMES = ("model",)
FUNCTION = "load_checkpoint"
CATEGORY = "ExtraModels/DiT"
TITLE = "DitCheckpointLoader"
def load_checkpoint(self, ckpt_name, model, image_size):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
model_conf = dit_conf[model]
model_conf["unet_config"]["input_size"] = image_size // 8
# model_conf["unet_config"]["num_classes"] = num_classes
dit = load_dit(
model_path = ckpt_path,
model_conf = model_conf,
)
return (dit,)
# todo: this needs frontend code to display properly
def get_label_data(label_file="labels/imagenet1000.json"):
label_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
label_file,
)
label_data = {0: "None"}
with open(label_path, "r") as f:
label_data = json.loads(f.read())
return label_data
label_data = get_label_data()
class DiTCondLabelSelect:
@classmethod
def INPUT_TYPES(s):
global label_data
return {
"required": {
"model" : ("MODEL",),
"label_name": (list(label_data.values()),),
}
}
RETURN_TYPES = ("CONDITIONING",)
RETURN_NAMES = ("class",)
FUNCTION = "cond_label"
CATEGORY = "ExtraModels/DiT"
TITLE = "DiTCondLabelSelect"
def cond_label(self, model, label_name):
global label_data
class_labels = [int(k) for k,v in label_data.items() if v == label_name]
y = torch.tensor([[class_labels[0]]]).to(torch.int)
return ([[y, {}]], )
class DiTCondLabelEmpty:
@classmethod
def INPUT_TYPES(s):
global label_data
return {
"required": {
"model" : ("MODEL",),
}
}
RETURN_TYPES = ("CONDITIONING",)
RETURN_NAMES = ("empty",)
FUNCTION = "cond_empty"
CATEGORY = "ExtraModels/DiT"
TITLE = "DiTCondLabelEmpty"
def cond_empty(self, model):
# [ID of last class + 1] == [num_classes]
y_null = model.model.model_config.unet_config["num_classes"]
y = torch.tensor([[y_null]]).to(torch.int)
return ([[y, {}]], )
NODE_CLASS_MAPPINGS = {
"DitCheckpointLoader" : DitCheckpointLoader,
"DiTCondLabelSelect" : DiTCondLabelSelect,
"DiTCondLabelEmpty" : DiTCondLabelEmpty,
}
|