yichenchenchen commited on
Commit
2318fdc
·
verified ·
1 Parent(s): e393842

Update inferencer.py

Browse files
Files changed (1) hide show
  1. inferencer.py +5 -5
inferencer.py CHANGED
@@ -49,20 +49,20 @@ class UniPicV2Inferencer:
49
  else:
50
  transformer = SD3Transformer2DKontextModel.from_pretrained(
51
  self.model_path, subfolder="transformer",
52
- torch_dtype=torch.float16, device_map="auto", low_cpu_mem_usage=True
53
  )
54
 
55
  # ===== 3. Load VAE =====
56
  vae = AutoencoderKL.from_pretrained(
57
  self.model_path, subfolder="vae",
58
- torch_dtype=torch.float16, device_map="auto", low_cpu_mem_usage=True
59
  ).to(self.device)
60
 
61
  # ===== 4. Load Qwen2.5-VL (LMM) =====
62
  try:
63
  self.lmm = Qwen2_5_VLForConditionalGeneration.from_pretrained(
64
  self.qwen_vl_path,
65
- torch_dtype=torch.float16,
66
  attn_implementation="flash_attention_2",
67
  device_map="auto",
68
  ).to(self.device)
@@ -70,7 +70,7 @@ class UniPicV2Inferencer:
70
  except Exception:
71
  self.lmm = Qwen2_5_VLForConditionalGeneration.from_pretrained(
72
  self.qwen_vl_path,
73
- torch_dtype=torch.float16,
74
  attn_implementation="sdpa",
75
  device_map="auto",
76
  ).to(self.device)
@@ -87,7 +87,7 @@ class UniPicV2Inferencer:
87
  # ===== 6. Load Conditioner =====
88
  self.conditioner = StableDiffusion3Conditioner.from_pretrained(
89
  self.model_path, subfolder="conditioner",
90
- torch_dtype=torch.float16, low_cpu_mem_usage=True
91
  ).to(self.device)
92
 
93
  # ===== 7. Load Scheduler =====
 
49
  else:
50
  transformer = SD3Transformer2DKontextModel.from_pretrained(
51
  self.model_path, subfolder="transformer",
52
+ torch_dtype=torch.bfloat16, device_map="auto", low_cpu_mem_usage=True
53
  )
54
 
55
  # ===== 3. Load VAE =====
56
  vae = AutoencoderKL.from_pretrained(
57
  self.model_path, subfolder="vae",
58
+ torch_dtype=torch.bfloat16, device_map="auto", low_cpu_mem_usage=True
59
  ).to(self.device)
60
 
61
  # ===== 4. Load Qwen2.5-VL (LMM) =====
62
  try:
63
  self.lmm = Qwen2_5_VLForConditionalGeneration.from_pretrained(
64
  self.qwen_vl_path,
65
+ torch_dtype=torch.bfloat16,
66
  attn_implementation="flash_attention_2",
67
  device_map="auto",
68
  ).to(self.device)
 
70
  except Exception:
71
  self.lmm = Qwen2_5_VLForConditionalGeneration.from_pretrained(
72
  self.qwen_vl_path,
73
+ torch_dtype=torch.bfloat16,
74
  attn_implementation="sdpa",
75
  device_map="auto",
76
  ).to(self.device)
 
87
  # ===== 6. Load Conditioner =====
88
  self.conditioner = StableDiffusion3Conditioner.from_pretrained(
89
  self.model_path, subfolder="conditioner",
90
+ torch_dtype=torch.bfloat16, low_cpu_mem_usage=True
91
  ).to(self.device)
92
 
93
  # ===== 7. Load Scheduler =====