dangthr commited on
Commit
f59198d
·
verified ·
1 Parent(s): 21a43df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +352 -230
app.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import torch
2
  import numpy as np
3
  import random
@@ -11,6 +14,16 @@ from PIL import Image
11
  from huggingface_hub import hf_hub_download
12
  import shutil
13
 
 
 
 
 
 
 
 
 
 
 
14
  from inference import (
15
  create_ltx_video_pipeline,
16
  create_latent_upsampler,
@@ -23,6 +36,186 @@ from inference import (
23
  from ltx_video.pipelines.pipeline_ltx_video import ConditioningItem, LTXMultiScalePipeline, LTXVideoPipeline
24
  from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  config_file_path = "configs/ltxv-13b-0.9.7-distilled.yaml"
27
  with open(config_file_path, "r") as file:
28
  PIPELINE_CONFIG_YAML = yaml.safe_load(file)
@@ -30,95 +223,66 @@ with open(config_file_path, "r") as file:
30
  LTX_REPO = "Lightricks/LTX-Video"
31
  MAX_IMAGE_SIZE = PIPELINE_CONFIG_YAML.get("max_resolution", 1280)
32
  MAX_NUM_FRAMES = 257
33
-
34
  FPS = 30.0
35
 
36
- # --- Global variables for loaded models ---
37
  pipeline_instance = None
38
  latent_upsampler_instance = None
39
  models_dir = "downloaded_models_gradio_cpu_init"
40
  Path(models_dir).mkdir(parents=True, exist_ok=True)
41
-
42
- # 创建输出目录
43
- output_dir = "output"
44
  Path(output_dir).mkdir(parents=True, exist_ok=True)
45
 
46
- print("Downloading models (if not present)...")
47
- distilled_model_actual_path = hf_hub_download(
48
- repo_id=LTX_REPO,
49
- filename=PIPELINE_CONFIG_YAML["checkpoint_path"],
50
- local_dir=models_dir,
51
- local_dir_use_symlinks=False
52
- )
53
- PIPELINE_CONFIG_YAML["checkpoint_path"] = distilled_model_actual_path
54
- print(f"Distilled model path: {distilled_model_actual_path}")
55
-
56
- SPATIAL_UPSCALER_FILENAME = PIPELINE_CONFIG_YAML["spatial_upscaler_model_path"]
57
- spatial_upscaler_actual_path = hf_hub_download(
58
- repo_id=LTX_REPO,
59
- filename=SPATIAL_UPSCALER_FILENAME,
60
- local_dir=models_dir,
61
- local_dir_use_symlinks=False
62
- )
63
- PIPELINE_CONFIG_YAML["spatial_upscaler_model_path"] = spatial_upscaler_actual_path
64
- print(f"Spatial upscaler model path: {spatial_upscaler_actual_path}")
65
-
66
- print("Creating LTX Video pipeline on CPU...")
67
- pipeline_instance = create_ltx_video_pipeline(
68
- ckpt_path=PIPELINE_CONFIG_YAML["checkpoint_path"],
69
- precision=PIPELINE_CONFIG_YAML["precision"],
70
- text_encoder_model_name_or_path=PIPELINE_CONFIG_YAML["text_encoder_model_name_or_path"],
71
- sampler=PIPELINE_CONFIG_YAML["sampler"],
72
- device="cpu",
73
- enhance_prompt=False,
74
- prompt_enhancer_image_caption_model_name_or_path=PIPELINE_CONFIG_YAML["prompt_enhancer_image_caption_model_name_or_path"],
75
- prompt_enhancer_llm_model_name_or_path=PIPELINE_CONFIG_YAML["prompt_enhancer_llm_model_name_or_path"],
76
- )
77
- print("LTX Video pipeline created on CPU.")
78
 
79
- if PIPELINE_CONFIG_YAML.get("spatial_upscaler_model_path"):
80
- print("Creating latent upsampler on CPU...")
81
- latent_upsampler_instance = create_latent_upsampler(
82
- PIPELINE_CONFIG_YAML["spatial_upscaler_model_path"],
83
- device="cpu"
 
 
 
84
  )
85
- print("Latent upsampler created on CPU.")
 
86
 
87
- target_inference_device = "cuda"
88
- print(f"Target inference device: {target_inference_device}")
89
- pipeline_instance.to(target_inference_device)
90
- if latent_upsampler_instance:
91
- latent_upsampler_instance.to(target_inference_device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
- # --- Helper function for dimension calculation ---
94
- MIN_DIM_SLIDER = 256
95
- TARGET_FIXED_SIDE = 768
 
 
 
96
 
97
- def calculate_new_dimensions(orig_w, orig_h):
98
- """
99
- Calculates new dimensions for height and width sliders based on original media dimensions.
100
- Ensures one side is TARGET_FIXED_SIDE, the other is scaled proportionally,
101
- both are multiples of 32, and within [MIN_DIM_SLIDER, MAX_IMAGE_SIZE].
102
- """
103
- if orig_w == 0 or orig_h == 0:
104
- return int(TARGET_FIXED_SIDE), int(TARGET_FIXED_SIDE)
105
-
106
- if orig_w >= orig_h:
107
- new_h = TARGET_FIXED_SIDE
108
- aspect_ratio = orig_w / orig_h
109
- new_w_ideal = new_h * aspect_ratio
110
- new_w = round(new_w_ideal / 32) * 32
111
- new_w = max(MIN_DIM_SLIDER, min(new_w, MAX_IMAGE_SIZE))
112
- new_h = max(MIN_DIM_SLIDER, min(new_h, MAX_IMAGE_SIZE))
113
- else:
114
- new_w = TARGET_FIXED_SIDE
115
- aspect_ratio = orig_h / orig_w
116
- new_h_ideal = new_w * aspect_ratio
117
- new_h = round(new_h_ideal / 32) * 32
118
- new_h = max(MIN_DIM_SLIDER, min(new_h, MAX_IMAGE_SIZE))
119
- new_w = max(MIN_DIM_SLIDER, min(new_w, MAX_IMAGE_SIZE))
120
 
121
- return int(new_h), int(new_w)
122
 
123
  def generate(prompt, negative_prompt="worst quality, inconsistent motion, blurry, jittery, distorted",
124
  input_image_filepath=None, input_video_filepath=None,
@@ -126,6 +290,11 @@ def generate(prompt, negative_prompt="worst quality, inconsistent motion, blurry
126
  duration_ui=2.0, ui_frames_to_use=9,
127
  seed_ui=42, randomize_seed=True, ui_guidance_scale=None, improve_texture_flag=True):
128
 
 
 
 
 
 
129
  if randomize_seed:
130
  seed_ui = random.randint(0, 2**32 - 1)
131
  seed_everething(int(seed_ui))
@@ -154,25 +323,14 @@ def generate(prompt, negative_prompt="worst quality, inconsistent motion, blurry
154
  padding_values = calculate_padding(actual_height, actual_width, height_padded, width_padded)
155
 
156
  call_kwargs = {
157
- "prompt": prompt,
158
- "negative_prompt": negative_prompt,
159
- "height": height_padded,
160
- "width": width_padded,
161
- "num_frames": num_frames_padded,
162
- "frame_rate": int(FPS),
163
  "generator": torch.Generator(device=target_inference_device).manual_seed(int(seed_ui)),
164
- "output_type": "pt",
165
- "conditioning_items": None,
166
- "media_items": None,
167
- "decode_timestep": PIPELINE_CONFIG_YAML["decode_timestep"],
168
- "decode_noise_scale": PIPELINE_CONFIG_YAML["decode_noise_scale"],
169
- "stochastic_sampling": PIPELINE_CONFIG_YAML["stochastic_sampling"],
170
- "image_cond_noise_scale": 0.15,
171
- "is_video": True,
172
- "vae_per_channel_normalize": True,
173
- "mixed_precision": (PIPELINE_CONFIG_YAML["precision"] == "mixed_precision"),
174
- "offload_to_cpu": False,
175
- "enhance_prompt": False,
176
  }
177
 
178
  stg_mode_str = PIPELINE_CONFIG_YAML.get("stg_mode", "attention_values")
@@ -189,184 +347,148 @@ def generate(prompt, negative_prompt="worst quality, inconsistent motion, blurry
189
 
190
  if mode == "image-to-video" and input_image_filepath:
191
  try:
192
- media_tensor = load_image_to_tensor_with_resize_and_crop(
193
- input_image_filepath, actual_height, actual_width
194
- )
195
  media_tensor = torch.nn.functional.pad(media_tensor, padding_values)
196
  call_kwargs["conditioning_items"] = [ConditioningItem(media_tensor.to(target_inference_device), 0, 1.0)]
197
  except Exception as e:
198
- print(f"Error loading image {input_image_filepath}: {e}")
199
  raise RuntimeError(f"Could not load image: {e}")
200
  elif mode == "video-to-video" and input_video_filepath:
201
  try:
202
  call_kwargs["media_items"] = load_media_file(
203
- media_path=input_video_filepath,
204
- height=actual_height,
205
- width=actual_width,
206
- max_frames=int(ui_frames_to_use),
207
- padding=padding_values
208
  ).to(target_inference_device)
209
  except Exception as e:
210
- print(f"Error loading video {input_video_filepath}: {e}")
211
  raise RuntimeError(f"Could not load video: {e}")
212
 
213
- print(f"Moving models to {target_inference_device} for inference (if not already there)...")
214
-
215
- active_latent_upsampler = None
216
- if improve_texture_flag and latent_upsampler_instance:
217
- active_latent_upsampler = latent_upsampler_instance
218
-
219
  result_images_tensor = None
 
220
  if improve_texture_flag:
221
  if not active_latent_upsampler:
222
- raise RuntimeError("Spatial upscaler model not loaded or improve_texture not selected, cannot use multi-scale.")
223
 
224
  multi_scale_pipeline_obj = LTXMultiScalePipeline(pipeline_instance, active_latent_upsampler)
 
 
225
 
226
- first_pass_args = PIPELINE_CONFIG_YAML.get("first_pass", {}).copy()
227
- first_pass_args["guidance_scale"] = float(ui_guidance_scale)
228
- first_pass_args.pop("num_inference_steps", None)
229
-
230
- second_pass_args = PIPELINE_CONFIG_YAML.get("second_pass", {}).copy()
231
- second_pass_args["guidance_scale"] = float(ui_guidance_scale)
232
- second_pass_args.pop("num_inference_steps", None)
233
-
234
- multi_scale_call_kwargs = call_kwargs.copy()
235
- multi_scale_call_kwargs.update({
236
- "downscale_factor": PIPELINE_CONFIG_YAML["downscale_factor"],
237
- "first_pass": first_pass_args,
238
- "second_pass": second_pass_args,
239
- })
240
 
241
- print(f"Calling multi-scale pipeline (eff. HxW: {actual_height}x{actual_width}, Frames: {actual_num_frames} -> Padded: {num_frames_padded}) on {target_inference_device}")
242
  result_images_tensor = multi_scale_pipeline_obj(**multi_scale_call_kwargs).images
243
  else:
244
- single_pass_call_kwargs = call_kwargs.copy()
245
- first_pass_config_from_yaml = PIPELINE_CONFIG_YAML.get("first_pass", {})
246
-
247
- single_pass_call_kwargs["timesteps"] = first_pass_config_from_yaml.get("timesteps")
248
- single_pass_call_kwargs["guidance_scale"] = float(ui_guidance_scale)
249
- single_pass_call_kwargs["stg_scale"] = first_pass_config_from_yaml.get("stg_scale")
250
- single_pass_call_kwargs["rescaling_scale"] = first_pass_config_from_yaml.get("rescaling_scale")
251
- single_pass_call_kwargs["skip_block_list"] = first_pass_config_from_yaml.get("skip_block_list")
252
-
253
- single_pass_call_kwargs.pop("num_inference_steps", None)
254
- single_pass_call_kwargs.pop("first_pass", None)
255
- single_pass_call_kwargs.pop("second_pass", None)
256
- single_pass_call_kwargs.pop("downscale_factor", None)
257
-
258
- print(f"Calling base pipeline (padded HxW: {height_padded}x{width_padded}, Frames: {actual_num_frames} -> Padded: {num_frames_padded}) on {target_inference_device}")
259
  result_images_tensor = pipeline_instance(**single_pass_call_kwargs).images
260
 
261
  if result_images_tensor is None:
262
- raise RuntimeError("Generation failed.")
263
 
264
  pad_left, pad_right, pad_top, pad_bottom = padding_values
265
  slice_h_end = -pad_bottom if pad_bottom > 0 else None
266
  slice_w_end = -pad_right if pad_right > 0 else None
267
 
268
- result_images_tensor = result_images_tensor[
269
- :, :, :actual_num_frames, pad_top:slice_h_end, pad_left:slice_w_end
270
- ]
271
-
272
- video_np = result_images_tensor[0].permute(1, 2, 3, 0).cpu().float().numpy()
273
- video_np = np.clip(video_np, 0, 1)
274
- video_np = (video_np * 255).astype(np.uint8)
275
 
276
- # 生成带时间戳的文件名并保存到output目录
 
 
277
  timestamp = random.randint(10000, 99999)
278
- output_video_path = os.path.join(output_dir, f"output_{timestamp}.mp4")
279
 
280
  try:
281
- with imageio.get_writer(output_video_path, fps=call_kwargs["frame_rate"], macro_block_size=1) as video_writer:
282
- for frame_idx in range(video_np.shape[0]):
283
- video_writer.append_data(video_np[frame_idx])
284
- if frame_idx % 10 == 0:
285
- print(f"Saving frame {frame_idx + 1}/{video_np.shape[0]}")
286
- except Exception as e:
287
- print(f"Error saving video with macro_block_size=1: {e}")
288
- try:
289
- with imageio.get_writer(output_video_path, fps=call_kwargs["frame_rate"], format='FFMPEG', codec='libx264', quality=8) as video_writer:
290
- for frame_idx in range(video_np.shape[0]):
291
- video_writer.append_data(video_np[frame_idx])
292
- if frame_idx % 10 == 0:
293
- print(f"Saving frame {frame_idx + 1}/{video_np.shape[0]} (fallback)")
294
- except Exception as e2:
295
- print(f"Fallback video saving error: {e2}")
296
- raise RuntimeError(f"Failed to save video: {e2}")
297
 
298
- print(f"Video saved successfully to: {output_video_path}")
299
  return output_video_path, seed_ui
300
 
301
- def main():
302
- parser = argparse.ArgumentParser(description="LTX Video Generation from Command Line")
303
- parser.add_argument("--prompt", required=True, help="Text prompt for video generation")
304
- parser.add_argument("--negative-prompt", default="worst quality, inconsistent motion, blurry, jittery, distorted",
305
- help="Negative prompt")
306
- parser.add_argument("--mode", choices=["text-to-video", "image-to-video", "video-to-video"],
307
- default="text-to-video", help="Generation mode")
308
- parser.add_argument("--input-image", help="Input image path for image-to-video mode")
309
- parser.add_argument("--input-video", help="Input video path for video-to-video mode")
310
- parser.add_argument("--duration", type=float, default=2.0, help="Video duration in seconds (0.3-8.5)")
311
- parser.add_argument("--height", type=int, default=512, help="Video height (must be divisible by 32)")
312
- parser.add_argument("--width", type=int, default=704, help="Video width (must be divisible by 32)")
313
- parser.add_argument("--seed", type=int, default=42, help="Random seed")
314
- parser.add_argument("--randomize-seed", action="store_true", help="Use random seed")
315
- parser.add_argument("--guidance-scale", type=float, help="Guidance scale for generation")
316
- parser.add_argument("--no-improve-texture", action="store_true", help="Disable texture improvement (faster)")
317
- parser.add_argument("--frames-to-use", type=int, default=9, help="Frames to use from input video (for video-to-video)")
318
-
319
- args = parser.parse_args()
320
-
321
- # Validate parameters
322
- if args.mode == "image-to-video" and not args.input_image:
323
- print("Error: --input-image is required for image-to-video mode")
324
- return
325
-
326
- if args.mode == "video-to-video" and not args.input_video:
327
- print("Error: --input-video is required for video-to-video mode")
328
- return
329
-
330
- # Ensure dimensions are divisible by 32
331
- args.height = ((args.height - 1) // 32 + 1) * 32
332
- args.width = ((args.width - 1) // 32 + 1) * 32
333
-
334
- print(f"Starting video generation...")
335
- print(f"Prompt: {args.prompt}")
336
- print(f"Mode: {args.mode}")
337
- print(f"Duration: {args.duration}s")
338
- print(f"Resolution: {args.width}x{args.height}")
339
- print(f"Output directory: {os.path.abspath(output_dir)}")
340
 
341
  try:
342
  output_path, used_seed = generate(
343
- prompt=args.prompt,
344
- negative_prompt=args.negative_prompt,
345
- input_image_filepath=args.input_image,
346
- input_video_filepath=args.input_video,
347
- height_ui=args.height,
348
- width_ui=args.width,
349
- mode=args.mode,
350
- duration_ui=args.duration,
351
- ui_frames_to_use=args.frames_to_use,
352
- seed_ui=args.seed,
353
- randomize_seed=args.randomize_seed,
354
- ui_guidance_scale=args.guidance_scale,
355
- improve_texture_flag=not args.no_improve_texture
356
  )
357
-
358
- print(f"\n✅ Video generation completed!")
359
- print(f"📁 Output saved to: {output_path}")
360
- print(f"🎲 Used seed: {used_seed}")
361
- print(f"📂 Full path: {os.path.abspath(output_path)}")
362
 
363
  except Exception as e:
364
- print(f"❌ Error during generation: {e}")
365
- raise
 
366
 
 
 
 
367
  if __name__ == "__main__":
368
- if os.path.exists(models_dir) and os.path.isdir(models_dir):
369
- print(f"Model directory: {Path(models_dir).resolve()}")
370
 
371
- print(f"Output directory: {Path(output_dir).resolve()}")
372
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==============================================================================
2
+ # 统一入口和依赖项
3
+ # ==============================================================================
4
  import torch
5
  import numpy as np
6
  import random
 
14
  from huggingface_hub import hf_hub_download
15
  import shutil
16
 
17
+ # 监听模式所需的依赖项
18
+ import asyncio
19
+ import websockets
20
+ import subprocess
21
+ import json
22
+ import logging
23
+ import sys
24
+ import urllib.parse
25
+ import requests
26
+
27
  from inference import (
28
  create_ltx_video_pipeline,
29
  create_latent_upsampler,
 
36
  from ltx_video.pipelines.pipeline_ltx_video import ConditioningItem, LTXMultiScalePipeline, LTXVideoPipeline
37
  from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
38
 
39
+ # ==============================================================================
40
+ # 日志配置
41
+ # ==============================================================================
42
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
43
+ logger = logging.getLogger(__name__)
44
+
45
+ # ==============================================================================
46
+ # 监听模式的函数 (原 remote_client.py)
47
+ # ==============================================================================
48
+
49
+ # 全局变量,用于在监听模式下共享状态
50
+ global_websocket = None
51
+ global_machine_id = None
52
+ global_card_id = None
53
+ global_machine_secret = None
54
+ global_server_url = None
55
+
56
+ async def upload_file_to_server(file_path, card_id, machine_secret, machine_id):
57
+ """将文件上传到服务器的指定端点"""
58
+ try:
59
+ if not os.path.exists(file_path):
60
+ logger.error(f"[Uploader] File not found: {file_path}")
61
+ return False
62
+
63
+ upload_url = f"{global_server_url}/terminal/{card_id}/machine-upload?secret={urllib.parse.quote(machine_secret)}"
64
+ files = {'file': (os.path.basename(file_path), open(file_path, 'rb'), 'application/octet-stream')}
65
+ data = {'machine_id': machine_id}
66
+
67
+ logger.info(f"[Uploader] Uploading {os.path.basename(file_path)} to {upload_url}...")
68
+ response = requests.post(upload_url, files=files, data=data, timeout=120)
69
+
70
+ if response.status_code == 200:
71
+ result = response.json()
72
+ if result and result.get("success"):
73
+ logger.info(f"[Uploader] Upload successful: {file_path}")
74
+ return True
75
+ else:
76
+ logger.error(f"[Uploader] Upload failed: {result.get('error', 'Unknown error')}")
77
+ return False
78
+ else:
79
+ logger.error(f"[Uploader] Upload failed with status code {response.status_code}: {response.text}")
80
+ return False
81
+
82
+ except Exception as e:
83
+ logger.error(f"[Uploader] An exception occurred during upload: {e}")
84
+ return False
85
+
86
+ async def watch_directory_for_uploads(dir_to_watch, card_id, secret, get_machine_id_func):
87
+ """
88
+ 监视指定目录中的新文件,并自动上传。
89
+ """
90
+ processed_files = set()
91
+ logger.info(f"[Watcher] Starting to watch directory: {dir_to_watch}")
92
+
93
+ # 初始扫描,将已存在的文件视为已处理
94
+ if os.path.isdir(dir_to_watch):
95
+ processed_files.update(os.listdir(dir_to_watch))
96
+ logger.info(f"[Watcher] Initial scan: {len(processed_files)} existing files ignored.")
97
+
98
+ while True:
99
+ await asyncio.sleep(5) # 每5秒检查一次
100
+ try:
101
+ if not os.path.isdir(dir_to_watch):
102
+ continue
103
+
104
+ current_files = set(os.listdir(dir_to_watch))
105
+ new_files = current_files - processed_files
106
+
107
+ if new_files:
108
+ machine_id = get_machine_id_func()
109
+ if not machine_id:
110
+ logger.warning("[Watcher] Machine ID not available, skipping upload cycle.")
111
+ continue
112
+
113
+ logger.info(f"[Watcher] Detected {len(new_files)} new file(s): {', '.join(new_files)}")
114
+ for filename in new_files:
115
+ file_path = os.path.join(dir_to_watch, filename)
116
+ # 等待文件写入完成 (简单检查)
117
+ await asyncio.sleep(2)
118
+
119
+ success = await upload_file_to_server(file_path, card_id, secret, machine_id)
120
+ if success:
121
+ logger.info(f"[Watcher] Successfully uploaded {filename}. Marking as processed.")
122
+ processed_files.add(filename)
123
+ else:
124
+ logger.warning(f"[Watcher] Failed to upload {filename}. Will retry on next cycle.")
125
+
126
+ # 同步已处理列表,移除已删除的文件
127
+ processed_files.intersection_update(current_files)
128
+
129
+ except Exception as e:
130
+ logger.error(f"[Watcher] Error in file watching loop: {e}")
131
+
132
+
133
+ async def start_listener_mode(card_id, machine_secret, watch_dir):
134
+ """
135
+ 启动监听模式的主函数。
136
+ """
137
+ global global_websocket, global_machine_id, global_card_id, global_machine_secret, global_server_url
138
+
139
+ global_card_id = card_id
140
+ global_machine_secret = machine_secret
141
+
142
+ server_hostname = "remote-terminal-worker.nianxi4563.workers.dev" # 或者您的服务器域名
143
+ global_server_url = f"https://{server_hostname}"
144
+ encoded_secret = urllib.parse.quote(machine_secret)
145
+ uri = f"wss://{server_hostname}/terminal/{card_id}?secret={encoded_secret}"
146
+
147
+ # 启动文件监视器
148
+ def get_machine_id(): return global_machine_id
149
+ watcher_task = asyncio.create_task(watch_directory_for_uploads(watch_dir, card_id, machine_secret, get_machine_id))
150
+
151
+ while True: # 自动重连循环
152
+ try:
153
+ logger.info(f"[Listener] Attempting to connect to {uri}")
154
+ async with websockets.connect(uri, ping_interval=20, ping_timeout=60) as websocket:
155
+ global_websocket = websocket
156
+ logger.info("[Listener] Connected to WebSocket server.")
157
+
158
+ # 循环以获取 machine_id
159
+ while global_machine_id is None:
160
+ try:
161
+ response = await asyncio.wait_for(websocket.recv(), timeout=10.0)
162
+ data = json.loads(response)
163
+ if data.get("type") == "connected" and "machine_id" in data:
164
+ global_machine_id = data["machine_id"]
165
+ logger.info(f"[Listener] Assigned machine ID: {global_machine_id}")
166
+ break
167
+ except asyncio.TimeoutError:
168
+ logger.debug("[Listener] Waiting for machine ID...")
169
+ except Exception as e:
170
+ logger.error(f"[Listener] Error receiving machine ID: {e}")
171
+ await asyncio.sleep(5) # 等待后重试
172
+ break # break inner loop to reconnect
173
+
174
+ if not global_machine_id:
175
+ continue # continue outer loop to reconnect
176
+
177
+ # 主消息处理循环
178
+ while True:
179
+ message = await websocket.recv()
180
+ data = json.loads(message)
181
+ logger.debug(f"[Listener] Received message: {data}")
182
+
183
+ if data.get("type") == "command":
184
+ command = data["command"]
185
+ logger.info(f"[Listener] Received command: {command}")
186
+
187
+ # 使用 subprocess 在新进程中执行命令
188
+ # 这使得监听器可以继续工作,而推理在后台运行
189
+ try:
190
+ # 将命令包装在 `python app.py ...` 中
191
+ full_command = f"python app.py {command}"
192
+ logger.info(f"Executing subprocess: {full_command}")
193
+ subprocess.run(full_command, shell=True, check=True)
194
+ logger.info("Subprocess finished successfully.")
195
+ # 结果文件将由 watcher 自动上传
196
+ except subprocess.CalledProcessError as e:
197
+ logger.error(f"Command execution failed with return code {e.returncode}")
198
+ error_output = e.stderr if e.stderr else e.stdout
199
+ if global_websocket:
200
+ await global_websocket.send(json.dumps({
201
+ "type": "error", "data": f"Command failed: {error_output}", "machine_id": global_machine_id
202
+ }))
203
+ except Exception as e:
204
+ logger.error(f"Failed to run command: {e}")
205
+
206
+ except websockets.exceptions.ConnectionClosed as e:
207
+ logger.warning(f"[Listener] WebSocket closed: code={e.code}, reason={e.reason}. Reconnecting in 10 seconds...")
208
+ except Exception as e:
209
+ logger.error(f"[Listener] Connection failed: {e}. Reconnecting in 10 seconds...")
210
+
211
+ global_websocket = None
212
+ global_machine_id = None
213
+ await asyncio.sleep(10)
214
+
215
+
216
+ # ==============================================================================
217
+ # 推理模式的函数 (原 app.py)
218
+ # ==============================================================================
219
  config_file_path = "configs/ltxv-13b-0.9.7-distilled.yaml"
220
  with open(config_file_path, "r") as file:
221
  PIPELINE_CONFIG_YAML = yaml.safe_load(file)
 
223
  LTX_REPO = "Lightricks/LTX-Video"
224
  MAX_IMAGE_SIZE = PIPELINE_CONFIG_YAML.get("max_resolution", 1280)
225
  MAX_NUM_FRAMES = 257
 
226
  FPS = 30.0
227
 
228
+ # 全局变量以缓存加载的模型
229
  pipeline_instance = None
230
  latent_upsampler_instance = None
231
  models_dir = "downloaded_models_gradio_cpu_init"
232
  Path(models_dir).mkdir(parents=True, exist_ok=True)
233
+ output_dir = "output" # 所有模式共用的输出目录
 
 
234
  Path(output_dir).mkdir(parents=True, exist_ok=True)
235
 
236
+ def initialize_models():
237
+ """加载并初始化所有AI模型(如果尚未加载)。"""
238
+ global pipeline_instance, latent_upsampler_instance
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
 
240
+ if pipeline_instance is not None:
241
+ logger.info("Models already initialized.")
242
+ return
243
+
244
+ logger.info("Initializing models for the first time...")
245
+ logger.info("Downloading models (if not present)...")
246
+ distilled_model_actual_path = hf_hub_download(
247
+ repo_id=LTX_REPO, filename=PIPELINE_CONFIG_YAML["checkpoint_path"], local_dir=models_dir, local_dir_use_symlinks=False
248
  )
249
+ PIPELINE_CONFIG_YAML["checkpoint_path"] = distilled_model_actual_path
250
+ logger.info(f"Distilled model path: {distilled_model_actual_path}")
251
 
252
+ SPATIAL_UPSCALER_FILENAME = PIPELINE_CONFIG_YAML["spatial_upscaler_model_path"]
253
+ spatial_upscaler_actual_path = hf_hub_download(
254
+ repo_id=LTX_REPO, filename=SPATIAL_UPSCALER_FILENAME, local_dir=models_dir, local_dir_use_symlinks=False
255
+ )
256
+ PIPELINE_CONFIG_YAML["spatial_upscaler_model_path"] = spatial_upscaler_actual_path
257
+ logger.info(f"Spatial upscaler model path: {spatial_upscaler_actual_path}")
258
+
259
+ logger.info("Creating LTX Video pipeline on CPU...")
260
+ pipeline_instance = create_ltx_video_pipeline(
261
+ ckpt_path=PIPELINE_CONFIG_YAML["checkpoint_path"],
262
+ precision=PIPELINE_CONFIG_YAML["precision"],
263
+ text_encoder_model_name_or_path=PIPELINE_CONFIG_YAML["text_encoder_model_name_or_path"],
264
+ sampler=PIPELINE_CONFIG_YAML["sampler"],
265
+ device="cpu",
266
+ enhance_prompt=False,
267
+ prompt_enhancer_image_caption_model_name_or_path=PIPELINE_CONFIG_YAML["prompt_enhancer_image_caption_model_name_or_path"],
268
+ prompt_enhancer_llm_model_name_or_path=PIPELINE_CONFIG_YAML["prompt_enhancer_llm_model_name_or_path"],
269
+ )
270
+ logger.info("LTX Video pipeline created on CPU.")
271
 
272
+ if PIPELINE_CONFIG_YAML.get("spatial_upscaler_model_path"):
273
+ logger.info("Creating latent upsampler on CPU...")
274
+ latent_upsampler_instance = create_latent_upsampler(
275
+ PIPELINE_CONFIG_YAML["spatial_upscaler_model_path"], device="cpu"
276
+ )
277
+ logger.info("Latent upsampler created on CPU.")
278
 
279
+ target_inference_device = "cuda" if torch.cuda.is_available() else "cpu"
280
+ logger.info(f"Moving models to target inference device: {target_inference_device}")
281
+ pipeline_instance.to(target_inference_device)
282
+ if latent_upsampler_instance:
283
+ latent_upsampler_instance.to(target_inference_device)
284
+ logger.info("Model initialization complete.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
 
 
286
 
287
  def generate(prompt, negative_prompt="worst quality, inconsistent motion, blurry, jittery, distorted",
288
  input_image_filepath=None, input_video_filepath=None,
 
290
  duration_ui=2.0, ui_frames_to_use=9,
291
  seed_ui=42, randomize_seed=True, ui_guidance_scale=None, improve_texture_flag=True):
292
 
293
+ # 确保模型已加载
294
+ initialize_models()
295
+
296
+ target_inference_device = "cuda" if torch.cuda.is_available() else "cpu"
297
+
298
  if randomize_seed:
299
  seed_ui = random.randint(0, 2**32 - 1)
300
  seed_everething(int(seed_ui))
 
323
  padding_values = calculate_padding(actual_height, actual_width, height_padded, width_padded)
324
 
325
  call_kwargs = {
326
+ "prompt": prompt, "negative_prompt": negative_prompt, "height": height_padded, "width": width_padded,
327
+ "num_frames": num_frames_padded, "frame_rate": int(FPS),
 
 
 
 
328
  "generator": torch.Generator(device=target_inference_device).manual_seed(int(seed_ui)),
329
+ "output_type": "pt", "conditioning_items": None, "media_items": None,
330
+ "decode_timestep": PIPELINE_CONFIG_YAML["decode_timestep"], "decode_noise_scale": PIPELINE_CONFIG_YAML["decode_noise_scale"],
331
+ "stochastic_sampling": PIPELINE_CONFIG_YAML["stochastic_sampling"], "image_cond_noise_scale": 0.15,
332
+ "is_video": True, "vae_per_channel_normalize": True, "mixed_precision": (PIPELINE_CONFIG_YAML["precision"] == "mixed_precision"),
333
+ "offload_to_cpu": False, "enhance_prompt": False,
 
 
 
 
 
 
 
334
  }
335
 
336
  stg_mode_str = PIPELINE_CONFIG_YAML.get("stg_mode", "attention_values")
 
347
 
348
  if mode == "image-to-video" and input_image_filepath:
349
  try:
350
+ media_tensor = load_image_to_tensor_with_resize_and_crop(input_image_filepath, actual_height, actual_width)
 
 
351
  media_tensor = torch.nn.functional.pad(media_tensor, padding_values)
352
  call_kwargs["conditioning_items"] = [ConditioningItem(media_tensor.to(target_inference_device), 0, 1.0)]
353
  except Exception as e:
354
+ logger.error(f"Error loading image {input_image_filepath}: {e}")
355
  raise RuntimeError(f"Could not load image: {e}")
356
  elif mode == "video-to-video" and input_video_filepath:
357
  try:
358
  call_kwargs["media_items"] = load_media_file(
359
+ media_path=input_video_filepath, height=actual_height, width=actual_width,
360
+ max_frames=int(ui_frames_to_use), padding=padding_values
 
 
 
361
  ).to(target_inference_device)
362
  except Exception as e:
363
+ logger.error(f"Error loading video {input_video_filepath}: {e}")
364
  raise RuntimeError(f"Could not load video: {e}")
365
 
366
+ active_latent_upsampler = latent_upsampler_instance if improve_texture_flag and latent_upsampler_instance else None
 
 
 
 
 
367
  result_images_tensor = None
368
+
369
  if improve_texture_flag:
370
  if not active_latent_upsampler:
371
+ raise RuntimeError("Spatial upscaler model not loaded or improve_texture not selected.")
372
 
373
  multi_scale_pipeline_obj = LTXMultiScalePipeline(pipeline_instance, active_latent_upsampler)
374
+ first_pass_args = {**PIPELINE_CONFIG_YAML.get("first_pass", {}), "guidance_scale": float(ui_guidance_scale)}
375
+ second_pass_args = {**PIPELINE_CONFIG_YAML.get("second_pass", {}), "guidance_scale": float(ui_guidance_scale)}
376
 
377
+ multi_scale_call_kwargs = {
378
+ **call_kwargs, "downscale_factor": PIPELINE_CONFIG_YAML["downscale_factor"],
379
+ "first_pass": first_pass_args, "second_pass": second_pass_args
380
+ }
 
 
 
 
 
 
 
 
 
 
381
 
382
+ logger.info(f"Calling multi-scale pipeline on {target_inference_device}")
383
  result_images_tensor = multi_scale_pipeline_obj(**multi_scale_call_kwargs).images
384
  else:
385
+ single_pass_call_kwargs = {**call_kwargs, **PIPELINE_CONFIG_YAML.get("first_pass", {}), "guidance_scale": float(ui_guidance_scale)}
386
+ logger.info(f"Calling base pipeline on {target_inference_device}")
 
 
 
 
 
 
 
 
 
 
 
 
 
387
  result_images_tensor = pipeline_instance(**single_pass_call_kwargs).images
388
 
389
  if result_images_tensor is None:
390
+ raise RuntimeError("Generation failed, result tensor is None.")
391
 
392
  pad_left, pad_right, pad_top, pad_bottom = padding_values
393
  slice_h_end = -pad_bottom if pad_bottom > 0 else None
394
  slice_w_end = -pad_right if pad_right > 0 else None
395
 
396
+ result_images_tensor = result_images_tensor[:, :, :actual_num_frames, pad_top:slice_h_end, pad_left:slice_w_end]
 
 
 
 
 
 
397
 
398
+ video_np = (result_images_tensor[0].permute(1, 2, 3, 0).cpu().float().numpy() * 255).clip(0, 255).astype(np.uint8)
399
+
400
+ # 使用随机数确保文件名几乎不重复
401
  timestamp = random.randint(10000, 99999)
402
+ output_video_path = os.path.join(output_dir, f"output_{timestamp}_{seed_ui}.mp4")
403
 
404
  try:
405
+ with imageio.get_writer(output_video_path, fps=call_kwargs["frame_rate"], macro_block_size=1) as writer:
406
+ for frame in video_np:
407
+ writer.append_data(frame)
408
+ except Exception:
409
+ with imageio.get_writer(output_video_path, fps=call_kwargs["frame_rate"], format='FFMPEG', codec='libx264') as writer:
410
+ for frame in video_np:
411
+ writer.append_data(frame)
 
 
 
 
 
 
 
 
 
412
 
413
+ logger.info(f"Video saved successfully to: {output_video_path}")
414
  return output_video_path, seed_ui
415
 
416
+ def run_inference(args):
417
+ """处理命令行参数并运行AI推理。"""
418
+ logger.info(f"Starting single-run inference...")
419
+ logger.info(f"Prompt: {args.prompt}")
420
+ logger.info(f"Mode: {args.mode}")
421
+ logger.info(f"Duration: {args.duration}s")
422
+ logger.info(f"Resolution: {args.width}x{args.height}")
423
+ logger.info(f"Output directory: {os.path.abspath(output_dir)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
 
425
  try:
426
  output_path, used_seed = generate(
427
+ prompt=args.prompt, negative_prompt=args.negative_prompt,
428
+ input_image_filepath=args.input_image, input_video_filepath=args.input_video,
429
+ height_ui=args.height, width_ui=args.width, mode=args.mode,
430
+ duration_ui=args.duration, ui_frames_to_use=args.frames_to_use,
431
+ seed_ui=args.seed, randomize_seed=args.randomize_seed,
432
+ ui_guidance_scale=args.guidance_scale, improve_texture_flag=not args.no_improve_texture
 
 
 
 
 
 
 
433
  )
434
+ logger.info(f"\n✅ Video generation completed!")
435
+ logger.info(f"📁 Output saved to: {output_path}")
436
+ logger.info(f"🎲 Used seed: {used_seed}")
 
 
437
 
438
  except Exception as e:
439
+ logger.error(f"❌ Error during generation: {e}", exc_info=True)
440
+ sys.exit(1)
441
+
442
 
443
+ # ==============================================================================
444
+ # 主入口和参数解析
445
+ # ==============================================================================
446
  if __name__ == "__main__":
447
+ parser = argparse.ArgumentParser(description="LTX Video Generation and Server Client")
 
448
 
449
+ # --- 模式选择 ---
450
+ group = parser.add_argument_group('运行模式')
451
+ group.add_argument("--listen", action="store_true", help="以监听模式运行,连接到服务器等待指令。")
452
+
453
+ # --- 监听模式参数 ---
454
+ listener_group = parser.add_argument_group('监听模式参数 (需配合 --listen)')
455
+ listener_group.add_argument("--card-id", help="用于向服务器认证的Card ID。")
456
+ listener_group.add_argument("--secret", help="用于向服务器认证的Machine Secret。")
457
+ listener_group.add_argument("--watch-dir", default=output_dir, help=f"监听新文件并自动上传的目录 (默认: {output_dir})")
458
+
459
+ # --- 推理模式参数 ---
460
+ inference_group = parser.add_argument_group('推理模式参数 (默认模式)')
461
+ inference_group.add_argument("--prompt", help="用于视频生成的文本提示。")
462
+ inference_group.add_argument("--negative-prompt", default="worst quality, inconsistent motion, blurry, jittery, distorted", help="负面提示。")
463
+ inference_group.add_argument("--mode", choices=["text-to-video", "image-to-video", "video-to-video"], default="text-to-video", help="生成模式。")
464
+ inference_group.add_argument("--input-image", help="输入图像路径 (用于 image-to-video 模式)。")
465
+ inference_group.add_argument("--input-video", help="输入视频路径 (用于 video-to-video 模式)。")
466
+ inference_group.add_argument("--duration", type=float, default=2.0, help="视频时长 (秒, 0.3-8.5)。")
467
+ inference_group.add_argument("--height", type=int, default=512, help="视频高度 (将被调整为32的倍数)。")
468
+ inference_group.add_argument("--width", type=int, default=704, help="视频宽度 (将被调整为32的倍数)。")
469
+ inference_group.add_argument("--seed", type=int, default=42, help="随机种子。")
470
+ inference_group.add_argument("--randomize-seed", action="store_true", help="使用一个随机的种子。")
471
+ inference_group.add_argument("--guidance-scale", type=float, help="引导比例。")
472
+ inference_group.add_argument("--no-improve-texture", action="store_true", help="禁用纹理增强 (更快,但质量可能较低)。")
473
+ inference_group.add_argument("--frames-to-use", type=int, default=9, help="从输入视频中使用多少帧 (用于 video-to-video)。")
474
+
475
+ args = parser.parse_args()
476
+
477
+ # 根据模式分发任务
478
+ if args.listen:
479
+ if not args.card_id or not args.secret:
480
+ parser.error("--card-id 和 --secret 是 --listen 模式的必需参数。")
481
+ logger.info(f"启动监听模式... Card ID: {args.card_id}, Watch Dir: {args.watch_dir}")
482
+ try:
483
+ asyncio.run(start_listener_mode(args.card_id, args.secret, args.watch_dir))
484
+ except KeyboardInterrupt:
485
+ logger.info("监听模式已停止。")
486
+ else:
487
+ if not args.prompt:
488
+ parser.error("--prompt 是推理模式的必需参数。")
489
+
490
+ # 确保尺寸是32的倍数
491
+ args.height = ((args.height - 1) // 32 + 1) * 32
492
+ args.width = ((args.width - 1) // 32 + 1) * 32
493
+
494
+ run_inference(args)