File size: 2,802 Bytes
08abb06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bed8371
 
 
 
 
08abb06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
#from shakkala import Shakkala
from pathlib import Path
import torch
from eo_pl import TashkeelModel as TashkeelModelEO
from ed_pl import TashkeelModel as TashkeelModelED
from tashkeel_tokenizer import TashkeelTokenizer
from utils import remove_non_arabic

app = FastAPI()

# Global variables to store loaded models
sh_model = None
sh_graph = None
eo_model = None
ed_model = None

#class ShakkalaRequest(BaseModel):
#    text: str

class CattRequest(BaseModel):
    text: str
    model_type: str  # "Encoder-Only" or "Encoder-Decoder"

@app.on_event("startup")
def load_models():
    global sh_model, sh_graph, eo_model, ed_model
    
    # Load Shakkala model
 #   sh = Shakkala(version=3)
 #   sh_model, sh_graph = sh.get_model()
    
    # Load CaTT models
    tokenizer = TashkeelTokenizer()
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    max_seq_len = 1024
    
    eo_ckpt_path = Path(__file__).parent / 'models/best_eo_mlm_ns_epoch_193.pt'
    ed_ckpt_path = Path(__file__).parent / 'models/best_ed_mlm_ns_epoch_178.pt'
    
    # Load Encoder-Only model
    eo_model = TashkeelModelEO(tokenizer, max_seq_len=max_seq_len, 
                              n_layers=6, learnable_pos_emb=False)
    eo_model.load_state_dict(torch.load(eo_ckpt_path, map_location=device))
    eo_model.eval().to(device)
    
    # Load Encoder-Decoder model
    ed_model = TashkeelModelED(tokenizer, max_seq_len=max_seq_len,
                              n_layers=3, learnable_pos_emb=False)
    ed_model.load_state_dict(torch.load(ed_ckpt_path, map_location=device))
    ed_model.eval().to(device)

'''@app.post("/shakkala")
def infer_shakkala(request: ShakkalaRequest):
    try:
        input_text = request.text
        sh = Shakkala(version=3)
        input_int = sh.prepare_input(input_text)
        logits = sh_model.predict(input_int)[0]
        predicted_harakat = sh.logits_to_text(logits)
        final_output = sh.get_final_text(input_text, predicted_harakat)
        return {"result": final_output}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))
'''

@app.get("/test")
def health_check():
    return {"status": "ok"}

@app.post("/catt")
def infer_catt(request: CattRequest):
    try:
        input_text = remove_non_arabic(request.text)
        batch_size = 16
        verbose = True
        
        if request.model_type == 'Encoder-Only':
            output_text = eo_model.do_tashkeel_batch([input_text], batch_size, verbose)
        else:
            output_text = ed_model.do_tashkeel_batch([input_text], batch_size, verbose)
            
        return {"result": output_text[0]}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))