File size: 1,995 Bytes
3b3a783
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import io
import os
from typing import Optional

import torch
import torchvision.transforms as transforms
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.responses import StreamingResponse
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel

from tld.diffusion import DiffusionTransformer, LTDConfig

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
to_pil = transforms.ToPILImage()

ltdconfig = LTDConfig()
diffusion_transformer = DiffusionTransformer(ltdconfig)

app = FastAPI()

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")


def validate_token(token: str = Depends(oauth2_scheme)):
    if token != os.getenv("API_TOKEN"):
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Invalid authentication credentials",
            headers={"WWW-Authenticate": "Bearer"},
        )


class ImageRequest(BaseModel):
    prompt: str
    class_guidance: Optional[int] = 6
    seed: Optional[int] = 11
    num_imgs: Optional[int] = 1
    img_size: Optional[int] = 32


@app.get("/")
def read_root():
    return {"message": "Welcome to Image Generator"}


@app.post("/generate-image/")
async def generate_image(request: ImageRequest, token: str = Depends(validate_token)):
    try:
        img = diffusion_transformer.generate_image_from_text(
            prompt=request.prompt,
            class_guidance=request.class_guidance,
            seed=request.seed,
            num_imgs=request.num_imgs,
            img_size=request.img_size,
        )
        # Convert PIL image to byte stream suitable for HTTP response
        img_byte_arr = io.BytesIO()
        img.save(img_byte_arr, format="JPEG")
        img_byte_arr.seek(0)

        return StreamingResponse(img_byte_arr, media_type="image/jpeg")
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


# build job to test and deploy the API on a docker image (maybe in Azure?)