henok3878 commited on
Commit
a9beef1
·
1 Parent(s): 7a9620f

feat: init commit for hugging face space

Browse files
.gitattributes CHANGED
@@ -1,3 +1,4 @@
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
@@ -33,4 +34,3 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- packaged_models/*.pt filter=lfs diff=lfs merge=lfs -text
 
1
+ packaged_models/*.pt filter=lfs diff=lfs merge=lfs -text
2
  *.7z filter=lfs diff=lfs merge=lfs -text
3
  *.arrow filter=lfs diff=lfs merge=lfs -text
4
  *.bin filter=lfs diff=lfs merge=lfs -text
 
34
  *.zip filter=lfs diff=lfs merge=lfs -text
35
  *.zst filter=lfs diff=lfs merge=lfs -text
36
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
README.md CHANGED
@@ -1,12 +1,20 @@
1
  ---
2
  title: Scriptify Api
3
- emoji: 🏆
4
- colorFrom: blue
5
  colorTo: green
6
- sdk: docker
 
 
7
  pinned: false
8
  license: mit
9
  short_description: An API for generating realistic handwriting stroke points.
10
  ---
11
 
 
 
 
 
 
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: Scriptify Api
3
+ emoji: ✍️
4
+ colorFrom: indigo
5
  colorTo: green
6
+ sdk: python
7
+ app_file: main.py
8
+ python_version: 3.9
9
  pinned: false
10
  license: mit
11
  short_description: An API for generating realistic handwriting stroke points.
12
  ---
13
 
14
+
15
+ # Scriptify Handwriting Generation API
16
+
17
+ This Space hosts an API for generating handwriting from text.
18
+ Use the `/generate` endpoint with a POST request.
19
+
20
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
inference_utils.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ import numpy as np
3
+
4
+ NULL_CHAR = '\x00'
5
+
6
+ def construct_alphabet_list(alphabet_string: str) -> list[str]:
7
+ if not isinstance(alphabet_string, str):
8
+ raise TypeError("alphabet_string must be a string")
9
+
10
+ char_list = list(alphabet_string)
11
+ return [NULL_CHAR] + char_list
12
+
13
+ def get_alphabet_map(alphabet_list: list[str]) -> Dict[str, int]:
14
+ """creates a char to index map from full alphabet list"""
15
+ return {char: idx for idx, char in enumerate(alphabet_list)}
16
+
17
+ def encode_text(text: str, char_to_index_map: Dict[str, int],
18
+ max_length: int, add_eos: bool = True, eos_char_index: int = 0
19
+ ) -> tuple[np.ndarray, int]:
20
+ """Encode a text string into a sequence of integer indices"""
21
+ encoded = [char_to_index_map.get(c, eos_char_index) for c in text]
22
+ if add_eos:
23
+ encoded.append(eos_char_index)
24
+
25
+ true_length = len(encoded)
26
+
27
+ if true_length <= max_length:
28
+ padded_encoded = np.full(max_length, eos_char_index, dtype=np.int64)
29
+ padded_encoded[:true_length] = encoded
30
+ else:
31
+ padded_encoded = np.array(encoded[:max_length], dtype=np.int64)
32
+ true_length = max_length
33
+
34
+ return np.array([padded_encoded]), true_length
35
+
36
+
37
+ def convert_offsets_to_absolute_coords(stroke_offsets: list[list[float]]) -> list[list[float]]:
38
+ if not stroke_offsets:
39
+ return []
40
+
41
+ # convert to numpy for vectorized operations
42
+ strokes_array = np.array(stroke_offsets)
43
+
44
+ # vectorized cumulative sum for x and y
45
+ strokes_array[:, 0] = np.cumsum(strokes_array[:, 0]) # cumulative dx
46
+ strokes_array[:, 1] = np.cumsum(strokes_array[:, 1]) # cumulative dy
47
+
48
+ return strokes_array.tolist()
main.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from fastapi import FastAPI, HTTPException, status
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from pydantic import BaseModel, Field
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from pathlib import Path
8
+ import logging
9
+ import time
10
+ from contextlib import asynccontextmanager
11
+ from inference_utils import construct_alphabet_list, convert_offsets_to_absolute_coords, encode_text, get_alphabet_map
12
+
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+ MODEL_DIR = Path("../../ml/packaged_models")
17
+ SCRIPTED_MODEL_NAME = "handwriting_model.scripted.pt"
18
+ METADATA_MODEL_NAME = "handwriting_model.pt"
19
+
20
+ scripted_model: Optional[torch.jit.ScriptModule] = None
21
+ model_metadata: Optional[dict] = None
22
+ device: Optional[torch.device] = None
23
+ alphabet_map: Optional[dict[str, int]] = None
24
+ ALPHABET_LIST: Optional[list[str]] = None
25
+ ALPHABET_SIZE: Optional[int] = None
26
+ max_text_len: Optional[int] = None
27
+ output_mixture_components: Optional[int] = None # To store num_mixtures for GMM sampling
28
+ lstm_size: Optional[int] = None
29
+ attention_mixture_components: Optional[int] = None
30
+
31
+ # Patience for early stopping in generate_strokes
32
+ PATIENCE_PEN_UP_EOS = 15
33
+ MIN_MOVEMENT_THRESHOLD = 0.02
34
+
35
+
36
+ class HandwritingRequest(BaseModel):
37
+ text: str = Field(..., min_length=1, max_length=40, description="Text to generate handwriting for")
38
+ max_length: int = Field(default=700, ge=50, le=1500, description="Maximum number of stroke points")
39
+ bias: float = Field(default=0.75, ge=0.1, le=2.0, description="Sampling bias for generation")
40
+ class HandwritingResponse(BaseModel):
41
+ success: bool = True
42
+ input_text: str
43
+ generation_time_ms: float
44
+ num_points: int
45
+ strokes: list[list[float]]
46
+ message: str = "Successfully generated handwriting."
47
+
48
+ class HealthResponse(BaseModel):
49
+ status: str
50
+ model_loaded: bool
51
+ device: str
52
+ model_metadata_keys: Optional[list[str]] = None
53
+
54
+ @asynccontextmanager
55
+ async def lifespan(app: FastAPI):
56
+ """Lifespan context manager for startup and shutdown events"""
57
+ global scripted_model, model_metadata, device, alphabet_map, max_text_len, ALPHABET_LIST, output_mixture_components, lstm_size, attention_mixture_components, ALPHABET_SIZE
58
+ logger.info("Attempting to load model resources during startup")
59
+ try:
60
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61
+ logger.info(f"Using device: {device}")
62
+
63
+ scripted_model_path = MODEL_DIR / SCRIPTED_MODEL_NAME
64
+ metadata_model_path = MODEL_DIR / METADATA_MODEL_NAME
65
+
66
+ if not scripted_model_path.exists():
67
+ logger.error(f"Traced model not found at {scripted_model_path}")
68
+ raise FileNotFoundError(f"Traced model not found at {scripted_model_path}")
69
+ if not metadata_model_path or not metadata_model_path.exists():
70
+ logger.error(f"Metadata model file not found at {metadata_model_path}")
71
+ raise FileNotFoundError(f"Metadata model file not found at {metadata_model_path}")
72
+
73
+ # Load the traced model
74
+ scripted_model = torch.jit.load(scripted_model_path, map_location=device)
75
+ if scripted_model:
76
+ scripted_model.eval()
77
+ logger.info(f"Traced model loaded successfully from {scripted_model_path}")
78
+
79
+ # Load the metadata
80
+ model_metadata = torch.load(metadata_model_path, map_location='cpu')
81
+ if model_metadata:
82
+ logger.info(f"Model metadata loaded successfully from {metadata_model_path}")
83
+ logger.info(f"Model metadata keys: {list(model_metadata.keys())}")
84
+
85
+ config_full = model_metadata['config_full']
86
+ if not config_full or not isinstance(config_full, dict):
87
+ raise ValueError(f"Key `config_full` not found or not a dict")
88
+
89
+ dataset_config = config_full['dataset']
90
+ model_params = config_full['model_params']
91
+
92
+ if not dataset_config or not isinstance(dataset_config, dict):
93
+ raise ValueError(f"Key `dataset` not found or not a dict in config_full")
94
+ alphabet_str = dataset_config['alphabet_string']
95
+ max_text_len = dataset_config['max_text_len']
96
+ output_mixture_components = model_params['output_mixture_components']
97
+
98
+ lstm_size = model_params['lstm_size']
99
+ attention_mixture_components = model_params['attention_mixture_components']
100
+
101
+ ALPHABET_LIST = construct_alphabet_list(alphabet_str)
102
+ ALPHABET_SIZE = len(ALPHABET_LIST)
103
+ alphabet_map = get_alphabet_map(ALPHABET_LIST)
104
+
105
+ logger.info(f"Alphabet created. Size: {len(ALPHABET_LIST)}")
106
+ logger.info("Model resources are loaded and ready")
107
+ else:
108
+ raise ValueError(f"Failed to load content frm metadata file")
109
+
110
+ except Exception as e:
111
+ logger.error(f"Error loading model resources: {e}", exc_info=True)
112
+ scripted_model = None
113
+ model_metadata = None
114
+ raise
115
+
116
+ yield
117
+
118
+ # Cleanup on shutdown
119
+ logger.info("Shutting down API and cleaning up resources")
120
+ scripted_model = None
121
+ model_metadata = None
122
+
123
+ app = FastAPI(
124
+ title="Scriptify API",
125
+ description="API to generate handwriting from text using a PyTorch model.",
126
+ version="0.1.0",
127
+ lifespan=lifespan
128
+ )
129
+
130
+ # add CORS middleware
131
+ app.add_middleware(
132
+ CORSMiddleware,
133
+ allow_origins=["http://localhost:5173","http://127.0.0.1:5173"],
134
+ allow_credentials=True,
135
+ allow_methods=["GET", "POST"],
136
+ allow_headers=["*"],
137
+ )
138
+
139
+ @app.get("/", tags=["General"])
140
+ async def read_root():
141
+ return {"message": "Welcome to the Scriptify Handwriting Generation API!"}
142
+
143
+ @app.get("/health", response_model=HealthResponse, tags=["General"])
144
+ async def health_check():
145
+ global scripted_model, model_metadata, device, alphabet_map, max_text_len, ALPHABET_LIST
146
+
147
+ is_healthy = all([scripted_model, model_metadata, device, alphabet_map, max_text_len, ALPHABET_LIST])
148
+
149
+ return HealthResponse(
150
+ status="healthy" if is_healthy else "unhealthy",
151
+ model_loaded=bool(scripted_model),
152
+ device=str(device) if device else "unknown",
153
+ model_metadata_keys=list(model_metadata.keys()) if model_metadata else None,
154
+ )
155
+
156
+ def text_to_tensor(text: str, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
157
+ """Convert text to tensor format expected by the model"""
158
+ global alphabet_map, max_text_len
159
+ if alphabet_map is None:
160
+ raise ValueError("Alphabet map not initialized during api startup")
161
+ if max_text_len is None:
162
+ raise ValueError("`max_text_len` is not initialized during api startup")
163
+ padded_encoded_np, true_length = encode_text(
164
+ text=text,
165
+ char_to_index_map=alphabet_map,
166
+ max_length=max_text_len
167
+ )
168
+
169
+ char_seq = torch.from_numpy(padded_encoded_np).to(device=device, dtype=torch.long)
170
+ char_len = torch.tensor([true_length], device=device, dtype=torch.long)
171
+
172
+ return char_seq, char_len
173
+
174
+ def generate_strokes(
175
+ char_seq: torch.Tensor,
176
+ char_lengths: torch.Tensor,
177
+ max_gen_len: int,
178
+ api_bias: float,
179
+ current_device: torch.device
180
+ ) -> list[list[float]]:
181
+ """Generate strokes using the model's built-in sample method"""
182
+ global scripted_model
183
+ if scripted_model is None:
184
+ raise ValueError("Scripted model not initialized.")
185
+
186
+ with torch.no_grad():
187
+ try:
188
+ stroke_tensors = scripted_model.sample(
189
+ char_seq,
190
+ char_lengths,
191
+ max_length=max_gen_len,
192
+ bias=api_bias
193
+ )
194
+
195
+ if len(stroke_tensors) == 1 and stroke_tensors[0].dim() == 2:
196
+ all_strokes_tensor = stroke_tensors[0]
197
+ stroke_offsets = all_strokes_tensor.cpu().numpy().tolist()
198
+ else:
199
+ stroke_offsets = []
200
+ for stroke_tensor in stroke_tensors:
201
+ if stroke_tensor.dim() == 2:
202
+ stroke_data = stroke_tensor.squeeze(0).cpu().numpy().tolist()
203
+ else:
204
+ stroke_data = stroke_tensor.cpu().numpy().tolist()
205
+
206
+ if len(stroke_data) == 3:
207
+ stroke_offsets.append(stroke_data)
208
+
209
+ return stroke_offsets
210
+
211
+ except Exception as e:
212
+ logger.error(f"Error in model sampling: {e}", exc_info=True)
213
+ return []
214
+
215
+ @app.post("/generate", response_model=HandwritingResponse, tags=["Generation"])
216
+ async def generate_handwriting_endpoint(request: HandwritingRequest):
217
+ if not all([scripted_model, model_metadata, device, alphabet_map, max_text_len]):
218
+ logger.error("API not fully initialized. Check /health endpoint.")
219
+ raise HTTPException(
220
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
221
+ detail="Model or required resources not loaded."
222
+ )
223
+
224
+ assert device is not None, "Device is None inside generate_handwriting"
225
+ start_time = time.time()
226
+
227
+ try:
228
+ char_seq_tensor, char_lengths_tensor = text_to_tensor(request.text, device)
229
+
230
+ relative_stroke_offsets = generate_strokes(
231
+ char_seq_tensor, char_lengths_tensor, request.max_length, request.bias, device
232
+ )
233
+
234
+ if not relative_stroke_offsets:
235
+ return HandwritingResponse(
236
+ success=False,
237
+ input_text=request.text,
238
+ strokes=[],
239
+ num_points=0,
240
+ generation_time_ms=(time.time() - start_time) * 1000,
241
+ message="No strokes generated."
242
+ )
243
+
244
+ absolute_stroke_coords = convert_offsets_to_absolute_coords(relative_stroke_offsets)
245
+ generation_time_ms = (time.time() - start_time) * 1000
246
+
247
+ return HandwritingResponse(
248
+ input_text=request.text,
249
+ strokes=absolute_stroke_coords,
250
+ num_points=len(absolute_stroke_coords),
251
+ generation_time_ms=generation_time_ms
252
+ )
253
+ except ValueError as ve:
254
+ logger.error(f"ValueError during generation for '{request.text}': {ve}", exc_info=True)
255
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))
256
+ except Exception as e:
257
+ logger.error(f"Unexpected error for '{request.text}': {e}", exc_info=True)
258
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.")
259
+
260
+ if __name__ == "__main__":
261
+ import uvicorn
262
+ logger.info("Starting Uvicorn server for Scriptify API...")
263
+ uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True, app_dir=".")
packaged_models/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d9430eccb030d1ad0458ea6bb19696346ad5b3998e658b78acdfd1f19779498a
3
+ size 17601066
packaged_models/model.scripted.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5deb88801c26ab924d0079d9e5522fd55114bd8429c180c7646dd7fbc0049f3e
3
+ size 17632110
packaged_models/model.scripted.quantized.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:793a525a5a8d4f62cc80ddbf0f0ca0fddc13ec202ef2fc6efd9bfaa32c78e306
3
+ size 17674936
requirements.txt ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ annotated-types==0.7.0
2
+ anyio==4.9.0
3
+ click==8.1.8
4
+ exceptiongroup==1.3.0
5
+ fastapi==0.115.12
6
+ filelock==3.13.1
7
+ fsspec==2024.6.1
8
+ h11==0.16.0
9
+ httptools==0.6.4
10
+ idna==3.10
11
+ Jinja2==3.1.4
12
+ MarkupSafe==2.1.5
13
+ mpmath==1.3.0
14
+ networkx==3.2.1
15
+ numpy==2.0.2
16
+ pydantic==2.11.5
17
+ pydantic-settings==2.9.1
18
+ pydantic_core==2.33.2
19
+ python-dotenv==1.1.0
20
+ PyYAML==6.0.2
21
+ sniffio==1.3.1
22
+ starlette==0.46.2
23
+ sympy==1.13.1
24
+ torch==2.5.1+cpu
25
+ typing-inspection==0.4.1
26
+ typing_extensions==4.13.2
27
+ uvicorn==0.34.2
28
+ uvloop==0.21.0
29
+ watchfiles==1.0.5
30
+ websockets==15.0.1