MohamedLotfy commited on
Commit
45542e5
·
1 Parent(s): 365abf6

replace main with tashkeel model

Browse files
Files changed (1) hide show
  1. main.py +54 -4
main.py CHANGED
@@ -1,7 +1,57 @@
1
- from fastapi import FastAPI
 
 
 
 
 
 
 
 
2
 
3
  app = FastAPI()
4
 
5
- @app.get("/")
6
- def read_root():
7
- return {"Hello": "World!"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ #from shakkala import Shakkala
4
+ from pathlib import Path
5
+ import torch
6
+ from eo_pl import TashkeelModel as TashkeelModelEO
7
+ from ed_pl import TashkeelModel as TashkeelModelED
8
+ from tashkeel_tokenizer import TashkeelTokenizer
9
+ from utils import remove_non_arabic
10
 
11
  app = FastAPI()
12
 
13
+ # Load CaTT models
14
+ tokenizer = TashkeelTokenizer()
15
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
16
+ max_seq_len = 1024
17
+
18
+ eo_ckpt_path = Path(__file__).parent / 'models/best_eo_mlm_ns_epoch_193.pt'
19
+ ed_ckpt_path = Path(__file__).parent / 'models/best_ed_mlm_ns_epoch_178.pt'
20
+
21
+ # Load Encoder-Only model
22
+ eo_model = TashkeelModelEO(tokenizer, max_seq_len=max_seq_len,
23
+ n_layers=6, learnable_pos_emb=False)
24
+ eo_model.load_state_dict(torch.load(eo_ckpt_path, map_location=device))
25
+ eo_model.eval().to(device)
26
+
27
+ # Load Encoder-Decoder model
28
+ ed_model = TashkeelModelED(tokenizer, max_seq_len=max_seq_len,
29
+ n_layers=3, learnable_pos_emb=False)
30
+ ed_model.load_state_dict(torch.load(ed_ckpt_path, map_location=device))
31
+ ed_model.eval().to(device)
32
+
33
+
34
+ class CattRequest(BaseModel):
35
+ text: str
36
+ model_type: str # "Encoder-Only" or "Encoder-Decoder"
37
+
38
+
39
+
40
+
41
+
42
+ @app.post("/catt")
43
+ def infer_catt(request: CattRequest):
44
+ try:
45
+ input_text = remove_non_arabic(request.text)
46
+ batch_size = 16
47
+ verbose = True
48
+
49
+ if request.model_type == 'Encoder-Only':
50
+ output_text = eo_model.do_tashkeel_batch([input_text], batch_size, verbose)
51
+ else:
52
+ output_text = ed_model.do_tashkeel_batch([input_text], batch_size, verbose)
53
+
54
+ return {"result": output_text[0]}
55
+ except Exception as e:
56
+ raise HTTPException(status_code=500, detail=str(e))
57
+