File size: 3,616 Bytes
72c3233
fcb099b
 
 
5e6959c
72c3233
d646931
 
41ddd36
d646931
 
72c3233
41ddd36
 
 
 
 
 
 
 
 
 
 
 
 
d646931
 
 
72c3233
 
 
 
41ddd36
779803e
72c3233
 
feabe43
41ddd36
5a59f40
41ddd36
feabe43
 
5e6959c
 
 
 
 
 
8ecbb74
5e6959c
 
 
 
 
5a59f40
1172e3f
d646931
 
 
 
72c3233
 
 
 
 
 
d646931
 
9ebba49
5e6959c
 
 
41ddd36
 
d646931
 
5e6959c
9187b51
5e6959c
dd3fef9
9ebba49
 
 
d646931
 
 
 
9187b51
620d149
9ebba49
 
 
d646931
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
os.environ["HF_HOME"] = "/data/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/data/huggingface"
os.makedirs("/data/huggingface/hub", exist_ok=True)
# os.makedirs("/data/huggingface/clip_vision_model", exist_ok=True)

import torch
from diffusers import StableDiffusionImg2ImgPipeline
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from PIL import Image


# --- Place any download or path setup here --- old 
# MODEL_ID ="runwayml/stable-diffusion-v1-5"  # Can swap for custom path if using IP-Adapter
# ADAPTER_PATH = "/workspace/.cache/huggingface/ip_adapter/ip-adapter_sd15.bin"
# ADAPTER_DIR = "/workspace/.cache/huggingface/ip_adapter"
# DEVICE = "cpu"
# MODEL_CACHE = "/workspace/.cache/huggingface"

# ---- SETTINGS ----
MODEL_ID = "runwayml/stable-diffusion-v1-5"
IPADAPTER_REPO = "h94/IP-Adapter"
IPADAPTER_WEIGHT_NAME = "ip-adapter_sd15.bin"
DEVICE = "cpu"  # Change to "cuda" if you have GPU
CACHE_DIR = os.environ.get("HF_HOME", "/data/huggingface")

# (Optional) Download IP-Adapter weights and patch pipeline if desired

# Load the model ONCE at startup, not per request!
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float32,
    cache_dir=CACHE_DIR,
    # safety_checker=None,  # Disable for demo/testing; enable in prod
).to(DEVICE)

pipe.load_ip_adapter(
    pretrained_model_name_or_path_or_dict=IPADAPTER_REPO,
    subfolder="models",
    weight_name=IPADAPTER_WEIGHT_NAME
)

# # Load vision encoder and processor for IP-Adapter embedding
# vision_encoder = CLIPVisionModelWithProjection.from_pretrained(
#     "h94/IP-Adapter",             # repo_id (main IP-Adapter repo)
#     subfolder="clip_vision_model",# subfolder within the repo!
#     cache_dir=CACHE_DIR
# )

# image_processor = CLIPImageProcessor.from_pretrained(
#     "h94/IP-Adapter",
#     subfolder="clip_vision_model",
#     cache_dir=CACHE_DIR
# )

def generate_sticker(input_image: Image.Image, style: str ):
    """
    Given a user image and a prompt, generates a sticker/emoji-style portrait.
    """
    # Load the model (download if not present)
    # pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
    #     MODEL_ID,
    #     torch_dtype=torch.float32,
    #     cache_dir=MODEL_CACHE,
    #     safety_checker=None,  # Disable for demo/testing
    # ).to(DEVICE)

    # Preprocess the image (resize, etc)
    face_img = input_image.convert("RGB").resize((512, 512))
    # inputs = image_processor(images=face_img, return_tensors="pt").to(DEVICE)
    # with torch.no_grad():
    #     image_embeds = vision_encoder(**inputs).image_embeds

    # 2. Prepare image for SD pipeline
    init_image = input_image.convert("RGB").resize((512, 512))

    # IP-Adapter expects the reference image via image_embeds, which is produced by this function:
    # image_embeds = pipe.prepare_ip_adapter_image_embeds(face_img)

    prompt = (f"A set of twelve {style}-style digital stickers"
               "each with a different expression: laughing, angry, crying, sulking, thinking, sleepy, blowing a kiss, winking, surprised, happy, sad, and confused. "
                "Each sticker has a bold black outline and a transparent background, in a playful, close-up cartoon style."
    )
    # Run inference (low strength for identity preservation)
    result = pipe(
        prompt=prompt,
        image=init_image,
        # image_embeds=image_embeds,
        ip_adapter_image=face_img,
        strength=0.6,
        guidance_scale=8,
        num_inference_steps=40
    )
    # Return the generated image (as PIL)
    return result.images[0]