File size: 2,789 Bytes
5e4062b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
import gradio as gr
from diffusers import StableDiffusionXLPipeline
from PIL import Image
from io import BytesIO
import os
import requests
import time
from tqdm import tqdm

# Load local Stable Diffusion XL model
model_path = "networks/TShirtDesignRedmondV2-Tshirtdesign-TshirtDesignAF.safetensors"
pipe = StableDiffusionXLPipeline.from_single_file(
    model_path,
    torch_dtype=torch.float16, 
    use_safetensors=True,
)
pipe = pipe.to("cuda")

repo = "artificialguybr/TshirtDesignRedmond-V2"
api_url = f"https://api-inference.huggingface.co/models/{repo}"

def infer(color_prompt, dress_type_prompt, design_prompt, text):
    prompt = (
        f"A single {color_prompt} colored {dress_type_prompt} featuring a bold {design_prompt} design printed on the {dress_type_prompt},"
        " hanging on a plain wall. The soft light and shadows create a striking contrast against the minimal background, evoking modern sophistication."
    )
    
    print("Generating image locally with prompt:", prompt)
    try:
        image = pipe(prompt).images[0]
        return image
    except Exception as e:
        print("Local generation failed. Switching to API.", str(e))
        
        # API fallback
        headers = {}
        payload = {
            "inputs": prompt,
            "parameters": {
                "negative_prompt": "(worst quality, low quality, lowres, bad details, watermark, text, blurry, cartoon, 3D, bad anatomy, outdated fashion, cheap look, unreal details, unwanted features)",
                "num_inference_steps": 30,
                "scheduler": "DPMSolverMultistepScheduler"
            },
        }

        error_count = 0
        pbar = tqdm(total=None, desc="Loading model")
        while True:
            response = requests.post(api_url, headers=headers, json=payload)
            if response.status_code == 200:
                return Image.open(BytesIO(response.content))
            elif response.status_code == 503:
                time.sleep(1)
                pbar.update(1)
            elif response.status_code == 500 and error_count < 5:
                time.sleep(1)
                error_count += 1
            else:
                raise Exception(f"API Error: {response.status_code}")

# Gradio Interface
iface = gr.Interface(
    fn=infer,
    inputs=[
        gr.Textbox(lines=1, placeholder="Color"),
        gr.Textbox(lines=1, placeholder="Dress Type"),
        gr.Textbox(lines=2, placeholder="Design"),
        gr.Textbox(lines=1, placeholder="Text (Optional)")
    ],
    outputs="image",
    title="AI-Generated T-Shirt Designs",
    description="Generate custom t-shirt designs using AI!",
    examples=[["Red", "T-shirt", "Minimalistic logo", "Brand Name"]]
)

print("Launching Gradio interface...")
iface.launch()