|
import copy |
|
import glob |
|
import inspect |
|
import json |
|
import os |
|
import random |
|
import sys |
|
import re |
|
from typing import Dict, List, Any, Callable, Tuple, TextIO |
|
from argparse import ArgumentParser |
|
import black |
|
|
|
|
|
from comfyui_to_python_utils import ( |
|
import_custom_nodes, |
|
find_path, |
|
add_comfyui_directory_to_sys_path, |
|
add_extra_model_paths, |
|
get_value_at_index |
|
) |
|
|
|
add_comfyui_directory_to_sys_path() |
|
from nodes import NODE_CLASS_MAPPINGS |
|
|
|
class FileHandler: |
|
"""Handles reading and writing files. |
|
|
|
This class provides methods to read JSON data from an input file and write code to an output file. |
|
""" |
|
|
|
@staticmethod |
|
def read_json_file(file_path: str | TextIO, encoding: str = "utf-8") -> dict: |
|
""" |
|
Reads a JSON file and returns its contents as a dictionary. |
|
|
|
Args: |
|
file_path (str): The path to the JSON file. |
|
|
|
Returns: |
|
dict: The contents of the JSON file as a dictionary. |
|
|
|
Raises: |
|
FileNotFoundError: If the file is not found, it lists all JSON files in the directory of the file path. |
|
ValueError: If the file is not a valid JSON. |
|
""" |
|
|
|
if hasattr(file_path, "read"): |
|
return json.load(file_path) |
|
with open(file_path, "r", encoding="utf-8") as file: |
|
data = json.load(file) |
|
return data |
|
|
|
@staticmethod |
|
def write_code_to_file(file_path: str | TextIO, code: str) -> None: |
|
"""Write the specified code to a Python file. |
|
|
|
Args: |
|
file_path (str): The path to the Python file. |
|
code (str): The code to write to the file. |
|
|
|
Returns: |
|
None |
|
""" |
|
if isinstance(file_path, str): |
|
|
|
directory = os.path.dirname(file_path) |
|
|
|
|
|
if directory and not os.path.exists(directory): |
|
os.makedirs(directory) |
|
|
|
|
|
with open(file_path, "w", encoding="utf-8") as file: |
|
file.write(code) |
|
else: |
|
file_path.write(code) |
|
|
|
|
|
class LoadOrderDeterminer: |
|
"""Determine the load order of each key in the provided dictionary. |
|
|
|
This class places the nodes without node dependencies first, then ensures that any node whose |
|
result is used in another node will be added to the list in the order it should be executed. |
|
|
|
Attributes: |
|
data (Dict): The dictionary for which to determine the load order. |
|
node_class_mappings (Dict): Mappings of node classes. |
|
""" |
|
|
|
def __init__(self, data: Dict, node_class_mappings: Dict): |
|
"""Initialize the LoadOrderDeterminer with the given data and node class mappings. |
|
|
|
Args: |
|
data (Dict): The dictionary for which to determine the load order. |
|
node_class_mappings (Dict): Mappings of node classes. |
|
""" |
|
self.data = data |
|
self.node_class_mappings = node_class_mappings |
|
self.visited = {} |
|
self.load_order = [] |
|
self.is_special_function = False |
|
self.is_loader_function = False |
|
|
|
def determine_load_order(self) -> List[Tuple[str, Dict, bool, bool]]: |
|
"""Determine the load order for the given data. |
|
|
|
Returns: |
|
List[Tuple[str, Dict, bool, bool]]: A list of tuples representing the load order. |
|
""" |
|
self._load_special_functions_first() |
|
self.is_special_function = False |
|
self.is_loader_function = False |
|
for key in self.data: |
|
if key not in self.visited: |
|
self._dfs(key) |
|
return self.load_order |
|
|
|
def _dfs(self, key: str) -> None: |
|
"""Depth-First Search function to determine the load order. |
|
|
|
Args: |
|
key (str): The key from which to start the DFS. |
|
|
|
Returns: |
|
None |
|
""" |
|
|
|
self.visited[key] = True |
|
inputs = self.data[key]["inputs"] |
|
|
|
for input_key, val in inputs.items(): |
|
|
|
|
|
if isinstance(val, list) and val[0] not in self.visited: |
|
self._dfs(val[0]) |
|
|
|
self.load_order.append((key, self.data[key], self.is_special_function, self.is_loader_function)) |
|
|
|
def _load_special_functions_first(self) -> None: |
|
"""Load functions without dependencies, loaders, and encoders first. |
|
|
|
Returns: |
|
None |
|
""" |
|
|
|
for key in self.data: |
|
class_def = self.node_class_mappings[self.data[key]["class_type"]]() |
|
|
|
self.is_special_function = ( |
|
class_def.CATEGORY == "loaders" |
|
|
|
or not any( |
|
isinstance(val, list) for val in self.data[key]["inputs"].values() |
|
) |
|
) and class_def.CATEGORY != "FramerComfy" |
|
|
|
self.is_loader_function = class_def.CATEGORY == "loaders" |
|
|
|
if self.is_special_function: |
|
|
|
if key not in self.visited: |
|
self._dfs(key) |
|
|
|
|
|
class CodeGenerator: |
|
"""Generates Python code for a workflow based on the load order. |
|
|
|
Attributes: |
|
node_class_mappings (Dict): Mappings of node classes. |
|
base_node_class_mappings (Dict): Base mappings of node classes. |
|
""" |
|
|
|
def __init__(self, node_class_mappings: Dict, base_node_class_mappings: Dict, workflow_models: List = None): |
|
"""Initialize the CodeGenerator with given node class mappings. |
|
|
|
Args: |
|
node_class_mappings (Dict): Mappings of node classes. |
|
base_node_class_mappings (Dict): Base mappings of node classes. |
|
workflow_models (List): List of models to download from huggingface. |
|
""" |
|
self.node_class_mappings = node_class_mappings |
|
self.base_node_class_mappings = base_node_class_mappings |
|
self.workflow_models = workflow_models or [] |
|
self.input_nodes = {} |
|
self.output_nodes = [] |
|
|
|
def collect_framer_nodes(self, load_order: List) -> Tuple[Dict, List]: |
|
"""Collect FramerComfy input and output nodes from the load order. |
|
|
|
Args: |
|
load_order (List): List of tuples containing node information. |
|
|
|
Returns: |
|
Tuple[Dict, List]: Dictionary of input parameters and list of output variables. |
|
""" |
|
for idx, data, _, _ in load_order: |
|
class_type = data["class_type"] |
|
if class_type.startswith("FramerComfy") and "Input" in class_type: |
|
|
|
class_def = self.node_class_mappings[class_type]() |
|
param_name = data.get("inputs", {}).get("name", f"param_{idx}") |
|
|
|
default_value = None |
|
|
|
param_type = class_def.__class__.__name__.replace("FramerComfyInput", "").lower() |
|
self.input_nodes[param_name] = { |
|
"var_name": f"{self.clean_variable_name(class_type)}_{idx}", |
|
"default": default_value, |
|
"type": param_type |
|
} |
|
elif class_type.startswith("FramerComfy") and "Save" in class_type: |
|
var_name = f"{self.clean_variable_name(class_type)}_{idx}" |
|
self.output_nodes.append({ |
|
"var_name": var_name, |
|
"type": class_type |
|
}) |
|
|
|
def generate_function_signature(self) -> str: |
|
"""Generate the function signature based on collected input nodes. |
|
|
|
Returns: |
|
str: The function signature string. |
|
""" |
|
params = [] |
|
for param_name, info in self.input_nodes.items(): |
|
default = f"={info['default']}" if info['default'] is not None else "" |
|
params.append(f"{param_name}{default}") |
|
|
|
return f"@spaces.GPU\ndef run_workflow({', '.join(params)}) -> Tuple[Any, ...]:" |
|
|
|
def generate_workflow( |
|
self, |
|
load_order: List, |
|
) -> str: |
|
"""Generate the execution code based on the load order. |
|
|
|
Args: |
|
load_order (List): A list of tuples representing the load order. |
|
|
|
Returns: |
|
str: Generated execution code as a string. |
|
""" |
|
|
|
import_statements, executed_variables = set(["NODE_CLASS_MAPPINGS"]), {} |
|
loader_code, loader_execution_code, special_functions_code, main_code = [], [], [], [] |
|
|
|
|
|
initialized_objects = {} |
|
|
|
|
|
self.collect_framer_nodes(load_order) |
|
|
|
|
|
input_node_mapping = {} |
|
for param_name, info in self.input_nodes.items(): |
|
for idx, data, _, _ in load_order: |
|
if f"{self.clean_variable_name(data['class_type'])}_{idx}" == info['var_name']: |
|
input_node_mapping[idx] = param_name |
|
break |
|
|
|
custom_nodes = False |
|
|
|
for idx, data, is_special_function, is_loader_function in load_order: |
|
|
|
if data["class_type"].startswith("FramerComfy") and "Input" in data["class_type"]: |
|
continue |
|
|
|
|
|
inputs, class_type = data["inputs"], data["class_type"] |
|
input_types = self.node_class_mappings[class_type].INPUT_TYPES() |
|
class_def = self.node_class_mappings[class_type]() |
|
|
|
|
|
missing_required_variable = False |
|
if "required" in input_types.keys(): |
|
for required in input_types["required"]: |
|
if required not in inputs.keys(): |
|
missing_required_variable = True |
|
if missing_required_variable: |
|
continue |
|
|
|
|
|
if class_type not in initialized_objects: |
|
|
|
if class_type == "PreviewImage": |
|
continue |
|
|
|
class_type, import_statement, class_code = self.get_class_info( |
|
class_type |
|
) |
|
initialized_objects[class_type] = self.clean_variable_name(class_type) |
|
if class_type in self.base_node_class_mappings.keys(): |
|
import_statements.add(import_statement) |
|
if class_type not in self.base_node_class_mappings.keys(): |
|
custom_nodes = True |
|
special_functions_code.append(class_code) |
|
|
|
|
|
class_def_params = self.get_function_parameters( |
|
getattr(class_def, class_def.FUNCTION) |
|
) |
|
no_params = class_def_params is None |
|
|
|
|
|
inputs = { |
|
key: value |
|
for key, value in inputs.items() |
|
if no_params or key in class_def_params |
|
} |
|
|
|
if ( |
|
"hidden" in input_types.keys() |
|
and "unique_id" in input_types["hidden"].keys() |
|
): |
|
inputs["unique_id"] = random.randint(1, 2**64) |
|
elif class_def_params is not None: |
|
if "unique_id" in class_def_params: |
|
inputs["unique_id"] = random.randint(1, 2**64) |
|
|
|
|
|
executed_variables[idx] = f"{self.clean_variable_name(class_type)}_{idx}" |
|
inputs = self.update_inputs(inputs, executed_variables, input_node_mapping) |
|
|
|
if is_loader_function: |
|
loader_execution_code.append( |
|
self.create_function_call_code( |
|
initialized_objects[class_type], |
|
class_def.FUNCTION, |
|
executed_variables[idx], |
|
True, |
|
**inputs, |
|
) |
|
) |
|
elif is_special_function: |
|
special_functions_code.append( |
|
self.create_function_call_code( |
|
initialized_objects[class_type], |
|
class_def.FUNCTION, |
|
executed_variables[idx], |
|
is_special_function, |
|
**inputs, |
|
) |
|
) |
|
else: |
|
main_code.append( |
|
self.create_function_call_code( |
|
initialized_objects[class_type], |
|
class_def.FUNCTION, |
|
executed_variables[idx], |
|
is_special_function, |
|
**inputs, |
|
) |
|
) |
|
|
|
return self.assemble_python_code( |
|
import_statements, |
|
loader_code, |
|
loader_execution_code, |
|
special_functions_code, |
|
main_code, |
|
custom_nodes, |
|
) |
|
|
|
def create_function_call_code( |
|
self, |
|
obj_name: str, |
|
func: str, |
|
variable_name: str, |
|
is_special_function: bool, |
|
**kwargs, |
|
) -> str: |
|
"""Generate Python code for a function call. |
|
|
|
Args: |
|
obj_name (str): The name of the initialized object. |
|
func (str): The function to be called. |
|
variable_name (str): The name of the variable that the function result should be assigned to. |
|
is_special_function (bool): Determines the code indentation. |
|
**kwargs: The keyword arguments for the function. |
|
|
|
Returns: |
|
str: The generated Python code. |
|
""" |
|
args = ", ".join(self.format_arg(key, value) for key, value in kwargs.items()) |
|
|
|
|
|
code = f"{variable_name} = {obj_name}.{func}({args})\n" |
|
|
|
|
|
|
|
if not is_special_function: |
|
code = f"\t{code}" |
|
|
|
return code |
|
|
|
def format_arg(self, key: str, value: any) -> str: |
|
"""Formats arguments based on key and value. |
|
|
|
Args: |
|
key (str): Argument key. |
|
value (any): Argument value. |
|
|
|
Returns: |
|
str: Formatted argument as a string. |
|
""" |
|
if key == "noise_seed" or key == "seed": |
|
return f"{key}=random.randint(1, 2**64)" |
|
elif isinstance(value, str): |
|
|
|
if value in self.input_nodes: |
|
return f"{key}={value}" |
|
value = value.replace("\n", "\\n").replace('"', "'") |
|
return f'{key}="{value}"' |
|
elif isinstance(value, dict) and "variable_name" in value: |
|
return f'{key}={value["variable_name"]}' |
|
return f"{key}={value}" |
|
|
|
def assemble_python_code( |
|
self, |
|
import_statements: set, |
|
loader_code: List[str], |
|
loader_execution_code: List[str], |
|
special_functions_code: List[str], |
|
main_code: List[str], |
|
custom_nodes=False, |
|
) -> str: |
|
"""Generates the final code string. |
|
|
|
Args: |
|
import_statements (set): A set of unique import statements. |
|
loader_code (List[str]): A list of loader functions code strings. |
|
loader_execution_code (List[str]): A list of loader function execution code strings. |
|
special_functions_code (List[str]): A list of special functions code strings. |
|
main_code (List[str]): A list of code strings. |
|
custom_nodes (bool): Whether to include custom nodes in the code. |
|
|
|
Returns: |
|
str: Generated final code as a string. |
|
""" |
|
|
|
func_strings = [] |
|
for func in [ |
|
get_value_at_index, |
|
find_path, |
|
add_comfyui_directory_to_sys_path, |
|
add_extra_model_paths |
|
]: |
|
func_strings.append(f"\n{inspect.getsource(func)}") |
|
|
|
|
|
model_download_code = [] |
|
if self.workflow_models: |
|
for model in self.workflow_models: |
|
model_download_code.append( |
|
f'hf_hub_download(repo_id="{model["repo_id"]}", filename="{model["file_name"]}", local_dir="{model["model_local_path"]}", token=hf_token)' |
|
) |
|
model_download_code = ["\n# Download required models from huggingface"] + ['hf_token = os.environ.get("HF_TOKEN")'] + model_download_code + [""] |
|
|
|
|
|
static_imports = ( |
|
[ |
|
"import os", |
|
"import random", |
|
"import sys", |
|
"from typing import Sequence, Mapping, Any, Union, Tuple", |
|
"import torch", |
|
"from PIL import Image", |
|
"import spaces", |
|
"import gradio as gr", |
|
"from huggingface_hub import hf_hub_download", |
|
"from comfy import model_management", |
|
] |
|
+ model_download_code |
|
+ func_strings |
|
+ ["\n\nadd_comfyui_directory_to_sys_path()\nadd_extra_model_paths()\n"] |
|
) |
|
|
|
|
|
if custom_nodes: |
|
static_imports.append(f"\n{inspect.getsource(import_custom_nodes)}\n") |
|
custom_nodes = "import_custom_nodes()\n" |
|
else: |
|
custom_nodes = "" |
|
|
|
imports_code = [ |
|
f"from nodes import {', '.join([class_name for class_name in import_statements])}\n" |
|
+ f"{custom_nodes}" |
|
] |
|
|
|
loader_init = "\n".join(loader_code) + "\n" if loader_code else "" |
|
|
|
|
|
loader_execution = "\n".join(loader_execution_code) + "\n" if loader_execution_code else "" |
|
|
|
|
|
model_vars = [] |
|
if loader_execution_code: |
|
for line in loader_execution_code: |
|
if "=" in line: |
|
var_name = line.split("=")[0].strip() |
|
model_vars.append(var_name) |
|
|
|
|
|
if model_vars: |
|
model_management_code = f"\nmodel_loaders = [{', '.join(model_vars)}]\n\n" |
|
model_management_code += "model_management.load_models_gpu([\n" |
|
model_management_code += " loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0] for loader in model_loaders\n" |
|
model_management_code += "])\n\n" |
|
else: |
|
model_management_code = "" |
|
|
|
|
|
initialization_code = "\n".join(special_functions_code) + "\n" if special_functions_code else "" |
|
|
|
|
|
workflow_code = [] |
|
|
|
|
|
workflow_code.extend([f"\twith torch.inference_mode():", f"\t\t" + "\n\t\t".join(main_code)]) |
|
|
|
|
|
if self.output_nodes: |
|
|
|
output_vars = [] |
|
for output in self.output_nodes: |
|
var_name = output["var_name"] |
|
if "SaveImage" in output["type"]: |
|
|
|
path_var = f"{var_name}_path" |
|
workflow_code.append(f"\t{path_var} = \"output/\" + {var_name}['ui']['images'][0]['filename']") |
|
output_vars.append(path_var) |
|
else: |
|
output_vars.append(var_name) |
|
|
|
|
|
if len(output_vars) == 1: |
|
workflow_code.append(f"\treturn {output_vars[0]}") |
|
else: |
|
workflow_code.append(f"\treturn ({', '.join(output_vars)})") |
|
else: |
|
workflow_code.append("\treturn None") |
|
|
|
|
|
gradio_components = [] |
|
input_components = [] |
|
output_components = [] |
|
|
|
|
|
output_declarations = [] |
|
for output in self.output_nodes: |
|
var_name = output["var_name"] |
|
output_name = var_name.replace("framercomfysave", "").replace("node_", "") |
|
if "SaveImage" in output["type"]: |
|
output_declarations.append( |
|
f"{output_name}_output = gr.Image(label=\"{output.get('label', 'Generated ' + output_name.title())}\")" |
|
) |
|
output_components.append(f"{output_name}_output") |
|
|
|
|
|
gradio_code = [ |
|
"# Create Gradio interface", |
|
*output_declarations, |
|
"", |
|
"with gr.Blocks() as app:", |
|
"\twith gr.Row():", |
|
"\t\twith gr.Column():" |
|
] |
|
|
|
|
|
input_declarations = [] |
|
for param_name, info in self.input_nodes.items(): |
|
if info["type"] == "stringnode": |
|
input_declarations.append( |
|
f"\t\t\t{param_name}_input = gr.Textbox(" |
|
f"label=\"{info.get('label', param_name.title())}\", " |
|
f"value=\"{info['default']}\" if \"{info['default']}\" else None, " |
|
f"placeholder=f\"Enter {param_name} here...\")" |
|
) |
|
input_components.append(f"{param_name}_input") |
|
elif info["type"] == "image": |
|
input_declarations.append( |
|
f"\t\t\t{param_name}_input = gr.Image(" |
|
f"label=\"{info.get('label', param_name.title())}\", " |
|
f"type=\"filepath\")" |
|
) |
|
input_components.append(f"{param_name}_input") |
|
elif info["type"] == "float" or info["type"] == "integer": |
|
min_val = f", minimum={info['min']}" if info.get('min') is not None else "" |
|
max_val = f", maximum={info['max']}" if info.get('max') is not None else "" |
|
step = f", step={info['step']}" if info.get('step') is not None else "" |
|
input_declarations.append( |
|
f"\t\t\t{param_name}_input = gr.{'Number' if info['type'] == 'float' else 'Slider'}(" |
|
f"label=\"{info.get('label', param_name.title())}\"" |
|
f"{min_val}{max_val}{step}, " |
|
f"value={info['default'] if info['default'] is not None else 'None'})" |
|
) |
|
input_components.append(f"{param_name}_input") |
|
|
|
gradio_code.extend([ |
|
*input_declarations, |
|
"\t\t\tgenerate_btn = gr.Button(\"Generate\")", |
|
"\t\twith gr.Column():", |
|
*[f"\t\t\t{comp}.render()" for comp in output_components], |
|
"\tgenerate_btn.click(", |
|
"\t\tfn=run_workflow,", |
|
f"\t\tinputs=[{', '.join(input_components)}],", |
|
f"\t\toutputs=[{', '.join(output_components)}]", |
|
"\t)", |
|
"", |
|
'if __name__ == "__main__":', |
|
"\tapp.launch(share=True)" |
|
]) |
|
|
|
|
|
final_code = "\n".join( |
|
static_imports |
|
|
|
+ imports_code |
|
+ ["", initialization_code, loader_init, loader_execution, model_management_code, |
|
self.generate_function_signature(), "\n".join(workflow_code), |
|
"", "\n".join(gradio_code)] |
|
) |
|
|
|
|
|
final_code = black.format_str(final_code, mode=black.Mode()) |
|
|
|
return final_code |
|
|
|
def get_class_info(self, class_type: str) -> Tuple[str, str, str]: |
|
"""Generates and returns necessary information about class type. |
|
|
|
Args: |
|
class_type (str): Class type. |
|
|
|
Returns: |
|
Tuple[str, str, str]: Updated class type, import statement string, class initialization code. |
|
""" |
|
import_statement = class_type |
|
variable_name = self.clean_variable_name(class_type) |
|
if class_type in self.base_node_class_mappings.keys(): |
|
class_code = f"{variable_name} = {class_type.strip()}()" |
|
else: |
|
class_code = f'{variable_name} = NODE_CLASS_MAPPINGS["{class_type}"]()' |
|
|
|
return class_type, import_statement, class_code |
|
|
|
@staticmethod |
|
def clean_variable_name(class_type: str) -> str: |
|
""" |
|
Remove any characters from variable name that could cause errors running the Python script. |
|
|
|
Args: |
|
class_type (str): Class type. |
|
|
|
Returns: |
|
str: Cleaned variable name with no special characters or spaces |
|
""" |
|
|
|
clean_name = class_type.lower().strip().replace("-", "_").replace(" ", "_") |
|
|
|
|
|
clean_name = re.sub(r"[^a-z0-9_]", "", clean_name) |
|
|
|
|
|
if clean_name[0].isdigit(): |
|
clean_name = "_" + clean_name |
|
|
|
return clean_name |
|
|
|
def get_function_parameters(self, func: Callable) -> List: |
|
"""Get the names of a function's parameters. |
|
|
|
Args: |
|
func (Callable): The function whose parameters we want to inspect. |
|
|
|
Returns: |
|
List: A list containing the names of the function's parameters. |
|
""" |
|
signature = inspect.signature(func) |
|
parameters = { |
|
name: param.default if param.default != param.empty else None |
|
for name, param in signature.parameters.items() |
|
} |
|
catch_all = any( |
|
param.kind == inspect.Parameter.VAR_KEYWORD |
|
for param in signature.parameters.values() |
|
) |
|
return list(parameters.keys()) if not catch_all else None |
|
|
|
def update_inputs(self, inputs: Dict, executed_variables: Dict, input_node_mapping: Dict) -> Dict: |
|
"""Update inputs based on the executed variables and input node mapping. |
|
|
|
Args: |
|
inputs (Dict): Inputs dictionary to update. |
|
executed_variables (Dict): Dictionary storing executed variable names. |
|
input_node_mapping (Dict): Mapping of input node IDs to parameter names. |
|
|
|
Returns: |
|
Dict: Updated inputs dictionary. |
|
""" |
|
for key in inputs.keys(): |
|
if isinstance(inputs[key], list): |
|
node_id = inputs[key][0] |
|
if node_id in input_node_mapping: |
|
|
|
inputs[key] = input_node_mapping[node_id] |
|
elif node_id in executed_variables: |
|
inputs[key] = { |
|
"variable_name": f"get_value_at_index({executed_variables[node_id]}, {inputs[key][1]})" |
|
} |
|
return inputs |
|
|
|
|
|
class ComfyUItoPython: |
|
"""Main workflow to generate Python code from a workflow_api.json file. |
|
|
|
Attributes: |
|
input_file (str): Path to the input JSON file. |
|
output_file (str): Path to the output Python file. |
|
node_class_mappings (Dict): Mappings of node classes. |
|
base_node_class_mappings (Dict): Base mappings of node classes. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
workflow: str = "", |
|
input_file: str = "", |
|
workflow_models: str | List = "", |
|
output_file: str | TextIO = "", |
|
node_class_mappings: Dict = NODE_CLASS_MAPPINGS, |
|
needs_init_custom_nodes: bool = False, |
|
): |
|
"""Initialize the ComfyUItoPython class with the given parameters. Exactly one of workflow or input_file must be specified. |
|
Args: |
|
workflow (str): The workflow's JSON. |
|
input_file (str): Path to the input JSON file. |
|
workflow_models (str | List): JSON string or list containing models to download. |
|
output_file (str | TextIO): Path to the output file or a file-like object. |
|
node_class_mappings (Dict): Mappings of node classes. Defaults to NODE_CLASS_MAPPINGS. |
|
needs_init_custom_nodes (bool): Whether to initialize custom nodes. Defaults to False. |
|
""" |
|
if input_file and workflow: |
|
raise ValueError("Can't provide both input_file and workflow") |
|
elif not input_file and not workflow: |
|
raise ValueError("Needs input_file or workflow") |
|
|
|
if not output_file: |
|
raise ValueError("Needs output_file") |
|
|
|
self.workflow = workflow |
|
self.input_file = input_file |
|
self.output_file = output_file |
|
self.node_class_mappings = node_class_mappings |
|
self.needs_init_custom_nodes = needs_init_custom_nodes |
|
|
|
|
|
if isinstance(workflow_models, str): |
|
self.workflow_models = json.loads(workflow_models) if workflow_models else [] |
|
else: |
|
self.workflow_models = workflow_models or [] |
|
|
|
self.base_node_class_mappings = copy.deepcopy(self.node_class_mappings) |
|
|
|
|
|
def execute(self): |
|
"""Execute the main workflow to generate Python code. |
|
|
|
Returns: |
|
None |
|
""" |
|
|
|
if self.needs_init_custom_nodes: |
|
import_custom_nodes() |
|
else: |
|
|
|
self.base_node_class_mappings = {} |
|
|
|
|
|
if self.input_file: |
|
data = FileHandler.read_json_file(self.input_file) |
|
else: |
|
data = json.loads(self.workflow) |
|
|
|
|
|
load_order_determiner = LoadOrderDeterminer(data, self.node_class_mappings) |
|
load_order = load_order_determiner.determine_load_order() |
|
|
|
|
|
code_generator = CodeGenerator(self.node_class_mappings, self.base_node_class_mappings, self.workflow_models) |
|
|
|
generated_code = code_generator.generate_workflow(load_order) |
|
|
|
|
|
FileHandler.write_code_to_file(self.output_file, generated_code) |
|
return generated_code |
|
|