smokxy commited on
Commit
9bf1d31
Β·
verified Β·
1 Parent(s): fe21c85

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -1,35 +1,2 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ # Auto detect text files and perform LF normalization
2
+ * text=auto
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,12 +1,190 @@
1
- ---
2
- title: AutoQuantNX
3
- emoji: 🐒
4
- colorFrom: purple
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 5.15.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: AutoQuantNX
3
+ app_file: app.py
4
+ sdk: gradio
5
+ sdk_version: 4.44.1
6
+ ---
7
+ # πŸ€— AutoQuantNX (**Still under testing and improvement phase**)
8
+
9
+ ## Overview
10
+ AutoQuantNX is a powerful Gradio-based web application designed to simplify the process of optimizing and deploying Hugging Face models. It supports a wide range of tasks, including quantization, ONNX conversion, and seamless integration with the Hugging Face Hub. With AutoQuantNX, you can easily convert models to ONNX format, apply quantization techniques, and push the optimized models to your Hugging Face accountβ€”all through an intuitive user interface.
11
+
12
+ ## Features
13
+
14
+ ### Supported Tasks
15
+ AutoQuantNX supports the following tasks:
16
+
17
+ * Text Classification
18
+ * Named Entity Recognition (NER)
19
+ * Question Answering
20
+ * Causal Language Modeling
21
+ * Masked Language Modeling
22
+ * Sequence-to-Sequence Language Modeling
23
+ * Multiple Choice
24
+ * Whisper (Speech-to-Text)
25
+ * Embedding Fine-Tuning
26
+ * Image Classification (Placeholder for future implementation)
27
+
28
+ ### Quantization Options
29
+ * None (default)
30
+ * 4-bit
31
+ * 8-bit
32
+ * 16-bit-float
33
+
34
+ ### ONNX Conversion
35
+ Converts models to ONNX format for optimized deployment.
36
+
37
+ Supports optional ONNX quantization:
38
+ * 8-bit
39
+ * 16-bit-int
40
+ * 16-bit-float
41
+
42
+ ### Hugging Face Hub Integration
43
+ * Automatically pushes optimized models to your Hugging Face Hub repository
44
+ * Tags models with metadata for easy identification (e.g., onnx, quantized, task type)
45
+
46
+ ### Performance Testing
47
+ Compares original and quantized models using metrics like:
48
+ * Mean Squared Error (MSE)
49
+ * Spearman Correlation
50
+ * Cosine Similarity
51
+ * Inference Time
52
+ * Model Size
53
+
54
+ ## File Structure
55
+ ```
56
+ AutoQuantNX/
57
+ β”œβ”€β”€ src/
58
+ β”‚ β”œβ”€β”€ handlers/
59
+ β”‚ β”‚ β”œβ”€β”€ audio_models/
60
+ β”‚ β”‚ β”‚ └── whisper_handler.py
61
+ β”‚ β”‚ β”œβ”€β”€ img_models/
62
+ β”‚ β”‚ β”‚ └── image_classification_handler.py
63
+ β”‚ β”‚ β”œβ”€β”€ nlp_models/
64
+ β”‚ β”‚ β”‚ β”œβ”€β”€ causal_lm_handler.py
65
+ β”‚ β”‚ β”‚ β”œβ”€β”€ embedding_model_handler.py
66
+ β”‚ β”‚ β”‚ β”œβ”€β”€ masked_lm_handler.py
67
+ β”‚ β”‚ β”‚ β”œβ”€β”€ multiple_choice_handler.py
68
+ β”‚ β”‚ β”‚ β”œβ”€β”€ question_answering_handler.py
69
+ β”‚ β”‚ β”‚ β”œβ”€β”€ seq2seq_lm_handler.py
70
+ β”‚ β”‚ β”‚ β”œβ”€β”€ sequence_classification_handler.py
71
+ β”‚ β”‚ β”‚ └── token_classification_handler.py
72
+ β”‚ β”‚ β”œβ”€β”€ __init__.py
73
+ β”‚ β”‚ └── base_handler.py
74
+ β”‚ β”œβ”€β”€ optimizations/
75
+ β”‚ β”‚ β”œβ”€β”€ onnx_conversion.py
76
+ β”‚ β”‚ └── quantize.py
77
+ β”‚ └── utilities/
78
+ β”‚ β”œβ”€β”€ push_to_hub.py
79
+ β”‚ └── resources.py
80
+ β”œβ”€β”€ README.md
81
+ β”œβ”€β”€ app.py
82
+ β”œβ”€β”€ poetry.lock
83
+ β”œβ”€β”€ pyproject.toml
84
+ └── requirements.txt
85
+ ```
86
+
87
+ ## Prerequisites
88
+
89
+ ### Using requirements.txt (Not preferable to me atleast)
90
+ * Python 3.8 or higher
91
+ * Install dependencies:
92
+ ```bash
93
+ pip install -r requirements.txt
94
+ ```
95
+
96
+ ### Using Poetry
97
+ 1. Install Poetry (if not already installed):
98
+
99
+ Linux:
100
+ ```bash
101
+ curl -sSL https://install.python-poetry.org | python3 -
102
+ ```
103
+ Other platforms: Follow the official instructions.
104
+
105
+ 2. Install dependencies:
106
+ ```bash
107
+ poetry install
108
+ ```
109
+
110
+ 3. Activate the virtual environment:
111
+ ```bash
112
+ poetry shell
113
+ ```
114
+
115
+ ## Usage
116
+
117
+ ### Launch the App
118
+ Run the following command to start the Gradio web application:
119
+ ```bash
120
+ python src/app.py
121
+ ```
122
+ The app will be accessible at http://localhost:7860 by default.
123
+
124
+ ### Steps to Use the App
125
+ 1. Enter Model Details:
126
+ * Provide the Hugging Face model name
127
+ * Select the task type (e.g., text classification, question answering)
128
+
129
+ 2. Select Optimization Options:
130
+ * Choose quantization type (e.g., 4-bit, 8-bit)
131
+ * Enable ONNX conversion and select quantization options if needed
132
+
133
+ 3. Provide Hugging Face Token:
134
+ * Enter your Hugging Face token for accessing and pushing models to the Hub
135
+
136
+ 4. Start Conversion:
137
+ * Click the "Start Conversion" button to process the model
138
+
139
+ 5. Monitor Progress:
140
+ * View real-time status updates, resource usage, and results directly in the app
141
+
142
+ 6. Push to Hub:
143
+ * Optimized models are automatically pushed to your specified Hugging Face repository
144
+
145
+ ### Example
146
+ For a model like bert-base-uncased performing text classification:
147
+ 1. Select text_classification as the task
148
+ 2. Enable quantization (e.g., 8-bit)
149
+ 3. Enable ONNX conversion with optimization
150
+ 4. Click "Start Conversion" and monitor progress
151
+
152
+ ## Key Functions
153
+
154
+ ### app.py
155
+ * `process_model`: Main function handling model quantization, ONNX conversion, and Hugging Face Hub integration
156
+ * `update_memory_info`: Monitors and displays system resource usage
157
+
158
+ ### optimization/onnx_conversion.py
159
+ * `convert_to_onnx`: Converts models to ONNX format
160
+ * `quantize_onnx_model`: Quantizes ONNX models for optimized inference
161
+
162
+ ### optimization/quantize.py
163
+ * `ModelQuantizer`: Handles quantization of PyTorch models and performance testing
164
+
165
+ ### utilities/push_to_hub.py
166
+ * `push_to_hub`: Pushes models to the Hugging Face Hub
167
+
168
+ ### utilities/resources.py
169
+ * `ResourceManager`: Manages temporary files and memory usage
170
+
171
+ ## Notes
172
+ * Ensure you have sufficient system resources for model conversion and quantization
173
+ * Use a Hugging Face Hub token with proper write permissions for pushing models
174
+
175
+ ## Troubleshooting
176
+ * Model Conversion Fails: Ensure the model and task are supported
177
+ * Insufficient Resources: Free up memory or reduce optimization levels
178
+ * ONNX Quantization Errors: Verify that the selected quantization type is supported for the model
179
+
180
+ ## License
181
+ This project is licensed under the MIT License. See the LICENSE file for details.
182
+
183
+ ## Contributions
184
+ Contributions are welcome! Please open an issue or submit a pull request for any improvements or bug fixes.
185
+
186
+ ## Acknowledgments
187
+ * Hugging Face Transformers
188
+ * Optimum Library
189
+ * Gradio
190
+ * ONNX Runtime
app.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import logging
3
+ from typing import Tuple, Dict, Any
4
+ from src.utilities.resources import ResourceManager
5
+ from src.utilities.push_to_hub import push_to_hub
6
+ from src.optimizations.onnx_conversion import convert_to_onnx
7
+ from src.optimizations.quantize import quantize_onnx_model
8
+ from src.handlers import get_model_handler, TASK_CONFIGS
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
+ import json
12
+
13
+ def process_model(
14
+ model_name: str,
15
+ task: str,
16
+ quantization_type: str,
17
+ enable_onnx: bool,
18
+ onnx_quantization: str,
19
+ hf_token: str,
20
+ repo_name: str,
21
+ test_text: str
22
+ ) -> Tuple[Dict[str, Any], str, Dict[str, Any]]:
23
+ try:
24
+ resource_manager = ResourceManager()
25
+ status_updates = []
26
+ status = {
27
+ "status": "Processing",
28
+ "progress": 0,
29
+ "current_step": "Initializing",
30
+ }
31
+
32
+ metrics = {}
33
+
34
+ if not model_name or not hf_token or not repo_name:
35
+ return (
36
+ {"status": "Error", "progress": 0, "current_step": "Validation Failed"},
37
+ "Model name, HuggingFace token, and repository name are required.",
38
+ metrics
39
+ )
40
+
41
+ status["progress"] = 0.2
42
+ status["current_step"] = "Initialization"
43
+ status_updates.append("Initialization complete")
44
+
45
+ quantized_model_path = None
46
+
47
+ if quantization_type != "None":
48
+ status.update({"progress": 0.4, "current_step": "Quantization"})
49
+ status_updates.append(f"Applying {quantization_type} quantization")
50
+
51
+ if not test_text:
52
+ test_text = TASK_CONFIGS[task]["example_text"]
53
+
54
+ try:
55
+ handler = get_model_handler(task, model_name, quantization_type, test_text)
56
+ quantized_model = handler.compare()
57
+ metrics = handler.get_metrics()
58
+ metrics = json.loads(json.dumps(metrics))
59
+
60
+ quantized_model_path = str(resource_manager.temp_dirs["quantized"] / "model")
61
+ quantized_model.save_pretrained(quantized_model_path)
62
+ status_updates.append("Quantization completed successfully")
63
+ except Exception as e:
64
+ logger.error(f"Quantization error: {str(e)}", exc_info=True)
65
+ return (
66
+ {"status": "Error", "progress": 0.4, "current_step": "Quantization Failed"},
67
+ f"Quantization failed: {str(e)}",
68
+ metrics
69
+ )
70
+
71
+ if enable_onnx:
72
+ status.update({"progress": 0.6, "current_step": "ONNX Conversion"})
73
+ status_updates.append("Converting to ONNX format")
74
+
75
+ try:
76
+ output_dir = str(resource_manager.temp_dirs["onnx"])
77
+ onnx_result = convert_to_onnx(model_name, task, output_dir)
78
+
79
+ if onnx_result is None:
80
+ return (
81
+ {"status": "Error", "progress": 0.6, "current_step": "ONNX Conversion Failed"},
82
+ "ONNX conversion failed.",
83
+ metrics
84
+ )
85
+
86
+ if onnx_quantization != "None":
87
+ status_updates.append(f"Applying {onnx_quantization} ONNX quantization")
88
+ quantize_onnx_model(output_dir, onnx_quantization)
89
+
90
+ status.update({"progress": 0.8, "current_step": "Pushing ONNX Model"})
91
+ status_updates.append("Pushing ONNX model to Hub")
92
+ result, push_message = push_to_hub(
93
+ local_path=output_dir,
94
+ repo_name=f"{repo_name}-optimized",
95
+ hf_token=hf_token,
96
+ tags=["onnx", "optimum", task],
97
+ )
98
+ status_updates.append(push_message)
99
+ except Exception as e:
100
+ logger.error(f"ONNX error: {str(e)}", exc_info=True)
101
+ return (
102
+ {"status": "Error", "progress": 0.6, "current_step": "ONNX Processing Failed"},
103
+ f"ONNX processing failed: {str(e)}",
104
+ metrics
105
+ )
106
+
107
+ if quantization_type != "None" and quantized_model_path:
108
+ status.update({"progress": 0.9, "current_step": "Pushing Quantized Model"})
109
+ status_updates.append("Pushing quantized model to Hub")
110
+ result, push_message = push_to_hub(
111
+ local_path=quantized_model_path,
112
+ repo_name=f"{repo_name}-optimized",
113
+ hf_token=hf_token,
114
+ tags=["quantized", task, quantization_type],
115
+ )
116
+ status_updates.append(push_message)
117
+
118
+ status.update({"progress": 1.0, "status": "Complete", "current_step": "Completed"})
119
+ cleanup_message = resource_manager.cleanup_temp_files()
120
+ status_updates.append(cleanup_message)
121
+
122
+ return (
123
+ status,
124
+ "\n".join(status_updates),
125
+ metrics
126
+ )
127
+
128
+ except Exception as e:
129
+ logger.error(f"Error during processing: {str(e)}", exc_info=True)
130
+ return (
131
+ {"status": "Error", "progress": 0, "current_step": "Process Failed"},
132
+ f"An error occurred: {str(e)}",
133
+ metrics
134
+ )
135
+
136
+ # Gradio Interface
137
+ with gr.Blocks(theme=gr.themes.Soft()) as app:
138
+ gr.Markdown("""
139
+ # πŸ€— Model Conversion Hub
140
+ Convert and optimize your Hugging Face models with quantization and ONNX support.
141
+ """)
142
+
143
+ with gr.Row():
144
+ with gr.Column(scale=2):
145
+ model_name = gr.Textbox(label="Model Name", placeholder="e.g., bert-base-uncased")
146
+ task = gr.Dropdown(choices=list(TASK_CONFIGS.keys()), label="Task", value="text_classification")
147
+
148
+ with gr.Group():
149
+ gr.Markdown("### Quantization Settings")
150
+ quantization_type = gr.Dropdown(choices=["None", "4-bit", "8-bit", "16-bit-float"], label="Quantization Type", value="None")
151
+ test_text = gr.Textbox(label="Test Text", placeholder="Enter text for model evaluation", lines=3, visible=False)
152
+
153
+ with gr.Group():
154
+ gr.Markdown("### ONNX Settings")
155
+ enable_onnx = gr.Checkbox(label="Enable ONNX Conversion")
156
+ with gr.Group(visible=False) as onnx_group:
157
+ onnx_quantization = gr.Dropdown(choices=["None", "8-bit", "16-bit-int", "16-bit-float"], label="ONNX Quantization", value="None")
158
+
159
+ with gr.Group():
160
+ gr.Markdown("### HuggingFace Settings")
161
+ hf_token = gr.Textbox(label="HuggingFace Token (Required)", type="password")
162
+ repo_name = gr.Textbox(label="Repository Name")
163
+
164
+ with gr.Column(scale=1):
165
+ status_output = gr.JSON(label="Status", value={"status": "Ready", "progress": 0, "current_step": "Waiting"})
166
+ message_output = gr.Markdown(label="Progress Messages")
167
+
168
+ gr.Markdown("### Metrics")
169
+ with gr.Group():
170
+ metrics_output = gr.JSON(
171
+ value={
172
+ "model_sizes": {"original": 0.0, "quantized": 0.0},
173
+ "inference_times": {"original": 0.0, "quantized": 0.0},
174
+ "comparison_metrics": {}
175
+ },
176
+ show_label=True
177
+ )
178
+
179
+ memory_info = gr.JSON(label="Resource Usage")
180
+ convert_btn = gr.Button("πŸš€ Start Conversion", variant="primary")
181
+
182
+ with gr.Accordion("ℹ️ Help", open=False):
183
+ gr.Markdown("""
184
+ ### Quick Guide
185
+ 1. Enter your model name and HuggingFace token.
186
+ 2. Select the appropriate task.
187
+ 3. Choose optimization options.
188
+ 4. Click Start Conversion.
189
+
190
+ ### Tips
191
+ - Ensure sufficient system resources.
192
+ - Use test text to validate conversions.
193
+ """)
194
+
195
+ def update_memory_info():
196
+ resource_manager = ResourceManager()
197
+ return resource_manager.get_memory_info()
198
+
199
+ quantization_type.change(lambda x: gr.update(visible=x != "None"), inputs=[quantization_type], outputs=[test_text])
200
+ task.change(lambda x: gr.update(value=TASK_CONFIGS[x]["example_text"]), inputs=[task], outputs=[test_text])
201
+ enable_onnx.change(lambda x: gr.update(visible=x), inputs=[enable_onnx], outputs=[onnx_group])
202
+
203
+ convert_btn.click(
204
+ process_model,
205
+ inputs=[model_name, task, quantization_type, enable_onnx, onnx_quantization, hf_token, repo_name, test_text],
206
+ outputs=[status_output, message_output, metrics_output]
207
+ )
208
+ app.load(update_memory_info, outputs=[memory_info], every=30)
209
+
210
+ if __name__ == "__main__":
211
+ app.launch(server_name="0.0.0.0", server_port=7860, share=True, debug=True)
poetry.lock ADDED
The diff for this file is too large to render. See raw diff
 
pyproject.toml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "autoquantnx"
3
+ version = "0.1.0"
4
+ description = "Webapp to quantize and convert to ONNX HF models in go and compare them"
5
+ authors = ["kartikbhtt7 <[email protected]>"]
6
+ readme = "README.md"
7
+
8
+ [tool.poetry.dependencies]
9
+ python = ">=3.10,<3.12"
10
+ gradio = "^4.00.0"
11
+ transformers = "^4.31.0"
12
+ pandas = "^2.2.2"
13
+ torch = "^2.3.0"
14
+ onnx = "^1.16.2"
15
+ onnxruntime = "^1.18.1"
16
+ onnxconverter-common = ">=1.14.0"
17
+ optimum = "^1.21.3"
18
+ huggingface-hub = "^0.24.6"
19
+ sentence-transformers = "^3.0.1"
20
+ bitsandbytes = "^0.43.3"
21
+ evaluate = "^0.4.0"
22
+ faiss-gpu = "^1.7.2"
23
+ faiss-cpu = "^1.8.0.post1"
24
+ azure-cognitiveservices-speech = "^1.40.0"
25
+ gdown = "^5.2.0"
26
+ jiwer = "^3.0.4"
27
+ pydub = "^0.25.1"
28
+ librosa = "^0.10.2.post1"
29
+ soundfile = "^0.12.1"
30
+ catalogue = "^2.0.10"
31
+ langchain-core = "^0.1.40"
32
+ langchain-openai = "^0.1.0"
33
+ fast-pytorch-kmeans = "^0.2.0.1"
34
+ typing-extensions = "^4.12.2"
35
+ textwrap3 = "^0.9.2"
36
+ pynvml = "^11.5.3"
37
+ psutil = "^6.1.1"
38
+ accelerate = "^0.26.0"
39
+
40
+ [tool.poetry.dev-dependencies]
41
+ black = "^23.7.0"
42
+ flake8 = "^6.1.0"
43
+ pytest = "^7.4.3"
44
+ pytest-asyncio = "^0.21.1"
45
+ pytest-django = "^4.8.0"
46
+ pytest-cov = "^4.1.0"
47
+ pytest-testmon = "^2.1.0"
48
+ pytest-watch = "^4.2.0"
49
+ coverage = "^7.3.2"
50
+
51
+ [build-system]
52
+ requires = ["poetry>=1.5.1"]
53
+ build-backend = "poetry.core.masonry.api"
requirements.txt ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.26.1
2
+ aiofiles==23.2.1
3
+ aiohappyeyeballs==2.4.4
4
+ aiohttp==3.11.11
5
+ aiosignal==1.3.2
6
+ annotated-types==0.7.0
7
+ anyio==4.7.0
8
+ async-timeout==5.0.1
9
+ attrs==24.3.0
10
+ audioread==3.0.1
11
+ azure-cognitiveservices-speech==1.41.1
12
+ beautifulsoup4==4.12.3
13
+ bitsandbytes==0.43.3
14
+ black==23.12.1
15
+ catalogue==2.0.10
16
+ certifi==2024.12.14
17
+ cffi==1.17.1
18
+ charset-normalizer==3.4.1
19
+ click==8.1.8
20
+ colorama==0.4.6
21
+ coloredlogs==15.0.1
22
+ contourpy==1.3.1
23
+ coverage==7.6.10
24
+ cycler==0.12.1
25
+ datasets==2.14.4
26
+ decorator==5.1.1
27
+ dill==0.3.7
28
+ distro==1.9.0
29
+ docopt==0.6.2
30
+ evaluate==0.4.3
31
+ exceptiongroup==1.2.2
32
+ faiss-cpu==1.9.0.post1
33
+ faiss-gpu==1.7.2
34
+ fast_pytorch_kmeans==0.2.2
35
+ fastapi==0.115.6
36
+ ffmpy==0.5.0
37
+ filelock==3.16.1
38
+ flake8==6.1.0
39
+ flatbuffers==24.12.23
40
+ fonttools==4.55.3
41
+ frozenlist==1.5.0
42
+ fsspec==2024.12.0
43
+ gdown==5.2.0
44
+ gradio==4.44.1
45
+ gradio_client==1.3.0
46
+ h11==0.14.0
47
+ httpcore==1.0.7
48
+ httpx==0.28.1
49
+ huggingface-hub==0.24.7
50
+ humanfriendly==10.0
51
+ idna==3.10
52
+ importlib_resources==6.4.5
53
+ iniconfig==2.0.0
54
+ Jinja2==3.1.5
55
+ jiter==0.8.2
56
+ jiwer==3.0.5
57
+ joblib==1.4.2
58
+ jsonpatch==1.33
59
+ jsonpointer==3.0.0
60
+ kiwisolver==1.4.8
61
+ langchain-core==0.1.53
62
+ langchain-openai==0.1.7
63
+ langsmith==0.1.147
64
+ lazy_loader==0.4
65
+ librosa==0.10.2.post1
66
+ llvmlite==0.43.0
67
+ markdown-it-py==3.0.0
68
+ MarkupSafe==2.1.5
69
+ matplotlib==3.10.0
70
+ mccabe==0.7.0
71
+ mdurl==0.1.2
72
+ mpmath==1.3.0
73
+ msgpack==1.1.0
74
+ multidict==6.1.0
75
+ multiprocess==0.70.15
76
+ mypy-extensions==1.0.0
77
+ networkx==3.4.2
78
+ numba==0.60.0
79
+ numpy==1.26.4
80
+ nvidia-cublas-cu12==12.1.3.1
81
+ nvidia-cuda-cupti-cu12==12.1.105
82
+ nvidia-cuda-nvrtc-cu12==12.1.105
83
+ nvidia-cuda-runtime-cu12==12.1.105
84
+ nvidia-cudnn-cu12==9.1.0.70
85
+ nvidia-cufft-cu12==11.0.2.54
86
+ nvidia-curand-cu12==10.3.2.106
87
+ nvidia-cusolver-cu12==11.4.5.107
88
+ nvidia-cusparse-cu12==12.1.0.106
89
+ nvidia-nccl-cu12==2.20.5
90
+ nvidia-nvjitlink-cu12==12.6.85
91
+ nvidia-nvtx-cu12==12.1.105
92
+ onnx==1.17.0
93
+ onnxconverter-common==1.14.0
94
+ onnxruntime==1.20.1
95
+ openai==1.58.1
96
+ optimum==1.23.3
97
+ orjson==3.10.13
98
+ packaging==23.2
99
+ pandas==2.2.3
100
+ pathspec==0.12.1
101
+ pillow==10.4.0
102
+ platformdirs==4.3.6
103
+ pluggy==1.5.0
104
+ pooch==1.8.2
105
+ propcache==0.2.1
106
+ protobuf==3.20.2
107
+ psutil==6.1.1
108
+ pyarrow==18.1.0
109
+ pycodestyle==2.11.1
110
+ pycparser==2.22
111
+ pydantic==2.10.4
112
+ pydantic_core==2.27.2
113
+ pydub==0.25.1
114
+ pyflakes==3.1.0
115
+ Pygments==2.19.1
116
+ pynvml==11.5.3
117
+ pyparsing==3.2.1
118
+ PySocks==1.7.1
119
+ pytest==7.4.4
120
+ pytest-asyncio==0.21.2
121
+ pytest-cov==4.1.0
122
+ pytest-django==4.9.0
123
+ pytest-testmon==2.1.3
124
+ pytest-watch==4.2.0
125
+ python-dateutil==2.9.0.post0
126
+ python-multipart==0.0.20
127
+ pytz==2024.2
128
+ PyYAML==6.0.2
129
+ RapidFuzz==3.11.0
130
+ regex==2024.11.6
131
+ requests==2.32.3
132
+ requests-toolbelt==1.0.0
133
+ rich==13.9.4
134
+ ruff==0.9.6
135
+ safetensors==0.4.5
136
+ scikit-learn==1.6.0
137
+ scipy==1.14.1
138
+ semantic-version==2.10.0
139
+ sentence-transformers==3.3.1
140
+ shellingham==1.5.4
141
+ six==1.17.0
142
+ sniffio==1.3.1
143
+ soundfile==0.12.1
144
+ soupsieve==2.6
145
+ soxr==0.5.0.post1
146
+ starlette==0.41.3
147
+ sympy==1.13.3
148
+ tenacity==8.5.0
149
+ textwrap3==0.9.2
150
+ threadpoolctl==3.5.0
151
+ tiktoken==0.8.0
152
+ tokenizers==0.21.0
153
+ tomli==2.2.1
154
+ tomlkit==0.12.0
155
+ torch==2.4.1
156
+ tqdm==4.67.1
157
+ transformers==4.47.1
158
+ triton==3.0.0
159
+ typer==0.15.1
160
+ typing_extensions==4.12.2
161
+ tzdata==2024.2
162
+ urllib3==2.3.0
163
+ uvicorn==0.34.0
164
+ watchdog==6.0.0
165
+ websockets==11.0.3
166
+ xxhash==3.5.0
167
+ yarl==1.18.3
src/handlers/__init__.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_handler import ModelHandler
2
+ from .nlp_models.sequence_classification_handler import SequenceClassificationHandler
3
+ from .nlp_models.question_answering_handler import QuestionAnsweringHandler
4
+ from .nlp_models.token_classification_handler import TokenClassificationHandler
5
+ from .nlp_models.causal_lm_handler import CausalLMHandler
6
+ from .nlp_models.embedding_model_handler import EmbeddingModelHandler
7
+ from .audio_models.whisper_handler import WhisperHandler
8
+ from .nlp_models.masked_lm_handler import MaskedLMHandler
9
+ from .nlp_models.seq2seq_lm_handler import Seq2SeqLMHandler
10
+ from .nlp_models.multiple_choice_handler import MultipleChoiceHandler
11
+ from .img_models.image_classification_handler import ImageClassificationHandler
12
+
13
+ from transformers import (
14
+ AutoModel,
15
+ AutoModelForTokenClassification,
16
+ AutoModelForSequenceClassification,
17
+ AutoModelForQuestionAnswering,
18
+ AutoModelForCausalLM,
19
+ AutoModelForMaskedLM,
20
+ AutoModelForSeq2SeqLM,
21
+ AutoModelForMultipleChoice,
22
+ )
23
+
24
+ TASK_CONFIGS = {
25
+ "embedding": {
26
+ "model_class": AutoModel,
27
+ "handler_class": EmbeddingModelHandler,
28
+ "example_text": "Hey, I am feeling way to good to be true.",
29
+ },
30
+ "ner": {
31
+ "model_class": AutoModelForTokenClassification,
32
+ "handler_class": TokenClassificationHandler,
33
+ "example_text": "John works at Google in New York as a software engineer.",
34
+ },
35
+ "text_classification": {
36
+ "model_class": AutoModelForSequenceClassification,
37
+ "handler_class": SequenceClassificationHandler,
38
+ "example_text": "This movie was great and I loved it.",
39
+ },
40
+ "question_answering": {
41
+ "model_class": AutoModelForQuestionAnswering,
42
+ "handler_class": QuestionAnsweringHandler,
43
+ "example_text": "The pyramids were built in ancient Egypt. QUES: Where were the pyramids built?",
44
+ },
45
+ "causal_lm": {
46
+ "model_class": AutoModelForCausalLM,
47
+ "handler_class": CausalLMHandler,
48
+ "example_text": "Once upon a time, there was ",
49
+ },
50
+ "mask_lm": {
51
+ "model_class": AutoModelForMaskedLM,
52
+ "handler_class": MaskedLMHandler,
53
+ "example_text": "The quick brown [MASK] jumps over the lazy dog.",
54
+ },
55
+ "seq2seq_lm": {
56
+ "model_class": AutoModelForSeq2SeqLM,
57
+ "handler_class": Seq2SeqLMHandler,
58
+ "example_text": "Translate English to French: The house is wonderful.",
59
+ },
60
+ "multiple_choice": {
61
+ "model_class": AutoModelForMultipleChoice,
62
+ "handler_class": MultipleChoiceHandler,
63
+ "example_text": "What is the capital of France? (A) Paris (B) London (C) Berlin (D) Rome",
64
+ },
65
+ "whisper_finetuning": {
66
+ "model_class": None, # Not implemented
67
+ "handler_class": WhisperHandler,
68
+ "example_text": "!!!!!NOT IMPLEMENTED!!!!!",
69
+ },
70
+ "image_classification": {
71
+ "model_class": None, # Not implemented
72
+ "handler_class": ImageClassificationHandler,
73
+ "example_text": "!!!!!NOT IMPLEMENTED!!!!!",
74
+ },
75
+ }
76
+
77
+ def get_model_handler(task: str, model_name: str, quantization_type: str, test_text: str):
78
+ task_config = TASK_CONFIGS.get(task)
79
+ if not task_config:
80
+ raise ValueError(f"No configuration found for task: {task}")
81
+
82
+ handler_class = task_config["handler_class"]
83
+ model_class = task_config["model_class"]
84
+ return handler_class(model_name, model_class, quantization_type, test_text)
src/handlers/audio_models/whisper_handler.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..base_handler import ModelHandler
2
+
3
+ class WhisperHandler(ModelHandler):
4
+ def __init__(self, model_name, model_class, quantization_type, test_text):
5
+ super().__init__(model_name, model_class, quantization_type, test_text)
6
+
7
+ def run_inference(self, model, text):
8
+ raise NotImplementedError("STT is not implemented.")
9
+
10
+ def decode_output(self, outputs):
11
+ raise NotImplementedError("STT is not implemented.")
12
+
13
+ def compare_outputs(self, original_outputs, quantized_outputs):
14
+ raise NotImplementedError("STT is not implemented.")
src/handlers/base_handler.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from optimizations.quantize import ModelQuantizer
2
+ import torch
3
+ import logging
4
+ import numpy as np
5
+ from dataclasses import dataclass
6
+ from typing import Dict, Any, Optional
7
+ import json
8
+ logger = logging.getLogger(__name__)
9
+
10
+ @dataclass
11
+ class ModelMetrics:
12
+ model_sizes: Dict[str, float]
13
+ inference_times: Dict[str, float]
14
+ comparison_metrics: Dict[str, Any]
15
+
16
+ class ModelHandler:
17
+ """Base class for handling different types of models"""
18
+
19
+ def __init__(self, model_name, model_class, quantization_type, test_text=None):
20
+ self.model_name = model_name
21
+ self.model_class = model_class
22
+ self.quantization_type = quantization_type
23
+ self.test_text = test_text
24
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25
+
26
+ # Load models
27
+ self.original_model = self._load_original_model()
28
+ self.quantized_model = self._load_quantized_model()
29
+ self.metrics: Optional[ModelMetrics] = None
30
+
31
+ def _load_original_model(self):
32
+ """Load the original model"""
33
+ model = self.model_class.from_pretrained(self.model_name)
34
+ return model.to(self.device)
35
+
36
+ def _load_quantized_model(self):
37
+ """Load the quantized model using ModelQuantizer"""
38
+ model = ModelQuantizer.quantize_model(
39
+ self.model_class,
40
+ self.model_name,
41
+ self.quantization_type
42
+ )
43
+ if self.quantization_type not in ["4-bit", "8-bit"]:
44
+ model = model.to(self.device)
45
+ return model
46
+
47
+ @staticmethod
48
+ def _convert_to_serializable(obj):
49
+ """Serialization for metrics"""
50
+ if isinstance(obj, np.generic):
51
+ return obj.item()
52
+ if isinstance(obj, (np.float32, np.float64)):
53
+ return float(obj)
54
+ if isinstance(obj, (np.int32, np.int64)):
55
+ return int(obj)
56
+ if isinstance(obj, np.ndarray):
57
+ return obj.tolist()
58
+ if isinstance(obj, torch.Tensor):
59
+ return obj.cpu().numpy().tolist()
60
+ if isinstance(obj, dict):
61
+ return {k: ModelHandler._convert_to_serializable(v) for k, v in obj.items()}
62
+ if isinstance(obj, list):
63
+ return [ModelHandler._convert_to_serializable(v) for v in obj]
64
+ return obj
65
+
66
+ def _format_metric_value(self, value):
67
+ """Format metric value based on its type"""
68
+ if isinstance(value, (float, np.float32, np.float64)):
69
+ return f"{value:.8f}"
70
+ elif isinstance(value, (int, np.int32, np.int64)):
71
+ return str(value)
72
+ elif isinstance(value, list):
73
+ return "\n" + "\n".join([f" - {item}" for item in value])
74
+ elif isinstance(value, dict):
75
+ return "\n" + "\n".join([f" {k}: {v}" for k, v in value.items()])
76
+ else:
77
+ return str(value)
78
+
79
+ def run_inference(self, model, text):
80
+ """Run model inference - to be implemented by subclasses"""
81
+ raise NotImplementedError
82
+
83
+ def decode_output(self, outputs):
84
+ """Decode model outputs - to be implemented by subclasses"""
85
+ raise NotImplementedError
86
+
87
+ def compare(self):
88
+ """Compare original and quantized models"""
89
+ try:
90
+ if self.test_text is None:
91
+ logger.warning("No test text provided. Skipping inference testing.")
92
+ return self.quantized_model
93
+
94
+ # Run inference
95
+ original_outputs, original_time = self.run_inference(self.original_model, self.test_text)
96
+ quantized_outputs, quantized_time = self.run_inference(self.quantized_model, self.test_text)
97
+
98
+ original_size = ModelQuantizer.get_model_size(self.original_model)
99
+ quantized_size = ModelQuantizer.get_model_size(self.quantized_model)
100
+
101
+ logger.info(f"Original Model Size: {original_size:.2f} MB")
102
+ logger.info(f"Quantized Model Size: {quantized_size:.2f} MB")
103
+ logger.info(f"Original Inference Time: {original_time:.4f} seconds")
104
+ logger.info(f"Quantized Inference Time: {quantized_time:.4f} seconds")
105
+
106
+ # Compare outputs
107
+ comparison_metrics = self.compare_outputs(original_outputs, quantized_outputs) or {}
108
+
109
+ for key, value in comparison_metrics.items():
110
+ comparison_metrics[key] = self._convert_to_serializable(value)
111
+
112
+ self.metrics = {
113
+ "model_sizes": {
114
+ "original": float(original_size),
115
+ "quantized": float(quantized_size)
116
+ },
117
+ "inference_times": {
118
+ "original": float(original_time),
119
+ "quantized": float(quantized_time)
120
+ },
121
+ "comparison_metrics": comparison_metrics
122
+ }
123
+
124
+ return self.quantized_model
125
+ except Exception as e:
126
+ logger.error(f"Quantization and comparison failed: {str(e)}")
127
+ raise e
128
+
129
+ def get_metrics(self) -> Dict[str, Any]:
130
+ """Return the metrics dictionary"""
131
+ if self.metrics is None:
132
+ return {
133
+ "model_sizes": {"original": 0.0, "quantized": 0.0},
134
+ "inference_times": {"original": 0.0, "quantized": 0.0},
135
+ "comparison_metrics": {}
136
+ }
137
+ serializable_metrics = self._convert_to_serializable(self.metrics)
138
+ try:
139
+ json.dumps(serializable_metrics)
140
+ return serializable_metrics
141
+ except (TypeError, ValueError) as e:
142
+ logger.error(f"Error serializing metrics: {str(e)}")
143
+ return {
144
+ "model_sizes": {"original": 0.0, "quantized": 0.0},
145
+ "inference_times": {"original": 0.0, "quantized": 0.0},
146
+ "comparison_metrics": {}
147
+ }
148
+
src/handlers/img_models/image_classification_handler.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..base_handler import ModelHandler
2
+
3
+ class ImageClassificationHandler(ModelHandler):
4
+ def __init__(self, model_name, model_class, quantization_type, test_text):
5
+ super().__init__(model_name, model_class, quantization_type, test_text)
6
+
7
+ def run_inference(self, model, text):
8
+ raise NotImplementedError("Image classification is not implemented.")
9
+
10
+ def decode_output(self, outputs):
11
+ raise NotImplementedError("Image classification is not implemented.")
12
+
13
+ def compare_outputs(self, original_outputs, quantized_outputs):
14
+ raise NotImplementedError("Image classification is not implemented.")
src/handlers/nlp_models/causal_lm_handler.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..base_handler import ModelHandler
2
+ from transformers import AutoTokenizer
3
+ import torch
4
+ import time
5
+ from scipy.stats import spearmanr
6
+ import numpy as np
7
+
8
+ class CausalLMHandler(ModelHandler):
9
+ def __init__(self, model_name, model_class, quantization_type, test_text):
10
+ super().__init__(model_name, model_class, quantization_type, test_text)
11
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+
13
+ def run_inference(self, model, text):
14
+ inputs = self.tokenizer(text, return_tensors='pt').to(self.device)
15
+ start_time = time.time()
16
+ with torch.no_grad():
17
+ outputs = model.generate(**inputs, max_length=50)
18
+ end_time = time.time()
19
+ return outputs, end_time - start_time
20
+
21
+ def decode_output(self, outputs):
22
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
23
+
24
+ def compare_outputs(self, original_outputs, quantized_outputs):
25
+ """Compare outputs for causal language models"""
26
+ if original_outputs is None or quantized_outputs is None:
27
+ return None
28
+
29
+ original_tokens = original_outputs[0].cpu().numpy()
30
+ quantized_tokens = quantized_outputs[0].cpu().numpy()
31
+
32
+ metrics = {
33
+ 'sequence_similarity': np.mean(original_tokens == quantized_tokens),
34
+ 'sequence_length_diff': abs(len(original_tokens) - len(quantized_tokens)),
35
+ 'vocab_distribution_correlation': spearmanr(
36
+ np.bincount(original_tokens),
37
+ np.bincount(quantized_tokens)
38
+ )[0] if len(original_tokens) == len(quantized_tokens) else 0.0
39
+ }
40
+
41
+ original_text = self.decode_output(original_outputs)
42
+ quantized_text = self.decode_output(quantized_outputs)
43
+ metrics['decoded_text_match'] = float(original_text == quantized_text)
44
+ metrics['original_model_text'] = original_text
45
+ metrics['quantized_model_text'] = quantized_text
46
+ return metrics
src/handlers/nlp_models/embedding_model_handler.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..base_handler import ModelHandler
2
+ from transformers import AutoTokenizer
3
+ import torch
4
+ import time
5
+ import numpy as np
6
+ from scipy.stats import spearmanr
7
+ from sklearn.metrics.pairwise import cosine_similarity
8
+
9
+ class EmbeddingModelHandler(ModelHandler):
10
+ def __init__(self, model_name, model_class, quantization_type, test_text):
11
+ super().__init__(model_name, model_class, quantization_type, test_text)
12
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
13
+
14
+ def run_inference(self, model, text):
15
+ inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(self.device)
16
+ start_time = time.time()
17
+ with torch.no_grad():
18
+ outputs = model(**inputs)
19
+ end_time = time.time()
20
+ return outputs, end_time - start_time
21
+
22
+ def decode_output(self, outputs):
23
+ return outputs.last_hidden_state.mean(dim=1).cpu().numpy()
24
+
25
+ def compare_outputs(self, original_outputs, quantized_outputs):
26
+ """Compare outputs for embedding models"""
27
+ if original_outputs is None or quantized_outputs is None:
28
+ return None
29
+
30
+ original_embeds = original_outputs.last_hidden_state.cpu().numpy()
31
+ quantized_embeds = quantized_outputs.last_hidden_state.cpu().numpy()
32
+
33
+ metrics = {
34
+ 'mse': ((original_embeds - quantized_embeds) ** 2).mean(),
35
+ 'cosine_similarity': cosine_similarity(
36
+ original_embeds.reshape(1, -1),
37
+ quantized_embeds.reshape(1, -1)
38
+ )[0][0],
39
+ 'correlation': spearmanr(
40
+ original_embeds.flatten(),
41
+ quantized_embeds.flatten()
42
+ )[0],
43
+ 'norm_difference': np.abs(
44
+ np.linalg.norm(original_embeds) -
45
+ np.linalg.norm(quantized_embeds)
46
+ )
47
+ }
48
+
49
+ return metrics
src/handlers/nlp_models/masked_lm_handler.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..base_handler import ModelHandler
2
+ from transformers import AutoTokenizer
3
+ import torch
4
+ import time
5
+ import numpy as np
6
+
7
+ class MaskedLMHandler(ModelHandler):
8
+ def __init__(self, model_name, model_class, quantization_type, test_text):
9
+ super().__init__(model_name, model_class, quantization_type, test_text)
10
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+
12
+ def run_inference(self, model, text):
13
+ inputs = self.tokenizer(text, return_tensors='pt').to(self.device)
14
+ start_time = time.time()
15
+ with torch.no_grad():
16
+ outputs = model(**inputs)
17
+ end_time = time.time()
18
+ return outputs, inputs, end_time - start_time
19
+
20
+ def decode_output(self, outputs, inputs):
21
+ logits = outputs.logits
22
+ masked_index = torch.where(inputs['input_ids'] == self.tokenizer.mask_token_id)[1]
23
+ predicted_token_id = logits[0, masked_index].argmax(axis=-1)
24
+ return self.tokenizer.decode(predicted_token_id)
25
+
26
+ def compare_outputs(self, original_outputs, quantized_outputs):
27
+ if original_outputs is None or quantized_outputs is None:
28
+ return None
29
+
30
+ original_logits = original_outputs.logits.detach().cpu().numpy()
31
+ quantized_logits = quantized_outputs.logits.detach().cpu().numpy()
32
+
33
+ metrics = {
34
+ 'mse': ((original_logits - quantized_logits) ** 2).mean(),
35
+ 'top_1_accuracy': np.mean(
36
+ np.argmax(original_logits, axis=-1) == np.argmax(quantized_logits, axis=-1)
37
+ ),
38
+ }
39
+ return metrics
src/handlers/nlp_models/multiple_choice_handler.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..base_handler import ModelHandler
2
+ from transformers import AutoTokenizer
3
+ import torch
4
+ import time
5
+ import numpy as np
6
+
7
+ class MultipleChoiceHandler(ModelHandler):
8
+ def __init__(self, model_name, model_class, quantization_type, test_text):
9
+ super().__init__(model_name, model_class, quantization_type, test_text)
10
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+
12
+ def run_inference(self, model, text):
13
+ choices = [text.split(f"({chr(65 + i)})")[1].strip() for i in range(4)]
14
+ inputs = self.tokenizer(choices, return_tensors='pt', padding=True).to(self.device)
15
+ start_time = time.time()
16
+ with torch.no_grad():
17
+ outputs = model(**inputs)
18
+ end_time = time.time()
19
+ return outputs, end_time - start_time
20
+
21
+ def decode_output(self, outputs):
22
+ logits = outputs.logits
23
+ predicted_choice = chr(65 + logits.argmax().item())
24
+ return f"Predicted choice: {predicted_choice}"
25
+
26
+ def compare_outputs(self, original_outputs, quantized_outputs):
27
+ if original_outputs is None or quantized_outputs is None:
28
+ return None
29
+
30
+ original_logits = original_outputs.logits.detach().cpu().numpy()
31
+ quantized_logits = quantized_outputs.logits.detach().cpu().numpy()
32
+
33
+ metrics = {
34
+ 'mse': ((original_logits - quantized_logits) ** 2).mean(),
35
+ 'top_1_accuracy': np.mean(
36
+ np.argmax(original_logits, axis=-1) == np.argmax(quantized_logits, axis=-1)
37
+ ),
38
+ }
39
+ return metrics
src/handlers/nlp_models/question_answering_handler.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..base_handler import ModelHandler
2
+ from transformers import AutoTokenizer
3
+ import torch
4
+ import time
5
+
6
+ class QuestionAnsweringHandler(ModelHandler):
7
+ def __init__(self, model_name, model_class, quantization_type, test_text):
8
+ super().__init__(model_name, model_class, quantization_type, test_text)
9
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
10
+
11
+ def run_inference(self, model, text):
12
+ parts = text.split('QUES')
13
+ context = parts[0].strip()
14
+ question = parts[1].strip()
15
+ inputs = self.tokenizer(question, context, return_tensors='pt', truncation=True, padding=True).to(self.device)
16
+ start_time = time.time()
17
+ with torch.no_grad():
18
+ outputs = model(**inputs)
19
+ end_time = time.time()
20
+ return outputs, end_time - start_time
21
+
22
+ def decode_output(self, outputs):
23
+ start_logits = outputs.start_logits
24
+ end_logits = outputs.end_logits
25
+ answer_start = torch.argmax(start_logits)
26
+ answer_end = torch.argmax(end_logits) + 1
27
+ input_ids = self.tokenizer.encode(self.test_text)
28
+ answer = self.tokenizer.decode(input_ids[answer_start:answer_end])
29
+ return f"Answer: {answer}"
30
+
31
+ def compare_outputs(self, original_outputs, quantized_outputs):
32
+ """Compare outputs for question answering models"""
33
+ if original_outputs is None or quantized_outputs is None:
34
+ return None
35
+
36
+ orig_start = original_outputs.start_logits.cpu().numpy()
37
+ orig_end = original_outputs.end_logits.cpu().numpy()
38
+ quant_start = quantized_outputs.start_logits.cpu().numpy()
39
+ quant_end = quantized_outputs.end_logits.cpu().numpy()
40
+
41
+ orig_start_pos = orig_start.argmax()
42
+ orig_end_pos = orig_end.argmax()
43
+ quant_start_pos = quant_start.argmax()
44
+ quant_end_pos = quant_end.argmax()
45
+
46
+ input_ids = self.tokenizer.encode(self.test_text)
47
+ original_answer = self.tokenizer.decode(input_ids[orig_start_pos:orig_end_pos + 1])
48
+ quantized_answer = self.tokenizer.decode(input_ids[quant_start_pos:quant_end_pos + 1])
49
+
50
+ metrics = {
51
+ 'original_answer': original_answer,
52
+ 'quantized_answer': quantized_answer,
53
+ 'start_position_match': float(orig_start_pos == quant_start_pos),
54
+ 'end_position_match': float(orig_end_pos == quant_end_pos),
55
+ 'start_logits_mse': ((orig_start - quant_start) ** 2).mean(),
56
+ 'end_logits_mse': ((orig_end - quant_end) ** 2).mean(),
57
+ }
58
+
59
+ return metrics
src/handlers/nlp_models/seq2seq_lm_handler.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..base_handler import ModelHandler
2
+ from transformers import AutoTokenizer
3
+ import torch
4
+ import time
5
+ import numpy as np
6
+
7
+ class Seq2SeqLMHandler(ModelHandler):
8
+ def __init__(self, model_name, model_class, quantization_type, test_text):
9
+ super().__init__(model_name, model_class, quantization_type, test_text)
10
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+
12
+ def run_inference(self, model, text):
13
+ inputs = self.tokenizer(text, return_tensors='pt').to(self.device)
14
+ start_time = time.time()
15
+ with torch.no_grad():
16
+ outputs = model.generate(**inputs, max_length=50)
17
+ end_time = time.time()
18
+ return outputs, end_time - start_time
19
+
20
+ def decode_output(self, outputs):
21
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
22
+
23
+ def compare_outputs(self, original_outputs, quantized_outputs):
24
+ if original_outputs is None or quantized_outputs is None:
25
+ return None
26
+
27
+ original_tokens = original_outputs[0].cpu().numpy()
28
+ quantized_tokens = quantized_outputs[0].cpu().numpy()
29
+
30
+ metrics = {
31
+ 'sequence_similarity': np.mean(original_tokens == quantized_tokens),
32
+ 'sequence_length_diff': abs(len(original_tokens) - len(quantized_tokens)),
33
+ }
34
+ return metrics
src/handlers/nlp_models/sequence_classification_handler.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..base_handler import ModelHandler
2
+ from transformers import AutoTokenizer
3
+ import torch
4
+ import time
5
+ import numpy as np
6
+
7
+ class SequenceClassificationHandler(ModelHandler):
8
+ def __init__(self, model_name, model_class, quantization_type, test_text):
9
+ super().__init__(model_name, model_class, quantization_type, test_text)
10
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+
12
+ def run_inference(self, model, text):
13
+ inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(self.device)
14
+ start_time = time.time()
15
+ with torch.no_grad():
16
+ outputs = model(**inputs)
17
+ end_time = time.time()
18
+ return outputs, end_time - start_time
19
+
20
+ def decode_output(self, outputs):
21
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
22
+ predicted_class = torch.argmax(probabilities, dim=-1).item()
23
+ return f"Predicted class: {predicted_class}"
24
+
25
+ def compare_outputs(self, original_outputs, quantized_outputs):
26
+ """Compare outputs for sequence classification models"""
27
+ if original_outputs is None or quantized_outputs is None:
28
+ return None
29
+
30
+ orig_logits = original_outputs.logits.cpu().numpy()
31
+ quant_logits = quantized_outputs.logits.cpu().numpy()
32
+
33
+ orig_probs = torch.nn.functional.softmax(torch.tensor(orig_logits), dim=-1).numpy()
34
+ quant_probs = torch.nn.functional.softmax(torch.tensor(quant_logits), dim=-1).numpy()
35
+
36
+ orig_pred = orig_probs.argmax(axis=-1)
37
+ quant_pred = quant_probs.argmax(axis=-1)
38
+
39
+ metrics = {
40
+ 'class_match': float(orig_pred == quant_pred),
41
+ 'logits_mse': ((orig_logits - quant_logits) ** 2).mean(),
42
+ 'probability_mse': ((orig_probs - quant_probs) ** 2).mean(),
43
+ 'max_probability_diff': abs(orig_probs.max() - quant_probs.max()),
44
+ 'kl_divergence': float(
45
+ (orig_probs * (np.log(orig_probs + 1e-10) - np.log(quant_probs + 1e-10))).sum()
46
+ )
47
+ }
48
+
49
+ return metrics
src/handlers/nlp_models/token_classification_handler.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..base_handler import ModelHandler
2
+ from transformers import AutoTokenizer
3
+ import torch
4
+ import time
5
+
6
+ class TokenClassificationHandler(ModelHandler):
7
+ def __init__(self, model_name, model_class, quantization_type, test_text):
8
+ super().__init__(model_name, model_class, quantization_type, test_text)
9
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
10
+
11
+ def run_inference(self, model, text):
12
+ inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(self.device)
13
+ start_time = time.time()
14
+ with torch.no_grad():
15
+ outputs = model(**inputs)
16
+ end_time = time.time()
17
+ return outputs, end_time - start_time
18
+
19
+ def decode_output(self, model, outputs):
20
+ tokens = self.tokenizer.convert_ids_to_tokens(outputs['input_ids'][0])
21
+ labels = torch.argmax(outputs.logits, dim=-1).squeeze().tolist()
22
+ decoded_labels = [model.config.id2label[label] for label in labels]
23
+ return dict(zip(tokens, decoded_labels))
24
+
25
+ def compare_outputs(self, original_outputs, quantized_outputs):
26
+ """Compare outputs for token classification models"""
27
+ if original_outputs is None or quantized_outputs is None:
28
+ return None
29
+
30
+ orig_logits = original_outputs.logits.cpu().numpy()
31
+ quant_logits = quantized_outputs.logits.cpu().numpy()
32
+
33
+ orig_preds = orig_logits.argmax(axis=-1)
34
+ quant_preds = quant_logits.argmax(axis=-1)
35
+
36
+ input_tokens = self.tokenizer.convert_ids_to_tokens(
37
+ self.tokenizer(self.test_text, return_tensors='pt')['input_ids'][0]
38
+ )
39
+
40
+ orig_labels = [self.original_model.config.id2label[p] for p in orig_preds[0]]
41
+ quant_labels = [self.quantized_model.config.id2label[p] for p in quant_preds[0]]
42
+
43
+ original_results = list(zip(input_tokens, orig_labels))
44
+ quantized_results = list(zip(input_tokens, quant_labels))
45
+
46
+ token_matches = sum(o_label == q_label for o_label, q_label in zip(orig_labels, quant_labels))
47
+ total_tokens = len(orig_labels)
48
+
49
+ metrics = {
50
+ 'original_predictions': original_results,
51
+ 'quantized_predictions': quantized_results,
52
+ 'token_level_accuracy': float(token_matches) / total_tokens if total_tokens > 0 else 0.0,
53
+ 'sequence_exact_match': float((orig_preds == quant_preds).all()),
54
+ 'logits_mse': ((orig_logits - quant_logits) ** 2).mean(),
55
+ }
56
+
57
+ return metrics
src/optimizations/onnx_conversion.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from transformers import AutoTokenizer, WhisperProcessor, AutoFeatureExtractor
3
+ from optimum.onnxruntime import (
4
+ ORTModelForQuestionAnswering,
5
+ ORTModelForCausalLM,
6
+ ORTModelForSequenceClassification,
7
+ ORTModelForTokenClassification,
8
+ ORTModelForSpeechSeq2Seq,
9
+ ORTOptimizer,
10
+ ORTModelForMaskedLM,
11
+ ORTModelForSeq2SeqLM,
12
+ ORTModelForMultipleChoice,
13
+ ORTModelForImageClassification,
14
+ )
15
+ import logging
16
+
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger(__name__)
19
+
20
+ TASK_MAPPING = {
21
+ # NLP models
22
+ "ner": (ORTModelForTokenClassification, AutoTokenizer),
23
+ "text_classification": (ORTModelForSequenceClassification, AutoTokenizer),
24
+ "question_answering": (ORTModelForQuestionAnswering, AutoTokenizer),
25
+ "causal_lm": (ORTModelForCausalLM, AutoTokenizer),
26
+ "mask_lm": (ORTModelForMaskedLM, AutoTokenizer),
27
+ "seq2seq_lm": (ORTModelForSeq2SeqLM, AutoTokenizer),
28
+ "multiple_choice": (ORTModelForMultipleChoice, AutoTokenizer),
29
+ # Audio models
30
+ "whisper_finetuning": (ORTModelForSpeechSeq2Seq, WhisperProcessor),
31
+ # Vision models
32
+ "image_classification": (ORTModelForImageClassification, AutoFeatureExtractor),
33
+ }
34
+
35
+ def convert_to_onnx(model_name, task, output_dir):
36
+ """
37
+ Convert model to ONNX format for the specified task.
38
+ """
39
+ logger.info(f"Converting model: {model_name} for task: {task}")
40
+
41
+ os.makedirs(output_dir, exist_ok=True)
42
+
43
+ if task not in TASK_MAPPING:
44
+ logger.error(f"Task {task} is not supported for ONNX conversion in this script.")
45
+ return None
46
+
47
+ ORTModelClass, ProcessorClass = TASK_MAPPING[task]
48
+
49
+ try:
50
+ if task == "embedding":
51
+ ort_optimizer = ORTOptimizer.from_pretrained(model_name)
52
+ ort_optimizer.export(output_dir=output_dir, task="feature-extraction")
53
+ else:
54
+ ort_model = ORTModelClass.from_pretrained(model_name, export=True)
55
+ ort_model.save_pretrained(output_dir)
56
+
57
+ processor = ProcessorClass.from_pretrained(model_name)
58
+ processor.save_pretrained(output_dir)
59
+
60
+ logger.info(f"Conversion complete. Model saved to: {output_dir}")
61
+ return output_dir
62
+ except Exception as e:
63
+ logger.error(f"Conversion failed: {str(e)}")
64
+ return None
src/optimizations/quantize.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import onnx
4
+ import logging
5
+ from scipy.stats import spearmanr
6
+ from sklearn.metrics.pairwise import cosine_similarity
7
+ from transformers import BitsAndBytesConfig
8
+ from onnxconverter_common import float16
9
+ from onnxruntime.quantization import quantize_dynamic, QuantType
10
+
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+ class ModelQuantizer:
15
+ """Handles model quantization and comparison operations"""
16
+
17
+ @staticmethod
18
+ def quantize_model(model_class, model_name, quantization_type):
19
+ """Quantizes a model based on specified quantization type"""
20
+ try:
21
+ if quantization_type == "4-bit":
22
+ quantization_config = BitsAndBytesConfig(load_in_4bit=True)
23
+ model = model_class.from_pretrained(model_name, quantization_config=quantization_config)
24
+ elif quantization_type == "8-bit":
25
+ quantization_config = BitsAndBytesConfig(load_in_8bit=True)
26
+ model = model_class.from_pretrained(model_name, quantization_config=quantization_config)
27
+ elif quantization_type == "16-bit-float":
28
+ model = model_class.from_pretrained(model_name)
29
+ model = model.to(torch.float16)
30
+ else:
31
+ raise ValueError(f"Unsupported quantization type: {quantization_type}")
32
+
33
+ return model
34
+ except Exception as e:
35
+ logger.error(f"Quantization failed: {str(e)}")
36
+ raise
37
+
38
+ @staticmethod
39
+ def get_model_size(model):
40
+ """Calculate model size in MB"""
41
+ try:
42
+ torch.save(model.state_dict(), "temp.pth")
43
+ size = os.path.getsize("temp.pth") / (1024 * 1024)
44
+ os.remove("temp.pth")
45
+ return size
46
+ except Exception as e:
47
+ logger.error(f"Failed to get model size: {str(e)}")
48
+ raise
49
+
50
+ @staticmethod
51
+ def compare_model_outputs(original_outputs, quantized_outputs):
52
+ """Compare outputs between original and quantized models"""
53
+ try:
54
+ if original_outputs is None or quantized_outputs is None:
55
+ return None
56
+
57
+ if hasattr(original_outputs, 'logits') and hasattr(quantized_outputs, 'logits'):
58
+ original_logits = original_outputs.logits.detach().cpu().numpy()
59
+ quantized_logits = quantized_outputs.logits.detach().cpu().numpy()
60
+
61
+ metrics = {
62
+ 'mse': ((original_logits - quantized_logits) ** 2).mean(),
63
+ 'spearman_corr': spearmanr(original_logits.flatten(), quantized_logits.flatten())[0],
64
+ 'cosine_sim': cosine_similarity(original_logits.reshape(1, -1), quantized_logits.reshape(1, -1))[0][0]
65
+ }
66
+ return metrics
67
+ return None
68
+ except Exception as e:
69
+ logger.error(f"Output comparison failed: {str(e)}")
70
+ raise
71
+
72
+ def quantize_onnx_model(model_dir, quantization_type):
73
+ """
74
+ Quantize ONNX model in the specified directory.
75
+ """
76
+ logger.info(f"Quantizing ONNX model in: {model_dir}")
77
+ for filename in os.listdir(model_dir):
78
+ if filename.endswith('.onnx'):
79
+ input_model_path = os.path.join(model_dir, filename)
80
+ output_model_path = os.path.join(model_dir, f"quantized_{filename}")
81
+
82
+ try:
83
+ model = onnx.load(input_model_path)
84
+
85
+ if quantization_type == "16-bit-float":
86
+ model_fp16 = float16.convert_float_to_float16(model)
87
+ onnx.save(model_fp16, output_model_path)
88
+ elif quantization_type in ["8-bit", "16-bit-int"]:
89
+ quant_type_mapping = {
90
+ "8-bit": QuantType.QInt8,
91
+ "16-bit-int": QuantType.QInt16,
92
+ }
93
+ quantize_dynamic(
94
+ model_input=input_model_path,
95
+ model_output=output_model_path,
96
+ weight_type=quant_type_mapping[quantization_type]
97
+ )
98
+ else:
99
+ logger.error(f"Unsupported quantization type: {quantization_type}")
100
+ continue
101
+
102
+ os.remove(input_model_path)
103
+ os.rename(output_model_path, input_model_path)
104
+
105
+ logger.info(f"Quantized ONNX model saved to: {input_model_path}")
106
+ except Exception as e:
107
+ logger.error(f"Error during ONNX quantization: {str(e)}")
108
+ if os.path.exists(output_model_path):
109
+ os.remove(output_model_path)
src/utilities/push_to_hub.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from pathlib import Path
4
+ from typing import Optional, Dict, Tuple
5
+ from huggingface_hub import HfApi, create_repo
6
+
7
+ logging.basicConfig(level=logging.INFO)
8
+ logger = logging.getLogger(__name__)
9
+
10
+ def push_to_hub(local_path: str, repo_name: str, hf_token: str, commit_message: Optional[str] = None, tags: Optional[list] = None) -> Tuple[Optional[str], str]:
11
+ """
12
+ Pushes a folder containing model files to the HuggingFace Hub.
13
+
14
+ Args:
15
+ local_path (str): Local directory containing the model files to upload.
16
+ repo_name (str): The repository name (not the full username/repo_name).
17
+ hf_token (str): HuggingFace authentication token.
18
+ commit_message (str, optional): Commit message for the upload.
19
+ tags (list, optional): Tags to include in the model card.
20
+
21
+ Returns:
22
+ Tuple[Optional[str], str]: (repository_name, status_message)
23
+ """
24
+ try:
25
+ api = HfApi(token=hf_token)
26
+
27
+ # Validate token
28
+ try:
29
+ user_info = api.whoami()
30
+ username = user_info["name"]
31
+ except Exception as e:
32
+ return None, f"❌ Authentication failed: Invalid token or network error ({str(e)})"
33
+
34
+ # Full repository name with the username
35
+ full_repo_name = f"{username}/{repo_name}"
36
+
37
+ # Create the repo
38
+ try:
39
+ create_repo(full_repo_name, token=hf_token, exist_ok=True)
40
+ logger.info(f"Repository created/verified: {full_repo_name}")
41
+ except Exception as e:
42
+ return None, f"❌ Repository creation failed: {str(e)}"
43
+
44
+ # Create model card
45
+ try:
46
+ tags_list = tags or []
47
+ tags_section = "\n".join(f"- {tag}" for tag in tags_list)
48
+ model_card = f"""---
49
+ tags:
50
+ {tags_section}
51
+ library_name: optimum
52
+ ---
53
+
54
+ # Model - {repo_name}
55
+
56
+ This model has been optimized and uploaded to the HuggingFace Hub.
57
+
58
+ ## Model Details
59
+ - Optimization Tags: {', '.join(tags_list)}
60
+ """
61
+ with open(os.path.join(local_path, "README.md"), "w") as f:
62
+ f.write(model_card)
63
+ except Exception as e:
64
+ logger.warning(f"Model card creation warning: {str(e)}")
65
+
66
+ # Upload the folder
67
+ try:
68
+ api.upload_folder(
69
+ folder_path=local_path,
70
+ repo_id=full_repo_name,
71
+ repo_type="model",
72
+ commit_message=commit_message or "Upload optimized model"
73
+ )
74
+ success_msg = f"βœ… Model successfully pushed to: {full_repo_name}"
75
+ logger.info(success_msg)
76
+ return full_repo_name, success_msg
77
+ except Exception as e:
78
+ error_msg = f"❌ Upload failed: {str(e)}"
79
+ logger.error(error_msg)
80
+ return None, error_msg
81
+
82
+ except Exception as e:
83
+ error_msg = f"❌ Unexpected error during push: {str(e)}"
84
+ logger.error(error_msg)
85
+ return None, error_msg
src/utilities/resources.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ import psutil
3
+ import torch
4
+ import logging
5
+ from pathlib import Path
6
+ from typing import Optional, Dict
7
+
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
+
11
+ class ResourceManager:
12
+ def __init__(self, temp_dir: str = "temp"):
13
+ self.temp_dir = Path(temp_dir)
14
+ self.temp_dirs = {
15
+ "onnx": self.temp_dir / "onnx_output",
16
+ "quantized": self.temp_dir / "quantized_models",
17
+ "cache": self.temp_dir / "model_cache"
18
+ }
19
+ self.setup_directories()
20
+
21
+ def setup_directories(self):
22
+ for dir_path in self.temp_dirs.values():
23
+ dir_path.mkdir(parents=True, exist_ok=True)
24
+
25
+ def cleanup_temp_files(self, specific_dir: Optional[str] = None) -> str:
26
+ try:
27
+ if specific_dir:
28
+ if specific_dir in self.temp_dirs:
29
+ shutil.rmtree(self.temp_dirs[specific_dir], ignore_errors=True)
30
+ self.temp_dirs[specific_dir].mkdir(exist_ok=True)
31
+ else:
32
+ shutil.rmtree(self.temp_dir, ignore_errors=True)
33
+ self.setup_directories()
34
+ return "✨ Cleanup successful!"
35
+ except Exception as e:
36
+ logger.error(f"Cleanup failed: {str(e)}")
37
+ return f"❌ Cleanup failed: {str(e)}"
38
+
39
+ def get_memory_info(self) -> Dict[str, float]:
40
+ vm = psutil.virtual_memory()
41
+ memory_info = {
42
+ "total_ram": vm.total / (1024 ** 3),
43
+ "available_ram": vm.available / (1024 ** 3),
44
+ "used_ram": vm.used / (1024 ** 3)
45
+ }
46
+
47
+ if torch.cuda.is_available():
48
+ device = torch.cuda.current_device()
49
+ memory_info.update({
50
+ "gpu_total": torch.cuda.get_device_properties(device).total_memory / (1024 ** 3),
51
+ "gpu_used": torch.cuda.memory_allocated(device) / (1024 ** 3)
52
+ })
53
+
54
+ return memory_info