File size: 2,947 Bytes
54e8ea7
 
 
6fb7a18
54e8ea7
 
 
4da1bec
 
54e8ea7
 
 
 
 
6fb7a18
 
54e8ea7
 
 
 
 
 
 
 
 
 
 
4da1bec
 
 
 
 
 
 
 
af969a2
4da1bec
54e8ea7
af969a2
54e8ea7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e60bad
 
54e8ea7
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import io  # 1. For byte stream handling (file uploads)
from fastapi import FastAPI, File, UploadFile, Form  # 2. FastAPI imports for API endpoints and file handling
from fastapi.responses import JSONResponse  # 3. Used to return errors as JSON
# from transformers import BlipProcessor, BlipForConditionalGeneration  # 4. BLIP for image captioning
from PIL import Image  # 5. Pillow for image processing
import openai  # 6. OpenAI library for DALL路E API calls
import os  # 7. OS for environment variables
from face_to_prompt import extract_face_prompt


# 8. Create the FastAPI app
app = FastAPI()

# 9. Load BLIP processor and model at startup to avoid reloading on every request
# processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
# model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

# 10. Get the OpenAI API key from environment variable
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")

# 11. Define the /generate POST endpoint
@app.post("/generate")
async def generate(
    image: UploadFile = File(...),      # 12. The uploaded image file
    style: str = Form("chibi"),         # 13. The desired style (chibi/anime/cartoon), defaults to "chibi"
):
    # 14. Load and convert the uploaded image to RGB
    # img_bytes = await image.read()
    # img = Image.open(io.BytesIO(img_bytes)).convert("RGB")

    # # 15. Caption the image using BLIP
    # inputs = processor(img, return_tensors="pt")
    # out = model.generate(**inputs)
    # caption = processor.decode(out[0], skip_special_tokens=True)

    with open("/tmp/temp_input.jpg", "wb") as f:
        f.write(await image.read())

    caption = extract_face_prompt("/tmp/temp_input.jpg")

    # 16. Construct the DALL路E prompt using the style and the caption
    prompt = (
        f"A set of twelve {style}-style digital stickers of {caption}, "
        "each with a different expression: laughing, angry, crying, sulking, thinking, sleepy, blowing a kiss, winking, surprised, happy, sad, and confused. "
        "Each sticker has a bold black outline and a transparent background, in a playful, close-up cartoon style."
    )

    # 17. Set the OpenAI API key
    openai.api_key = OPENAI_API_KEY
    try:
        # 18. Call DALL路E 3 to generate the image
        response = openai.images.generate(
            model="dall-e-3",
            prompt=prompt,
            n=1,
            size="1024x1024"
        )
        image_url = response.data[0].url  # 19. Get the image URL from the response
    except Exception as e:
        import traceback
        print("Error in /generate:", traceback.format_exc())
        # 20. Return a JSON error message if the API call fails
        return JSONResponse(content={"error": str(e)}, status_code=500)

    # 21. Return the BLIP caption, the constructed prompt, and the generated image URL
    return {"caption": caption, "prompt": prompt, "image_url": image_url}