from fastapi import FastAPI, HTTPException from pydantic import BaseModel #from shakkala import Shakkala from pathlib import Path import torch from eo_pl import TashkeelModel as TashkeelModelEO from ed_pl import TashkeelModel as TashkeelModelED from tashkeel_tokenizer import TashkeelTokenizer from utils import remove_non_arabic app = FastAPI() # Global variables to store loaded models sh_model = None sh_graph = None eo_model = None ed_model = None #class ShakkalaRequest(BaseModel): # text: str class CattRequest(BaseModel): text: str model_type: str # "Encoder-Only" or "Encoder-Decoder" @app.on_event("startup") def load_models(): global sh_model, sh_graph, eo_model, ed_model # Load Shakkala model # sh = Shakkala(version=3) # sh_model, sh_graph = sh.get_model() # Load CaTT models tokenizer = TashkeelTokenizer() device = 'cuda' if torch.cuda.is_available() else 'cpu' max_seq_len = 1024 eo_ckpt_path = Path(__file__).parent / 'models/best_eo_mlm_ns_epoch_193.pt' ed_ckpt_path = Path(__file__).parent / 'models/best_ed_mlm_ns_epoch_178.pt' # Load Encoder-Only model eo_model = TashkeelModelEO(tokenizer, max_seq_len=max_seq_len, n_layers=6, learnable_pos_emb=False) eo_model.load_state_dict(torch.load(eo_ckpt_path, map_location=device)) eo_model.eval().to(device) # Load Encoder-Decoder model ed_model = TashkeelModelED(tokenizer, max_seq_len=max_seq_len, n_layers=3, learnable_pos_emb=False) ed_model.load_state_dict(torch.load(ed_ckpt_path, map_location=device)) ed_model.eval().to(device) '''@app.post("/shakkala") def infer_shakkala(request: ShakkalaRequest): try: input_text = request.text sh = Shakkala(version=3) input_int = sh.prepare_input(input_text) logits = sh_model.predict(input_int)[0] predicted_harakat = sh.logits_to_text(logits) final_output = sh.get_final_text(input_text, predicted_harakat) return {"result": final_output} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) ''' @app.get("/test") def health_check(): return {"status": "ok"} @app.post("/catt") def infer_catt(request: CattRequest): try: input_text = remove_non_arabic(request.text) batch_size = 16 verbose = True if request.model_type == 'Encoder-Only': output_text = eo_model.do_tashkeel_batch([input_text], batch_size, verbose) else: output_text = ed_model.do_tashkeel_batch([input_text], batch_size, verbose) return {"result": output_text[0]} except Exception as e: raise HTTPException(status_code=500, detail=str(e))