Update app.py
Browse files
app.py
CHANGED
@@ -3,7 +3,7 @@ import gradio as gr
|
|
3 |
from gradio_client import Client, handle_file
|
4 |
import torch
|
5 |
import spaces
|
6 |
-
from diffusers import
|
7 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
8 |
|
9 |
if torch.cuda.is_available():
|
@@ -20,8 +20,8 @@ def set_client_for_session(request: gr.Request):
|
|
20 |
|
21 |
# Load models
|
22 |
def load_models():
|
23 |
-
pipe =
|
24 |
-
"X-ART/LeX-
|
25 |
torch_dtype=torch.bfloat16
|
26 |
)
|
27 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
@@ -57,18 +57,16 @@ def generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale):
|
|
57 |
print(f"enhanced caption:\n{enhanced_caption}")
|
58 |
|
59 |
generator = torch.Generator("cpu").manual_seed(seed) if seed != 0 else None
|
60 |
-
|
61 |
image = pipe(
|
62 |
enhanced_caption,
|
63 |
height=1024,
|
64 |
width=1024,
|
65 |
-
guidance_scale=
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
generator=generator,
|
71 |
-
system_prompt="You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts.",
|
72 |
).images[0]
|
73 |
|
74 |
print(image)
|
|
|
3 |
from gradio_client import Client, handle_file
|
4 |
import torch
|
5 |
import spaces
|
6 |
+
from diffusers import FluxPipeline
|
7 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
8 |
|
9 |
if torch.cuda.is_available():
|
|
|
20 |
|
21 |
# Load models
|
22 |
def load_models():
|
23 |
+
pipe = FluxPipeline.from_pretrained(
|
24 |
+
"X-ART/LeX-FLUX",
|
25 |
torch_dtype=torch.bfloat16
|
26 |
)
|
27 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
57 |
print(f"enhanced caption:\n{enhanced_caption}")
|
58 |
|
59 |
generator = torch.Generator("cpu").manual_seed(seed) if seed != 0 else None
|
60 |
+
|
61 |
image = pipe(
|
62 |
enhanced_caption,
|
63 |
height=1024,
|
64 |
width=1024,
|
65 |
+
guidance_scale=3.5,
|
66 |
+
output_type="pil",
|
67 |
+
num_inference_steps=28,
|
68 |
+
max_sequence_length=512,
|
69 |
+
generator=torch.Generator("cpu").manual_seed(0)
|
|
|
|
|
70 |
).images[0]
|
71 |
|
72 |
print(image)
|