File size: 1,767 Bytes
c42d629
 
 
 
 
c7327d4
 
d646931
c7327d4
 
d646931
c7327d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import os
os.environ['TRANSFORMERS_CACHE'] = '/workspace/cache/hf'
os.environ['HF_HOME'] = '/workspace/cache/hf'
os.makedirs('/workspace/cache/hf', exist_ok=True)

from fastapi import FastAPI, UploadFile, File, Form
from fastapi.responses import StreamingResponse
from utils import generate_sticker
from io import BytesIO
from PIL import Image

app = FastAPI()

@app.post("/generate")
async def generate(image: UploadFile = File(...), prompt: str = Form(...)):
    # Read image file as PIL
    image_pil = Image.open(BytesIO(await image.read()))
    # Generate sticker
    result_img = generate_sticker(image_pil, prompt)
    # Save output image to a buffer
    buf = BytesIO()
    result_img.save(buf, format="PNG")
    buf.seek(0)
    return StreamingResponse(buf, media_type="image/png")

# If you want to run directly: uvicorn app:app --host 0.0.0.0 --port 8000



# import gradio as gr
# from utils import generate_sticker

# def predict(image, prompt):
#     result_img = generate_sticker(image, prompt)
#     return result_img  # Should be PIL Image or np.array or filepath

# with gr.Blocks() as demo:
#     gr.Markdown("# 🦄 AI Sticker Generator (Stable Diffusion + IP-Adapter)")
#     with gr.Row():
#         image_input = gr.Image(type="pil", label="Upload your photo")
#         prompt_input = gr.Textbox(
#             label="Prompt (style or mood for emoji)",
#             value="cartoon emoji, white outline, clean background",
#         )
#     output_image = gr.Image(label="Sticker Output")
#     run_btn = gr.Button("Generate Sticker")
#     run_btn.click(
#         predict,
#         inputs=[image_input, prompt_input],
#         outputs=output_image
#     )

# if __name__ == "__main__":
#     demo.launch(server_name="0.0.0.0", share=True)