File size: 1,165 Bytes
bf3184f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from contextlib import asynccontextmanager
from fastapi import FastAPI
from router import router
from models.english_scoring import EnglishScoringModel
from models.indonesian_scoring import IndonesianScoringModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch


@asynccontextmanager
async def lifespan(app: FastAPI):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    app.state.english_scoring_model = EnglishScoringModel.load(
        model_path="./models/all-mpnet-base-v2", type="biencoder", device=device
    )
    app.state.english_scoring_model.eval()

    indonesian_tokenizer = AutoTokenizer.from_pretrained("./models/pseudolabel")
    indonesian_model = AutoModelForSequenceClassification.from_pretrained(
        "./models/pseudolabel"
    )
    app.state.indonesian_scoring_model = IndonesianScoringModel(
        model=indonesian_model, tokenizer=indonesian_tokenizer, device=device
    )
    app.state.indonesian_scoring_model.eval()

    yield

    del app.state.english_scoring_model, app.state.indonesian_scoring_model


app = FastAPI(lifespan=lifespan)
app.include_router(router)