Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- .gitattributes +2 -35
- README.md +190 -12
- app.py +211 -0
- poetry.lock +0 -0
- pyproject.toml +53 -0
- requirements.txt +167 -0
- src/handlers/__init__.py +84 -0
- src/handlers/audio_models/whisper_handler.py +14 -0
- src/handlers/base_handler.py +148 -0
- src/handlers/img_models/image_classification_handler.py +14 -0
- src/handlers/nlp_models/causal_lm_handler.py +46 -0
- src/handlers/nlp_models/embedding_model_handler.py +49 -0
- src/handlers/nlp_models/masked_lm_handler.py +39 -0
- src/handlers/nlp_models/multiple_choice_handler.py +39 -0
- src/handlers/nlp_models/question_answering_handler.py +59 -0
- src/handlers/nlp_models/seq2seq_lm_handler.py +34 -0
- src/handlers/nlp_models/sequence_classification_handler.py +49 -0
- src/handlers/nlp_models/token_classification_handler.py +57 -0
- src/optimizations/onnx_conversion.py +64 -0
- src/optimizations/quantize.py +109 -0
- src/utilities/push_to_hub.py +85 -0
- src/utilities/resources.py +54 -0
.gitattributes
CHANGED
@@ -1,35 +1,2 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
# Auto detect text files and perform LF normalization
|
2 |
+
* text=auto
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
@@ -1,12 +1,190 @@
|
|
1 |
-
---
|
2 |
-
title: AutoQuantNX
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: AutoQuantNX
|
3 |
+
app_file: app.py
|
4 |
+
sdk: gradio
|
5 |
+
sdk_version: 4.44.1
|
6 |
+
---
|
7 |
+
# π€ AutoQuantNX (**Still under testing and improvement phase**)
|
8 |
+
|
9 |
+
## Overview
|
10 |
+
AutoQuantNX is a powerful Gradio-based web application designed to simplify the process of optimizing and deploying Hugging Face models. It supports a wide range of tasks, including quantization, ONNX conversion, and seamless integration with the Hugging Face Hub. With AutoQuantNX, you can easily convert models to ONNX format, apply quantization techniques, and push the optimized models to your Hugging Face accountβall through an intuitive user interface.
|
11 |
+
|
12 |
+
## Features
|
13 |
+
|
14 |
+
### Supported Tasks
|
15 |
+
AutoQuantNX supports the following tasks:
|
16 |
+
|
17 |
+
* Text Classification
|
18 |
+
* Named Entity Recognition (NER)
|
19 |
+
* Question Answering
|
20 |
+
* Causal Language Modeling
|
21 |
+
* Masked Language Modeling
|
22 |
+
* Sequence-to-Sequence Language Modeling
|
23 |
+
* Multiple Choice
|
24 |
+
* Whisper (Speech-to-Text)
|
25 |
+
* Embedding Fine-Tuning
|
26 |
+
* Image Classification (Placeholder for future implementation)
|
27 |
+
|
28 |
+
### Quantization Options
|
29 |
+
* None (default)
|
30 |
+
* 4-bit
|
31 |
+
* 8-bit
|
32 |
+
* 16-bit-float
|
33 |
+
|
34 |
+
### ONNX Conversion
|
35 |
+
Converts models to ONNX format for optimized deployment.
|
36 |
+
|
37 |
+
Supports optional ONNX quantization:
|
38 |
+
* 8-bit
|
39 |
+
* 16-bit-int
|
40 |
+
* 16-bit-float
|
41 |
+
|
42 |
+
### Hugging Face Hub Integration
|
43 |
+
* Automatically pushes optimized models to your Hugging Face Hub repository
|
44 |
+
* Tags models with metadata for easy identification (e.g., onnx, quantized, task type)
|
45 |
+
|
46 |
+
### Performance Testing
|
47 |
+
Compares original and quantized models using metrics like:
|
48 |
+
* Mean Squared Error (MSE)
|
49 |
+
* Spearman Correlation
|
50 |
+
* Cosine Similarity
|
51 |
+
* Inference Time
|
52 |
+
* Model Size
|
53 |
+
|
54 |
+
## File Structure
|
55 |
+
```
|
56 |
+
AutoQuantNX/
|
57 |
+
βββ src/
|
58 |
+
β βββ handlers/
|
59 |
+
β β βββ audio_models/
|
60 |
+
β β β βββ whisper_handler.py
|
61 |
+
β β βββ img_models/
|
62 |
+
β β β βββ image_classification_handler.py
|
63 |
+
β β βββ nlp_models/
|
64 |
+
β β β βββ causal_lm_handler.py
|
65 |
+
β β β βββ embedding_model_handler.py
|
66 |
+
β β β βββ masked_lm_handler.py
|
67 |
+
β β β βββ multiple_choice_handler.py
|
68 |
+
β β β βββ question_answering_handler.py
|
69 |
+
β β β βββ seq2seq_lm_handler.py
|
70 |
+
β β β βββ sequence_classification_handler.py
|
71 |
+
β β β βββ token_classification_handler.py
|
72 |
+
β β βββ __init__.py
|
73 |
+
β β βββ base_handler.py
|
74 |
+
β βββ optimizations/
|
75 |
+
β β βββ onnx_conversion.py
|
76 |
+
β β βββ quantize.py
|
77 |
+
β βββ utilities/
|
78 |
+
β βββ push_to_hub.py
|
79 |
+
β βββ resources.py
|
80 |
+
βββ README.md
|
81 |
+
βββ app.py
|
82 |
+
βββ poetry.lock
|
83 |
+
βββ pyproject.toml
|
84 |
+
βββ requirements.txt
|
85 |
+
```
|
86 |
+
|
87 |
+
## Prerequisites
|
88 |
+
|
89 |
+
### Using requirements.txt (Not preferable to me atleast)
|
90 |
+
* Python 3.8 or higher
|
91 |
+
* Install dependencies:
|
92 |
+
```bash
|
93 |
+
pip install -r requirements.txt
|
94 |
+
```
|
95 |
+
|
96 |
+
### Using Poetry
|
97 |
+
1. Install Poetry (if not already installed):
|
98 |
+
|
99 |
+
Linux:
|
100 |
+
```bash
|
101 |
+
curl -sSL https://install.python-poetry.org | python3 -
|
102 |
+
```
|
103 |
+
Other platforms: Follow the official instructions.
|
104 |
+
|
105 |
+
2. Install dependencies:
|
106 |
+
```bash
|
107 |
+
poetry install
|
108 |
+
```
|
109 |
+
|
110 |
+
3. Activate the virtual environment:
|
111 |
+
```bash
|
112 |
+
poetry shell
|
113 |
+
```
|
114 |
+
|
115 |
+
## Usage
|
116 |
+
|
117 |
+
### Launch the App
|
118 |
+
Run the following command to start the Gradio web application:
|
119 |
+
```bash
|
120 |
+
python src/app.py
|
121 |
+
```
|
122 |
+
The app will be accessible at http://localhost:7860 by default.
|
123 |
+
|
124 |
+
### Steps to Use the App
|
125 |
+
1. Enter Model Details:
|
126 |
+
* Provide the Hugging Face model name
|
127 |
+
* Select the task type (e.g., text classification, question answering)
|
128 |
+
|
129 |
+
2. Select Optimization Options:
|
130 |
+
* Choose quantization type (e.g., 4-bit, 8-bit)
|
131 |
+
* Enable ONNX conversion and select quantization options if needed
|
132 |
+
|
133 |
+
3. Provide Hugging Face Token:
|
134 |
+
* Enter your Hugging Face token for accessing and pushing models to the Hub
|
135 |
+
|
136 |
+
4. Start Conversion:
|
137 |
+
* Click the "Start Conversion" button to process the model
|
138 |
+
|
139 |
+
5. Monitor Progress:
|
140 |
+
* View real-time status updates, resource usage, and results directly in the app
|
141 |
+
|
142 |
+
6. Push to Hub:
|
143 |
+
* Optimized models are automatically pushed to your specified Hugging Face repository
|
144 |
+
|
145 |
+
### Example
|
146 |
+
For a model like bert-base-uncased performing text classification:
|
147 |
+
1. Select text_classification as the task
|
148 |
+
2. Enable quantization (e.g., 8-bit)
|
149 |
+
3. Enable ONNX conversion with optimization
|
150 |
+
4. Click "Start Conversion" and monitor progress
|
151 |
+
|
152 |
+
## Key Functions
|
153 |
+
|
154 |
+
### app.py
|
155 |
+
* `process_model`: Main function handling model quantization, ONNX conversion, and Hugging Face Hub integration
|
156 |
+
* `update_memory_info`: Monitors and displays system resource usage
|
157 |
+
|
158 |
+
### optimization/onnx_conversion.py
|
159 |
+
* `convert_to_onnx`: Converts models to ONNX format
|
160 |
+
* `quantize_onnx_model`: Quantizes ONNX models for optimized inference
|
161 |
+
|
162 |
+
### optimization/quantize.py
|
163 |
+
* `ModelQuantizer`: Handles quantization of PyTorch models and performance testing
|
164 |
+
|
165 |
+
### utilities/push_to_hub.py
|
166 |
+
* `push_to_hub`: Pushes models to the Hugging Face Hub
|
167 |
+
|
168 |
+
### utilities/resources.py
|
169 |
+
* `ResourceManager`: Manages temporary files and memory usage
|
170 |
+
|
171 |
+
## Notes
|
172 |
+
* Ensure you have sufficient system resources for model conversion and quantization
|
173 |
+
* Use a Hugging Face Hub token with proper write permissions for pushing models
|
174 |
+
|
175 |
+
## Troubleshooting
|
176 |
+
* Model Conversion Fails: Ensure the model and task are supported
|
177 |
+
* Insufficient Resources: Free up memory or reduce optimization levels
|
178 |
+
* ONNX Quantization Errors: Verify that the selected quantization type is supported for the model
|
179 |
+
|
180 |
+
## License
|
181 |
+
This project is licensed under the MIT License. See the LICENSE file for details.
|
182 |
+
|
183 |
+
## Contributions
|
184 |
+
Contributions are welcome! Please open an issue or submit a pull request for any improvements or bug fixes.
|
185 |
+
|
186 |
+
## Acknowledgments
|
187 |
+
* Hugging Face Transformers
|
188 |
+
* Optimum Library
|
189 |
+
* Gradio
|
190 |
+
* ONNX Runtime
|
app.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import logging
|
3 |
+
from typing import Tuple, Dict, Any
|
4 |
+
from src.utilities.resources import ResourceManager
|
5 |
+
from src.utilities.push_to_hub import push_to_hub
|
6 |
+
from src.optimizations.onnx_conversion import convert_to_onnx
|
7 |
+
from src.optimizations.quantize import quantize_onnx_model
|
8 |
+
from src.handlers import get_model_handler, TASK_CONFIGS
|
9 |
+
logging.basicConfig(level=logging.INFO)
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
import json
|
12 |
+
|
13 |
+
def process_model(
|
14 |
+
model_name: str,
|
15 |
+
task: str,
|
16 |
+
quantization_type: str,
|
17 |
+
enable_onnx: bool,
|
18 |
+
onnx_quantization: str,
|
19 |
+
hf_token: str,
|
20 |
+
repo_name: str,
|
21 |
+
test_text: str
|
22 |
+
) -> Tuple[Dict[str, Any], str, Dict[str, Any]]:
|
23 |
+
try:
|
24 |
+
resource_manager = ResourceManager()
|
25 |
+
status_updates = []
|
26 |
+
status = {
|
27 |
+
"status": "Processing",
|
28 |
+
"progress": 0,
|
29 |
+
"current_step": "Initializing",
|
30 |
+
}
|
31 |
+
|
32 |
+
metrics = {}
|
33 |
+
|
34 |
+
if not model_name or not hf_token or not repo_name:
|
35 |
+
return (
|
36 |
+
{"status": "Error", "progress": 0, "current_step": "Validation Failed"},
|
37 |
+
"Model name, HuggingFace token, and repository name are required.",
|
38 |
+
metrics
|
39 |
+
)
|
40 |
+
|
41 |
+
status["progress"] = 0.2
|
42 |
+
status["current_step"] = "Initialization"
|
43 |
+
status_updates.append("Initialization complete")
|
44 |
+
|
45 |
+
quantized_model_path = None
|
46 |
+
|
47 |
+
if quantization_type != "None":
|
48 |
+
status.update({"progress": 0.4, "current_step": "Quantization"})
|
49 |
+
status_updates.append(f"Applying {quantization_type} quantization")
|
50 |
+
|
51 |
+
if not test_text:
|
52 |
+
test_text = TASK_CONFIGS[task]["example_text"]
|
53 |
+
|
54 |
+
try:
|
55 |
+
handler = get_model_handler(task, model_name, quantization_type, test_text)
|
56 |
+
quantized_model = handler.compare()
|
57 |
+
metrics = handler.get_metrics()
|
58 |
+
metrics = json.loads(json.dumps(metrics))
|
59 |
+
|
60 |
+
quantized_model_path = str(resource_manager.temp_dirs["quantized"] / "model")
|
61 |
+
quantized_model.save_pretrained(quantized_model_path)
|
62 |
+
status_updates.append("Quantization completed successfully")
|
63 |
+
except Exception as e:
|
64 |
+
logger.error(f"Quantization error: {str(e)}", exc_info=True)
|
65 |
+
return (
|
66 |
+
{"status": "Error", "progress": 0.4, "current_step": "Quantization Failed"},
|
67 |
+
f"Quantization failed: {str(e)}",
|
68 |
+
metrics
|
69 |
+
)
|
70 |
+
|
71 |
+
if enable_onnx:
|
72 |
+
status.update({"progress": 0.6, "current_step": "ONNX Conversion"})
|
73 |
+
status_updates.append("Converting to ONNX format")
|
74 |
+
|
75 |
+
try:
|
76 |
+
output_dir = str(resource_manager.temp_dirs["onnx"])
|
77 |
+
onnx_result = convert_to_onnx(model_name, task, output_dir)
|
78 |
+
|
79 |
+
if onnx_result is None:
|
80 |
+
return (
|
81 |
+
{"status": "Error", "progress": 0.6, "current_step": "ONNX Conversion Failed"},
|
82 |
+
"ONNX conversion failed.",
|
83 |
+
metrics
|
84 |
+
)
|
85 |
+
|
86 |
+
if onnx_quantization != "None":
|
87 |
+
status_updates.append(f"Applying {onnx_quantization} ONNX quantization")
|
88 |
+
quantize_onnx_model(output_dir, onnx_quantization)
|
89 |
+
|
90 |
+
status.update({"progress": 0.8, "current_step": "Pushing ONNX Model"})
|
91 |
+
status_updates.append("Pushing ONNX model to Hub")
|
92 |
+
result, push_message = push_to_hub(
|
93 |
+
local_path=output_dir,
|
94 |
+
repo_name=f"{repo_name}-optimized",
|
95 |
+
hf_token=hf_token,
|
96 |
+
tags=["onnx", "optimum", task],
|
97 |
+
)
|
98 |
+
status_updates.append(push_message)
|
99 |
+
except Exception as e:
|
100 |
+
logger.error(f"ONNX error: {str(e)}", exc_info=True)
|
101 |
+
return (
|
102 |
+
{"status": "Error", "progress": 0.6, "current_step": "ONNX Processing Failed"},
|
103 |
+
f"ONNX processing failed: {str(e)}",
|
104 |
+
metrics
|
105 |
+
)
|
106 |
+
|
107 |
+
if quantization_type != "None" and quantized_model_path:
|
108 |
+
status.update({"progress": 0.9, "current_step": "Pushing Quantized Model"})
|
109 |
+
status_updates.append("Pushing quantized model to Hub")
|
110 |
+
result, push_message = push_to_hub(
|
111 |
+
local_path=quantized_model_path,
|
112 |
+
repo_name=f"{repo_name}-optimized",
|
113 |
+
hf_token=hf_token,
|
114 |
+
tags=["quantized", task, quantization_type],
|
115 |
+
)
|
116 |
+
status_updates.append(push_message)
|
117 |
+
|
118 |
+
status.update({"progress": 1.0, "status": "Complete", "current_step": "Completed"})
|
119 |
+
cleanup_message = resource_manager.cleanup_temp_files()
|
120 |
+
status_updates.append(cleanup_message)
|
121 |
+
|
122 |
+
return (
|
123 |
+
status,
|
124 |
+
"\n".join(status_updates),
|
125 |
+
metrics
|
126 |
+
)
|
127 |
+
|
128 |
+
except Exception as e:
|
129 |
+
logger.error(f"Error during processing: {str(e)}", exc_info=True)
|
130 |
+
return (
|
131 |
+
{"status": "Error", "progress": 0, "current_step": "Process Failed"},
|
132 |
+
f"An error occurred: {str(e)}",
|
133 |
+
metrics
|
134 |
+
)
|
135 |
+
|
136 |
+
# Gradio Interface
|
137 |
+
with gr.Blocks(theme=gr.themes.Soft()) as app:
|
138 |
+
gr.Markdown("""
|
139 |
+
# π€ Model Conversion Hub
|
140 |
+
Convert and optimize your Hugging Face models with quantization and ONNX support.
|
141 |
+
""")
|
142 |
+
|
143 |
+
with gr.Row():
|
144 |
+
with gr.Column(scale=2):
|
145 |
+
model_name = gr.Textbox(label="Model Name", placeholder="e.g., bert-base-uncased")
|
146 |
+
task = gr.Dropdown(choices=list(TASK_CONFIGS.keys()), label="Task", value="text_classification")
|
147 |
+
|
148 |
+
with gr.Group():
|
149 |
+
gr.Markdown("### Quantization Settings")
|
150 |
+
quantization_type = gr.Dropdown(choices=["None", "4-bit", "8-bit", "16-bit-float"], label="Quantization Type", value="None")
|
151 |
+
test_text = gr.Textbox(label="Test Text", placeholder="Enter text for model evaluation", lines=3, visible=False)
|
152 |
+
|
153 |
+
with gr.Group():
|
154 |
+
gr.Markdown("### ONNX Settings")
|
155 |
+
enable_onnx = gr.Checkbox(label="Enable ONNX Conversion")
|
156 |
+
with gr.Group(visible=False) as onnx_group:
|
157 |
+
onnx_quantization = gr.Dropdown(choices=["None", "8-bit", "16-bit-int", "16-bit-float"], label="ONNX Quantization", value="None")
|
158 |
+
|
159 |
+
with gr.Group():
|
160 |
+
gr.Markdown("### HuggingFace Settings")
|
161 |
+
hf_token = gr.Textbox(label="HuggingFace Token (Required)", type="password")
|
162 |
+
repo_name = gr.Textbox(label="Repository Name")
|
163 |
+
|
164 |
+
with gr.Column(scale=1):
|
165 |
+
status_output = gr.JSON(label="Status", value={"status": "Ready", "progress": 0, "current_step": "Waiting"})
|
166 |
+
message_output = gr.Markdown(label="Progress Messages")
|
167 |
+
|
168 |
+
gr.Markdown("### Metrics")
|
169 |
+
with gr.Group():
|
170 |
+
metrics_output = gr.JSON(
|
171 |
+
value={
|
172 |
+
"model_sizes": {"original": 0.0, "quantized": 0.0},
|
173 |
+
"inference_times": {"original": 0.0, "quantized": 0.0},
|
174 |
+
"comparison_metrics": {}
|
175 |
+
},
|
176 |
+
show_label=True
|
177 |
+
)
|
178 |
+
|
179 |
+
memory_info = gr.JSON(label="Resource Usage")
|
180 |
+
convert_btn = gr.Button("π Start Conversion", variant="primary")
|
181 |
+
|
182 |
+
with gr.Accordion("βΉοΈ Help", open=False):
|
183 |
+
gr.Markdown("""
|
184 |
+
### Quick Guide
|
185 |
+
1. Enter your model name and HuggingFace token.
|
186 |
+
2. Select the appropriate task.
|
187 |
+
3. Choose optimization options.
|
188 |
+
4. Click Start Conversion.
|
189 |
+
|
190 |
+
### Tips
|
191 |
+
- Ensure sufficient system resources.
|
192 |
+
- Use test text to validate conversions.
|
193 |
+
""")
|
194 |
+
|
195 |
+
def update_memory_info():
|
196 |
+
resource_manager = ResourceManager()
|
197 |
+
return resource_manager.get_memory_info()
|
198 |
+
|
199 |
+
quantization_type.change(lambda x: gr.update(visible=x != "None"), inputs=[quantization_type], outputs=[test_text])
|
200 |
+
task.change(lambda x: gr.update(value=TASK_CONFIGS[x]["example_text"]), inputs=[task], outputs=[test_text])
|
201 |
+
enable_onnx.change(lambda x: gr.update(visible=x), inputs=[enable_onnx], outputs=[onnx_group])
|
202 |
+
|
203 |
+
convert_btn.click(
|
204 |
+
process_model,
|
205 |
+
inputs=[model_name, task, quantization_type, enable_onnx, onnx_quantization, hf_token, repo_name, test_text],
|
206 |
+
outputs=[status_output, message_output, metrics_output]
|
207 |
+
)
|
208 |
+
app.load(update_memory_info, outputs=[memory_info], every=30)
|
209 |
+
|
210 |
+
if __name__ == "__main__":
|
211 |
+
app.launch(server_name="0.0.0.0", server_port=7860, share=True, debug=True)
|
poetry.lock
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pyproject.toml
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[tool.poetry]
|
2 |
+
name = "autoquantnx"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = "Webapp to quantize and convert to ONNX HF models in go and compare them"
|
5 |
+
authors = ["kartikbhtt7 <[email protected]>"]
|
6 |
+
readme = "README.md"
|
7 |
+
|
8 |
+
[tool.poetry.dependencies]
|
9 |
+
python = ">=3.10,<3.12"
|
10 |
+
gradio = "^4.00.0"
|
11 |
+
transformers = "^4.31.0"
|
12 |
+
pandas = "^2.2.2"
|
13 |
+
torch = "^2.3.0"
|
14 |
+
onnx = "^1.16.2"
|
15 |
+
onnxruntime = "^1.18.1"
|
16 |
+
onnxconverter-common = ">=1.14.0"
|
17 |
+
optimum = "^1.21.3"
|
18 |
+
huggingface-hub = "^0.24.6"
|
19 |
+
sentence-transformers = "^3.0.1"
|
20 |
+
bitsandbytes = "^0.43.3"
|
21 |
+
evaluate = "^0.4.0"
|
22 |
+
faiss-gpu = "^1.7.2"
|
23 |
+
faiss-cpu = "^1.8.0.post1"
|
24 |
+
azure-cognitiveservices-speech = "^1.40.0"
|
25 |
+
gdown = "^5.2.0"
|
26 |
+
jiwer = "^3.0.4"
|
27 |
+
pydub = "^0.25.1"
|
28 |
+
librosa = "^0.10.2.post1"
|
29 |
+
soundfile = "^0.12.1"
|
30 |
+
catalogue = "^2.0.10"
|
31 |
+
langchain-core = "^0.1.40"
|
32 |
+
langchain-openai = "^0.1.0"
|
33 |
+
fast-pytorch-kmeans = "^0.2.0.1"
|
34 |
+
typing-extensions = "^4.12.2"
|
35 |
+
textwrap3 = "^0.9.2"
|
36 |
+
pynvml = "^11.5.3"
|
37 |
+
psutil = "^6.1.1"
|
38 |
+
accelerate = "^0.26.0"
|
39 |
+
|
40 |
+
[tool.poetry.dev-dependencies]
|
41 |
+
black = "^23.7.0"
|
42 |
+
flake8 = "^6.1.0"
|
43 |
+
pytest = "^7.4.3"
|
44 |
+
pytest-asyncio = "^0.21.1"
|
45 |
+
pytest-django = "^4.8.0"
|
46 |
+
pytest-cov = "^4.1.0"
|
47 |
+
pytest-testmon = "^2.1.0"
|
48 |
+
pytest-watch = "^4.2.0"
|
49 |
+
coverage = "^7.3.2"
|
50 |
+
|
51 |
+
[build-system]
|
52 |
+
requires = ["poetry>=1.5.1"]
|
53 |
+
build-backend = "poetry.core.masonry.api"
|
requirements.txt
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==0.26.1
|
2 |
+
aiofiles==23.2.1
|
3 |
+
aiohappyeyeballs==2.4.4
|
4 |
+
aiohttp==3.11.11
|
5 |
+
aiosignal==1.3.2
|
6 |
+
annotated-types==0.7.0
|
7 |
+
anyio==4.7.0
|
8 |
+
async-timeout==5.0.1
|
9 |
+
attrs==24.3.0
|
10 |
+
audioread==3.0.1
|
11 |
+
azure-cognitiveservices-speech==1.41.1
|
12 |
+
beautifulsoup4==4.12.3
|
13 |
+
bitsandbytes==0.43.3
|
14 |
+
black==23.12.1
|
15 |
+
catalogue==2.0.10
|
16 |
+
certifi==2024.12.14
|
17 |
+
cffi==1.17.1
|
18 |
+
charset-normalizer==3.4.1
|
19 |
+
click==8.1.8
|
20 |
+
colorama==0.4.6
|
21 |
+
coloredlogs==15.0.1
|
22 |
+
contourpy==1.3.1
|
23 |
+
coverage==7.6.10
|
24 |
+
cycler==0.12.1
|
25 |
+
datasets==2.14.4
|
26 |
+
decorator==5.1.1
|
27 |
+
dill==0.3.7
|
28 |
+
distro==1.9.0
|
29 |
+
docopt==0.6.2
|
30 |
+
evaluate==0.4.3
|
31 |
+
exceptiongroup==1.2.2
|
32 |
+
faiss-cpu==1.9.0.post1
|
33 |
+
faiss-gpu==1.7.2
|
34 |
+
fast_pytorch_kmeans==0.2.2
|
35 |
+
fastapi==0.115.6
|
36 |
+
ffmpy==0.5.0
|
37 |
+
filelock==3.16.1
|
38 |
+
flake8==6.1.0
|
39 |
+
flatbuffers==24.12.23
|
40 |
+
fonttools==4.55.3
|
41 |
+
frozenlist==1.5.0
|
42 |
+
fsspec==2024.12.0
|
43 |
+
gdown==5.2.0
|
44 |
+
gradio==4.44.1
|
45 |
+
gradio_client==1.3.0
|
46 |
+
h11==0.14.0
|
47 |
+
httpcore==1.0.7
|
48 |
+
httpx==0.28.1
|
49 |
+
huggingface-hub==0.24.7
|
50 |
+
humanfriendly==10.0
|
51 |
+
idna==3.10
|
52 |
+
importlib_resources==6.4.5
|
53 |
+
iniconfig==2.0.0
|
54 |
+
Jinja2==3.1.5
|
55 |
+
jiter==0.8.2
|
56 |
+
jiwer==3.0.5
|
57 |
+
joblib==1.4.2
|
58 |
+
jsonpatch==1.33
|
59 |
+
jsonpointer==3.0.0
|
60 |
+
kiwisolver==1.4.8
|
61 |
+
langchain-core==0.1.53
|
62 |
+
langchain-openai==0.1.7
|
63 |
+
langsmith==0.1.147
|
64 |
+
lazy_loader==0.4
|
65 |
+
librosa==0.10.2.post1
|
66 |
+
llvmlite==0.43.0
|
67 |
+
markdown-it-py==3.0.0
|
68 |
+
MarkupSafe==2.1.5
|
69 |
+
matplotlib==3.10.0
|
70 |
+
mccabe==0.7.0
|
71 |
+
mdurl==0.1.2
|
72 |
+
mpmath==1.3.0
|
73 |
+
msgpack==1.1.0
|
74 |
+
multidict==6.1.0
|
75 |
+
multiprocess==0.70.15
|
76 |
+
mypy-extensions==1.0.0
|
77 |
+
networkx==3.4.2
|
78 |
+
numba==0.60.0
|
79 |
+
numpy==1.26.4
|
80 |
+
nvidia-cublas-cu12==12.1.3.1
|
81 |
+
nvidia-cuda-cupti-cu12==12.1.105
|
82 |
+
nvidia-cuda-nvrtc-cu12==12.1.105
|
83 |
+
nvidia-cuda-runtime-cu12==12.1.105
|
84 |
+
nvidia-cudnn-cu12==9.1.0.70
|
85 |
+
nvidia-cufft-cu12==11.0.2.54
|
86 |
+
nvidia-curand-cu12==10.3.2.106
|
87 |
+
nvidia-cusolver-cu12==11.4.5.107
|
88 |
+
nvidia-cusparse-cu12==12.1.0.106
|
89 |
+
nvidia-nccl-cu12==2.20.5
|
90 |
+
nvidia-nvjitlink-cu12==12.6.85
|
91 |
+
nvidia-nvtx-cu12==12.1.105
|
92 |
+
onnx==1.17.0
|
93 |
+
onnxconverter-common==1.14.0
|
94 |
+
onnxruntime==1.20.1
|
95 |
+
openai==1.58.1
|
96 |
+
optimum==1.23.3
|
97 |
+
orjson==3.10.13
|
98 |
+
packaging==23.2
|
99 |
+
pandas==2.2.3
|
100 |
+
pathspec==0.12.1
|
101 |
+
pillow==10.4.0
|
102 |
+
platformdirs==4.3.6
|
103 |
+
pluggy==1.5.0
|
104 |
+
pooch==1.8.2
|
105 |
+
propcache==0.2.1
|
106 |
+
protobuf==3.20.2
|
107 |
+
psutil==6.1.1
|
108 |
+
pyarrow==18.1.0
|
109 |
+
pycodestyle==2.11.1
|
110 |
+
pycparser==2.22
|
111 |
+
pydantic==2.10.4
|
112 |
+
pydantic_core==2.27.2
|
113 |
+
pydub==0.25.1
|
114 |
+
pyflakes==3.1.0
|
115 |
+
Pygments==2.19.1
|
116 |
+
pynvml==11.5.3
|
117 |
+
pyparsing==3.2.1
|
118 |
+
PySocks==1.7.1
|
119 |
+
pytest==7.4.4
|
120 |
+
pytest-asyncio==0.21.2
|
121 |
+
pytest-cov==4.1.0
|
122 |
+
pytest-django==4.9.0
|
123 |
+
pytest-testmon==2.1.3
|
124 |
+
pytest-watch==4.2.0
|
125 |
+
python-dateutil==2.9.0.post0
|
126 |
+
python-multipart==0.0.20
|
127 |
+
pytz==2024.2
|
128 |
+
PyYAML==6.0.2
|
129 |
+
RapidFuzz==3.11.0
|
130 |
+
regex==2024.11.6
|
131 |
+
requests==2.32.3
|
132 |
+
requests-toolbelt==1.0.0
|
133 |
+
rich==13.9.4
|
134 |
+
ruff==0.9.6
|
135 |
+
safetensors==0.4.5
|
136 |
+
scikit-learn==1.6.0
|
137 |
+
scipy==1.14.1
|
138 |
+
semantic-version==2.10.0
|
139 |
+
sentence-transformers==3.3.1
|
140 |
+
shellingham==1.5.4
|
141 |
+
six==1.17.0
|
142 |
+
sniffio==1.3.1
|
143 |
+
soundfile==0.12.1
|
144 |
+
soupsieve==2.6
|
145 |
+
soxr==0.5.0.post1
|
146 |
+
starlette==0.41.3
|
147 |
+
sympy==1.13.3
|
148 |
+
tenacity==8.5.0
|
149 |
+
textwrap3==0.9.2
|
150 |
+
threadpoolctl==3.5.0
|
151 |
+
tiktoken==0.8.0
|
152 |
+
tokenizers==0.21.0
|
153 |
+
tomli==2.2.1
|
154 |
+
tomlkit==0.12.0
|
155 |
+
torch==2.4.1
|
156 |
+
tqdm==4.67.1
|
157 |
+
transformers==4.47.1
|
158 |
+
triton==3.0.0
|
159 |
+
typer==0.15.1
|
160 |
+
typing_extensions==4.12.2
|
161 |
+
tzdata==2024.2
|
162 |
+
urllib3==2.3.0
|
163 |
+
uvicorn==0.34.0
|
164 |
+
watchdog==6.0.0
|
165 |
+
websockets==11.0.3
|
166 |
+
xxhash==3.5.0
|
167 |
+
yarl==1.18.3
|
src/handlers/__init__.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_handler import ModelHandler
|
2 |
+
from .nlp_models.sequence_classification_handler import SequenceClassificationHandler
|
3 |
+
from .nlp_models.question_answering_handler import QuestionAnsweringHandler
|
4 |
+
from .nlp_models.token_classification_handler import TokenClassificationHandler
|
5 |
+
from .nlp_models.causal_lm_handler import CausalLMHandler
|
6 |
+
from .nlp_models.embedding_model_handler import EmbeddingModelHandler
|
7 |
+
from .audio_models.whisper_handler import WhisperHandler
|
8 |
+
from .nlp_models.masked_lm_handler import MaskedLMHandler
|
9 |
+
from .nlp_models.seq2seq_lm_handler import Seq2SeqLMHandler
|
10 |
+
from .nlp_models.multiple_choice_handler import MultipleChoiceHandler
|
11 |
+
from .img_models.image_classification_handler import ImageClassificationHandler
|
12 |
+
|
13 |
+
from transformers import (
|
14 |
+
AutoModel,
|
15 |
+
AutoModelForTokenClassification,
|
16 |
+
AutoModelForSequenceClassification,
|
17 |
+
AutoModelForQuestionAnswering,
|
18 |
+
AutoModelForCausalLM,
|
19 |
+
AutoModelForMaskedLM,
|
20 |
+
AutoModelForSeq2SeqLM,
|
21 |
+
AutoModelForMultipleChoice,
|
22 |
+
)
|
23 |
+
|
24 |
+
TASK_CONFIGS = {
|
25 |
+
"embedding": {
|
26 |
+
"model_class": AutoModel,
|
27 |
+
"handler_class": EmbeddingModelHandler,
|
28 |
+
"example_text": "Hey, I am feeling way to good to be true.",
|
29 |
+
},
|
30 |
+
"ner": {
|
31 |
+
"model_class": AutoModelForTokenClassification,
|
32 |
+
"handler_class": TokenClassificationHandler,
|
33 |
+
"example_text": "John works at Google in New York as a software engineer.",
|
34 |
+
},
|
35 |
+
"text_classification": {
|
36 |
+
"model_class": AutoModelForSequenceClassification,
|
37 |
+
"handler_class": SequenceClassificationHandler,
|
38 |
+
"example_text": "This movie was great and I loved it.",
|
39 |
+
},
|
40 |
+
"question_answering": {
|
41 |
+
"model_class": AutoModelForQuestionAnswering,
|
42 |
+
"handler_class": QuestionAnsweringHandler,
|
43 |
+
"example_text": "The pyramids were built in ancient Egypt. QUES: Where were the pyramids built?",
|
44 |
+
},
|
45 |
+
"causal_lm": {
|
46 |
+
"model_class": AutoModelForCausalLM,
|
47 |
+
"handler_class": CausalLMHandler,
|
48 |
+
"example_text": "Once upon a time, there was ",
|
49 |
+
},
|
50 |
+
"mask_lm": {
|
51 |
+
"model_class": AutoModelForMaskedLM,
|
52 |
+
"handler_class": MaskedLMHandler,
|
53 |
+
"example_text": "The quick brown [MASK] jumps over the lazy dog.",
|
54 |
+
},
|
55 |
+
"seq2seq_lm": {
|
56 |
+
"model_class": AutoModelForSeq2SeqLM,
|
57 |
+
"handler_class": Seq2SeqLMHandler,
|
58 |
+
"example_text": "Translate English to French: The house is wonderful.",
|
59 |
+
},
|
60 |
+
"multiple_choice": {
|
61 |
+
"model_class": AutoModelForMultipleChoice,
|
62 |
+
"handler_class": MultipleChoiceHandler,
|
63 |
+
"example_text": "What is the capital of France? (A) Paris (B) London (C) Berlin (D) Rome",
|
64 |
+
},
|
65 |
+
"whisper_finetuning": {
|
66 |
+
"model_class": None, # Not implemented
|
67 |
+
"handler_class": WhisperHandler,
|
68 |
+
"example_text": "!!!!!NOT IMPLEMENTED!!!!!",
|
69 |
+
},
|
70 |
+
"image_classification": {
|
71 |
+
"model_class": None, # Not implemented
|
72 |
+
"handler_class": ImageClassificationHandler,
|
73 |
+
"example_text": "!!!!!NOT IMPLEMENTED!!!!!",
|
74 |
+
},
|
75 |
+
}
|
76 |
+
|
77 |
+
def get_model_handler(task: str, model_name: str, quantization_type: str, test_text: str):
|
78 |
+
task_config = TASK_CONFIGS.get(task)
|
79 |
+
if not task_config:
|
80 |
+
raise ValueError(f"No configuration found for task: {task}")
|
81 |
+
|
82 |
+
handler_class = task_config["handler_class"]
|
83 |
+
model_class = task_config["model_class"]
|
84 |
+
return handler_class(model_name, model_class, quantization_type, test_text)
|
src/handlers/audio_models/whisper_handler.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..base_handler import ModelHandler
|
2 |
+
|
3 |
+
class WhisperHandler(ModelHandler):
|
4 |
+
def __init__(self, model_name, model_class, quantization_type, test_text):
|
5 |
+
super().__init__(model_name, model_class, quantization_type, test_text)
|
6 |
+
|
7 |
+
def run_inference(self, model, text):
|
8 |
+
raise NotImplementedError("STT is not implemented.")
|
9 |
+
|
10 |
+
def decode_output(self, outputs):
|
11 |
+
raise NotImplementedError("STT is not implemented.")
|
12 |
+
|
13 |
+
def compare_outputs(self, original_outputs, quantized_outputs):
|
14 |
+
raise NotImplementedError("STT is not implemented.")
|
src/handlers/base_handler.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from optimizations.quantize import ModelQuantizer
|
2 |
+
import torch
|
3 |
+
import logging
|
4 |
+
import numpy as np
|
5 |
+
from dataclasses import dataclass
|
6 |
+
from typing import Dict, Any, Optional
|
7 |
+
import json
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
@dataclass
|
11 |
+
class ModelMetrics:
|
12 |
+
model_sizes: Dict[str, float]
|
13 |
+
inference_times: Dict[str, float]
|
14 |
+
comparison_metrics: Dict[str, Any]
|
15 |
+
|
16 |
+
class ModelHandler:
|
17 |
+
"""Base class for handling different types of models"""
|
18 |
+
|
19 |
+
def __init__(self, model_name, model_class, quantization_type, test_text=None):
|
20 |
+
self.model_name = model_name
|
21 |
+
self.model_class = model_class
|
22 |
+
self.quantization_type = quantization_type
|
23 |
+
self.test_text = test_text
|
24 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
25 |
+
|
26 |
+
# Load models
|
27 |
+
self.original_model = self._load_original_model()
|
28 |
+
self.quantized_model = self._load_quantized_model()
|
29 |
+
self.metrics: Optional[ModelMetrics] = None
|
30 |
+
|
31 |
+
def _load_original_model(self):
|
32 |
+
"""Load the original model"""
|
33 |
+
model = self.model_class.from_pretrained(self.model_name)
|
34 |
+
return model.to(self.device)
|
35 |
+
|
36 |
+
def _load_quantized_model(self):
|
37 |
+
"""Load the quantized model using ModelQuantizer"""
|
38 |
+
model = ModelQuantizer.quantize_model(
|
39 |
+
self.model_class,
|
40 |
+
self.model_name,
|
41 |
+
self.quantization_type
|
42 |
+
)
|
43 |
+
if self.quantization_type not in ["4-bit", "8-bit"]:
|
44 |
+
model = model.to(self.device)
|
45 |
+
return model
|
46 |
+
|
47 |
+
@staticmethod
|
48 |
+
def _convert_to_serializable(obj):
|
49 |
+
"""Serialization for metrics"""
|
50 |
+
if isinstance(obj, np.generic):
|
51 |
+
return obj.item()
|
52 |
+
if isinstance(obj, (np.float32, np.float64)):
|
53 |
+
return float(obj)
|
54 |
+
if isinstance(obj, (np.int32, np.int64)):
|
55 |
+
return int(obj)
|
56 |
+
if isinstance(obj, np.ndarray):
|
57 |
+
return obj.tolist()
|
58 |
+
if isinstance(obj, torch.Tensor):
|
59 |
+
return obj.cpu().numpy().tolist()
|
60 |
+
if isinstance(obj, dict):
|
61 |
+
return {k: ModelHandler._convert_to_serializable(v) for k, v in obj.items()}
|
62 |
+
if isinstance(obj, list):
|
63 |
+
return [ModelHandler._convert_to_serializable(v) for v in obj]
|
64 |
+
return obj
|
65 |
+
|
66 |
+
def _format_metric_value(self, value):
|
67 |
+
"""Format metric value based on its type"""
|
68 |
+
if isinstance(value, (float, np.float32, np.float64)):
|
69 |
+
return f"{value:.8f}"
|
70 |
+
elif isinstance(value, (int, np.int32, np.int64)):
|
71 |
+
return str(value)
|
72 |
+
elif isinstance(value, list):
|
73 |
+
return "\n" + "\n".join([f" - {item}" for item in value])
|
74 |
+
elif isinstance(value, dict):
|
75 |
+
return "\n" + "\n".join([f" {k}: {v}" for k, v in value.items()])
|
76 |
+
else:
|
77 |
+
return str(value)
|
78 |
+
|
79 |
+
def run_inference(self, model, text):
|
80 |
+
"""Run model inference - to be implemented by subclasses"""
|
81 |
+
raise NotImplementedError
|
82 |
+
|
83 |
+
def decode_output(self, outputs):
|
84 |
+
"""Decode model outputs - to be implemented by subclasses"""
|
85 |
+
raise NotImplementedError
|
86 |
+
|
87 |
+
def compare(self):
|
88 |
+
"""Compare original and quantized models"""
|
89 |
+
try:
|
90 |
+
if self.test_text is None:
|
91 |
+
logger.warning("No test text provided. Skipping inference testing.")
|
92 |
+
return self.quantized_model
|
93 |
+
|
94 |
+
# Run inference
|
95 |
+
original_outputs, original_time = self.run_inference(self.original_model, self.test_text)
|
96 |
+
quantized_outputs, quantized_time = self.run_inference(self.quantized_model, self.test_text)
|
97 |
+
|
98 |
+
original_size = ModelQuantizer.get_model_size(self.original_model)
|
99 |
+
quantized_size = ModelQuantizer.get_model_size(self.quantized_model)
|
100 |
+
|
101 |
+
logger.info(f"Original Model Size: {original_size:.2f} MB")
|
102 |
+
logger.info(f"Quantized Model Size: {quantized_size:.2f} MB")
|
103 |
+
logger.info(f"Original Inference Time: {original_time:.4f} seconds")
|
104 |
+
logger.info(f"Quantized Inference Time: {quantized_time:.4f} seconds")
|
105 |
+
|
106 |
+
# Compare outputs
|
107 |
+
comparison_metrics = self.compare_outputs(original_outputs, quantized_outputs) or {}
|
108 |
+
|
109 |
+
for key, value in comparison_metrics.items():
|
110 |
+
comparison_metrics[key] = self._convert_to_serializable(value)
|
111 |
+
|
112 |
+
self.metrics = {
|
113 |
+
"model_sizes": {
|
114 |
+
"original": float(original_size),
|
115 |
+
"quantized": float(quantized_size)
|
116 |
+
},
|
117 |
+
"inference_times": {
|
118 |
+
"original": float(original_time),
|
119 |
+
"quantized": float(quantized_time)
|
120 |
+
},
|
121 |
+
"comparison_metrics": comparison_metrics
|
122 |
+
}
|
123 |
+
|
124 |
+
return self.quantized_model
|
125 |
+
except Exception as e:
|
126 |
+
logger.error(f"Quantization and comparison failed: {str(e)}")
|
127 |
+
raise e
|
128 |
+
|
129 |
+
def get_metrics(self) -> Dict[str, Any]:
|
130 |
+
"""Return the metrics dictionary"""
|
131 |
+
if self.metrics is None:
|
132 |
+
return {
|
133 |
+
"model_sizes": {"original": 0.0, "quantized": 0.0},
|
134 |
+
"inference_times": {"original": 0.0, "quantized": 0.0},
|
135 |
+
"comparison_metrics": {}
|
136 |
+
}
|
137 |
+
serializable_metrics = self._convert_to_serializable(self.metrics)
|
138 |
+
try:
|
139 |
+
json.dumps(serializable_metrics)
|
140 |
+
return serializable_metrics
|
141 |
+
except (TypeError, ValueError) as e:
|
142 |
+
logger.error(f"Error serializing metrics: {str(e)}")
|
143 |
+
return {
|
144 |
+
"model_sizes": {"original": 0.0, "quantized": 0.0},
|
145 |
+
"inference_times": {"original": 0.0, "quantized": 0.0},
|
146 |
+
"comparison_metrics": {}
|
147 |
+
}
|
148 |
+
|
src/handlers/img_models/image_classification_handler.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..base_handler import ModelHandler
|
2 |
+
|
3 |
+
class ImageClassificationHandler(ModelHandler):
|
4 |
+
def __init__(self, model_name, model_class, quantization_type, test_text):
|
5 |
+
super().__init__(model_name, model_class, quantization_type, test_text)
|
6 |
+
|
7 |
+
def run_inference(self, model, text):
|
8 |
+
raise NotImplementedError("Image classification is not implemented.")
|
9 |
+
|
10 |
+
def decode_output(self, outputs):
|
11 |
+
raise NotImplementedError("Image classification is not implemented.")
|
12 |
+
|
13 |
+
def compare_outputs(self, original_outputs, quantized_outputs):
|
14 |
+
raise NotImplementedError("Image classification is not implemented.")
|
src/handlers/nlp_models/causal_lm_handler.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..base_handler import ModelHandler
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
import torch
|
4 |
+
import time
|
5 |
+
from scipy.stats import spearmanr
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
class CausalLMHandler(ModelHandler):
|
9 |
+
def __init__(self, model_name, model_class, quantization_type, test_text):
|
10 |
+
super().__init__(model_name, model_class, quantization_type, test_text)
|
11 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
12 |
+
|
13 |
+
def run_inference(self, model, text):
|
14 |
+
inputs = self.tokenizer(text, return_tensors='pt').to(self.device)
|
15 |
+
start_time = time.time()
|
16 |
+
with torch.no_grad():
|
17 |
+
outputs = model.generate(**inputs, max_length=50)
|
18 |
+
end_time = time.time()
|
19 |
+
return outputs, end_time - start_time
|
20 |
+
|
21 |
+
def decode_output(self, outputs):
|
22 |
+
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
23 |
+
|
24 |
+
def compare_outputs(self, original_outputs, quantized_outputs):
|
25 |
+
"""Compare outputs for causal language models"""
|
26 |
+
if original_outputs is None or quantized_outputs is None:
|
27 |
+
return None
|
28 |
+
|
29 |
+
original_tokens = original_outputs[0].cpu().numpy()
|
30 |
+
quantized_tokens = quantized_outputs[0].cpu().numpy()
|
31 |
+
|
32 |
+
metrics = {
|
33 |
+
'sequence_similarity': np.mean(original_tokens == quantized_tokens),
|
34 |
+
'sequence_length_diff': abs(len(original_tokens) - len(quantized_tokens)),
|
35 |
+
'vocab_distribution_correlation': spearmanr(
|
36 |
+
np.bincount(original_tokens),
|
37 |
+
np.bincount(quantized_tokens)
|
38 |
+
)[0] if len(original_tokens) == len(quantized_tokens) else 0.0
|
39 |
+
}
|
40 |
+
|
41 |
+
original_text = self.decode_output(original_outputs)
|
42 |
+
quantized_text = self.decode_output(quantized_outputs)
|
43 |
+
metrics['decoded_text_match'] = float(original_text == quantized_text)
|
44 |
+
metrics['original_model_text'] = original_text
|
45 |
+
metrics['quantized_model_text'] = quantized_text
|
46 |
+
return metrics
|
src/handlers/nlp_models/embedding_model_handler.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..base_handler import ModelHandler
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
import torch
|
4 |
+
import time
|
5 |
+
import numpy as np
|
6 |
+
from scipy.stats import spearmanr
|
7 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
8 |
+
|
9 |
+
class EmbeddingModelHandler(ModelHandler):
|
10 |
+
def __init__(self, model_name, model_class, quantization_type, test_text):
|
11 |
+
super().__init__(model_name, model_class, quantization_type, test_text)
|
12 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
13 |
+
|
14 |
+
def run_inference(self, model, text):
|
15 |
+
inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(self.device)
|
16 |
+
start_time = time.time()
|
17 |
+
with torch.no_grad():
|
18 |
+
outputs = model(**inputs)
|
19 |
+
end_time = time.time()
|
20 |
+
return outputs, end_time - start_time
|
21 |
+
|
22 |
+
def decode_output(self, outputs):
|
23 |
+
return outputs.last_hidden_state.mean(dim=1).cpu().numpy()
|
24 |
+
|
25 |
+
def compare_outputs(self, original_outputs, quantized_outputs):
|
26 |
+
"""Compare outputs for embedding models"""
|
27 |
+
if original_outputs is None or quantized_outputs is None:
|
28 |
+
return None
|
29 |
+
|
30 |
+
original_embeds = original_outputs.last_hidden_state.cpu().numpy()
|
31 |
+
quantized_embeds = quantized_outputs.last_hidden_state.cpu().numpy()
|
32 |
+
|
33 |
+
metrics = {
|
34 |
+
'mse': ((original_embeds - quantized_embeds) ** 2).mean(),
|
35 |
+
'cosine_similarity': cosine_similarity(
|
36 |
+
original_embeds.reshape(1, -1),
|
37 |
+
quantized_embeds.reshape(1, -1)
|
38 |
+
)[0][0],
|
39 |
+
'correlation': spearmanr(
|
40 |
+
original_embeds.flatten(),
|
41 |
+
quantized_embeds.flatten()
|
42 |
+
)[0],
|
43 |
+
'norm_difference': np.abs(
|
44 |
+
np.linalg.norm(original_embeds) -
|
45 |
+
np.linalg.norm(quantized_embeds)
|
46 |
+
)
|
47 |
+
}
|
48 |
+
|
49 |
+
return metrics
|
src/handlers/nlp_models/masked_lm_handler.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..base_handler import ModelHandler
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
import torch
|
4 |
+
import time
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
class MaskedLMHandler(ModelHandler):
|
8 |
+
def __init__(self, model_name, model_class, quantization_type, test_text):
|
9 |
+
super().__init__(model_name, model_class, quantization_type, test_text)
|
10 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
11 |
+
|
12 |
+
def run_inference(self, model, text):
|
13 |
+
inputs = self.tokenizer(text, return_tensors='pt').to(self.device)
|
14 |
+
start_time = time.time()
|
15 |
+
with torch.no_grad():
|
16 |
+
outputs = model(**inputs)
|
17 |
+
end_time = time.time()
|
18 |
+
return outputs, inputs, end_time - start_time
|
19 |
+
|
20 |
+
def decode_output(self, outputs, inputs):
|
21 |
+
logits = outputs.logits
|
22 |
+
masked_index = torch.where(inputs['input_ids'] == self.tokenizer.mask_token_id)[1]
|
23 |
+
predicted_token_id = logits[0, masked_index].argmax(axis=-1)
|
24 |
+
return self.tokenizer.decode(predicted_token_id)
|
25 |
+
|
26 |
+
def compare_outputs(self, original_outputs, quantized_outputs):
|
27 |
+
if original_outputs is None or quantized_outputs is None:
|
28 |
+
return None
|
29 |
+
|
30 |
+
original_logits = original_outputs.logits.detach().cpu().numpy()
|
31 |
+
quantized_logits = quantized_outputs.logits.detach().cpu().numpy()
|
32 |
+
|
33 |
+
metrics = {
|
34 |
+
'mse': ((original_logits - quantized_logits) ** 2).mean(),
|
35 |
+
'top_1_accuracy': np.mean(
|
36 |
+
np.argmax(original_logits, axis=-1) == np.argmax(quantized_logits, axis=-1)
|
37 |
+
),
|
38 |
+
}
|
39 |
+
return metrics
|
src/handlers/nlp_models/multiple_choice_handler.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..base_handler import ModelHandler
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
import torch
|
4 |
+
import time
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
class MultipleChoiceHandler(ModelHandler):
|
8 |
+
def __init__(self, model_name, model_class, quantization_type, test_text):
|
9 |
+
super().__init__(model_name, model_class, quantization_type, test_text)
|
10 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
11 |
+
|
12 |
+
def run_inference(self, model, text):
|
13 |
+
choices = [text.split(f"({chr(65 + i)})")[1].strip() for i in range(4)]
|
14 |
+
inputs = self.tokenizer(choices, return_tensors='pt', padding=True).to(self.device)
|
15 |
+
start_time = time.time()
|
16 |
+
with torch.no_grad():
|
17 |
+
outputs = model(**inputs)
|
18 |
+
end_time = time.time()
|
19 |
+
return outputs, end_time - start_time
|
20 |
+
|
21 |
+
def decode_output(self, outputs):
|
22 |
+
logits = outputs.logits
|
23 |
+
predicted_choice = chr(65 + logits.argmax().item())
|
24 |
+
return f"Predicted choice: {predicted_choice}"
|
25 |
+
|
26 |
+
def compare_outputs(self, original_outputs, quantized_outputs):
|
27 |
+
if original_outputs is None or quantized_outputs is None:
|
28 |
+
return None
|
29 |
+
|
30 |
+
original_logits = original_outputs.logits.detach().cpu().numpy()
|
31 |
+
quantized_logits = quantized_outputs.logits.detach().cpu().numpy()
|
32 |
+
|
33 |
+
metrics = {
|
34 |
+
'mse': ((original_logits - quantized_logits) ** 2).mean(),
|
35 |
+
'top_1_accuracy': np.mean(
|
36 |
+
np.argmax(original_logits, axis=-1) == np.argmax(quantized_logits, axis=-1)
|
37 |
+
),
|
38 |
+
}
|
39 |
+
return metrics
|
src/handlers/nlp_models/question_answering_handler.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..base_handler import ModelHandler
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
import torch
|
4 |
+
import time
|
5 |
+
|
6 |
+
class QuestionAnsweringHandler(ModelHandler):
|
7 |
+
def __init__(self, model_name, model_class, quantization_type, test_text):
|
8 |
+
super().__init__(model_name, model_class, quantization_type, test_text)
|
9 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
10 |
+
|
11 |
+
def run_inference(self, model, text):
|
12 |
+
parts = text.split('QUES')
|
13 |
+
context = parts[0].strip()
|
14 |
+
question = parts[1].strip()
|
15 |
+
inputs = self.tokenizer(question, context, return_tensors='pt', truncation=True, padding=True).to(self.device)
|
16 |
+
start_time = time.time()
|
17 |
+
with torch.no_grad():
|
18 |
+
outputs = model(**inputs)
|
19 |
+
end_time = time.time()
|
20 |
+
return outputs, end_time - start_time
|
21 |
+
|
22 |
+
def decode_output(self, outputs):
|
23 |
+
start_logits = outputs.start_logits
|
24 |
+
end_logits = outputs.end_logits
|
25 |
+
answer_start = torch.argmax(start_logits)
|
26 |
+
answer_end = torch.argmax(end_logits) + 1
|
27 |
+
input_ids = self.tokenizer.encode(self.test_text)
|
28 |
+
answer = self.tokenizer.decode(input_ids[answer_start:answer_end])
|
29 |
+
return f"Answer: {answer}"
|
30 |
+
|
31 |
+
def compare_outputs(self, original_outputs, quantized_outputs):
|
32 |
+
"""Compare outputs for question answering models"""
|
33 |
+
if original_outputs is None or quantized_outputs is None:
|
34 |
+
return None
|
35 |
+
|
36 |
+
orig_start = original_outputs.start_logits.cpu().numpy()
|
37 |
+
orig_end = original_outputs.end_logits.cpu().numpy()
|
38 |
+
quant_start = quantized_outputs.start_logits.cpu().numpy()
|
39 |
+
quant_end = quantized_outputs.end_logits.cpu().numpy()
|
40 |
+
|
41 |
+
orig_start_pos = orig_start.argmax()
|
42 |
+
orig_end_pos = orig_end.argmax()
|
43 |
+
quant_start_pos = quant_start.argmax()
|
44 |
+
quant_end_pos = quant_end.argmax()
|
45 |
+
|
46 |
+
input_ids = self.tokenizer.encode(self.test_text)
|
47 |
+
original_answer = self.tokenizer.decode(input_ids[orig_start_pos:orig_end_pos + 1])
|
48 |
+
quantized_answer = self.tokenizer.decode(input_ids[quant_start_pos:quant_end_pos + 1])
|
49 |
+
|
50 |
+
metrics = {
|
51 |
+
'original_answer': original_answer,
|
52 |
+
'quantized_answer': quantized_answer,
|
53 |
+
'start_position_match': float(orig_start_pos == quant_start_pos),
|
54 |
+
'end_position_match': float(orig_end_pos == quant_end_pos),
|
55 |
+
'start_logits_mse': ((orig_start - quant_start) ** 2).mean(),
|
56 |
+
'end_logits_mse': ((orig_end - quant_end) ** 2).mean(),
|
57 |
+
}
|
58 |
+
|
59 |
+
return metrics
|
src/handlers/nlp_models/seq2seq_lm_handler.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..base_handler import ModelHandler
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
import torch
|
4 |
+
import time
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
class Seq2SeqLMHandler(ModelHandler):
|
8 |
+
def __init__(self, model_name, model_class, quantization_type, test_text):
|
9 |
+
super().__init__(model_name, model_class, quantization_type, test_text)
|
10 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
11 |
+
|
12 |
+
def run_inference(self, model, text):
|
13 |
+
inputs = self.tokenizer(text, return_tensors='pt').to(self.device)
|
14 |
+
start_time = time.time()
|
15 |
+
with torch.no_grad():
|
16 |
+
outputs = model.generate(**inputs, max_length=50)
|
17 |
+
end_time = time.time()
|
18 |
+
return outputs, end_time - start_time
|
19 |
+
|
20 |
+
def decode_output(self, outputs):
|
21 |
+
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
22 |
+
|
23 |
+
def compare_outputs(self, original_outputs, quantized_outputs):
|
24 |
+
if original_outputs is None or quantized_outputs is None:
|
25 |
+
return None
|
26 |
+
|
27 |
+
original_tokens = original_outputs[0].cpu().numpy()
|
28 |
+
quantized_tokens = quantized_outputs[0].cpu().numpy()
|
29 |
+
|
30 |
+
metrics = {
|
31 |
+
'sequence_similarity': np.mean(original_tokens == quantized_tokens),
|
32 |
+
'sequence_length_diff': abs(len(original_tokens) - len(quantized_tokens)),
|
33 |
+
}
|
34 |
+
return metrics
|
src/handlers/nlp_models/sequence_classification_handler.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..base_handler import ModelHandler
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
import torch
|
4 |
+
import time
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
class SequenceClassificationHandler(ModelHandler):
|
8 |
+
def __init__(self, model_name, model_class, quantization_type, test_text):
|
9 |
+
super().__init__(model_name, model_class, quantization_type, test_text)
|
10 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
11 |
+
|
12 |
+
def run_inference(self, model, text):
|
13 |
+
inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(self.device)
|
14 |
+
start_time = time.time()
|
15 |
+
with torch.no_grad():
|
16 |
+
outputs = model(**inputs)
|
17 |
+
end_time = time.time()
|
18 |
+
return outputs, end_time - start_time
|
19 |
+
|
20 |
+
def decode_output(self, outputs):
|
21 |
+
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
22 |
+
predicted_class = torch.argmax(probabilities, dim=-1).item()
|
23 |
+
return f"Predicted class: {predicted_class}"
|
24 |
+
|
25 |
+
def compare_outputs(self, original_outputs, quantized_outputs):
|
26 |
+
"""Compare outputs for sequence classification models"""
|
27 |
+
if original_outputs is None or quantized_outputs is None:
|
28 |
+
return None
|
29 |
+
|
30 |
+
orig_logits = original_outputs.logits.cpu().numpy()
|
31 |
+
quant_logits = quantized_outputs.logits.cpu().numpy()
|
32 |
+
|
33 |
+
orig_probs = torch.nn.functional.softmax(torch.tensor(orig_logits), dim=-1).numpy()
|
34 |
+
quant_probs = torch.nn.functional.softmax(torch.tensor(quant_logits), dim=-1).numpy()
|
35 |
+
|
36 |
+
orig_pred = orig_probs.argmax(axis=-1)
|
37 |
+
quant_pred = quant_probs.argmax(axis=-1)
|
38 |
+
|
39 |
+
metrics = {
|
40 |
+
'class_match': float(orig_pred == quant_pred),
|
41 |
+
'logits_mse': ((orig_logits - quant_logits) ** 2).mean(),
|
42 |
+
'probability_mse': ((orig_probs - quant_probs) ** 2).mean(),
|
43 |
+
'max_probability_diff': abs(orig_probs.max() - quant_probs.max()),
|
44 |
+
'kl_divergence': float(
|
45 |
+
(orig_probs * (np.log(orig_probs + 1e-10) - np.log(quant_probs + 1e-10))).sum()
|
46 |
+
)
|
47 |
+
}
|
48 |
+
|
49 |
+
return metrics
|
src/handlers/nlp_models/token_classification_handler.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..base_handler import ModelHandler
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
import torch
|
4 |
+
import time
|
5 |
+
|
6 |
+
class TokenClassificationHandler(ModelHandler):
|
7 |
+
def __init__(self, model_name, model_class, quantization_type, test_text):
|
8 |
+
super().__init__(model_name, model_class, quantization_type, test_text)
|
9 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
10 |
+
|
11 |
+
def run_inference(self, model, text):
|
12 |
+
inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(self.device)
|
13 |
+
start_time = time.time()
|
14 |
+
with torch.no_grad():
|
15 |
+
outputs = model(**inputs)
|
16 |
+
end_time = time.time()
|
17 |
+
return outputs, end_time - start_time
|
18 |
+
|
19 |
+
def decode_output(self, model, outputs):
|
20 |
+
tokens = self.tokenizer.convert_ids_to_tokens(outputs['input_ids'][0])
|
21 |
+
labels = torch.argmax(outputs.logits, dim=-1).squeeze().tolist()
|
22 |
+
decoded_labels = [model.config.id2label[label] for label in labels]
|
23 |
+
return dict(zip(tokens, decoded_labels))
|
24 |
+
|
25 |
+
def compare_outputs(self, original_outputs, quantized_outputs):
|
26 |
+
"""Compare outputs for token classification models"""
|
27 |
+
if original_outputs is None or quantized_outputs is None:
|
28 |
+
return None
|
29 |
+
|
30 |
+
orig_logits = original_outputs.logits.cpu().numpy()
|
31 |
+
quant_logits = quantized_outputs.logits.cpu().numpy()
|
32 |
+
|
33 |
+
orig_preds = orig_logits.argmax(axis=-1)
|
34 |
+
quant_preds = quant_logits.argmax(axis=-1)
|
35 |
+
|
36 |
+
input_tokens = self.tokenizer.convert_ids_to_tokens(
|
37 |
+
self.tokenizer(self.test_text, return_tensors='pt')['input_ids'][0]
|
38 |
+
)
|
39 |
+
|
40 |
+
orig_labels = [self.original_model.config.id2label[p] for p in orig_preds[0]]
|
41 |
+
quant_labels = [self.quantized_model.config.id2label[p] for p in quant_preds[0]]
|
42 |
+
|
43 |
+
original_results = list(zip(input_tokens, orig_labels))
|
44 |
+
quantized_results = list(zip(input_tokens, quant_labels))
|
45 |
+
|
46 |
+
token_matches = sum(o_label == q_label for o_label, q_label in zip(orig_labels, quant_labels))
|
47 |
+
total_tokens = len(orig_labels)
|
48 |
+
|
49 |
+
metrics = {
|
50 |
+
'original_predictions': original_results,
|
51 |
+
'quantized_predictions': quantized_results,
|
52 |
+
'token_level_accuracy': float(token_matches) / total_tokens if total_tokens > 0 else 0.0,
|
53 |
+
'sequence_exact_match': float((orig_preds == quant_preds).all()),
|
54 |
+
'logits_mse': ((orig_logits - quant_logits) ** 2).mean(),
|
55 |
+
}
|
56 |
+
|
57 |
+
return metrics
|
src/optimizations/onnx_conversion.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from transformers import AutoTokenizer, WhisperProcessor, AutoFeatureExtractor
|
3 |
+
from optimum.onnxruntime import (
|
4 |
+
ORTModelForQuestionAnswering,
|
5 |
+
ORTModelForCausalLM,
|
6 |
+
ORTModelForSequenceClassification,
|
7 |
+
ORTModelForTokenClassification,
|
8 |
+
ORTModelForSpeechSeq2Seq,
|
9 |
+
ORTOptimizer,
|
10 |
+
ORTModelForMaskedLM,
|
11 |
+
ORTModelForSeq2SeqLM,
|
12 |
+
ORTModelForMultipleChoice,
|
13 |
+
ORTModelForImageClassification,
|
14 |
+
)
|
15 |
+
import logging
|
16 |
+
|
17 |
+
logging.basicConfig(level=logging.INFO)
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
TASK_MAPPING = {
|
21 |
+
# NLP models
|
22 |
+
"ner": (ORTModelForTokenClassification, AutoTokenizer),
|
23 |
+
"text_classification": (ORTModelForSequenceClassification, AutoTokenizer),
|
24 |
+
"question_answering": (ORTModelForQuestionAnswering, AutoTokenizer),
|
25 |
+
"causal_lm": (ORTModelForCausalLM, AutoTokenizer),
|
26 |
+
"mask_lm": (ORTModelForMaskedLM, AutoTokenizer),
|
27 |
+
"seq2seq_lm": (ORTModelForSeq2SeqLM, AutoTokenizer),
|
28 |
+
"multiple_choice": (ORTModelForMultipleChoice, AutoTokenizer),
|
29 |
+
# Audio models
|
30 |
+
"whisper_finetuning": (ORTModelForSpeechSeq2Seq, WhisperProcessor),
|
31 |
+
# Vision models
|
32 |
+
"image_classification": (ORTModelForImageClassification, AutoFeatureExtractor),
|
33 |
+
}
|
34 |
+
|
35 |
+
def convert_to_onnx(model_name, task, output_dir):
|
36 |
+
"""
|
37 |
+
Convert model to ONNX format for the specified task.
|
38 |
+
"""
|
39 |
+
logger.info(f"Converting model: {model_name} for task: {task}")
|
40 |
+
|
41 |
+
os.makedirs(output_dir, exist_ok=True)
|
42 |
+
|
43 |
+
if task not in TASK_MAPPING:
|
44 |
+
logger.error(f"Task {task} is not supported for ONNX conversion in this script.")
|
45 |
+
return None
|
46 |
+
|
47 |
+
ORTModelClass, ProcessorClass = TASK_MAPPING[task]
|
48 |
+
|
49 |
+
try:
|
50 |
+
if task == "embedding":
|
51 |
+
ort_optimizer = ORTOptimizer.from_pretrained(model_name)
|
52 |
+
ort_optimizer.export(output_dir=output_dir, task="feature-extraction")
|
53 |
+
else:
|
54 |
+
ort_model = ORTModelClass.from_pretrained(model_name, export=True)
|
55 |
+
ort_model.save_pretrained(output_dir)
|
56 |
+
|
57 |
+
processor = ProcessorClass.from_pretrained(model_name)
|
58 |
+
processor.save_pretrained(output_dir)
|
59 |
+
|
60 |
+
logger.info(f"Conversion complete. Model saved to: {output_dir}")
|
61 |
+
return output_dir
|
62 |
+
except Exception as e:
|
63 |
+
logger.error(f"Conversion failed: {str(e)}")
|
64 |
+
return None
|
src/optimizations/quantize.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import onnx
|
4 |
+
import logging
|
5 |
+
from scipy.stats import spearmanr
|
6 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
7 |
+
from transformers import BitsAndBytesConfig
|
8 |
+
from onnxconverter_common import float16
|
9 |
+
from onnxruntime.quantization import quantize_dynamic, QuantType
|
10 |
+
|
11 |
+
logging.basicConfig(level=logging.INFO)
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
class ModelQuantizer:
|
15 |
+
"""Handles model quantization and comparison operations"""
|
16 |
+
|
17 |
+
@staticmethod
|
18 |
+
def quantize_model(model_class, model_name, quantization_type):
|
19 |
+
"""Quantizes a model based on specified quantization type"""
|
20 |
+
try:
|
21 |
+
if quantization_type == "4-bit":
|
22 |
+
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
23 |
+
model = model_class.from_pretrained(model_name, quantization_config=quantization_config)
|
24 |
+
elif quantization_type == "8-bit":
|
25 |
+
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
26 |
+
model = model_class.from_pretrained(model_name, quantization_config=quantization_config)
|
27 |
+
elif quantization_type == "16-bit-float":
|
28 |
+
model = model_class.from_pretrained(model_name)
|
29 |
+
model = model.to(torch.float16)
|
30 |
+
else:
|
31 |
+
raise ValueError(f"Unsupported quantization type: {quantization_type}")
|
32 |
+
|
33 |
+
return model
|
34 |
+
except Exception as e:
|
35 |
+
logger.error(f"Quantization failed: {str(e)}")
|
36 |
+
raise
|
37 |
+
|
38 |
+
@staticmethod
|
39 |
+
def get_model_size(model):
|
40 |
+
"""Calculate model size in MB"""
|
41 |
+
try:
|
42 |
+
torch.save(model.state_dict(), "temp.pth")
|
43 |
+
size = os.path.getsize("temp.pth") / (1024 * 1024)
|
44 |
+
os.remove("temp.pth")
|
45 |
+
return size
|
46 |
+
except Exception as e:
|
47 |
+
logger.error(f"Failed to get model size: {str(e)}")
|
48 |
+
raise
|
49 |
+
|
50 |
+
@staticmethod
|
51 |
+
def compare_model_outputs(original_outputs, quantized_outputs):
|
52 |
+
"""Compare outputs between original and quantized models"""
|
53 |
+
try:
|
54 |
+
if original_outputs is None or quantized_outputs is None:
|
55 |
+
return None
|
56 |
+
|
57 |
+
if hasattr(original_outputs, 'logits') and hasattr(quantized_outputs, 'logits'):
|
58 |
+
original_logits = original_outputs.logits.detach().cpu().numpy()
|
59 |
+
quantized_logits = quantized_outputs.logits.detach().cpu().numpy()
|
60 |
+
|
61 |
+
metrics = {
|
62 |
+
'mse': ((original_logits - quantized_logits) ** 2).mean(),
|
63 |
+
'spearman_corr': spearmanr(original_logits.flatten(), quantized_logits.flatten())[0],
|
64 |
+
'cosine_sim': cosine_similarity(original_logits.reshape(1, -1), quantized_logits.reshape(1, -1))[0][0]
|
65 |
+
}
|
66 |
+
return metrics
|
67 |
+
return None
|
68 |
+
except Exception as e:
|
69 |
+
logger.error(f"Output comparison failed: {str(e)}")
|
70 |
+
raise
|
71 |
+
|
72 |
+
def quantize_onnx_model(model_dir, quantization_type):
|
73 |
+
"""
|
74 |
+
Quantize ONNX model in the specified directory.
|
75 |
+
"""
|
76 |
+
logger.info(f"Quantizing ONNX model in: {model_dir}")
|
77 |
+
for filename in os.listdir(model_dir):
|
78 |
+
if filename.endswith('.onnx'):
|
79 |
+
input_model_path = os.path.join(model_dir, filename)
|
80 |
+
output_model_path = os.path.join(model_dir, f"quantized_{filename}")
|
81 |
+
|
82 |
+
try:
|
83 |
+
model = onnx.load(input_model_path)
|
84 |
+
|
85 |
+
if quantization_type == "16-bit-float":
|
86 |
+
model_fp16 = float16.convert_float_to_float16(model)
|
87 |
+
onnx.save(model_fp16, output_model_path)
|
88 |
+
elif quantization_type in ["8-bit", "16-bit-int"]:
|
89 |
+
quant_type_mapping = {
|
90 |
+
"8-bit": QuantType.QInt8,
|
91 |
+
"16-bit-int": QuantType.QInt16,
|
92 |
+
}
|
93 |
+
quantize_dynamic(
|
94 |
+
model_input=input_model_path,
|
95 |
+
model_output=output_model_path,
|
96 |
+
weight_type=quant_type_mapping[quantization_type]
|
97 |
+
)
|
98 |
+
else:
|
99 |
+
logger.error(f"Unsupported quantization type: {quantization_type}")
|
100 |
+
continue
|
101 |
+
|
102 |
+
os.remove(input_model_path)
|
103 |
+
os.rename(output_model_path, input_model_path)
|
104 |
+
|
105 |
+
logger.info(f"Quantized ONNX model saved to: {input_model_path}")
|
106 |
+
except Exception as e:
|
107 |
+
logger.error(f"Error during ONNX quantization: {str(e)}")
|
108 |
+
if os.path.exists(output_model_path):
|
109 |
+
os.remove(output_model_path)
|
src/utilities/push_to_hub.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Optional, Dict, Tuple
|
5 |
+
from huggingface_hub import HfApi, create_repo
|
6 |
+
|
7 |
+
logging.basicConfig(level=logging.INFO)
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
def push_to_hub(local_path: str, repo_name: str, hf_token: str, commit_message: Optional[str] = None, tags: Optional[list] = None) -> Tuple[Optional[str], str]:
|
11 |
+
"""
|
12 |
+
Pushes a folder containing model files to the HuggingFace Hub.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
local_path (str): Local directory containing the model files to upload.
|
16 |
+
repo_name (str): The repository name (not the full username/repo_name).
|
17 |
+
hf_token (str): HuggingFace authentication token.
|
18 |
+
commit_message (str, optional): Commit message for the upload.
|
19 |
+
tags (list, optional): Tags to include in the model card.
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
Tuple[Optional[str], str]: (repository_name, status_message)
|
23 |
+
"""
|
24 |
+
try:
|
25 |
+
api = HfApi(token=hf_token)
|
26 |
+
|
27 |
+
# Validate token
|
28 |
+
try:
|
29 |
+
user_info = api.whoami()
|
30 |
+
username = user_info["name"]
|
31 |
+
except Exception as e:
|
32 |
+
return None, f"β Authentication failed: Invalid token or network error ({str(e)})"
|
33 |
+
|
34 |
+
# Full repository name with the username
|
35 |
+
full_repo_name = f"{username}/{repo_name}"
|
36 |
+
|
37 |
+
# Create the repo
|
38 |
+
try:
|
39 |
+
create_repo(full_repo_name, token=hf_token, exist_ok=True)
|
40 |
+
logger.info(f"Repository created/verified: {full_repo_name}")
|
41 |
+
except Exception as e:
|
42 |
+
return None, f"β Repository creation failed: {str(e)}"
|
43 |
+
|
44 |
+
# Create model card
|
45 |
+
try:
|
46 |
+
tags_list = tags or []
|
47 |
+
tags_section = "\n".join(f"- {tag}" for tag in tags_list)
|
48 |
+
model_card = f"""---
|
49 |
+
tags:
|
50 |
+
{tags_section}
|
51 |
+
library_name: optimum
|
52 |
+
---
|
53 |
+
|
54 |
+
# Model - {repo_name}
|
55 |
+
|
56 |
+
This model has been optimized and uploaded to the HuggingFace Hub.
|
57 |
+
|
58 |
+
## Model Details
|
59 |
+
- Optimization Tags: {', '.join(tags_list)}
|
60 |
+
"""
|
61 |
+
with open(os.path.join(local_path, "README.md"), "w") as f:
|
62 |
+
f.write(model_card)
|
63 |
+
except Exception as e:
|
64 |
+
logger.warning(f"Model card creation warning: {str(e)}")
|
65 |
+
|
66 |
+
# Upload the folder
|
67 |
+
try:
|
68 |
+
api.upload_folder(
|
69 |
+
folder_path=local_path,
|
70 |
+
repo_id=full_repo_name,
|
71 |
+
repo_type="model",
|
72 |
+
commit_message=commit_message or "Upload optimized model"
|
73 |
+
)
|
74 |
+
success_msg = f"β
Model successfully pushed to: {full_repo_name}"
|
75 |
+
logger.info(success_msg)
|
76 |
+
return full_repo_name, success_msg
|
77 |
+
except Exception as e:
|
78 |
+
error_msg = f"β Upload failed: {str(e)}"
|
79 |
+
logger.error(error_msg)
|
80 |
+
return None, error_msg
|
81 |
+
|
82 |
+
except Exception as e:
|
83 |
+
error_msg = f"β Unexpected error during push: {str(e)}"
|
84 |
+
logger.error(error_msg)
|
85 |
+
return None, error_msg
|
src/utilities/resources.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import shutil
|
2 |
+
import psutil
|
3 |
+
import torch
|
4 |
+
import logging
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Optional, Dict
|
7 |
+
|
8 |
+
logging.basicConfig(level=logging.INFO)
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
class ResourceManager:
|
12 |
+
def __init__(self, temp_dir: str = "temp"):
|
13 |
+
self.temp_dir = Path(temp_dir)
|
14 |
+
self.temp_dirs = {
|
15 |
+
"onnx": self.temp_dir / "onnx_output",
|
16 |
+
"quantized": self.temp_dir / "quantized_models",
|
17 |
+
"cache": self.temp_dir / "model_cache"
|
18 |
+
}
|
19 |
+
self.setup_directories()
|
20 |
+
|
21 |
+
def setup_directories(self):
|
22 |
+
for dir_path in self.temp_dirs.values():
|
23 |
+
dir_path.mkdir(parents=True, exist_ok=True)
|
24 |
+
|
25 |
+
def cleanup_temp_files(self, specific_dir: Optional[str] = None) -> str:
|
26 |
+
try:
|
27 |
+
if specific_dir:
|
28 |
+
if specific_dir in self.temp_dirs:
|
29 |
+
shutil.rmtree(self.temp_dirs[specific_dir], ignore_errors=True)
|
30 |
+
self.temp_dirs[specific_dir].mkdir(exist_ok=True)
|
31 |
+
else:
|
32 |
+
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
33 |
+
self.setup_directories()
|
34 |
+
return "β¨ Cleanup successful!"
|
35 |
+
except Exception as e:
|
36 |
+
logger.error(f"Cleanup failed: {str(e)}")
|
37 |
+
return f"β Cleanup failed: {str(e)}"
|
38 |
+
|
39 |
+
def get_memory_info(self) -> Dict[str, float]:
|
40 |
+
vm = psutil.virtual_memory()
|
41 |
+
memory_info = {
|
42 |
+
"total_ram": vm.total / (1024 ** 3),
|
43 |
+
"available_ram": vm.available / (1024 ** 3),
|
44 |
+
"used_ram": vm.used / (1024 ** 3)
|
45 |
+
}
|
46 |
+
|
47 |
+
if torch.cuda.is_available():
|
48 |
+
device = torch.cuda.current_device()
|
49 |
+
memory_info.update({
|
50 |
+
"gpu_total": torch.cuda.get_device_properties(device).total_memory / (1024 ** 3),
|
51 |
+
"gpu_used": torch.cuda.memory_allocated(device) / (1024 ** 3)
|
52 |
+
})
|
53 |
+
|
54 |
+
return memory_info
|