bubbliiiing commited on
Commit
62be1bf
·
1 Parent(s): a5c8285

Update Space

Browse files
Files changed (1) hide show
  1. app.py +69 -0
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+
5
+ import torch
6
+
7
+ current_file_path = os.path.abspath(__file__)
8
+ project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))]
9
+ for project_root in project_roots:
10
+ sys.path.insert(0, project_root) if project_root not in sys.path else None
11
+
12
+ from cogvideox.api.api import (infer_forward_api,
13
+ update_diffusion_transformer_api,
14
+ update_edition_api)
15
+ from cogvideox.ui.controller import flow_scheduler_dict
16
+ from cogvideox.ui.wan_fun_ui import ui, ui_eas, ui_modelscope
17
+
18
+ if __name__ == "__main__":
19
+ # Choose the ui mode
20
+ ui_mode = "eas"
21
+
22
+ # GPU memory mode, which can be choosen in [model_cpu_offload, model_cpu_offload_and_qfloat8, sequential_cpu_offload].
23
+ # model_cpu_offload means that the entire model will be moved to the CPU after use, which can save some GPU memory.
24
+ #
25
+ # model_cpu_offload_and_qfloat8 indicates that the entire model will be moved to the CPU after use,
26
+ # and the transformer model has been quantized to float8, which can save more GPU memory.
27
+ #
28
+ # sequential_cpu_offload means that each layer of the model will be moved to the CPU after use,
29
+ # resulting in slower speeds but saving a large amount of GPU memory.
30
+ GPU_memory_mode = "model_cpu_offload"
31
+ # Use torch.float16 if GPU does not support torch.bfloat16
32
+ # ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16
33
+ weight_dtype = torch.bfloat16
34
+ # Config path
35
+ config_path = "config/wan2.1/wan_civitai.yaml"
36
+
37
+ # Server ip
38
+ server_name = "0.0.0.0"
39
+ server_port = 7860
40
+
41
+ # Params below is used when ui_mode = "modelscope"
42
+ model_name = "models/Diffusion_Transformer/Wan2.1-Fun-1.3B-InP"
43
+ # "Inpaint" or "Control"
44
+ model_type = "Inpaint"
45
+ # Save dir of this model
46
+ savedir_sample = "samples"
47
+
48
+ if ui_mode == "modelscope":
49
+ demo, controller = ui_modelscope(model_name, model_type, savedir_sample, GPU_memory_mode, flow_scheduler_dict, weight_dtype, config_path)
50
+ elif ui_mode == "eas":
51
+ demo, controller = ui_eas(model_name, flow_scheduler_dict, savedir_sample, config_path)
52
+ else:
53
+ demo, controller = ui(GPU_memory_mode, flow_scheduler_dict, weight_dtype, config_path)
54
+
55
+ # launch gradio
56
+ app, _, _ = demo.queue(status_update_rate=1).launch(
57
+ server_name=server_name,
58
+ server_port=server_port,
59
+ prevent_thread_lock=True
60
+ )
61
+
62
+ # launch api
63
+ infer_forward_api(None, app, controller)
64
+ update_diffusion_transformer_api(None, app, controller)
65
+ update_edition_api(None, app, controller)
66
+
67
+ # not close the python
68
+ while True:
69
+ time.sleep(5)