rodrigomasini commited on
Commit
f3baeaf
·
verified ·
1 Parent(s): 2699f7d

Update mdr_pdf_parser.py

Browse files
Files changed (1) hide show
  1. mdr_pdf_parser.py +7 -8
mdr_pdf_parser.py CHANGED
@@ -1813,23 +1813,22 @@ class MDRLayoutReader:
1813
  # In class MDRLayoutReader:
1814
  def _get_model(self) -> LayoutLMv3ForTokenClassification | None:
1815
  if self._model is None:
1816
- cache = mdr_ensure_directory(self._model_path) # This should be self._model_path / "layoutreader"
1817
- # Correct cache path for transformers
1818
- layoutreader_cache_dir = Path(self._model_dir) / "layoutreader" # Assuming _model_dir is the main one
1819
- mdr_ensure_directory(str(layoutreader_cache_dir))
1820
 
1821
  name = "microsoft/layoutlmv3-base"
1822
 
1823
  print(f"MDRLayoutReader: Attempting to load LayoutLMv3 model '{name}'. Cache dir: {layoutreader_cache_dir}")
1824
  try:
1825
  self._model = LayoutLMv3ForTokenClassification.from_pretrained(
1826
- name, # Use the HF model name
1827
  cache_dir=str(layoutreader_cache_dir),
1828
- local_files_only=False, # Allow download on first run
1829
  num_labels=_MDR_MAX_LEN+1
1830
  )
1831
- # Explicitly move model to the determined device
1832
- self._model.to(torch.device(self._device)) # ENSURE THIS LINE IS PRESENT AND CORRECT
1833
  self._model.eval()
1834
  print(f"MDR LayoutReader model '{name}' loaded successfully on device: {self._model.device}.")
1835
  except Exception as e:
 
1813
  # In class MDRLayoutReader:
1814
  def _get_model(self) -> LayoutLMv3ForTokenClassification | None:
1815
  if self._model is None:
1816
+ # MODIFIED: Use self._model_path for the layoutreader's specific cache,
1817
+ # and ensure it's a directory. self._model_path is passed during MDRLayoutReader init.
1818
+ layoutreader_cache_dir = Path(self._model_path) # self._model_path is like "./mdr_models/layoutreader"
1819
+ mdr_ensure_directory(str(layoutreader_cache_dir)) # Ensure this specific directory exists
1820
 
1821
  name = "microsoft/layoutlmv3-base"
1822
 
1823
  print(f"MDRLayoutReader: Attempting to load LayoutLMv3 model '{name}'. Cache dir: {layoutreader_cache_dir}")
1824
  try:
1825
  self._model = LayoutLMv3ForTokenClassification.from_pretrained(
1826
+ name,
1827
  cache_dir=str(layoutreader_cache_dir),
1828
+ local_files_only=False,
1829
  num_labels=_MDR_MAX_LEN+1
1830
  )
1831
+ self._model.to(torch.device(self._device))
 
1832
  self._model.eval()
1833
  print(f"MDR LayoutReader model '{name}' loaded successfully on device: {self._model.device}.")
1834
  except Exception as e: