aiqcamp commited on
Commit
9aac7f0
·
verified ·
1 Parent(s): 9d60901

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -1
app.py CHANGED
@@ -43,8 +43,15 @@ pipe = StableDiffusionXLFillPipeline.from_pretrained(
43
 
44
  pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
45
 
46
- # 텍스트 인코더의 dtype 불일치를 막기 위해 float16으로 변환
47
  pipe.text_encoder = pipe.text_encoder.to("cuda", dtype=torch.float16)
 
 
 
 
 
 
 
48
 
49
  def can_expand(source_width, source_height, target_width, target_height, alignment):
50
  """Checks if the image can be expanded based on the alignment."""
 
43
 
44
  pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
45
 
46
+ # 텍스트 인코더를 float16으로 강제 변환
47
  pipe.text_encoder = pipe.text_encoder.to("cuda", dtype=torch.float16)
48
+ # 추가: text_projection의 forward를 오버라이딩하여 입력이 float16이 아니면 half로 캐스팅
49
+ original_text_projection_forward = pipe.text_encoder.text_projection.forward
50
+ def fixed_text_projection_forward(x):
51
+ if x.dtype != torch.float16:
52
+ x = x.half()
53
+ return original_text_projection_forward(x)
54
+ pipe.text_encoder.text_projection.forward = fixed_text_projection_forward
55
 
56
  def can_expand(source_width, source_height, target_width, target_height, alignment):
57
  """Checks if the image can be expanded based on the alignment."""