File size: 2,789 Bytes
5e4062b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import torch
import gradio as gr
from diffusers import StableDiffusionXLPipeline
from PIL import Image
from io import BytesIO
import os
import requests
import time
from tqdm import tqdm
# Load local Stable Diffusion XL model
model_path = "networks/TShirtDesignRedmondV2-Tshirtdesign-TshirtDesignAF.safetensors"
pipe = StableDiffusionXLPipeline.from_single_file(
model_path,
torch_dtype=torch.float16,
use_safetensors=True,
)
pipe = pipe.to("cuda")
repo = "artificialguybr/TshirtDesignRedmond-V2"
api_url = f"https://api-inference.huggingface.co/models/{repo}"
def infer(color_prompt, dress_type_prompt, design_prompt, text):
prompt = (
f"A single {color_prompt} colored {dress_type_prompt} featuring a bold {design_prompt} design printed on the {dress_type_prompt},"
" hanging on a plain wall. The soft light and shadows create a striking contrast against the minimal background, evoking modern sophistication."
)
print("Generating image locally with prompt:", prompt)
try:
image = pipe(prompt).images[0]
return image
except Exception as e:
print("Local generation failed. Switching to API.", str(e))
# API fallback
headers = {}
payload = {
"inputs": prompt,
"parameters": {
"negative_prompt": "(worst quality, low quality, lowres, bad details, watermark, text, blurry, cartoon, 3D, bad anatomy, outdated fashion, cheap look, unreal details, unwanted features)",
"num_inference_steps": 30,
"scheduler": "DPMSolverMultistepScheduler"
},
}
error_count = 0
pbar = tqdm(total=None, desc="Loading model")
while True:
response = requests.post(api_url, headers=headers, json=payload)
if response.status_code == 200:
return Image.open(BytesIO(response.content))
elif response.status_code == 503:
time.sleep(1)
pbar.update(1)
elif response.status_code == 500 and error_count < 5:
time.sleep(1)
error_count += 1
else:
raise Exception(f"API Error: {response.status_code}")
# Gradio Interface
iface = gr.Interface(
fn=infer,
inputs=[
gr.Textbox(lines=1, placeholder="Color"),
gr.Textbox(lines=1, placeholder="Dress Type"),
gr.Textbox(lines=2, placeholder="Design"),
gr.Textbox(lines=1, placeholder="Text (Optional)")
],
outputs="image",
title="AI-Generated T-Shirt Designs",
description="Generate custom t-shirt designs using AI!",
examples=[["Red", "T-shirt", "Minimalistic logo", "Brand Name"]]
)
print("Launching Gradio interface...")
iface.launch() |