henok3878 commited on
Commit
70e1f1d
·
1 Parent(s): a8bfa27

feature(utils): add utilities to support priming logic

Browse files
Files changed (1) hide show
  1. inference_utils.py +44 -3
inference_utils.py CHANGED
@@ -1,8 +1,17 @@
1
- from typing import Dict
2
- import numpy as np
 
 
3
 
4
  NULL_CHAR = '\x00'
5
 
 
 
 
 
 
 
 
6
  def construct_alphabet_list(alphabet_string: str) -> list[str]:
7
  if not isinstance(alphabet_string, str):
8
  raise TypeError("alphabet_string must be a string")
@@ -45,4 +54,36 @@ def convert_offsets_to_absolute_coords(stroke_offsets: list[list[float]]) -> lis
45
  strokes_array[:, 0] = np.cumsum(strokes_array[:, 0]) # cumulative dx
46
  strokes_array[:, 1] = np.cumsum(strokes_array[:, 1]) # cumulative dy
47
 
48
- return strokes_array.tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Dict, NamedTuple, Union
3
+ import numpy as np
4
+ import torch
5
 
6
  NULL_CHAR = '\x00'
7
 
8
+
9
+ class PrimingData(NamedTuple):
10
+ """combines data required for priming the HandwritingRNN sampling"""
11
+ stroke_tensors: torch.Tensor # (batch_size, num_prime_strokes, 3)
12
+ char_seq_tensors: torch.Tensor # (batch_size, num_prime_chars)
13
+ char_seq_lengths: torch.Tensor # (batch_size,)
14
+
15
  def construct_alphabet_list(alphabet_string: str) -> list[str]:
16
  if not isinstance(alphabet_string, str):
17
  raise TypeError("alphabet_string must be a string")
 
54
  strokes_array[:, 0] = np.cumsum(strokes_array[:, 0]) # cumulative dx
55
  strokes_array[:, 1] = np.cumsum(strokes_array[:, 1]) # cumulative dy
56
 
57
+ return strokes_array.tolist()
58
+
59
+
60
+ def load_np_strokes(stroke_path: Union[Path, str]) -> np.ndarray:
61
+ """loads stroke sequence from stroke_path"""
62
+ stroke_path = Path(stroke_path)
63
+ if not stroke_path.exists():
64
+ raise FileNotFoundError(f"style strokes file not found at {stroke_path}")
65
+
66
+ return np.load(stroke_path)
67
+
68
+ def load_text(text_path: Union[Path, str]) -> str:
69
+ """loads text from a text_path"""
70
+ text_path = Path(text_path)
71
+ if not text_path.exists():
72
+ raise FileNotFoundError(f"Text file not found at {text_path}")
73
+ if not text_path.is_file():
74
+ raise IsADirectoryError(f"Path is a directory, not a file.")
75
+
76
+ try:
77
+ with open(text_path, 'r', encoding='utf-8') as f:
78
+ content = f.read()
79
+ return content
80
+
81
+ except Exception as e:
82
+ raise IOError(f"Error reading text file {text_path}: {e}")
83
+
84
+ def load_priming_data(style: int):
85
+
86
+ priming_text = load_text(f"./data/samples/sample{style}.txt")
87
+ priming_strokes = load_np_strokes(f"./data/samples/sample{style}.npy")
88
+
89
+ return priming_text, priming_strokes