from fastapi import FastAPI, HTTPException from pydantic import BaseModel #from shakkala import Shakkala from pathlib import Path import torch from eo_pl import TashkeelModel as TashkeelModelEO from ed_pl import TashkeelModel as TashkeelModelED from tashkeel_tokenizer import TashkeelTokenizer from utils import remove_non_arabic app = FastAPI() # Load CaTT models tokenizer = TashkeelTokenizer() device = 'cuda' if torch.cuda.is_available() else 'cpu' max_seq_len = 1024 eo_ckpt_path = Path(__file__).parent / 'models/best_eo_mlm_ns_epoch_193.pt' ed_ckpt_path = Path(__file__).parent / 'models/best_ed_mlm_ns_epoch_178.pt' # Load Encoder-Only model eo_model = TashkeelModelEO(tokenizer, max_seq_len=max_seq_len, n_layers=6, learnable_pos_emb=False) eo_model.load_state_dict(torch.load(eo_ckpt_path, map_location=device)) eo_model.eval().to(device) # Load Encoder-Decoder model ed_model = TashkeelModelED(tokenizer, max_seq_len=max_seq_len, n_layers=3, learnable_pos_emb=False) ed_model.load_state_dict(torch.load(ed_ckpt_path, map_location=device)) ed_model.eval().to(device) class CattRequest(BaseModel): text: str model_type: str # "Encoder-Only" or "Encoder-Decoder" @app.post("/catt") def infer_catt(request: CattRequest): try: input_text = remove_non_arabic(request.text) batch_size = 16 verbose = True if request.model_type == 'Encoder-Only': output_text = eo_model.do_tashkeel_batch([input_text], batch_size, verbose) else: output_text = ed_model.do_tashkeel_batch([input_text], batch_size, verbose) return {"result": output_text[0]} except Exception as e: raise HTTPException(status_code=500, detail=str(e))