File size: 2,030 Bytes
72c3233
 
 
c6513d5
72c3233
 
d646931
 
 
 
72c3233
d646931
feabe43
 
f782ae9
d646931
72c3233
d646931
 
 
72c3233
 
 
 
 
779803e
72c3233
 
feabe43
f782ae9
 
 
feabe43
 
 
d646931
 
 
 
 
72c3233
 
 
 
 
 
d646931
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61

import os
# Set Hugging Face cache dir to a safe writable location (works in Spaces & Docker)
os.environ["HF_HOME"] = "/workspace/.cache/huggingface"
os.makedirs("/workspace/.cache/huggingface", exist_ok=True)

import torch
from diffusers import StableDiffusionImg2ImgPipeline
from PIL import Image


# --- Place any download or path setup here ---
MODEL_ID ="runwayml/stable-diffusion-v1-5"  # Can swap for custom path if using IP-Adapter
ADAPTER_PATH = "/workspace/.cache/huggingface/ip_adapter/ip-adapter_sd15.bin"
ADAPTER_DIR = "/workspace/.cache/huggingface/ip_adapter"
DEVICE = "cpu"
MODEL_CACHE = "/workspace/.cache/huggingface"

# (Optional) Download IP-Adapter weights and patch pipeline if desired

# Load the model ONCE at startup, not per request!
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float32,
    cache_dir=MODEL_CACHE,
    # safety_checker=None,  # Disable for demo/testing; enable in prod
).to(DEVICE)

pipe.load_ip_adapter(
    pretrained_model_name_or_path=ADAPTER_DIR,
    subfolder=".",  # The weights file is directly in ADAPTER_DIR
    weight_name="ip-adapter_sd15.bin"
    # Optionally: subfolder="models" if using the repo, not a direct path
)

def generate_sticker(input_image, prompt):
    """
    Given a user image and a prompt, generates a sticker/emoji-style portrait.
    """
    # Load the model (download if not present)
    # pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
    #     MODEL_ID,
    #     torch_dtype=torch.float32,
    #     cache_dir=MODEL_CACHE,
    #     safety_checker=None,  # Disable for demo/testing
    # ).to(DEVICE)

    # Preprocess the image (resize, etc)
    init_image = input_image.convert("RGB").resize((512, 512))

    # Run inference (low strength for identity preservation)
    result = pipe(
        prompt=prompt,
        image=init_image,
        strength=0.65,
        guidance_scale=7.5,
        num_inference_steps=30
    )
    # Return the generated image (as PIL)
    return result.images[0]