Simple Diffusion XS

XS Size, Excess Quality

At AiArtLab, we strive to create a free, compact (1.7b) and fast (3 sec/image) model that can be trained on consumer graphics cards.

  • We use U-Net for its high efficiency.
  • We have chosen the multilingual/multimodal encoder Mexma-SigLIP, which supports 80 languages.
  • We use the AuraDiffusion 16ch-VAE architecture, which preserves details and anatomy.
  • The model was trained (~1 month on 4xA5000) on approximately 1 million images with various resolutions and styles, including anime and realistic photos.

Model Limitations:

  • Limited concept coverage due to the small dataset.
  • The Image2Image functionality requires further training.

Acknowledgments

  • Stan — Key investor. Thank you for believing in us when others called it madness.
  • Captainsaturnus
  • Love. Death. Transformers.

Datasets

Training budget

Around ~$1k for now, but research budget ~$10k

Donations

Please contact with us if you may provide some GPU's or money on training

DOGE: DEw2DR8C7BnF8GgcrfTzUjSnGkuMeJhg83

BTC: 3JHv9Hb8kEW8zMAccdgCdZGfrHeMhH1rpN

Contacts

recoilme

Train status, in progress: wandb

result

Example

import torch
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
from transformers import AutoModel, AutoTokenizer
from PIL import Image
from tqdm.auto import tqdm
import os

def encode_prompt(prompt, negative_prompt, device, dtype):
    if negative_prompt is None:
        negative_prompt = ""

    with torch.no_grad():
        positive_inputs = tokenizer(
            prompt,
            return_tensors="pt",
            padding="max_length",
            max_length=512,
            truncation=True,
        ).to(device)
        positive_embeddings = text_model.encode_texts(
            positive_inputs.input_ids, positive_inputs.attention_mask
        )
        if positive_embeddings.ndim == 2:
            positive_embeddings = positive_embeddings.unsqueeze(1)
        positive_embeddings = positive_embeddings.to(device, dtype=dtype)
        
        negative_inputs = tokenizer(
            negative_prompt,
            return_tensors="pt",
            padding="max_length",
            max_length=512,
            truncation=True,
        ).to(device)
        negative_embeddings = text_model.encode_texts(negative_inputs.input_ids, negative_inputs.attention_mask)
        if negative_embeddings.ndim == 2:
            negative_embeddings = negative_embeddings.unsqueeze(1)
        negative_embeddings = negative_embeddings.to(device, dtype=dtype)
    return torch.cat([negative_embeddings, positive_embeddings], dim=0)

def generate_latents(embeddings, height=576, width=576, num_inference_steps=50, guidance_scale=5.5):
    with torch.no_grad():
        device, dtype = embeddings.device, embeddings.dtype
        half = embeddings.shape[0] // 2
        latent_shape = (half, 16, height // 8, width // 8)
        latents = torch.randn(latent_shape, device=device, dtype=dtype)
        embeddings = embeddings.repeat_interleave(half, dim=0)

        scheduler.set_timesteps(num_inference_steps)

        for t in tqdm(scheduler.timesteps, desc="Генерация"):
            latent_model_input = torch.cat([latents] * 2)
            latent_model_input = scheduler.scale_model_input(latent_model_input, t)
            noise_pred = unet(latent_model_input, t, embeddings).sample
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
            latents = scheduler.step(noise_pred, t, latents).prev_sample
    return latents


def decode_latents(latents, vae, output_type="pil"):
    latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
    with torch.no_grad():
        images = vae.decode(latents).sample
    images = (images / 2 + 0.5).clamp(0, 1)
    images = images.cpu().permute(0, 2, 3, 1).float().numpy()
    if output_type == "pil":
        images = (images * 255).round().astype("uint8")
        images = [Image.fromarray(image) for image in images]
    return images

# Example usage:
if __name__ == "__main__":
    device = "cuda"
    dtype = torch.float16

    prompt = "girl"
    negative_prompt = "bad quality"
    tokenizer = AutoTokenizer.from_pretrained("visheratin/mexma-siglip")
    text_model = AutoModel.from_pretrained(
        "visheratin/mexma-siglip", torch_dtype=dtype, trust_remote_code=True
    ).to(device, dtype=dtype).eval()
    
    embeddings = encode_prompt(prompt, negative_prompt, device, dtype)    

    pipeid = "AiArtLab/sdxs"
    variant = "fp16"
    
    unet = UNet2DConditionModel.from_pretrained(pipeid, subfolder="unet", variant=variant).to(device, dtype=dtype).eval()
    vae = AutoencoderKL.from_pretrained(pipeid, subfolder="vae", variant=variant).to(device, dtype=dtype).eval()
    scheduler = DDPMScheduler.from_pretrained(pipeid, subfolder="scheduler")


    height, width = 576, 384
    num_inference_steps = 40
    output_folder, project_name = "samples", "sdxs"
    latents = generate_latents(
        embeddings=embeddings,
        height=height,
        width=width,
        num_inference_steps = num_inference_steps
    )

    images = decode_latents(latents, vae)

    os.makedirs(output_folder, exist_ok=True)
    for idx, image in enumerate(images):
        image.save(f"{output_folder}/{project_name}_{idx}.jpg")

    print("Images generated and saved to:", output_folder)
Downloads last month
53
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support