PengWeixuanSZU commited on
Commit
49b23e6
·
verified ·
1 Parent(s): bc4fdd8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -5
app.py CHANGED
@@ -28,7 +28,7 @@ import subprocess
28
  import spaces
29
  from huggingface_hub import snapshot_download
30
 
31
- device = "cuda" if torch.cuda.is_available() else "cpu"
32
 
33
 
34
  def download_model():
@@ -50,7 +50,6 @@ def get_prompt(file:str):
50
  a=f.readlines()
51
  return a #a[0]:positive prompt, a[1] negative prompt
52
 
53
- @spaces.GPU(duration=120)
54
  def init_pipe():
55
  def unwarp_model(state_dict):
56
  new_state_dict = {}
@@ -73,7 +72,7 @@ def init_pipe():
73
  )
74
 
75
  text_encoder = T5EncoderModel.from_pretrained(f"./cogvideox-5b-{key}/", subfolder="text_encoder", torch_dtype=torch.float16)
76
- vae = AutoencoderKLCogVideoX.from_pretrained(f"./cogvideox-5b-{key}/", subfolder="vae", torch_dtype=torch.float16).to(device)
77
  tokenizer = T5Tokenizer.from_pretrained(f"./cogvideox-5b-{key}/tokenizer", torch_dtype=torch.float16)
78
 
79
 
@@ -104,8 +103,8 @@ def init_pipe():
104
  transformer.load_state_dict(transformer_state_dict, strict=True)
105
  controlnet_transformer.load_state_dict(controlnet_transformer_state_dict, strict=True)
106
 
107
- transformer = transformer.half().to(device)
108
- controlnet_transformer = controlnet_transformer.half().to(device)
109
 
110
  vae = vae.eval()
111
  text_encoder = text_encoder.eval()
@@ -134,6 +133,10 @@ def inference(source_images,
134
  h, w, random_seed)->List[PIL.Image.Image]:
135
  torch.manual_seed(random_seed)
136
 
 
 
 
 
137
  source_pixel_values = source_images/127.5 - 1.0
138
  source_pixel_values = source_pixel_values.to(torch.float16).to("cuda:0")
139
  if target_images is not None:
 
28
  import spaces
29
  from huggingface_hub import snapshot_download
30
 
31
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
32
 
33
 
34
  def download_model():
 
50
  a=f.readlines()
51
  return a #a[0]:positive prompt, a[1] negative prompt
52
 
 
53
  def init_pipe():
54
  def unwarp_model(state_dict):
55
  new_state_dict = {}
 
72
  )
73
 
74
  text_encoder = T5EncoderModel.from_pretrained(f"./cogvideox-5b-{key}/", subfolder="text_encoder", torch_dtype=torch.float16)
75
+ vae = AutoencoderKLCogVideoX.from_pretrained(f"./cogvideox-5b-{key}/", subfolder="vae", torch_dtype=torch.float16)
76
  tokenizer = T5Tokenizer.from_pretrained(f"./cogvideox-5b-{key}/tokenizer", torch_dtype=torch.float16)
77
 
78
 
 
103
  transformer.load_state_dict(transformer_state_dict, strict=True)
104
  controlnet_transformer.load_state_dict(controlnet_transformer_state_dict, strict=True)
105
 
106
+ transformer = transformer.half()
107
+ controlnet_transformer = controlnet_transformer.half()
108
 
109
  vae = vae.eval()
110
  text_encoder = text_encoder.eval()
 
133
  h, w, random_seed)->List[PIL.Image.Image]:
134
  torch.manual_seed(random_seed)
135
 
136
+ pipe.vae.to(DEVICE)
137
+ pipe.transformer.to(DEVICE)
138
+ pipe.controlnet_transformer.to(DEVICE)
139
+
140
  source_pixel_values = source_images/127.5 - 1.0
141
  source_pixel_values = source_pixel_values.to(torch.float16).to("cuda:0")
142
  if target_images is not None: