|
from fastapi import FastAPI, HTTPException |
|
from pydantic import BaseModel |
|
|
|
from pathlib import Path |
|
import torch |
|
from eo_pl import TashkeelModel as TashkeelModelEO |
|
from ed_pl import TashkeelModel as TashkeelModelED |
|
from tashkeel_tokenizer import TashkeelTokenizer |
|
from utils import remove_non_arabic |
|
|
|
app = FastAPI() |
|
|
|
|
|
tokenizer = TashkeelTokenizer() |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
max_seq_len = 1024 |
|
|
|
eo_ckpt_path = Path(__file__).parent / 'models/best_eo_mlm_ns_epoch_193.pt' |
|
ed_ckpt_path = Path(__file__).parent / 'models/best_ed_mlm_ns_epoch_178.pt' |
|
|
|
|
|
eo_model = TashkeelModelEO(tokenizer, max_seq_len=max_seq_len, |
|
n_layers=6, learnable_pos_emb=False) |
|
eo_model.load_state_dict(torch.load(eo_ckpt_path, map_location=device)) |
|
eo_model.eval().to(device) |
|
|
|
|
|
ed_model = TashkeelModelED(tokenizer, max_seq_len=max_seq_len, |
|
n_layers=3, learnable_pos_emb=False) |
|
ed_model.load_state_dict(torch.load(ed_ckpt_path, map_location=device)) |
|
ed_model.eval().to(device) |
|
|
|
|
|
class CattRequest(BaseModel): |
|
text: str |
|
model_type: str |
|
|
|
|
|
|
|
|
|
|
|
@app.post("/catt") |
|
def infer_catt(request: CattRequest): |
|
try: |
|
input_text = remove_non_arabic(request.text) |
|
batch_size = 16 |
|
verbose = True |
|
|
|
if request.model_type == 'Encoder-Only': |
|
output_text = eo_model.do_tashkeel_batch([input_text], batch_size, verbose) |
|
else: |
|
output_text = ed_model.do_tashkeel_batch([input_text], batch_size, verbose) |
|
|
|
return {"result": output_text[0]} |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|