|
import gradio as gr |
|
|
|
from absl import flags |
|
from absl import app |
|
from ml_collections import config_flags |
|
import os |
|
|
|
import ml_collections |
|
import torch |
|
from torch import multiprocessing as mp |
|
import torch.nn as nn |
|
import accelerate |
|
import utils |
|
import tempfile |
|
from absl import logging |
|
import builtins |
|
import einops |
|
import math |
|
import numpy as np |
|
import time |
|
from PIL import Image |
|
import random |
|
|
|
from diffusion.flow_matching import FlowMatching, ODEFlowMatchingSolver, ODEEulerFlowMatchingSolver |
|
from tools.clip_score import ClipSocre |
|
import libs.autoencoder |
|
from libs.clip import FrozenCLIPEmbedder |
|
from libs.t5 import T5Embedder |
|
|
|
|
|
def unpreprocess(x): |
|
x = 0.5 * (x + 1.) |
|
x.clamp_(0., 1.) |
|
return x |
|
|
|
def batch_decode(_z, decode, batch_size=10): |
|
""" |
|
The VAE decoder requires large GPU memory. To run the interpolation model on GPUs with 24 GB or smaller RAM, you can use this code to reduce memory usage for the VAE. |
|
It works by splitting the input tensor into smaller chunks. |
|
""" |
|
num_samples = _z.size(0) |
|
decoded_batches = [] |
|
|
|
for i in range(0, num_samples, batch_size): |
|
batch = _z[i:i + batch_size] |
|
decoded_batch = decode(batch) |
|
decoded_batches.append(decoded_batch) |
|
|
|
image_unprocessed = torch.cat(decoded_batches, dim=0) |
|
return image_unprocessed |
|
|
|
def get_caption(llm, text_model, prompt_dict, batch_size): |
|
|
|
if batch_size == 3: |
|
|
|
assert len(prompt_dict) == 2 |
|
_batch_con = list(prompt_dict.values()) + [' '] |
|
elif batch_size == 4: |
|
|
|
assert len(prompt_dict) == 3 |
|
_batch_con = list(prompt_dict.values()) + [' '] |
|
elif batch_size >= 5: |
|
|
|
assert len(prompt_dict) == 2 |
|
_batch_con = [prompt_dict['prompt_1']] + [' '] * (batch_size-2) + [prompt_dict['prompt_2']] |
|
|
|
if llm == "clip": |
|
_latent, _latent_and_others = text_model.encode(_batch_con) |
|
_con = _latent_and_others['token_embedding'].detach() |
|
elif llm == "t5": |
|
_latent, _latent_and_others = text_model.get_text_embeddings(_batch_con) |
|
_con = (_latent_and_others['token_embedding'] * 10.0).detach() |
|
else: |
|
raise NotImplementedError |
|
_con_mask = _latent_and_others['token_mask'].detach() |
|
_batch_token = _latent_and_others['tokens'].detach() |
|
_batch_caption = _batch_con |
|
return (_con, _con_mask, _batch_token, _batch_caption) |
|
|
|
import spaces |
|
from diffusers import DiffusionPipeline |
|
import torch |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model_repo_id = "stabilityai/sdxl-turbo" |
|
|
|
if torch.cuda.is_available(): |
|
torch_dtype = torch.float16 |
|
else: |
|
torch_dtype = torch.float32 |
|
|
|
|
|
|
|
|
|
MAX_SEED = np.iinfo(np.int32).max |
|
MAX_IMAGE_SIZE = 1024 |
|
|
|
|
|
@spaces.GPU |
|
def infer( |
|
prompt1, |
|
prompt2, |
|
negative_prompt, |
|
seed, |
|
randomize_seed, |
|
guidance_scale, |
|
num_inference_steps, |
|
progress=gr.Progress(track_tqdm=True), |
|
): |
|
if randomize_seed: |
|
seed = random.randint(0, MAX_SEED) |
|
|
|
generator = torch.Generator().manual_seed(seed) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples = [ |
|
["A dog cooking dinner in the kitchen", "An orange cat wearing sunglasses on a ship"], |
|
] |
|
|
|
css = """ |
|
#col-container { |
|
margin: 0 auto; |
|
max-width: 640px; |
|
} |
|
""" |
|
|
|
with gr.Blocks(css=css) as demo: |
|
with gr.Column(elem_id="col-container"): |
|
gr.Markdown(" # CrossFlow") |
|
gr.Markdown(" CrossFlow directly transforms text representations into images for text-to-image generation, enabling interpolation in the input text latent space.") |
|
|
|
with gr.Row(): |
|
prompt1 = gr.Text( |
|
label="Prompt_1", |
|
show_label=False, |
|
max_lines=1, |
|
placeholder="Enter your prompt for the first image", |
|
container=False, |
|
) |
|
|
|
with gr.Row(): |
|
prompt2 = gr.Text( |
|
label="Prompt_2", |
|
show_label=False, |
|
max_lines=1, |
|
placeholder="Enter your prompt for the second image", |
|
container=False, |
|
) |
|
|
|
with gr.Row(): |
|
run_button = gr.Button("Run", scale=0, variant="primary") |
|
|
|
result = gr.Image(label="Result", show_label=False) |
|
|
|
with gr.Accordion("Advanced Settings", open=False): |
|
negative_prompt = gr.Text( |
|
label="Negative prompt", |
|
max_lines=1, |
|
placeholder="Enter a negative prompt", |
|
visible=False, |
|
) |
|
|
|
seed = gr.Slider( |
|
label="Seed", |
|
minimum=0, |
|
maximum=MAX_SEED, |
|
step=1, |
|
value=0, |
|
) |
|
|
|
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) |
|
|
|
with gr.Row(): |
|
guidance_scale = gr.Slider( |
|
label="Guidance scale", |
|
minimum=0.0, |
|
maximum=10.0, |
|
step=0.1, |
|
value=7.0, |
|
) |
|
|
|
num_inference_steps = gr.Slider( |
|
label="Number of inference steps", |
|
minimum=1, |
|
maximum=50, |
|
step=1, |
|
value=50, |
|
) |
|
|
|
gr.Examples(examples=examples, inputs=[prompt1, prompt2]) |
|
gr.on( |
|
triggers=[run_button.click, prompt1.submit, prompt2.submit], |
|
fn=infer, |
|
inputs=[ |
|
prompt1, |
|
prompt2, |
|
negative_prompt, |
|
seed, |
|
randomize_seed, |
|
guidance_scale, |
|
num_inference_steps, |
|
], |
|
outputs=[result, seed], |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|