Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -10,9 +10,9 @@ from diffusers import EulerDiscreteScheduler
|
|
| 10 |
import gradio as gr
|
| 11 |
|
| 12 |
# Download the model files
|
| 13 |
-
|
| 14 |
|
| 15 |
-
# Function to load models
|
| 16 |
def load_models():
|
| 17 |
text_encoder = ChatGLMModel.from_pretrained(
|
| 18 |
os.path.join(ckpt_dir, 'text_encoder'),
|
|
@@ -29,21 +29,20 @@ def load_models():
|
|
| 29 |
unet=unet,
|
| 30 |
scheduler=scheduler,
|
| 31 |
force_zeros_for_empty_prompt=False
|
| 32 |
-
)
|
| 33 |
|
| 34 |
# Create a global variable to hold the pipeline
|
| 35 |
pipe = load_models()
|
| 36 |
|
| 37 |
-
@spaces.GPU(duration=200)
|
| 38 |
def generate_image(prompt, negative_prompt, height, width, num_inference_steps, guidance_scale, num_images_per_prompt, use_random_seed, seed, progress=gr.Progress(track_tqdm=True)):
|
| 39 |
if use_random_seed:
|
| 40 |
seed = random.randint(0, 2**32 - 1)
|
| 41 |
else:
|
| 42 |
seed = int(seed) # Ensure seed is an integer
|
| 43 |
|
| 44 |
-
# Move the model to the
|
| 45 |
with torch.no_grad():
|
| 46 |
-
generator = torch.Generator(
|
| 47 |
result = pipe(
|
| 48 |
prompt=prompt,
|
| 49 |
negative_prompt=negative_prompt,
|
|
@@ -58,8 +57,6 @@ def generate_image(prompt, negative_prompt, height, width, num_inference_steps,
|
|
| 58 |
|
| 59 |
return image, seed
|
| 60 |
|
| 61 |
-
|
| 62 |
-
|
| 63 |
# Gradio interface
|
| 64 |
iface = gr.Interface(
|
| 65 |
fn=generate_image,
|
|
|
|
| 10 |
import gradio as gr
|
| 11 |
|
| 12 |
# Download the model files
|
| 13 |
+
ckpt_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors")
|
| 14 |
|
| 15 |
+
# Function to load models
|
| 16 |
def load_models():
|
| 17 |
text_encoder = ChatGLMModel.from_pretrained(
|
| 18 |
os.path.join(ckpt_dir, 'text_encoder'),
|
|
|
|
| 29 |
unet=unet,
|
| 30 |
scheduler=scheduler,
|
| 31 |
force_zeros_for_empty_prompt=False
|
| 32 |
+
)
|
| 33 |
|
| 34 |
# Create a global variable to hold the pipeline
|
| 35 |
pipe = load_models()
|
| 36 |
|
|
|
|
| 37 |
def generate_image(prompt, negative_prompt, height, width, num_inference_steps, guidance_scale, num_images_per_prompt, use_random_seed, seed, progress=gr.Progress(track_tqdm=True)):
|
| 38 |
if use_random_seed:
|
| 39 |
seed = random.randint(0, 2**32 - 1)
|
| 40 |
else:
|
| 41 |
seed = int(seed) # Ensure seed is an integer
|
| 42 |
|
| 43 |
+
# Move the model to the CPU for inference and clear unnecessary variables
|
| 44 |
with torch.no_grad():
|
| 45 |
+
generator = torch.Generator().manual_seed(seed)
|
| 46 |
result = pipe(
|
| 47 |
prompt=prompt,
|
| 48 |
negative_prompt=negative_prompt,
|
|
|
|
| 57 |
|
| 58 |
return image, seed
|
| 59 |
|
|
|
|
|
|
|
| 60 |
# Gradio interface
|
| 61 |
iface = gr.Interface(
|
| 62 |
fn=generate_image,
|