Spaces:
Runtime error
Runtime error
File size: 2,170 Bytes
9e70bac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoModel\n",
"import torch"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def calculate_memory_required(model_name):\n",
" model = AutoModel.from_pretrained(model_name)\n",
"\n",
" # Calculate total parameters (assuming model parameters and gradients are in FP32)\n",
" total_params = sum(p.numel() for p in model.parameters())\n",
" total_memory_params = total_params * 4 # 4 bytes for FP32\n",
"\n",
" # Optimizer states (e.g., for Adam, it's roughly the same as the model parameters)\n",
" optimizer_memory = total_memory_params * 2 # Adam stores two values per parameter\n",
"\n",
" # Batch size and sequence length\n",
" batch_size = 32\n",
" sequence_length = 512\n",
" # Estimate activation memory (very rough estimate)\n",
" activation_memory_per_example = sequence_length * model.config.hidden_size * 4 # 4 bytes for FP32\n",
" total_activation_memory = batch_size * activation_memory_per_example\n",
"\n",
" # Total estimated memory\n",
" total_estimated_memory = total_memory_params + optimizer_memory + total_activation_memory\n",
"\n",
" print(f\"Estimated memory for model and gradients: {total_memory_params / (1024 ** 3):.2f} GB\")\n",
" print(f\"Estimated memory for optimizer states: {optimizer_memory / (1024 ** 3):.2f} GB\")\n",
" print(f\"Estimated memory for activations: {total_activation_memory / (1024 ** 3):.2f} GB\")\n",
" print(f\"Total estimated memory: {total_estimated_memory / (1024 ** 3):.2f} GB\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Load model\n",
"model_name = 'mistralai/Mistral-7B-v0.1'\n",
"calculate_memory_required(model_name)\n"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
|