marvin90's picture
Update app.py
dd8b265 verified
# @spaces.GPU
import spaces
import warnings
warnings.filterwarnings("ignore", message="Can't initialize NVML")
import torch
import gradio as gr
from diffusers import StableDiffusionPipeline
from PIL import Image
import numpy as np
from sklearn.cluster import KMeans
# === Configuración base ===
BASE_MODEL = "runwayml/stable-diffusion-v1-5"
MODEL_CACHE = {}
DEFAULT_LORA = "marvin90/pixelartmaster-lite-faces"
# === Carga del modelo con ZeroGPU o CPU ===
@spaces.GPU
def load_model(repo_id):
"""Carga el modelo y LoRA solo una vez, usando GPU si está disponible."""
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
if repo_id not in MODEL_CACHE:
print(f"🌐 Cargando modelo en {device}: {repo_id}")
pipe = StableDiffusionPipeline.from_pretrained(
BASE_MODEL,
torch_dtype=dtype
).to(device)
# Intentar cargar LoRA
try:
pipe.load_lora_weights(repo_id)
print("✅ LoRA cargado correctamente")
except Exception as e:
print(f"⚠️ No se pudo cargar LoRA: {e}")
MODEL_CACHE[repo_id] = pipe
return MODEL_CACHE[repo_id]
# === Postprocesamiento con KMeans ===
def apply_kmeans_pil(image, k=12):
"""Reduce la cantidad de colores para estilo pixel art."""
pixels = np.array(image)
if pixels.ndim == 3 and pixels.shape[2] == 4:
pixels = pixels[:, :, :3]
shape = pixels.shape[:2]
flat_pixels = pixels.reshape(-1, 3)
kmeans = KMeans(n_clusters=k, n_init=10, random_state=42)
labels = kmeans.fit_predict(flat_pixels)
new_colors = np.clip(kmeans.cluster_centers_.astype("uint8"), 0, 255)
clustered_pixels = new_colors[labels].reshape(shape[0], shape[1], 3)
return Image.fromarray(clustered_pixels)
# === Generación principal ===
@spaces.GPU
def generar_pixelart(prompt):
"""Genera imagen pixel art con postprocesado."""
pipe = load_model(DEFAULT_LORA)
negative_prompt = (
"low quality, bad anatomy, bad hands, text, error, "
"missing fingers, extra digit, fewer digits, "
"worst quality, blurry, smudged, grainy, noise, "
"jpeg artifacts, signature, watermark, username"
)
# Generar imagen base
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=30,
guidance_scale=7.5
).images[0]
# Pixelado
pixelated = image.resize((128, 128), Image.NEAREST).resize((512, 512), Image.NEAREST)
# Aplicar KMeans
final_image = apply_kmeans_pil(pixelated, k=12)
return final_image
# === Interfaz Gradio ===
with gr.Blocks() as demo:
gr.Markdown("## 🎮 PixelArtMaster Lite Faces")
gr.Markdown(
"Generador ligero de retratos futuristas en pixel art desde texto, "
"con reducción de colores usando K-means y escalado tipo pixelado. "
"Modelo LoRA entrenado por [AmericaPixelGames](https://americapixelgames.com)."
)
prompt = gr.Textbox(label="Prompt de entrada")
output = gr.Image(label="Resultado estilo Pixel Art")
btn = gr.Button("🎨 Generar")
btn.click(fn=generar_pixelart, inputs=prompt, outputs=output)
gr.Markdown("---")
gr.Markdown("Desarrollado por **[AmericaPixelGames](https://americapixelgames.com)** 🎮")
if __name__ == "__main__":
demo.queue().launch(debug=True)