File size: 1,803 Bytes
2b3d156
 
 
7b2eca8
2b3d156
 
 
84f04ee
405db35
2b3d156
7b2eca8
 
 
2b3d156
 
 
 
 
623450c
5f629ed
31cc6ff
 
 
 
 
 
5f629ed
54c9e50
 
 
 
2b3d156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405db35
2b3d156
 
405db35
2b3d156
 
 
 
 
5f629ed
54c9e50
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# app/main.py

from fastapi import FastAPI, UploadFile, File
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
import shutil
from pathlib import Path
import uvicorn
import os

from vit_captioning.generate import CaptionGenerator

app = FastAPI()

# Serve static files
static_dir = Path(__file__).parent / "vit_captioning" / "static"
app.mount("/static", StaticFiles(directory=static_dir), name="static")

#Landing page at `/`
@app.get("/", response_class=HTMLResponse)
async def landing():
    return Path("vit_captioning/static/landing.html").read_text()

# @app.get("/", response_class=HTMLResponse)
# def root():
#     return "<h3>βœ… Hugging Face Space is alive</h3>"

@app.get("/health")
def health_check():
    return {"status": "ok"}

# βœ… Captioning page at `/captioning`
@app.get("/captioning", response_class=HTMLResponse)
async def captioning():
    return Path("vit_captioning/static/captioning/index.html").read_text()

# βœ… Example: Project 2 placeholder
@app.get("/project2", response_class=HTMLResponse)
async def project2():
    return "<h1>Coming Soon: Project 2</h1>"

# βœ… Caption generation endpoint for captioning app
# Keep the path consistent with your JS fetch()!
caption_generator = CaptionGenerator(
    model_type="CLIPEncoder",
    checkpoint_path="./vit_captioning/artifacts/CLIPEncoder_40epochs_unfreeze12.pth",
    quantized=False
)

@app.post("/generate")
@app.post("/generate")
async def generate(file: UploadFile = File(...)):
    temp_file = os.path.join("/tmp", file.filename)
    with open(temp_file, "wb") as buffer:
        shutil.copyfileobj(file.file, buffer)

    captions = caption_generator.generate_caption(temp_file)
    return captions

# if __name__ == "__main__":
#     uvicorn.run(app, host="0.0.0.0", port=8000)