from contextlib import asynccontextmanager from fastapi import FastAPI from router import router from models.english_scoring import EnglishScoringModel from models.indonesian_scoring import IndonesianScoringModel from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch @asynccontextmanager async def lifespan(app: FastAPI): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") app.state.english_scoring_model = EnglishScoringModel.load( model_path="./models/all-mpnet-base-v2", type="biencoder", device=device ) app.state.english_scoring_model.eval() indonesian_tokenizer = AutoTokenizer.from_pretrained("./models/pseudolabel") indonesian_model = AutoModelForSequenceClassification.from_pretrained( "./models/pseudolabel" ) app.state.indonesian_scoring_model = IndonesianScoringModel( model=indonesian_model, tokenizer=indonesian_tokenizer, device=device ) app.state.indonesian_scoring_model.eval() yield del app.state.english_scoring_model, app.state.indonesian_scoring_model app = FastAPI(lifespan=lifespan) app.include_router(router)