chenlei commited on
Commit
f8de692
·
1 Parent(s): 8c19fbe
ootd/inference_ootd_dc.py CHANGED
@@ -23,6 +23,7 @@ import torch.nn as nn
23
  import torch.nn.functional as F
24
  from transformers import AutoProcessor, CLIPVisionModelWithProjection
25
  from transformers import CLIPTextModel, CLIPTokenizer
 
26
 
27
  VIT_PATH = "openai/clip-vit-large-patch14"
28
  VAE_PATH = "levihsu/OOTDiffusion"
@@ -36,23 +37,32 @@ class OOTDiffusionDC:
36
 
37
  vae = AutoencoderKL.from_pretrained(
38
  VAE_PATH,
39
- subfolder="vae",
40
  torch_dtype=torch.float16,
41
  )
42
 
43
  unet_garm = UNetGarm2DConditionModel.from_pretrained(
44
  UNET_PATH,
45
- subfolder="unet_garm",
46
  torch_dtype=torch.float16,
47
  use_safetensors=True,
48
  )
49
  unet_vton = UNetVton2DConditionModel.from_pretrained(
50
  UNET_PATH,
51
- subfolder="unet_vton",
52
  torch_dtype=torch.float16,
53
  use_safetensors=True,
54
  )
55
 
 
 
 
 
 
 
 
 
 
56
  self.pipe = OotdPipeline.from_pretrained(
57
  MODEL_PATH,
58
  unet_garm=unet_garm,
@@ -75,8 +85,8 @@ class OOTDiffusionDC:
75
  subfolder="tokenizer",
76
  )
77
  self.text_encoder = CLIPTextModel.from_pretrained(
78
- MODEL_PATH,
79
- subfolder="text_encoder",
80
  ).to(self.gpu_id)
81
 
82
 
 
23
  import torch.nn.functional as F
24
  from transformers import AutoProcessor, CLIPVisionModelWithProjection
25
  from transformers import CLIPTextModel, CLIPTokenizer
26
+ import requests
27
 
28
  VIT_PATH = "openai/clip-vit-large-patch14"
29
  VAE_PATH = "levihsu/OOTDiffusion"
 
37
 
38
  vae = AutoencoderKL.from_pretrained(
39
  VAE_PATH,
40
+ subfolder="checkpoints/ootd/vae",
41
  torch_dtype=torch.float16,
42
  )
43
 
44
  unet_garm = UNetGarm2DConditionModel.from_pretrained(
45
  UNET_PATH,
46
+ subfolder="checkpoints/ootd/ootd_hd/checkpoint-36000/unet_garm",
47
  torch_dtype=torch.float16,
48
  use_safetensors=True,
49
  )
50
  unet_vton = UNetVton2DConditionModel.from_pretrained(
51
  UNET_PATH,
52
+ subfolder="checkpoints/ootd/ootd_hd/checkpoint-36000/unet_vton",
53
  torch_dtype=torch.float16,
54
  use_safetensors=True,
55
  )
56
 
57
+ #判断文件是否存在
58
+ filePath = "/home/user/app/checkpoints/ootd/text_encoder/pytorch_model.bin"
59
+ if os.path.exists(filePath) == False:
60
+ url = "https://huggingface.co/yangjoe/pytorch_model/resolve/main/pytorch_model.bin"
61
+ response = requests.get(url)
62
+ #下载该文件
63
+ with open("/home/user/app/checkpoints/ootd/text_encoder/pytorch_model.bin", "wb") as f:
64
+ f.write(response.content)
65
+
66
  self.pipe = OotdPipeline.from_pretrained(
67
  MODEL_PATH,
68
  unet_garm=unet_garm,
 
85
  subfolder="tokenizer",
86
  )
87
  self.text_encoder = CLIPTextModel.from_pretrained(
88
+ 'yangjoe/pytorch_model',
89
+ #subfolder="text_encoder",
90
  ).to(self.gpu_id)
91
 
92
 
ootd/inference_ootd_hd.py CHANGED
@@ -54,11 +54,14 @@ class OOTDiffusionHD:
54
  torch_dtype=torch.float16,
55
  use_safetensors=True,
56
  )
57
- url = "https://huggingface.co/yangjoe/pytorch_model/resolve/main/pytorch_model.bin"
58
- response = requests.get(url)
59
- #下载该文件
60
- with open("/home/user/app/checkpoints/ootd/text_encoder/pytorch_model.bin", "wb") as f:
61
- f.write(response.content)
 
 
 
62
 
63
  self.pipe = OotdPipeline.from_pretrained(
64
  MODEL_PATH,
 
54
  torch_dtype=torch.float16,
55
  use_safetensors=True,
56
  )
57
+ #判断文件是否存在
58
+ filePath = "/home/user/app/checkpoints/ootd/text_encoder/pytorch_model.bin"
59
+ if os.path.exists(filePath) == False:
60
+ url = "https://huggingface.co/yangjoe/pytorch_model/resolve/main/pytorch_model.bin"
61
+ response = requests.get(url)
62
+ #下载该文件
63
+ with open("/home/user/app/checkpoints/ootd/text_encoder/pytorch_model.bin", "wb") as f:
64
+ f.write(response.content)
65
 
66
  self.pipe = OotdPipeline.from_pretrained(
67
  MODEL_PATH,