aiqcamp commited on
Commit
9aac7f0
ยท
verified ยท
1 Parent(s): 9d60901

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -1
app.py CHANGED
@@ -43,8 +43,15 @@ pipe = StableDiffusionXLFillPipeline.from_pretrained(
43
 
44
  pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
45
 
46
- # ํ…์ŠคํŠธ ์ธ์ฝ”๋”์˜ dtype ๋ถˆ์ผ์น˜๋ฅผ ๋ง‰๊ธฐ ์œ„ํ•ด float16์œผ๋กœ ๋ณ€ํ™˜
47
  pipe.text_encoder = pipe.text_encoder.to("cuda", dtype=torch.float16)
 
 
 
 
 
 
 
48
 
49
  def can_expand(source_width, source_height, target_width, target_height, alignment):
50
  """Checks if the image can be expanded based on the alignment."""
 
43
 
44
  pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
45
 
46
+ # ํ…์ŠคํŠธ ์ธ์ฝ”๋”๋ฅผ float16์œผ๋กœ ๊ฐ•์ œ ๋ณ€ํ™˜
47
  pipe.text_encoder = pipe.text_encoder.to("cuda", dtype=torch.float16)
48
+ # ์ถ”๊ฐ€: text_projection์˜ forward๋ฅผ ์˜ค๋ฒ„๋ผ์ด๋”ฉํ•˜์—ฌ ์ž…๋ ฅ์ด float16์ด ์•„๋‹ˆ๋ฉด half๋กœ ์บ์ŠคํŒ…
49
+ original_text_projection_forward = pipe.text_encoder.text_projection.forward
50
+ def fixed_text_projection_forward(x):
51
+ if x.dtype != torch.float16:
52
+ x = x.half()
53
+ return original_text_projection_forward(x)
54
+ pipe.text_encoder.text_projection.forward = fixed_text_projection_forward
55
 
56
  def can_expand(source_width, source_height, target_width, target_height, alignment):
57
  """Checks if the image can be expanded based on the alignment."""