File size: 2,170 Bytes
9e70bac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModel\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_memory_required(model_name):\n",
    "    model = AutoModel.from_pretrained(model_name)\n",
    "\n",
    "    # Calculate total parameters (assuming model parameters and gradients are in FP32)\n",
    "    total_params = sum(p.numel() for p in model.parameters())\n",
    "    total_memory_params = total_params * 4  # 4 bytes for FP32\n",
    "\n",
    "    # Optimizer states (e.g., for Adam, it's roughly the same as the model parameters)\n",
    "    optimizer_memory = total_memory_params * 2  # Adam stores two values per parameter\n",
    "\n",
    "    # Batch size and sequence length\n",
    "    batch_size = 32\n",
    "    sequence_length = 512\n",
    "    # Estimate activation memory (very rough estimate)\n",
    "    activation_memory_per_example = sequence_length * model.config.hidden_size * 4  # 4 bytes for FP32\n",
    "    total_activation_memory = batch_size * activation_memory_per_example\n",
    "\n",
    "    # Total estimated memory\n",
    "    total_estimated_memory = total_memory_params + optimizer_memory + total_activation_memory\n",
    "\n",
    "    print(f\"Estimated memory for model and gradients: {total_memory_params / (1024 ** 3):.2f} GB\")\n",
    "    print(f\"Estimated memory for optimizer states: {optimizer_memory / (1024 ** 3):.2f} GB\")\n",
    "    print(f\"Estimated memory for activations: {total_activation_memory / (1024 ** 3):.2f} GB\")\n",
    "    print(f\"Total estimated memory: {total_estimated_memory / (1024 ** 3):.2f} GB\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load model\n",
    "model_name = 'mistralai/Mistral-7B-v0.1'\n",
    "calculate_memory_required(model_name)\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}