Uddipan Basu Bir commited on
Commit
0d4b0fc
·
1 Parent(s): 5a0deb7

Download checkpoint from HF hub in OcrReorderPipeline

Browse files
Files changed (1) hide show
  1. app.py +42 -0
app.py CHANGED
@@ -34,6 +34,48 @@ tokenizer = AutoTokenizer.from_pretrained(
34
  repo, subfolder="preprocessor"
35
  )
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  # Projection head: load from checkpoint
38
  ckpt_file = hf_hub_download(repo_id=repo, filename="pytorch_model.bin")
39
  ckpt = torch.load(ckpt_file, map_location="cpu")
 
34
  repo, subfolder="preprocessor"
35
  )
36
 
37
+ # Ensure decoder_start_token_id is set
38
+ if t5_model.config.decoder_start_token_id is None:
39
+ # Fallback to bos_token_id if present
40
+ t5_model.config.decoder_start_token_id = tokenizer.bos_token_id
41
+
42
+ # Projection head: load from checkpoint
43
+ ckpt_file = hf_hub_download(repo_id=repo, filename="pytorch_model.bin")
44
+ ckpt = torch.load(ckpt_file, map_location="cpu")
45
+ proj_state= ckpt["projection"]
46
+ projection = torch.nn.Sequential(
47
+ torch.nn.Linear(768, t5_model.config.d_model),
48
+ torch.nn.LayerNorm(t5_model.config.d_model),
49
+ torch.nn.GELU()
50
+ )
51
+ projection.load_state_dict(proj_state)
52
+ projection.eval()
53
+
54
+ # Move models to CPU (Spaces are CPU-only)
55
+ device = torch.device("cpu")
56
+ layout_model.to(device)
57
+ t5_model.to(device)
58
+ projection.to(device)
59
+ repo = "Uddipan107/ocr-layoutlmv3-base-t5-small"
60
+
61
+ # Processor for LayoutLMv3
62
+ processor = AutoProcessor.from_pretrained(
63
+ repo,
64
+ subfolder="preprocessor",
65
+ apply_ocr=False
66
+ )
67
+
68
+ # LayoutLMv3 encoder
69
+ layout_model = LayoutLMv3Model.from_pretrained(repo)
70
+ layout_model.eval()
71
+
72
+ # T5 decoder & tokenizer
73
+ t5_model = T5ForConditionalGeneration.from_pretrained(repo)
74
+ t5_model.eval()
75
+ tokenizer = AutoTokenizer.from_pretrained(
76
+ repo, subfolder="preprocessor"
77
+ )
78
+
79
  # Projection head: load from checkpoint
80
  ckpt_file = hf_hub_download(repo_id=repo, filename="pytorch_model.bin")
81
  ckpt = torch.load(ckpt_file, map_location="cpu")