henok3878 commited on
Commit
a55bf24
·
1 Parent(s): 70e1f1d

refactor: update inference to use priming by default

Browse files
Files changed (1) hide show
  1. main.py +66 -46
main.py CHANGED
@@ -8,7 +8,7 @@ from pathlib import Path
8
  import logging
9
  import time
10
  from contextlib import asynccontextmanager
11
- from inference_utils import construct_alphabet_list, convert_offsets_to_absolute_coords, encode_text, get_alphabet_map
12
 
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
@@ -18,13 +18,13 @@ QUANTIZED_MODEL_NAME = "model.scripted.quantized.pt"
18
  SCRIPTED_MODEL_NAME = "model.scripted.pt"
19
  METADATA_MODEL_NAME = "model.pt"
20
 
21
- scripted_model: Optional[torch.jit.ScriptModule] = None
22
- model_metadata: Optional[dict] = None
23
- device: Optional[torch.device] = None
24
- alphabet_map: Optional[dict[str, int]] = None
25
  ALPHABET_LIST: Optional[list[str]] = None
26
  ALPHABET_SIZE: Optional[int] = None
27
- max_text_len: Optional[int] = None
28
  output_mixture_components: Optional[int] = None # To store num_mixtures for GMM sampling
29
  lstm_size: Optional[int] = None
30
  attention_mixture_components: Optional[int] = None
@@ -55,15 +55,15 @@ class HealthResponse(BaseModel):
55
  @asynccontextmanager
56
  async def lifespan(app: FastAPI):
57
  """Lifespan context manager for startup and shutdown events"""
58
- global scripted_model, model_metadata, device, alphabet_map, max_text_len, ALPHABET_LIST, output_mixture_components, lstm_size, attention_mixture_components, ALPHABET_SIZE
59
  logger.info("Attempting to load model resources during startup")
60
  try:
61
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
62
- logger.info(f"Using device: {device}")
63
 
64
  scripted_model_path = MODEL_DIR / SCRIPTED_MODEL_NAME
65
  metadata_model_path = MODEL_DIR / METADATA_MODEL_NAME
66
- if device.type == "cpu":
67
  scripted_model_path = MODEL_DIR / QUANTIZED_MODEL_NAME
68
 
69
  if not scripted_model_path.exists():
@@ -74,18 +74,18 @@ async def lifespan(app: FastAPI):
74
  raise FileNotFoundError(f"Metadata model file not found at {metadata_model_path}")
75
 
76
  # Load the traced model
77
- scripted_model = torch.jit.load(scripted_model_path, map_location=device)
78
- if scripted_model:
79
- scripted_model.eval()
80
  logger.info(f"Traced model loaded successfully from {scripted_model_path}")
81
 
82
  # Load the metadata
83
- model_metadata = torch.load(metadata_model_path, map_location='cpu')
84
- if model_metadata:
85
  logger.info(f"Model metadata loaded successfully from {metadata_model_path}")
86
- logger.info(f"Model metadata keys: {list(model_metadata.keys())}")
87
 
88
- config_full = model_metadata['config_full']
89
  if not config_full or not isinstance(config_full, dict):
90
  raise ValueError(f"Key `config_full` not found or not a dict")
91
 
@@ -95,7 +95,7 @@ async def lifespan(app: FastAPI):
95
  if not dataset_config or not isinstance(dataset_config, dict):
96
  raise ValueError(f"Key `dataset` not found or not a dict in config_full")
97
  alphabet_str = dataset_config['alphabet_string']
98
- max_text_len = dataset_config['max_text_len']
99
  output_mixture_components = model_params['output_mixture_components']
100
 
101
  lstm_size = model_params['lstm_size']
@@ -103,7 +103,7 @@ async def lifespan(app: FastAPI):
103
 
104
  ALPHABET_LIST = construct_alphabet_list(alphabet_str)
105
  ALPHABET_SIZE = len(ALPHABET_LIST)
106
- alphabet_map = get_alphabet_map(ALPHABET_LIST)
107
 
108
  logger.info(f"Alphabet created. Size: {len(ALPHABET_LIST)}")
109
  logger.info("Model resources are loaded and ready")
@@ -112,16 +112,16 @@ async def lifespan(app: FastAPI):
112
 
113
  except Exception as e:
114
  logger.error(f"Error loading model resources: {e}", exc_info=True)
115
- scripted_model = None
116
- model_metadata = None
117
  raise
118
 
119
  yield
120
 
121
  # Cleanup on shutdown
122
  logger.info("Shutting down API and cleaning up resources")
123
- scripted_model = None
124
- model_metadata = None
125
 
126
  app = FastAPI(
127
  title="Scriptify API",
@@ -145,32 +145,31 @@ async def read_root():
145
 
146
  @app.get("/health", response_model=HealthResponse, tags=["General"])
147
  async def health_check():
148
- global scripted_model, model_metadata, device, alphabet_map, max_text_len, ALPHABET_LIST
149
 
150
- is_healthy = all([scripted_model, model_metadata, device, alphabet_map, max_text_len, ALPHABET_LIST])
151
 
152
  return HealthResponse(
153
  status="healthy" if is_healthy else "unhealthy",
154
- model_loaded=bool(scripted_model),
155
- device=str(device) if device else "unknown",
156
- model_metadata_keys=list(model_metadata.keys()) if model_metadata else None,
157
  )
158
 
159
- def text_to_tensor(text: str, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
160
  """Convert text to tensor format expected by the model"""
161
- global alphabet_map, max_text_len
162
- if alphabet_map is None:
163
  raise ValueError("Alphabet map not initialized during api startup")
164
- if max_text_len is None:
165
- raise ValueError("`max_text_len` is not initialized during api startup")
166
  padded_encoded_np, true_length = encode_text(
167
  text=text,
168
- char_to_index_map=alphabet_map,
169
- max_length=max_text_len
 
170
  )
171
 
172
- char_seq = torch.from_numpy(padded_encoded_np).to(device=device, dtype=torch.long)
173
- char_len = torch.tensor([true_length], device=device, dtype=torch.long)
174
 
175
  return char_seq, char_len
176
 
@@ -179,20 +178,38 @@ def generate_strokes(
179
  char_lengths: torch.Tensor,
180
  max_gen_len: int,
181
  api_bias: float,
182
- current_device: torch.device
183
  ) -> list[list[float]]:
184
  """Generate strokes using the model's built-in sample method"""
185
- global scripted_model
186
- if scripted_model is None:
187
  raise ValueError("Scripted model not initialized.")
188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  with torch.no_grad():
190
  try:
191
- stroke_tensors = scripted_model.sample(
192
  char_seq,
193
  char_lengths,
194
  max_length=max_gen_len,
195
- bias=api_bias
 
196
  )
197
 
198
  if len(stroke_tensors) == 1 and stroke_tensors[0].dim() == 2:
@@ -217,21 +234,24 @@ def generate_strokes(
217
 
218
  @app.post("/generate", response_model=HandwritingResponse, tags=["Generation"])
219
  async def generate_handwriting_endpoint(request: HandwritingRequest):
220
- if not all([scripted_model, model_metadata, device, alphabet_map, max_text_len]):
221
  logger.error("API not fully initialized. Check /health endpoint.")
222
  raise HTTPException(
223
  status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
224
  detail="Model or required resources not loaded."
225
  )
226
 
227
- assert device is not None, "Device is None inside generate_handwriting"
228
  start_time = time.time()
229
 
230
  try:
231
- char_seq_tensor, char_lengths_tensor = text_to_tensor(request.text, device)
232
 
233
  relative_stroke_offsets = generate_strokes(
234
- char_seq_tensor, char_lengths_tensor, request.max_length, request.bias, device
 
 
 
235
  )
236
 
237
  if not relative_stroke_offsets:
 
8
  import logging
9
  import time
10
  from contextlib import asynccontextmanager
11
+ from inference_utils import PrimingData, construct_alphabet_list, convert_offsets_to_absolute_coords, encode_text, get_alphabet_map, load_priming_data
12
 
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
 
18
  SCRIPTED_MODEL_NAME = "model.scripted.pt"
19
  METADATA_MODEL_NAME = "model.pt"
20
 
21
+ SCRIPTED_MODEL: Optional[torch.jit.ScriptModule] = None
22
+ MODEL_METADATA: Optional[dict] = None
23
+ DEVICE: Optional[torch.device] = None
24
+ ALPHABET_MAP: Optional[dict[str, int]] = None
25
  ALPHABET_LIST: Optional[list[str]] = None
26
  ALPHABET_SIZE: Optional[int] = None
27
+ MAX_TEXT_LEN: Optional[int] = None
28
  output_mixture_components: Optional[int] = None # To store num_mixtures for GMM sampling
29
  lstm_size: Optional[int] = None
30
  attention_mixture_components: Optional[int] = None
 
55
  @asynccontextmanager
56
  async def lifespan(app: FastAPI):
57
  """Lifespan context manager for startup and shutdown events"""
58
+ global SCRIPTED_MODEL, MODEL_METADATA, DEVICE, ALPHABET_MAP, MAX_TEXT_LEN, ALPHABET_LIST, output_mixture_components, lstm_size, attention_mixture_components, ALPHABET_SIZE
59
  logger.info("Attempting to load model resources during startup")
60
  try:
61
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
62
+ logger.info(f"Using device: {DEVICE}")
63
 
64
  scripted_model_path = MODEL_DIR / SCRIPTED_MODEL_NAME
65
  metadata_model_path = MODEL_DIR / METADATA_MODEL_NAME
66
+ if DEVICE.type == "cpu":
67
  scripted_model_path = MODEL_DIR / QUANTIZED_MODEL_NAME
68
 
69
  if not scripted_model_path.exists():
 
74
  raise FileNotFoundError(f"Metadata model file not found at {metadata_model_path}")
75
 
76
  # Load the traced model
77
+ SCRIPTED_MODEL = torch.jit.load(scripted_model_path, map_location=DEVICE)
78
+ if SCRIPTED_MODEL:
79
+ SCRIPTED_MODEL.eval()
80
  logger.info(f"Traced model loaded successfully from {scripted_model_path}")
81
 
82
  # Load the metadata
83
+ MODEL_METADATA = torch.load(metadata_model_path, map_location='cpu')
84
+ if MODEL_METADATA:
85
  logger.info(f"Model metadata loaded successfully from {metadata_model_path}")
86
+ logger.info(f"Model metadata keys: {list(MODEL_METADATA.keys())}")
87
 
88
+ config_full = MODEL_METADATA['config_full']
89
  if not config_full or not isinstance(config_full, dict):
90
  raise ValueError(f"Key `config_full` not found or not a dict")
91
 
 
95
  if not dataset_config or not isinstance(dataset_config, dict):
96
  raise ValueError(f"Key `dataset` not found or not a dict in config_full")
97
  alphabet_str = dataset_config['alphabet_string']
98
+ MAX_TEXT_LEN = dataset_config['max_text_len']
99
  output_mixture_components = model_params['output_mixture_components']
100
 
101
  lstm_size = model_params['lstm_size']
 
103
 
104
  ALPHABET_LIST = construct_alphabet_list(alphabet_str)
105
  ALPHABET_SIZE = len(ALPHABET_LIST)
106
+ ALPHABET_MAP = get_alphabet_map(ALPHABET_LIST)
107
 
108
  logger.info(f"Alphabet created. Size: {len(ALPHABET_LIST)}")
109
  logger.info("Model resources are loaded and ready")
 
112
 
113
  except Exception as e:
114
  logger.error(f"Error loading model resources: {e}", exc_info=True)
115
+ SCRIPTED_MODEL = None
116
+ MODEL_METADATA = None
117
  raise
118
 
119
  yield
120
 
121
  # Cleanup on shutdown
122
  logger.info("Shutting down API and cleaning up resources")
123
+ SCRIPTED_MODEL = None
124
+ MODEL_METADATA = None
125
 
126
  app = FastAPI(
127
  title="Scriptify API",
 
145
 
146
  @app.get("/health", response_model=HealthResponse, tags=["General"])
147
  async def health_check():
148
+ global SCRIPTED_MODEL, MODEL_METADATA, DEVICE, ALPHABET_MAP, MAX_TEXT_LEN, ALPHABET_LIST
149
 
150
+ is_healthy = all([SCRIPTED_MODEL, MODEL_METADATA, DEVICE, ALPHABET_MAP, MAX_TEXT_LEN, ALPHABET_LIST])
151
 
152
  return HealthResponse(
153
  status="healthy" if is_healthy else "unhealthy",
154
+ model_loaded=bool(SCRIPTED_MODEL),
155
+ device=str(DEVICE) if DEVICE else "unknown",
156
+ model_metadata_keys=list(MODEL_METADATA.keys()) if MODEL_METADATA else None,
157
  )
158
 
159
+ def text_to_tensor(text: str, max_text_length: int, add_eos: bool = True) -> tuple[torch.Tensor, torch.Tensor]:
160
  """Convert text to tensor format expected by the model"""
161
+ if ALPHABET_MAP is None:
 
162
  raise ValueError("Alphabet map not initialized during api startup")
163
+
 
164
  padded_encoded_np, true_length = encode_text(
165
  text=text,
166
+ char_to_index_map=ALPHABET_MAP,
167
+ max_length=max_text_length,
168
+ add_eos = add_eos
169
  )
170
 
171
+ char_seq = torch.from_numpy(padded_encoded_np).to(device=DEVICE, dtype=torch.long)
172
+ char_len = torch.tensor([true_length], device=DEVICE, dtype=torch.long)
173
 
174
  return char_seq, char_len
175
 
 
178
  char_lengths: torch.Tensor,
179
  max_gen_len: int,
180
  api_bias: float,
181
+ style: Optional[int] = None
182
  ) -> list[list[float]]:
183
  """Generate strokes using the model's built-in sample method"""
184
+ global SCRIPTED_MODEL
185
+ if SCRIPTED_MODEL is None:
186
  raise ValueError("Scripted model not initialized.")
187
 
188
+ primingData = None
189
+
190
+ if style is not None:
191
+ priming_text, priming_strokes = load_priming_data(style)
192
+
193
+ priming_text_tensor, priming_text_len_tensor = text_to_tensor(
194
+ priming_text, max_text_length=len(priming_text), add_eos=False)
195
+
196
+ priming_stroke_tensor = torch.tensor(priming_strokes,
197
+ dtype=torch.float32,
198
+ device=DEVICE).unsqueeze(dim=0)
199
+
200
+ primingData = PrimingData(priming_stroke_tensor,
201
+ char_seq_tensors=priming_text_tensor,
202
+ char_seq_lengths=priming_text_len_tensor)
203
+
204
+
205
  with torch.no_grad():
206
  try:
207
+ stroke_tensors = SCRIPTED_MODEL.sample(
208
  char_seq,
209
  char_lengths,
210
  max_length=max_gen_len,
211
+ bias=api_bias,
212
+ prime=primingData
213
  )
214
 
215
  if len(stroke_tensors) == 1 and stroke_tensors[0].dim() == 2:
 
234
 
235
  @app.post("/generate", response_model=HandwritingResponse, tags=["Generation"])
236
  async def generate_handwriting_endpoint(request: HandwritingRequest):
237
+ if not all([SCRIPTED_MODEL, MODEL_METADATA, DEVICE, ALPHABET_MAP, MAX_TEXT_LEN]):
238
  logger.error("API not fully initialized. Check /health endpoint.")
239
  raise HTTPException(
240
  status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
241
  detail="Model or required resources not loaded."
242
  )
243
 
244
+ assert DEVICE is not None, "Device is None inside generate_handwriting"
245
  start_time = time.time()
246
 
247
  try:
248
+ char_seq_tensor, char_lengths_tensor = text_to_tensor(request.text, max_text_length=MAX_TEXT_LEN) # type: ignore
249
 
250
  relative_stroke_offsets = generate_strokes(
251
+ char_seq_tensor, char_lengths_tensor,
252
+ request.max_length,
253
+ request.bias,
254
+ style=1 #TODO: style is hardcode since the current version is hosted on cpu
255
  )
256
 
257
  if not relative_stroke_offsets: