try_catt / app.py
MohamedLotfy's picture
adding health check to endpoint
bed8371
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
#from shakkala import Shakkala
from pathlib import Path
import torch
from eo_pl import TashkeelModel as TashkeelModelEO
from ed_pl import TashkeelModel as TashkeelModelED
from tashkeel_tokenizer import TashkeelTokenizer
from utils import remove_non_arabic
app = FastAPI()
# Global variables to store loaded models
sh_model = None
sh_graph = None
eo_model = None
ed_model = None
#class ShakkalaRequest(BaseModel):
# text: str
class CattRequest(BaseModel):
text: str
model_type: str # "Encoder-Only" or "Encoder-Decoder"
@app.on_event("startup")
def load_models():
global sh_model, sh_graph, eo_model, ed_model
# Load Shakkala model
# sh = Shakkala(version=3)
# sh_model, sh_graph = sh.get_model()
# Load CaTT models
tokenizer = TashkeelTokenizer()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
max_seq_len = 1024
eo_ckpt_path = Path(__file__).parent / 'models/best_eo_mlm_ns_epoch_193.pt'
ed_ckpt_path = Path(__file__).parent / 'models/best_ed_mlm_ns_epoch_178.pt'
# Load Encoder-Only model
eo_model = TashkeelModelEO(tokenizer, max_seq_len=max_seq_len,
n_layers=6, learnable_pos_emb=False)
eo_model.load_state_dict(torch.load(eo_ckpt_path, map_location=device))
eo_model.eval().to(device)
# Load Encoder-Decoder model
ed_model = TashkeelModelED(tokenizer, max_seq_len=max_seq_len,
n_layers=3, learnable_pos_emb=False)
ed_model.load_state_dict(torch.load(ed_ckpt_path, map_location=device))
ed_model.eval().to(device)
'''@app.post("/shakkala")
def infer_shakkala(request: ShakkalaRequest):
try:
input_text = request.text
sh = Shakkala(version=3)
input_int = sh.prepare_input(input_text)
logits = sh_model.predict(input_int)[0]
predicted_harakat = sh.logits_to_text(logits)
final_output = sh.get_final_text(input_text, predicted_harakat)
return {"result": final_output}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
'''
@app.get("/test")
def health_check():
return {"status": "ok"}
@app.post("/catt")
def infer_catt(request: CattRequest):
try:
input_text = remove_non_arabic(request.text)
batch_size = 16
verbose = True
if request.model_type == 'Encoder-Only':
output_text = eo_model.do_tashkeel_batch([input_text], batch_size, verbose)
else:
output_text = ed_model.do_tashkeel_batch([input_text], batch_size, verbose)
return {"result": output_text[0]}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))