1inkusFace commited on
Commit
7834fa8
·
verified ·
1 Parent(s): a4a2927

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -6
app.py CHANGED
@@ -12,7 +12,7 @@ import asyncio # Import asyncio
12
  os.environ["SAFETENSORS_FAST_GPU"] = "1"
13
  os.putenv("HF_HUB_ENABLE_HF_TRANSFER", "1")
14
 
15
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
16
 
17
  # Use gr.State to hold the predictor. Initialize it to None.
18
  predictor_state = gr.State(None)
@@ -29,12 +29,13 @@ def init_predictor(task_type: str):
29
  task_type=TaskType.I2V if task_type == "i2v" else TaskType.T2V,
30
  model_id="Skywork/skyreels-v1-Hunyuan-i2v", # Adjust model ID as needed
31
  quant_model=True,
32
- is_offload=True,
33
  offload_config=OffloadConfig(
34
  high_cpu_memory=True,
35
  parameters_level=True,
36
  ),
37
  use_multiprocessing=False,
 
38
  )
39
  return predictor
40
  except (RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError) as e:
@@ -74,15 +75,18 @@ async def generate_video(prompt, image_file, predictor):
74
  }
75
 
76
  try:
77
- kwargs["image"] = load_image(image=image_file.name)
 
 
 
78
  except Exception as e:
79
- return gr.Error(f"image loading error: {e}")
80
 
81
  try:
82
  output = predictor.inference(kwargs)
83
  frames = output
84
  except Exception as e:
85
- return gr.Error(f"Inference error: {e}")
86
 
87
  save_dir = "./result/i2v" # Consistent directory
88
  os.makedirs(save_dir, exist_ok=True)
@@ -92,7 +96,7 @@ async def generate_video(prompt, image_file, predictor):
92
  try:
93
  export_to_video(frames, video_out_file, fps=24)
94
  except Exception as e:
95
- return gr.Error(f"Video export error: {e}")
96
 
97
  return video_out_file, predictor # Return updated predictor
98
 
 
12
  os.environ["SAFETENSORS_FAST_GPU"] = "1"
13
  os.putenv("HF_HUB_ENABLE_HF_TRANSFER", "1")
14
 
15
+ # No longer needed here: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
16
 
17
  # Use gr.State to hold the predictor. Initialize it to None.
18
  predictor_state = gr.State(None)
 
29
  task_type=TaskType.I2V if task_type == "i2v" else TaskType.T2V,
30
  model_id="Skywork/skyreels-v1-Hunyuan-i2v", # Adjust model ID as needed
31
  quant_model=True,
32
+ is_offload=True, # Consider removing if you have enough GPU memory
33
  offload_config=OffloadConfig(
34
  high_cpu_memory=True,
35
  parameters_level=True,
36
  ),
37
  use_multiprocessing=False,
38
+ device="cuda:0" if torch.cuda.is_available() else "cpu" # Pass device to the constructor
39
  )
40
  return predictor
41
  except (RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError) as e:
 
75
  }
76
 
77
  try:
78
+ # Load the image and move it to the correct device *before* inference
79
+ image = load_image(image=image_file.name)
80
+ # No need to manually move to device. SkyReelsVideoInfer should handle it.
81
+ kwargs["image"] = image
82
  except Exception as e:
83
+ return gr.Error(f"Image loading error: {e}")
84
 
85
  try:
86
  output = predictor.inference(kwargs)
87
  frames = output
88
  except Exception as e:
89
+ return gr.Error(f"Inference error: {e}"), None # Return None for predictor on error
90
 
91
  save_dir = "./result/i2v" # Consistent directory
92
  os.makedirs(save_dir, exist_ok=True)
 
96
  try:
97
  export_to_video(frames, video_out_file, fps=24)
98
  except Exception as e:
99
+ return gr.Error(f"Video export error: {e}"), None # Return None for predictor
100
 
101
  return video_out_file, predictor # Return updated predictor
102