File size: 2,489 Bytes
8866644
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import json
import torch
import folder_paths

from .conf import dit_conf
from .loader import load_dit

class DitCheckpointLoader:
	@classmethod
	def INPUT_TYPES(s):
		return {
			"required": {
				"ckpt_name": (folder_paths.get_filename_list("checkpoints"),),
				"model": (list(dit_conf.keys()),),
				"image_size": ([256, 512],),
				# "num_classes": ("INT", {"default": 1000, "min": 0,}),
			}
		}
	RETURN_TYPES = ("MODEL",)
	RETURN_NAMES = ("model",)
	FUNCTION = "load_checkpoint"
	CATEGORY = "ExtraModels/DiT"
	TITLE = "DitCheckpointLoader"

	def load_checkpoint(self, ckpt_name, model, image_size):
		ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
		model_conf = dit_conf[model]
		model_conf["unet_config"]["input_size"]  = image_size // 8
		# model_conf["unet_config"]["num_classes"] = num_classes
		dit = load_dit(
			model_path = ckpt_path,
			model_conf = model_conf,
		)
		return (dit,)

# todo: this needs frontend code to display properly
def get_label_data(label_file="labels/imagenet1000.json"):
	label_path = os.path.join(
		os.path.dirname(os.path.realpath(__file__)),
		label_file,
	)
	label_data = {0: "None"}
	with open(label_path, "r") as f:
		label_data = json.loads(f.read())
	return label_data
label_data = get_label_data()

class DiTCondLabelSelect:
	@classmethod
	def INPUT_TYPES(s):
		global label_data
		return {
			"required": {
				"model" : ("MODEL",),
				"label_name": (list(label_data.values()),),
			}
		}

	RETURN_TYPES = ("CONDITIONING",)
	RETURN_NAMES = ("class",)
	FUNCTION = "cond_label"
	CATEGORY = "ExtraModels/DiT"
	TITLE = "DiTCondLabelSelect"

	def cond_label(self, model, label_name):
		global label_data
		class_labels = [int(k) for k,v in label_data.items() if v == label_name]
		y = torch.tensor([[class_labels[0]]]).to(torch.int)
		return ([[y, {}]], )

class DiTCondLabelEmpty:
	@classmethod
	def INPUT_TYPES(s):
		global label_data
		return {
			"required": {
				"model" : ("MODEL",),
			}
		}

	RETURN_TYPES = ("CONDITIONING",)
	RETURN_NAMES = ("empty",)
	FUNCTION = "cond_empty"
	CATEGORY = "ExtraModels/DiT"
	TITLE = "DiTCondLabelEmpty"

	def cond_empty(self, model):
		# [ID of last class + 1] == [num_classes]
		y_null = model.model.model_config.unet_config["num_classes"]
		y = torch.tensor([[y_null]]).to(torch.int)
		return ([[y, {}]], )

NODE_CLASS_MAPPINGS = {
	"DitCheckpointLoader" : DitCheckpointLoader,
	"DiTCondLabelSelect"  : DiTCondLabelSelect,
	"DiTCondLabelEmpty"   : DiTCondLabelEmpty,
}