Text2img / app.py
gaur3009's picture
Create app.py
5e4062b verified
raw
history blame
2.79 kB
import torch
import gradio as gr
from diffusers import StableDiffusionXLPipeline
from PIL import Image
from io import BytesIO
import os
import requests
import time
from tqdm import tqdm
# Load local Stable Diffusion XL model
model_path = "networks/TShirtDesignRedmondV2-Tshirtdesign-TshirtDesignAF.safetensors"
pipe = StableDiffusionXLPipeline.from_single_file(
model_path,
torch_dtype=torch.float16,
use_safetensors=True,
)
pipe = pipe.to("cuda")
repo = "artificialguybr/TshirtDesignRedmond-V2"
api_url = f"https://api-inference.huggingface.co/models/{repo}"
def infer(color_prompt, dress_type_prompt, design_prompt, text):
prompt = (
f"A single {color_prompt} colored {dress_type_prompt} featuring a bold {design_prompt} design printed on the {dress_type_prompt},"
" hanging on a plain wall. The soft light and shadows create a striking contrast against the minimal background, evoking modern sophistication."
)
print("Generating image locally with prompt:", prompt)
try:
image = pipe(prompt).images[0]
return image
except Exception as e:
print("Local generation failed. Switching to API.", str(e))
# API fallback
headers = {}
payload = {
"inputs": prompt,
"parameters": {
"negative_prompt": "(worst quality, low quality, lowres, bad details, watermark, text, blurry, cartoon, 3D, bad anatomy, outdated fashion, cheap look, unreal details, unwanted features)",
"num_inference_steps": 30,
"scheduler": "DPMSolverMultistepScheduler"
},
}
error_count = 0
pbar = tqdm(total=None, desc="Loading model")
while True:
response = requests.post(api_url, headers=headers, json=payload)
if response.status_code == 200:
return Image.open(BytesIO(response.content))
elif response.status_code == 503:
time.sleep(1)
pbar.update(1)
elif response.status_code == 500 and error_count < 5:
time.sleep(1)
error_count += 1
else:
raise Exception(f"API Error: {response.status_code}")
# Gradio Interface
iface = gr.Interface(
fn=infer,
inputs=[
gr.Textbox(lines=1, placeholder="Color"),
gr.Textbox(lines=1, placeholder="Dress Type"),
gr.Textbox(lines=2, placeholder="Design"),
gr.Textbox(lines=1, placeholder="Text (Optional)")
],
outputs="image",
title="AI-Generated T-Shirt Designs",
description="Generate custom t-shirt designs using AI!",
examples=[["Red", "T-shirt", "Minimalistic logo", "Brand Name"]]
)
print("Launching Gradio interface...")
iface.launch()