wifix199 commited on
Commit
63cbc5c
·
verified ·
1 Parent(s): 69e0906

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -3
app.py CHANGED
@@ -6,12 +6,16 @@ model_id = "SG161222/RealVisXL_V4.0"
6
  pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
7
  pipe.to("cpu") # Use "cuda" if GPU is available
8
 
9
- def generate_image(prompt):
10
- image = pipe(prompt).images[0]
 
 
 
11
  return image
 
12
  def chatbot(prompt):
13
  # Generate the image based on the user's input
14
- image = generate_image(prompt)
15
  return image
16
 
17
  def get_aug_embed(self, text_embeds, image):
 
6
  pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
7
  pipe.to("cpu") # Use "cuda" if GPU is available
8
 
9
+ unet = pipe.unet
10
+
11
+ def generate_image(prompt, unet):
12
+ added_cond_kwargs = {"text_embeds": pipe.get_text_embedding(prompt)}
13
+ image = unet(prompt, **added_cond_kwargs).images[0]
14
  return image
15
+
16
  def chatbot(prompt):
17
  # Generate the image based on the user's input
18
+ image = generate_image(prompt, unet)
19
  return image
20
 
21
  def get_aug_embed(self, text_embeds, image):