Spaces:
Runtime error
Runtime error
import os.path | |
import stat | |
from collections import OrderedDict | |
from modules import shared, sd_models | |
from modules_forge.shared import controlnet_dir, supported_preprocessors | |
from typing import Dict, Tuple, List | |
CN_MODEL_EXTS = [".pt", ".pth", ".ckpt", ".safetensors", ".bin", ".patch"] | |
def traverse_all_files(curr_path, model_list): | |
f_list = [ | |
(os.path.join(curr_path, entry.name), entry.stat()) | |
for entry in os.scandir(curr_path) | |
if os.path.isdir(curr_path) | |
] | |
for f_info in f_list: | |
fname, fstat = f_info | |
if os.path.splitext(fname)[1] in CN_MODEL_EXTS: | |
model_list.append(f_info) | |
elif stat.S_ISDIR(fstat.st_mode): | |
model_list = traverse_all_files(fname, model_list) | |
return model_list | |
def get_all_models(sort_by, filter_by, path): | |
res = OrderedDict() | |
fileinfos = traverse_all_files(path, []) | |
filter_by = filter_by.strip(" ") | |
if len(filter_by) != 0: | |
fileinfos = [x for x in fileinfos if filter_by.lower() | |
in os.path.basename(x[0]).lower()] | |
if sort_by == "name": | |
fileinfos = sorted(fileinfos, key=lambda x: os.path.basename(x[0])) | |
elif sort_by == "date": | |
fileinfos = sorted(fileinfos, key=lambda x: -x[1].st_mtime) | |
elif sort_by == "path name": | |
fileinfos = sorted(fileinfos) | |
for finfo in fileinfos: | |
filename = finfo[0] | |
name = os.path.splitext(os.path.basename(filename))[0] | |
# Prevent a hypothetical "None.pt" from being listed. | |
if name != "None": | |
res[name + f" [{sd_models.model_hash(filename)}]"] = filename | |
return res | |
controlnet_filename_dict = {'None': 'model.safetensors'} | |
controlnet_names = ['None'] | |
def get_preprocessor(name): | |
return supported_preprocessors.get(name, None) | |
def get_default_preprocessor(tag): | |
ps = get_filtered_preprocessor_names(tag) | |
assert len(ps) > 0 | |
return ps[0] if len(ps) == 1 else ps[1] | |
def get_sorted_preprocessors(): | |
preprocessors = [p for k, p in supported_preprocessors.items() if k != 'None'] | |
preprocessors = sorted(preprocessors, key=lambda x: str(x.sorting_priority).zfill(8) + x.name)[::-1] | |
results = OrderedDict() | |
results['None'] = supported_preprocessors['None'] | |
for p in preprocessors: | |
results[p.name] = p | |
return results | |
def get_all_controlnet_names(): | |
return controlnet_names | |
def get_controlnet_filename(controlnet_name): | |
return controlnet_filename_dict[controlnet_name] | |
def get_all_preprocessor_names(): | |
return list(get_sorted_preprocessors().keys()) | |
def get_all_preprocessor_tags(): | |
tags = [] | |
for k, p in supported_preprocessors.items(): | |
tags += p.tags | |
tags = list(set(tags)) | |
tags = sorted(tags) | |
return ['All'] + tags | |
def get_filtered_preprocessors(tag): | |
if tag == 'All': | |
return supported_preprocessors | |
return {k: v for k, v in get_sorted_preprocessors().items() if tag in v.tags or k == 'None'} | |
def get_filtered_preprocessor_names(tag): | |
return list(get_filtered_preprocessors(tag).keys()) | |
def get_filtered_controlnet_names(tag): | |
filtered_preprocessors = get_filtered_preprocessors(tag) | |
model_filename_filters = [] | |
for p in filtered_preprocessors.values(): | |
model_filename_filters += p.model_filename_filters | |
return [x for x in controlnet_names if x == 'None' or any(f.lower() in x.lower() for f in model_filename_filters)] | |
def update_controlnet_filenames(): | |
global controlnet_filename_dict, controlnet_names | |
controlnet_filename_dict = {'None': 'model.safetensors'} | |
controlnet_names = ['None'] | |
ext_dirs = (shared.opts.data.get("control_net_models_path", None), getattr(shared.cmd_opts, 'controlnet_dir', None)) | |
extra_lora_paths = (extra_lora_path for extra_lora_path in ext_dirs | |
if extra_lora_path is not None and os.path.exists(extra_lora_path)) | |
paths = [controlnet_dir, *extra_lora_paths] | |
for path in paths: | |
sort_by = shared.opts.data.get("control_net_models_sort_models_by", "name") | |
filter_by = shared.opts.data.get("control_net_models_name_filter", "") | |
found = get_all_models(sort_by, filter_by, path) | |
controlnet_filename_dict.update(found) | |
controlnet_names = list(controlnet_filename_dict.keys()) | |
return | |
def select_control_type( | |
control_type: str, | |
) -> Tuple[List[str], List[str], str, str]: | |
global controlnet_names | |
pattern = control_type.lower() | |
all_models = list(controlnet_names) | |
if pattern == "all": | |
preprocessors = get_sorted_preprocessors().values() | |
return [ | |
[p.name for p in preprocessors], | |
all_models, | |
'none', # default option | |
"None" # default model | |
] | |
filtered_model_list = get_filtered_controlnet_names(control_type) | |
if pattern == "none": | |
filtered_model_list.append("None") | |
assert len(filtered_model_list) > 0, "'None' model should always be available." | |
if len(filtered_model_list) == 1: | |
default_model = "None" | |
else: | |
default_model = filtered_model_list[1] | |
for x in filtered_model_list: | |
if "11" in x.split("[")[0]: | |
default_model = x | |
break | |
return ( | |
get_filtered_preprocessor_names(control_type), | |
filtered_model_list, | |
get_default_preprocessor(control_type), | |
default_model | |
) | |