fisherman611 commited on
Commit
c70f97e
ยท
verified ยท
1 Parent(s): afdd9da

Upload 5 files

Browse files
Files changed (5) hide show
  1. .gitignore +177 -0
  2. README.md +98 -13
  3. app.py +771 -0
  4. config.json +29 -0
  5. requirements.txt +0 -0
.gitignore ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
169
+
170
+ # Ruff stuff:
171
+ .ruff_cache/
172
+
173
+ # PyPI configuration file
174
+ .pypirc
175
+
176
+ /data
177
+ /checkpoints
README.md CHANGED
@@ -1,13 +1,98 @@
1
- ---
2
- title: Handwritten Mathematical Expression Recognition
3
- emoji: ๐Ÿข
4
- colorFrom: purple
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 5.38.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # **Handwritten Mathematical Expression Recognition**
2
+
3
+ ## **Project Overview**
4
+ This project focuses on recognizing handwritten mathematical expressions and converting them into LaTeX format. The system leverages deep learning techniques to process images of handwritten equations, interpret their structure, and generate corresponding LaTeX code. The primary goal is to achieve high accuracy in recognizing complex mathematical expressions, addressing challenges such as varying handwriting styles and intricate symbol arrangements. The project is built using PyTorch and incorporates advanced neural network architectures tailored for this task.
5
+
6
+ ## **Dataset**
7
+
8
+ The project utilizes the **CROHME (Competition on Recognition of Online Handwritten Mathematical Expressions)** dataset, a widely used benchmark for handwritten mathematical expression recognition. The dataset is organized into several subsets, each containing images and their corresponding LaTeX annotations.
9
+
10
+ Download the splitted dataset: [CROHME Splitted](https://husteduvn-my.sharepoint.com/:f:/g/personal/thanh_lh225458_sis_hust_edu_vn/EviH0ckuHR9KiXftU5ETkPQBHvEL77YTscIHvfN7LBSrSg?e=CHwNxv) and then place in the `data/` directory.
11
+
12
+ ## **Methods and Models**
13
+
14
+ ### **Preprocessing**
15
+ Steps to clean and standardize images:
16
+ * Load in grayscale.
17
+ * Use Canny edge detection, dilate with $7 \times 13$ kernel to connect edges.
18
+ * Crop with F1-score method to focus on the expression.
19
+ * Binarize with adaptive thresholding; set background to black if needed.
20
+ * Apply median blur (kernel 3) multiple times to reduce noise.
21
+ * Add 5-pixel padding, resize to $128 \times 384$, pad with black if needed.
22
+
23
+ ### **Augmentation**
24
+ Augmentation to handle handwriting variations:
25
+ * Rotate up to 5 degrees, border replication.
26
+ * Elastic transform for stroke variations.
27
+ * Random morphology: erode or dilate to change stroke thickness.
28
+ * Normalize and convert to tensor.
29
+
30
+ ### **Model: Counting-Aware Network (CAN)**
31
+
32
+ CAN is an end-to-end model for HMER, combining recognition and symbol counting:
33
+ * **Backbone:**
34
+
35
+ * DenseNet (or ResNet)
36
+ * Takes grayscale image $H' \times W' \times 1$, outputs feature map $\mathcal{F} \in \mathbb{R}^{H \times W \times 684}$, where ($H = \frac{H'}{16}$), ($W = \frac{W'}{16}$).
37
+
38
+ * **Multi-Scale Counting Module (MSCM):**
39
+
40
+ * Uses $3 \times 3$ and $5 \times 5$ conv branches for multi-scales features.
41
+ * Channel attention: $$\mathcal{Q} = \sigma(W_1(G(\mathcal{H})) + b_1)$$ $$\mathcal{S} = \mathcal{Q} \otimes g(W_2 \mathcal{Q} + b_2)$$
42
+ * Concatenates features, $1 \times 1$ conv to counting map $$\mathcal{M} \in \mathbb{R}^{H \times W \times C}$$
43
+ * Sum-pooling gives counting vector $$\mathcal{V}i = \sum{p=1}^H \sum_{q=1}^W \mathcal{M}_{i,pq}$$
44
+
45
+ * **Counting-Combined Attentional Decoder (CCAD):**
46
+
47
+ * $1 \times 1$ conv on $\mathcal{F}$ to $\mathcal{T} \in \mathbb{R}^{H \times W \times 512}$, adds positional encoding.
48
+ * GRU gives hidden state $h_t \in \mathbb{R}^{1 \times 256}$, attention weights: $$e_{t,ij} = w^T \tanh(\mathcal{T} + \mathcal{P} + W_a \mathcal{A} + W_h h_t) + b$$ $$\alpha_{t,ij} = \frac{\exp(e_{t,ij})}{\sum_{p=1}^H \sum_{q=1}^W \exp(e_{t,pq})}$$
49
+ * Context vector $\mathcal{C} \in \mathbb{R}^{1 \times 256}$, predicts token: $$p(y_t) = \text{softmax}(w_o^T (W_c \mathcal{C} + W_v \mathcal{V}^f + W_t h_t + W_e E(y_{t-1})) + b_o)$$
50
+ * Beam search (width = 5) for inference.
51
+
52
+ * **Loss:**
53
+
54
+ * Loss class: $$\mathcal{L}{\text{cls}} = -\frac{1}{T} \sum{t=1}^T \log(p(y_t))$$
55
+ * Loss counting: $$\mathcal{L}{\text{counting}} = \text{smooth}{L_1}(\mathcal{V}^f, \hat{\mathcal{V}})$$
56
+ * Total loss: $$\mathcal{L} = \mathcal{L}{\text{cls}} + \lambda \mathcal{L}{\text{counting}}$$, $$\lambda = 0.01$$
57
+
58
+ ## **Results**
59
+ | Model | ExpRate | ExpRate-Leq1 | ExpRate-Leq2 | ExpRate-Leq3 |
60
+ |-------------|--------|---------|---------|---------|
61
+ | Customized DenseNet-CAN | 0.4248 | 0.6385 | 0.7313 | 0.8036 |
62
+ | Customized ResNet-CAN | 0.4511 | 0.6459 | 0.7288 | 0.7888 |
63
+ | Pretrained ResNet-CAN | 0.424 | 0.622 | 0.7214 | 0.7888 |
64
+ | Pretrained DenseNet-CAN | **0.5316** | **0,7149** | **0.8069** | **0.8521** |
65
+
66
+ ## **Conclusion**
67
+ CAN works well for handwritten math recognition on CROHME dataset. It handles complex expressions with counting and attention. Future ideas: try transformer decoders, add synthetic data, improve preprocessing for noisy images.
68
+
69
+ ## **Installation**
70
+ Clone the repository and naviagate to the project directory:
71
+ ```bash
72
+ git clone https://github.com/fisherman611/handwritten-mathematical-expression-recognition.git
73
+ ```
74
+
75
+ Navigate to the project directory:
76
+ ```bash
77
+ cd handwritten-mathematical-expression-recognition
78
+ ```
79
+
80
+ Install the required dependencies:
81
+ ```bash
82
+ pip install -r requirements.txt
83
+ ```
84
+
85
+ ## **Download the pretrained model**
86
+ Download the pretrained model checkpoints from this [OneDrive link](https://husteduvn-my.sharepoint.com/:f:/g/personal/thanh_lh225458_sis_hust_edu_vn/EvWQqIjJQtNKuQwwH1G8EMkBcRPM8s3msiI7-IBERbve1A?e=6SeGHB)
87
+
88
+ Place the downloaded checkpoint in the `checkpoints/` directory within the repository.
89
+
90
+ ## **Inference**
91
+ Run `app.py` to make the inference
92
+ ```bash
93
+ python app.py
94
+ ```
95
+ ## **References**
96
+ [1] B. Li, Y. Yuan, D. Liang, X. Liu, Z. Ji, J. Bai, W. Liu, and X. Bai, "When Counting Meets HMER: Counting-Aware Network for Handwritten Mathematical Expression Recognition," arXiv preprint arXiv:2207.11463, 2022. [Online]. Available: https://arxiv.org/abs/2207.11463
97
+ ## **License**
98
+ This project is licensed under the [MIT License](LICENSE).
app.py ADDED
@@ -0,0 +1,771 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import cv2
5
+ import albumentations as A
6
+ from albumentations.pytorch import ToTensorV2
7
+ from PIL import Image
8
+ import matplotlib.pyplot as plt
9
+ import torch.nn.functional as F
10
+ import os
11
+ import sys
12
+
13
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
14
+
15
+ import json
16
+ from models.can.can import CAN, create_can_model
17
+ from models.can.can_dataloader import Vocabulary, INPUT_HEIGHT, INPUT_WIDTH
18
+
19
+ # Load configuration
20
+ with open("config.json", "r") as json_file:
21
+ cfg = json.load(json_file)
22
+ CAN_CONFIG = cfg["can"]
23
+
24
+ # Global constants
25
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
26
+ BACKBONE_TYPE = CAN_CONFIG["backbone_type"]
27
+ PRETRAINED_BACKBONE = True if CAN_CONFIG["pretrained_backbone"] == 1 else False
28
+ CHECKPOINT_PATH = f'checkpoints/{BACKBONE_TYPE}_can_best.pth' if not PRETRAINED_BACKBONE else f'checkpoints/p_{BACKBONE_TYPE}_can_best.pth'
29
+
30
+ # Modified process_img to accept numpy array and validate shapes
31
+ def process_img(image, convert_to_rgb=False):
32
+ """
33
+ Process a numpy array image: binarize, ensure black background, resize, and apply padding.
34
+
35
+ Args:
36
+ image: Numpy array (grayscale)
37
+ convert_to_rgb: Whether to convert to RGB
38
+
39
+ Returns:
40
+ Processed image and crop information, or None if invalid
41
+ """
42
+ def is_effectively_binary(img, threshold_percentage=0.9):
43
+ dark_pixels = np.sum(img < 20)
44
+ bright_pixels = np.sum(img > 235)
45
+ total_pixels = img.size
46
+ return (dark_pixels + bright_pixels) / total_pixels > threshold_percentage
47
+
48
+ def before_padding(image):
49
+ if image.shape[0] < 2 or image.shape[1] < 2:
50
+ return None, None # Invalid image size
51
+
52
+ # Ensure image is uint8
53
+ if image.dtype != np.uint8:
54
+ if image.max() <= 1.0: # If image is normalized (0-1)
55
+ image = (image * 255).astype(np.uint8)
56
+ else: # If image is in other float format
57
+ image = np.clip(image, 0, 255).astype(np.uint8)
58
+
59
+ edges = cv2.Canny(image, 50, 150)
60
+ kernel = np.ones((7, 13), np.uint8)
61
+ dilated = cv2.dilate(edges, kernel, iterations=8)
62
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(dilated, connectivity=8)
63
+ sorted_components = sorted(range(1, num_labels), key=lambda i: stats[i, cv2.CC_STAT_AREA], reverse=True)
64
+ best_f1 = 0
65
+ best_crop = (0, 0, image.shape[1], image.shape[0])
66
+ total_white_pixels = np.sum(dilated > 0)
67
+ current_mask = np.zeros_like(dilated)
68
+ x_min, y_min = image.shape[1], image.shape[0]
69
+ x_max, y_max = 0, 0
70
+
71
+ for component_idx in sorted_components:
72
+ component_mask = labels == component_idx
73
+ current_mask = np.logical_or(current_mask, component_mask)
74
+ comp_y, comp_x = np.where(component_mask)
75
+ if len(comp_x) > 0 and len(comp_y) > 0:
76
+ x_min = min(x_min, np.min(comp_x))
77
+ y_min = min(y_min, np.min(comp_y))
78
+ x_max = max(x_max, np.max(comp_x))
79
+ y_max = max(y_max, np.max(comp_y))
80
+ width = x_max - x_min + 1
81
+ height = y_max - y_min + 1
82
+ if width < 2 or height < 2:
83
+ continue
84
+ crop_area = width * height
85
+ crop_mask = np.zeros_like(dilated)
86
+ crop_mask[y_min:y_max + 1, x_min:x_max + 1] = 1
87
+ white_in_crop = np.sum(np.logical_and(dilated > 0, crop_mask > 0))
88
+ precision = white_in_crop / crop_area if crop_area > 0 else 0
89
+ recall = white_in_crop / total_white_pixels if total_white_pixels > 0 else 0
90
+ f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
91
+ if f1 > best_f1:
92
+ best_f1 = f1
93
+ best_crop = (x_min, y_min, x_max, y_max)
94
+
95
+ x_min, y_min, x_max, y_max = best_crop
96
+ cropped_image = image[y_min:y_max + 1, x_min:x_max + 1]
97
+ if cropped_image.shape[0] < 2 or cropped_image.shape[1] < 2:
98
+ return None, None
99
+ if is_effectively_binary(cropped_image):
100
+ _, thresh = cv2.threshold(cropped_image, 127, 255, cv2.THRESH_BINARY)
101
+ else:
102
+ thresh = cv2.adaptiveThreshold(cropped_image, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2)
103
+ white = np.sum(thresh == 255)
104
+ black = np.sum(thresh == 0)
105
+ if white > black:
106
+ thresh = 255 - thresh
107
+ denoised = cv2.medianBlur(thresh, 3)
108
+ for _ in range(3):
109
+ denoised = cv2.medianBlur(denoised, 3)
110
+ result = cv2.copyMakeBorder(denoised, 5, 5, 5, 5, cv2.BORDER_CONSTANT, value=0)
111
+ return result, best_crop
112
+
113
+ if len(image.shape) != 2:
114
+ return None, None # Expect grayscale image
115
+
116
+ # Ensure image is uint8 before processing
117
+ if image.dtype != np.uint8:
118
+ if image.max() <= 1.0: # If image is normalized (0-1)
119
+ image = (image * 255).astype(np.uint8)
120
+ else: # If image is in other float format
121
+ image = np.clip(image, 0, 255).astype(np.uint8)
122
+
123
+ bin_img, best_crop = before_padding(image)
124
+ if bin_img is None:
125
+ return None, None
126
+ h, w = bin_img.shape
127
+ if h < 2 or w < 2:
128
+ return None, None
129
+ new_w = int((INPUT_HEIGHT / h) * w)
130
+
131
+ if new_w > INPUT_WIDTH:
132
+ resized_img = cv2.resize(bin_img, (INPUT_WIDTH, INPUT_HEIGHT), interpolation=cv2.INTER_AREA)
133
+ else:
134
+ resized_img = cv2.resize(bin_img, (new_w, INPUT_HEIGHT), interpolation=cv2.INTER_AREA)
135
+ padded_img = np.zeros((INPUT_HEIGHT, INPUT_WIDTH), dtype=np.uint8)
136
+ x_offset = (INPUT_WIDTH - new_w) // 2
137
+ padded_img[:, x_offset:x_offset + new_w] = resized_img
138
+ resized_img = padded_img
139
+
140
+ if convert_to_rgb:
141
+ resized_img = cv2.cvtColor(resized_img, cv2.COLOR_GRAY2BGR)
142
+
143
+ return resized_img, best_crop
144
+
145
+ # Load model and vocabulary
146
+ def load_checkpoint(checkpoint_path, device, pretrained_backbone=True, backbone='densenet'):
147
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
148
+ vocab = checkpoint.get('vocab')
149
+ if vocab is None:
150
+ vocab_path = os.path.join(os.path.dirname(checkpoint_path), 'hmer_vocab.pth')
151
+ if os.path.exists(vocab_path):
152
+ vocab_data = torch.load(vocab_path)
153
+ vocab = Vocabulary()
154
+ vocab.word2idx = vocab_data['word2idx']
155
+ vocab.idx2word = vocab_data['idx2word']
156
+ vocab.idx = vocab_data['idx']
157
+ vocab.pad_token = vocab.word2idx['<pad>']
158
+ vocab.start_token = vocab.word2idx['<start>']
159
+ vocab.end_token = vocab.word2idx['<end>']
160
+ vocab.unk_token = vocab.word2idx['<unk>']
161
+ else:
162
+ raise ValueError(f"Vocabulary not found in checkpoint and {vocab_path} does not exist")
163
+
164
+ hidden_size = checkpoint.get('hidden_size', 256)
165
+ embedding_dim = checkpoint.get('embedding_dim', 256)
166
+ use_coverage = checkpoint.get('use_coverage', True)
167
+
168
+ model = create_can_model(
169
+ num_classes=len(vocab),
170
+ hidden_size=hidden_size,
171
+ embedding_dim=embedding_dim,
172
+ use_coverage=use_coverage,
173
+ pretrained_backbone=pretrained_backbone,
174
+ backbone_type=backbone
175
+ ).to(device)
176
+
177
+ model.load_state_dict(checkpoint['model'])
178
+ model.eval()
179
+ return model, vocab
180
+
181
+ model, vocab = load_checkpoint(CHECKPOINT_PATH, DEVICE, PRETRAINED_BACKBONE, BACKBONE_TYPE)
182
+
183
+ # Image processing function for Gradio
184
+ def gradio_process_img(image, convert_to_rgb=False):
185
+ # Convert Gradio image (PIL, numpy, or dict from Sketchpad) to grayscale numpy array
186
+ if isinstance(image, dict): # Handle Sketchpad input
187
+ # The Sketchpad component returns a dict with 'background' and 'layers' keys
188
+ # We need to combine the background and layers to get the final image
189
+ background = np.array(image['background'])
190
+ layers = image['layers']
191
+
192
+ # Start with the background
193
+ final_image = background.copy()
194
+
195
+ # Add each layer on top
196
+ for layer in layers:
197
+ if layer is not None: # Some layers might be None
198
+ layer_img = np.array(layer)
199
+ # Create a mask for non-transparent pixels
200
+ mask = layer_img[..., 3] > 0
201
+ # Replace pixels in final_image where mask is True, keeping the alpha channel
202
+ final_image[mask] = layer_img[mask]
203
+
204
+ # Convert to grayscale using the alpha channel
205
+ if len(final_image.shape) == 3:
206
+ # Use alpha channel to determine which pixels to keep
207
+ alpha_mask = final_image[..., 3] > 0
208
+ # Convert to grayscale using standard formula
209
+ gray = np.dot(final_image[..., :3], [0.299, 0.587, 0.114])
210
+ # Create a white background
211
+ final_image = np.ones_like(gray) * 255
212
+ # Apply the drawing where alpha > 0
213
+ final_image[alpha_mask] = gray[alpha_mask]
214
+ # Invert the image to get black on white
215
+ final_image = 255 - final_image
216
+ elif isinstance(image, Image.Image):
217
+ image = np.array(image.convert('L'))
218
+ elif isinstance(image, np.ndarray):
219
+ if len(image.shape) == 3:
220
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
221
+ elif len(image.shape) != 2:
222
+ raise ValueError("Invalid image format: Expected grayscale or RGB image")
223
+ else:
224
+ raise ValueError("Unsupported image input type")
225
+
226
+ # For Sketchpad input, use the final_image we created
227
+ if isinstance(image, dict):
228
+ image = final_image
229
+
230
+ # Apply modified process_img
231
+ processed_img, best_crop = process_img(image, convert_to_rgb=False)
232
+ if processed_img is None:
233
+ raise ValueError("Image processing failed: Resulted in invalid image size")
234
+
235
+ # Prepare for model input
236
+ transform = A.Compose([
237
+ A.Normalize(mean=[0.0], std=[1.0]),
238
+ ToTensorV2()
239
+ ])
240
+ processed_img = np.expand_dims(processed_img, axis=-1) # [H, W, 1]
241
+ image_tensor = transform(image=processed_img)['image'].unsqueeze(0).to(DEVICE)
242
+
243
+ return image_tensor, processed_img, best_crop
244
+
245
+ # Model inference
246
+ def recognize_image(image_tensor, processed_img, best_crop):
247
+ with torch.no_grad():
248
+ predictions, _ = model.recognize(
249
+ image_tensor,
250
+ max_length=150,
251
+ start_token=vocab.start_token,
252
+ end_token=vocab.end_token,
253
+ beam_width=5
254
+ )
255
+
256
+ # Convert indices to LaTeX tokens
257
+ latex_tokens = []
258
+ for idx in predictions:
259
+ if idx == vocab.end_token:
260
+ break
261
+ if idx != vocab.start_token:
262
+ latex_tokens.append(vocab.idx2word[idx])
263
+
264
+ latex = ' '.join(latex_tokens)
265
+
266
+ # Format LaTeX for rendering
267
+ rendered_latex = f"$${latex}$$"
268
+
269
+ return latex, rendered_latex
270
+
271
+ # Gradio interface function
272
+ def process_draw(image):
273
+ if image is None:
274
+ return "Please draw an expression", ""
275
+
276
+ try:
277
+ # Process image
278
+ image_tensor, processed_img, best_crop = gradio_process_img(image)
279
+
280
+ # Recognize
281
+ latex, rendered_latex = recognize_image(image_tensor, processed_img, best_crop)
282
+
283
+ return latex, rendered_latex
284
+ except Exception as e:
285
+ return f"Error processing image: {str(e)}", ""
286
+
287
+ def process_upload(image):
288
+ if image is None:
289
+ return "Please upload an image", ""
290
+
291
+ try:
292
+ # Process image
293
+ image_tensor, processed_img, best_crop = gradio_process_img(image)
294
+
295
+ # Recognize
296
+ latex, rendered_latex = recognize_image(image_tensor, processed_img, best_crop)
297
+
298
+ return latex, rendered_latex
299
+ except Exception as e:
300
+ return f"Error processing image: {str(e)}", ""
301
+
302
+ # Enhanced custom CSS with expanded input areas
303
+ custom_css = """
304
+ /* Global styles */
305
+ .gradio-container {
306
+ max-width: 1400px !important;
307
+ margin: 0 auto !important;
308
+ font-family: 'Segoe UI', 'Roboto', sans-serif !important;
309
+ padding: 1rem !important;
310
+ box-sizing: border-box !important;
311
+ }
312
+
313
+ /* Header styling */
314
+ .header-title {
315
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
316
+ -webkit-background-clip: text !important;
317
+ -webkit-text-fill-color: transparent !important;
318
+ background-clip: text !important;
319
+ text-align: center !important;
320
+ font-size: clamp(1.8rem, 5vw, 2.5rem) !important;
321
+ font-weight: 700 !important;
322
+ margin-bottom: 1.5rem !important;
323
+ text-shadow: 0 2px 4px rgba(0,0,0,0.1) !important;
324
+ padding: 0 1rem !important;
325
+ }
326
+
327
+ /* Main container styling */
328
+ .main-container {
329
+ background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%) !important;
330
+ border-radius: 20px !important;
331
+ padding: clamp(1rem, 3vw, 2rem) !important;
332
+ box-shadow: 0 10px 30px rgba(0,0,0,0.1) !important;
333
+ margin: 1rem 0 !important;
334
+ }
335
+
336
+ /* Input section styling - RESPONSIVE */
337
+ .input-section {
338
+ background: white !important;
339
+ border-radius: 15px !important;
340
+ padding: clamp(1rem, 3vw, 2rem) !important;
341
+ box-shadow: 0 4px 15px rgba(0,0,0,0.05) !important;
342
+ border: 1px solid #e1e8ed !important;
343
+ min-height: min(700px, 80vh) !important;
344
+ width: 100% !important;
345
+ box-sizing: border-box !important;
346
+ }
347
+
348
+ /* Output section styling - RESPONSIVE */
349
+ .output-section {
350
+ background: white !important;
351
+ border-radius: 15px !important;
352
+ padding: clamp(1rem, 3vw, 1.5rem) !important;
353
+ box-shadow: 0 4px 15px rgba(0,0,0,0.05) !important;
354
+ border: 1px solid #e1e8ed !important;
355
+ min-height: min(700px, 80vh) !important;
356
+ width: 100% !important;
357
+ box-sizing: border-box !important;
358
+ }
359
+
360
+ /* Tab styling - RESPONSIVE */
361
+ .tab-nav {
362
+ background: #f8f9fa !important;
363
+ border-radius: 10px !important;
364
+ padding: 0.5rem !important;
365
+ margin-bottom: 1.5rem !important;
366
+ display: flex !important;
367
+ flex-wrap: wrap !important;
368
+ gap: 0.5rem !important;
369
+ }
370
+
371
+ .tab-nav button {
372
+ border-radius: 8px !important;
373
+ padding: clamp(0.5rem, 2vw, 0.75rem) clamp(1rem, 3vw, 1.5rem) !important;
374
+ font-weight: 600 !important;
375
+ transition: all 0.3s ease !important;
376
+ border: none !important;
377
+ background: transparent !important;
378
+ color: #6c757d !important;
379
+ font-size: clamp(0.9rem, 2vw, 1rem) !important;
380
+ white-space: nowrap !important;
381
+ }
382
+
383
+ /* Sketchpad styling - RESPONSIVE */
384
+ .sketchpad-container {
385
+ border: 3px dashed #667eea !important;
386
+ border-radius: 15px !important;
387
+ background: #fafbfc !important;
388
+ transition: all 0.3s ease !important;
389
+ overflow: hidden !important;
390
+ min-height: min(500px, 60vh) !important;
391
+ height: min(500px, 60vh) !important;
392
+ width: 100% !important;
393
+ box-sizing: border-box !important;
394
+ }
395
+
396
+ .sketchpad-container canvas {
397
+ width: 100% !important;
398
+ height: 100% !important;
399
+ min-height: min(500px, 60vh) !important;
400
+ touch-action: none !important;
401
+ }
402
+
403
+ /* Upload area styling - RESPONSIVE */
404
+ .upload-container {
405
+ border: 3px dashed #667eea !important;
406
+ border-radius: 15px !important;
407
+ background: #fafbfc !important;
408
+ padding: clamp(1.5rem, 5vw, 3rem) !important;
409
+ text-align: center !important;
410
+ transition: all 0.3s ease !important;
411
+ min-height: min(500px, 60vh) !important;
412
+ display: flex !important;
413
+ flex-direction: column !important;
414
+ justify-content: center !important;
415
+ align-items: center !important;
416
+ box-sizing: border-box !important;
417
+ }
418
+
419
+ .upload-container img {
420
+ max-height: min(400px, 50vh) !important;
421
+ max-width: 100% !important;
422
+ object-fit: contain !important;
423
+ }
424
+
425
+ /* Button styling - RESPONSIVE */
426
+ .process-button {
427
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
428
+ border: none !important;
429
+ border-radius: 12px !important;
430
+ padding: clamp(0.8rem, 2vw, 1.2rem) clamp(1.5rem, 4vw, 2.5rem) !important;
431
+ font-size: clamp(1rem, 2.5vw, 1.2rem) !important;
432
+ font-weight: 600 !important;
433
+ color: white !important;
434
+ cursor: pointer !important;
435
+ transition: all 0.3s ease !important;
436
+ box-shadow: 0 4px 15px rgba(102, 126, 234, 0.3) !important;
437
+ text-transform: uppercase !important;
438
+ letter-spacing: 0.5px !important;
439
+ width: 100% !important;
440
+ margin-top: 1.5rem !important;
441
+ white-space: nowrap !important;
442
+ }
443
+
444
+ /* Output text styling - RESPONSIVE */
445
+ .latex-output {
446
+ background: #f8f9fa !important;
447
+ border: 1px solid #e9ecef !important;
448
+ border-radius: 10px !important;
449
+ padding: clamp(1rem, 3vw, 1.5rem) !important;
450
+ font-family: 'Monaco', 'Consolas', monospace !important;
451
+ font-size: clamp(0.9rem, 2vw, 1rem) !important;
452
+ line-height: 1.6 !important;
453
+ min-height: min(200px, 30vh) !important;
454
+ overflow-x: auto !important;
455
+ white-space: pre-wrap !important;
456
+ word-break: break-word !important;
457
+ }
458
+
459
+ .rendered-output {
460
+ background: white !important;
461
+ border: 1px solid #e9ecef !important;
462
+ border-radius: 10px !important;
463
+ padding: clamp(1.5rem, 4vw, 2.5rem) !important;
464
+ text-align: center !important;
465
+ min-height: min(300px, 40vh) !important;
466
+ display: flex !important;
467
+ align-items: center !important;
468
+ justify-content: center !important;
469
+ font-size: clamp(1.2rem, 3vw, 1.8rem) !important;
470
+ overflow-x: auto !important;
471
+ }
472
+
473
+ /* Instructions styling - RESPONSIVE */
474
+ .instructions {
475
+ background: linear-gradient(135deg, #84fab0 0%, #8fd3f4 100%) !important;
476
+ border-radius: 12px !important;
477
+ padding: clamp(1rem, 3vw, 1.5rem) !important;
478
+ margin-bottom: clamp(1rem, 3vw, 2rem) !important;
479
+ border-left: 4px solid #28a745 !important;
480
+ }
481
+
482
+ .instructions h3 {
483
+ color: #155724 !important;
484
+ margin-bottom: 0.8rem !important;
485
+ font-weight: 600 !important;
486
+ font-size: clamp(1rem, 2.5vw, 1.1rem) !important;
487
+ }
488
+
489
+ .instructions p {
490
+ color: #155724 !important;
491
+ margin: 0.5rem 0 !important;
492
+ font-size: clamp(0.9rem, 2vw, 1rem) !important;
493
+ line-height: 1.5 !important;
494
+ }
495
+
496
+ /* Drawing tips styling - RESPONSIVE */
497
+ .drawing-tips {
498
+ background: linear-gradient(135deg, #ffecd2 0%, #fcb69f 100%) !important;
499
+ border-radius: 10px !important;
500
+ padding: clamp(0.8rem, 2vw, 1rem) !important;
501
+ margin-top: 1rem !important;
502
+ border-left: 4px solid #fd7e14 !important;
503
+ }
504
+
505
+ .drawing-tips h4 {
506
+ color: #8a4100 !important;
507
+ margin-bottom: 0.5rem !important;
508
+ font-weight: 600 !important;
509
+ font-size: clamp(0.9rem, 2vw, 1rem) !important;
510
+ }
511
+
512
+ .drawing-tips ul {
513
+ color: #8a4100 !important;
514
+ margin: 0 !important;
515
+ padding-left: clamp(1rem, 3vw, 1.5rem) !important;
516
+ }
517
+
518
+ .drawing-tips li {
519
+ margin: 0.3rem 0 !important;
520
+ font-size: clamp(0.8rem, 1.8vw, 0.9rem) !important;
521
+ }
522
+
523
+ /* Full-width layout adjustments - RESPONSIVE */
524
+ .input-output-container {
525
+ display: grid !important;
526
+ grid-template-columns: repeat(auto-fit, minmax(min(100%, 600px), 1fr)) !important;
527
+ gap: clamp(1rem, 3vw, 2rem) !important;
528
+ align-items: start !important;
529
+ width: 100% !important;
530
+ box-sizing: border-box !important;
531
+ }
532
+
533
+ /* Examples section - RESPONSIVE */
534
+ .examples-grid {
535
+ display: grid !important;
536
+ grid-template-columns: repeat(auto-fit, minmax(min(100%, 250px), 1fr)) !important;
537
+ gap: clamp(1rem, 3vw, 1.5rem) !important;
538
+ text-align: center !important;
539
+ }
540
+
541
+ .example-card {
542
+ padding: clamp(1rem, 3vw, 1.5rem) !important;
543
+ background: linear-gradient(135deg, #f8f9fa 0%, #e9ecef 100%) !important;
544
+ border-radius: 12px !important;
545
+ border: 2px solid #dee2e6 !important;
546
+ }
547
+
548
+ .example-card strong {
549
+ color: #495057 !important;
550
+ font-size: clamp(0.9rem, 2.5vw, 1.1rem) !important;
551
+ display: block !important;
552
+ margin-bottom: 0.5rem !important;
553
+ }
554
+
555
+ .example-card span {
556
+ font-family: monospace !important;
557
+ color: #6c757d !important;
558
+ font-size: clamp(0.8rem, 2vw, 0.9rem) !important;
559
+ line-height: 1.6 !important;
560
+ }
561
+
562
+ /* Performance metrics section - RESPONSIVE */
563
+ .metrics-grid {
564
+ display: grid !important;
565
+ grid-template-columns: repeat(auto-fit, minmax(min(100%, 200px), 1fr)) !important;
566
+ gap: clamp(0.8rem, 2vw, 1rem) !important;
567
+ }
568
+
569
+ .metric-item {
570
+ text-align: center !important;
571
+ padding: clamp(0.5rem, 2vw, 1rem) !important;
572
+ }
573
+
574
+ .metric-item strong {
575
+ color: #e65100 !important;
576
+ font-size: clamp(0.9rem, 2.5vw, 1rem) !important;
577
+ display: block !important;
578
+ margin-bottom: 0.3rem !important;
579
+ }
580
+
581
+ .metric-item span {
582
+ color: #bf360c !important;
583
+ font-size: clamp(0.8rem, 2vw, 0.9rem) !important;
584
+ }
585
+
586
+ /* Responsive breakpoints */
587
+ @media (max-width: 1200px) {
588
+ .gradio-container {
589
+ padding: 0.8rem !important;
590
+ }
591
+ }
592
+
593
+ @media (max-width: 768px) {
594
+ .gradio-container {
595
+ padding: 0.5rem !important;
596
+ }
597
+
598
+ .main-container {
599
+ padding: 0.8rem !important;
600
+ margin: 0.5rem 0 !important;
601
+ }
602
+
603
+ .input-section, .output-section {
604
+ padding: 0.8rem !important;
605
+ }
606
+
607
+ .tab-nav {
608
+ flex-direction: column !important;
609
+ }
610
+
611
+ .tab-nav button {
612
+ width: 100% !important;
613
+ }
614
+ }
615
+
616
+ @media (max-width: 480px) {
617
+ .gradio-container {
618
+ padding: 0.3rem !important;
619
+ }
620
+
621
+ .main-container {
622
+ padding: 0.5rem !important;
623
+ margin: 0.3rem 0 !important;
624
+ }
625
+
626
+ .input-section, .output-section {
627
+ padding: 0.5rem !important;
628
+ }
629
+
630
+ .process-button {
631
+ padding: 0.8rem 1.2rem !important;
632
+ font-size: 0.9rem !important;
633
+ }
634
+ }
635
+
636
+ /* Touch device optimizations */
637
+ @media (hover: none) {
638
+ .process-button:hover {
639
+ transform: none !important;
640
+ }
641
+
642
+ .sketchpad-container {
643
+ touch-action: none !important;
644
+ -webkit-touch-callout: none !important;
645
+ -webkit-user-select: none !important;
646
+ user-select: none !important;
647
+ }
648
+
649
+ .tab-nav button {
650
+ padding: 1rem !important;
651
+ }
652
+ }
653
+
654
+ /* Print styles */
655
+ @media print {
656
+ .gradio-container {
657
+ max-width: 100% !important;
658
+ padding: 0 !important;
659
+ }
660
+
661
+ .input-section, .output-section {
662
+ break-inside: avoid !important;
663
+ }
664
+
665
+ .process-button, .tab-nav {
666
+ display: none !important;
667
+ }
668
+ }
669
+ """
670
+
671
+ # Create the enhanced Gradio interface with expanded input
672
+ with gr.Blocks(css=custom_css, title="Math Expression Recognition") as demo:
673
+ gr.HTML('<h1 class="header-title">๐Ÿงฎ Handwritten Mathematical Expression Recognition</h1>')
674
+
675
+ # Enhanced Instructions
676
+ gr.HTML("""
677
+ <div class="instructions">
678
+ <h3>๐Ÿ“ How to use this expanded interface:</h3>
679
+ <p><strong>โœ๏ธ Draw Tab:</strong> Use the large drawing canvas (800x500px) to draw mathematical expressions with your mouse or touch device</p>
680
+ <p><strong>๐Ÿ“ Upload Tab:</strong> Upload high-resolution images containing handwritten mathematical expressions</p>
681
+ <p><strong>๐ŸŽฏ Tips:</strong> Write clearly, use proper mathematical notation, and ensure good contrast between your writing and the background</p>
682
+ </div>
683
+ """)
684
+
685
+ with gr.Row(elem_classes="input-output-container"):
686
+ # Expanded Input Section
687
+ with gr.Column(elem_classes="input-section"):
688
+ gr.HTML('<h2 style="text-align: center; color: #667eea; margin-bottom: 1.5rem; font-size: 1.5rem;">๐Ÿ“ฅ Input Area</h2>')
689
+
690
+ with gr.Tabs():
691
+ with gr.TabItem("โœ๏ธ Draw Expression"):
692
+ gr.HTML("""
693
+ <div class="drawing-tips">
694
+ <h4>๐ŸŽจ Drawing Tips:</h4>
695
+ <ul>
696
+ <li>Use clear, legible handwriting</li>
697
+ <li>Draw symbols at reasonable sizes</li>
698
+ <li>Leave space between different parts</li>
699
+ <li>Use standard mathematical notation</li>
700
+ <li>Avoid overlapping symbols</li>
701
+ </ul>
702
+ </div>
703
+ """)
704
+
705
+ draw_input = gr.Sketchpad(
706
+ label="Draw your mathematical expression here",
707
+ elem_classes="sketchpad-container",
708
+ height=500,
709
+ width=800,
710
+ canvas_size=(800, 500)
711
+ )
712
+ draw_button = gr.Button("๐Ÿš€ Recognize Drawn Expression", elem_classes="process-button")
713
+
714
+ with gr.TabItem("๐Ÿ“ Upload Image"):
715
+ gr.HTML("""
716
+ <div class="drawing-tips">
717
+ <h4>๐Ÿ“ท Upload Tips:</h4>
718
+ <ul>
719
+ <li>Use high-resolution images (minimum 300 DPI)</li>
720
+ <li>Ensure good lighting and contrast</li>
721
+ <li>Crop the image to focus on the expression</li>
722
+ <li>Avoid shadows or glare</li>
723
+ <li>Supported formats: PNG, JPG, JPEG</li>
724
+ </ul>
725
+ </div>
726
+ """)
727
+
728
+ upload_input = gr.Image(
729
+ label="Upload your mathematical expression image",
730
+ elem_classes="upload-container",
731
+ height=500,
732
+ type="pil"
733
+ )
734
+ upload_button = gr.Button("๐Ÿš€ Recognize Uploaded Expression", elem_classes="process-button")
735
+
736
+ # Output Section
737
+ with gr.Column(elem_classes="output-section"):
738
+ gr.HTML('<h2 style="text-align: center; color: #667eea; margin-bottom: 1.5rem; font-size: 1.5rem;">๐Ÿ“ค Recognition Results</h2>')
739
+
740
+ with gr.Tabs():
741
+ with gr.TabItem("๐Ÿ“„ LaTeX Code"):
742
+ latex_output = gr.Textbox(
743
+ label="Generated LaTeX Code",
744
+ elem_classes="latex-output",
745
+ lines=8,
746
+ placeholder="Your LaTeX code will appear here...\n\nThis is the raw LaTeX markup that represents your mathematical expression. You can copy this code and use it in any LaTeX document or LaTeX-compatible system.",
747
+ interactive=False
748
+ )
749
+
750
+ with gr.TabItem("๐ŸŽจ Rendered Expression"):
751
+ rendered_output = gr.Markdown(
752
+ label="Rendered Mathematical Expression",
753
+ elem_classes="rendered-output",
754
+ value="*Your beautifully rendered mathematical expression will appear here...*\n\n*Draw or upload an expression to see the magic happen!*"
755
+ )
756
+
757
+ # Connect the buttons to their respective functions
758
+ draw_button.click(
759
+ fn=process_draw,
760
+ inputs=[draw_input],
761
+ outputs=[latex_output, rendered_output]
762
+ )
763
+
764
+ upload_button.click(
765
+ fn=process_upload,
766
+ inputs=[upload_input],
767
+ outputs=[latex_output, rendered_output]
768
+ )
769
+
770
+ if __name__ == "__main__":
771
+ demo.launch(share=True, inbrowser=True)
config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "can":{
3
+ "input_height": 128,
4
+ "input_width": 1024,
5
+ "base_dir": "data/CROHME_splitted",
6
+ "batch_size": 32,
7
+ "num_workers": 4,
8
+ "seed": 1337,
9
+ "checkpoint_dir": "checkpoints",
10
+ "pretrained_backbone": 1,
11
+ "backbone_type": "densenet",
12
+ "hidden_size": 256,
13
+ "embedding_dim": 256,
14
+ "use_coverage": 1,
15
+ "lambda_count": 0.01,
16
+ "lr": 3e-4,
17
+ "epochs": 100,
18
+ "grad_clip": 5.0,
19
+ "print_freq": 20,
20
+ "t": 5,
21
+ "t_mult": 2,
22
+ "visualize": 1,
23
+ "test_folder": "data/CROHME_splitted/test/img",
24
+ "label_file": "data/CROHME_splitted/test/caption.txt",
25
+ "relative_image_path": "18_em_18.bmp",
26
+ "mode": "evaluate",
27
+ "classifier": "frac"
28
+ }
29
+ }
requirements.txt ADDED
Binary file (3.07 kB). View file