oobabooga's picture
Update app.py
2c659e0 verified
import io
import re
import struct
from enum import IntEnum
from math import floor
import requests
import gradio as gr
class GGUFValueType(IntEnum):
UINT8 = 0
INT8 = 1
UINT16 = 2
INT16 = 3
UINT32 = 4
INT32 = 5
FLOAT32 = 6
BOOL = 7
STRING = 8
ARRAY = 9
UINT64 = 10
INT64 = 11
FLOAT64 = 12
_simple_value_packing = {
GGUFValueType.UINT8: "<B",
GGUFValueType.INT8: "<b",
GGUFValueType.UINT16: "<H",
GGUFValueType.INT16: "<h",
GGUFValueType.UINT32: "<I",
GGUFValueType.INT32: "<i",
GGUFValueType.FLOAT32: "<f",
GGUFValueType.UINT64: "<Q",
GGUFValueType.INT64: "<q",
GGUFValueType.FLOAT64: "<d",
GGUFValueType.BOOL: "?",
}
value_type_info = {
GGUFValueType.UINT8: 1,
GGUFValueType.INT8: 1,
GGUFValueType.UINT16: 2,
GGUFValueType.INT16: 2,
GGUFValueType.UINT32: 4,
GGUFValueType.INT32: 4,
GGUFValueType.FLOAT32: 4,
GGUFValueType.UINT64: 8,
GGUFValueType.INT64: 8,
GGUFValueType.FLOAT64: 8,
GGUFValueType.BOOL: 1,
}
def get_single(value_type, file):
if value_type == GGUFValueType.STRING:
value_length = struct.unpack("<Q", file.read(8))[0]
value = file.read(value_length)
try:
value = value.decode('utf-8')
except:
pass
else:
type_str = _simple_value_packing.get(value_type)
bytes_length = value_type_info.get(value_type)
value = struct.unpack(type_str, file.read(bytes_length))[0]
return value
def load_metadata_from_file(file_obj):
"""Load metadata from a file-like object"""
metadata = {}
GGUF_MAGIC = struct.unpack("<I", file_obj.read(4))[0]
GGUF_VERSION = struct.unpack("<I", file_obj.read(4))[0]
ti_data_count = struct.unpack("<Q", file_obj.read(8))[0]
kv_data_count = struct.unpack("<Q", file_obj.read(8))[0]
if GGUF_VERSION == 1:
raise Exception('You are using an outdated GGUF, please download a new one.')
for i in range(kv_data_count):
key_length = struct.unpack("<Q", file_obj.read(8))[0]
key = file_obj.read(key_length)
value_type = GGUFValueType(struct.unpack("<I", file_obj.read(4))[0])
if value_type == GGUFValueType.ARRAY:
ltype = GGUFValueType(struct.unpack("<I", file_obj.read(4))[0])
length = struct.unpack("<Q", file_obj.read(8))[0]
arr = [get_single(ltype, file_obj) for _ in range(length)]
metadata[key.decode()] = arr
else:
value = get_single(value_type, file_obj)
metadata[key.decode()] = value
# Extract specific fields needed for VRAM calculation
extracted_fields = {}
for key, value in metadata.items():
if key.endswith('.block_count'):
extracted_fields['n_layers'] = value
elif key.endswith('.attention.head_count_kv'):
extracted_fields['n_kv_heads'] = max(value) if isinstance(value, list) else value
elif key.endswith('.embedding_length'):
extracted_fields['embedding_dim'] = value
elif key.endswith('.context_length'):
extracted_fields['context_length'] = value
elif key.endswith('.feed_forward_length'):
extracted_fields['feed_forward_dim'] = value
# Add extracted fields to metadata for easy access
metadata.update(extracted_fields)
return metadata
def download_gguf_partial(url, max_bytes=25 * 1024 * 1024):
"""Download the first max_bytes from a GGUF URL"""
try:
# Set up headers for partial content request
headers = {'Range': f'bytes=0-{max_bytes-1}'}
# Make the request
response = requests.get(url, headers=headers, stream=True)
response.raise_for_status()
# Read the content
content = response.content
# Convert to BytesIO for file-like interface
return io.BytesIO(content)
except Exception as e:
raise Exception(f"Failed to download GGUF file: {str(e)}")
def load_metadata(model_url, current_metadata):
"""Load metadata from model URL and return updated metadata dict"""
if not model_url or model_url.strip() == "":
return {}, "Please enter a model URL"
try:
# Get model size first
model_size_mb = get_model_size_mb_from_url(model_url)
# Normalize URL for downloading
normalized_url = normalize_huggingface_url(model_url)
# Download the first 25MB of the file
file_obj = download_gguf_partial(normalized_url)
# Parse the metadata
metadata = load_metadata_from_file(file_obj)
# Extract filename from URL
gguf_filename = model_url.split('/')[-1].split('?')[0] # Remove query parameters if any
# Extract model name from URL if it's a Hugging Face URL
model_name = model_url
if "huggingface.co/" in model_url:
try:
# Extract model name from URL like https://huggingface.co/user/model
parts = model_url.split("huggingface.co/")[1].split("/")
if len(parts) >= 2:
model_name = f"{parts[0]}/{parts[1]}"
except:
model_name = model_url
# Add URL, model name, and size to metadata
metadata['url'] = model_url
metadata['model_name'] = model_name
metadata['model_size_mb'] = model_size_mb
metadata['loaded'] = True
return metadata, gr.update(value=metadata["n_layers"], maximum=metadata["n_layers"]), f"Metadata loaded successfully for: {gguf_filename}"
except Exception as e:
error_msg = f"Error loading metadata: {str(e)}"
return {}, gr.update(), error_msg
def normalize_huggingface_url(url: str) -> str:
"""Normalize HuggingFace URL to resolve format for direct access"""
if 'huggingface.co' not in url:
return url
# Remove query parameters first
base_url = url.split('?')[0]
# Convert blob URL to resolve URL
if '/blob/' in base_url:
base_url = base_url.replace('/blob/', '/resolve/')
return base_url
def get_model_size_mb_from_url(model_url: str) -> float:
"""Get model size in MB from URL without downloading, handling multi-part files"""
try:
# Normalize the URL for direct access
normalized_url = normalize_huggingface_url(model_url)
# Get size of the main file
response = requests.head(normalized_url, allow_redirects=True)
response.raise_for_status()
main_file_size = int(response.headers.get('content-length', 0))
# Extract filename from original URL
filename = normalized_url.split('/')[-1]
# Check for multipart pattern (e.g., model-00001-of-00002.gguf)
match = re.match(r'(.+)-(\d+)-of-(\d+)\.gguf$', filename)
if match:
base_pattern = match.group(1)
total_parts = int(match.group(3))
total_size = 0
base_url = '/'.join(normalized_url.split('/')[:-1]) + '/'
# Get size of all parts
for part_num in range(1, total_parts + 1):
part_filename = f"{base_pattern}-{part_num:05d}-of-{total_parts:05d}.gguf"
part_url = base_url + part_filename
try:
part_response = requests.head(part_url, allow_redirects=True)
part_response.raise_for_status()
part_size = int(part_response.headers.get('content-length', 0))
total_size += part_size
except requests.RequestException as e:
print(f"Warning: Could not get size of {part_filename}, estimating...")
# If we can't get some parts, estimate based on what we have
if total_size > 0:
avg_size = total_size / (part_num - 1)
remaining_parts = total_parts - (part_num - 1)
total_size += avg_size * remaining_parts
else:
# Fallback to main file size * total parts
total_size = main_file_size * total_parts
break
return total_size / (1024 ** 2)
else:
# Single part file
return main_file_size / (1024 ** 2)
except Exception as e:
print(f"Error getting model size: {e}")
return 0.0
def estimate_vram(metadata, gpu_layers, ctx_size, cache_type):
"""Calculate VRAM usage using the actual formula"""
try:
# Extract required values from metadata
n_layers = metadata.get('n_layers')
n_kv_heads = metadata.get('n_kv_heads')
embedding_dim = metadata.get('embedding_dim')
context_length = metadata.get('context_length')
feed_forward_dim = metadata.get('feed_forward_dim')
size_in_mb = metadata.get('model_size_mb', 0)
# Check if we have all required fields
required_fields = [n_layers, n_kv_heads, embedding_dim, context_length, feed_forward_dim]
if any(field is None for field in required_fields):
missing = [name for name, field in zip(
['n_layers', 'n_kv_heads', 'embedding_dim', 'context_length', 'feed_forward_dim'],
required_fields) if field is None]
raise ValueError(f"Missing required metadata fields: {missing}")
# Ensure gpu_layers doesn't exceed total layers
if gpu_layers > n_layers:
gpu_layers = n_layers
# Convert cache_type to numeric
if cache_type == 'q4_0':
cache_type = 4
elif cache_type == 'q8_0':
cache_type = 8
else:
cache_type = 16
# Derived features
size_per_layer = size_in_mb / max(n_layers, 1e-6)
kv_cache_factor = n_kv_heads * cache_type * ctx_size
embedding_per_context = embedding_dim / ctx_size
# Calculate VRAM using the model
# Details: https://oobabooga.github.io/blog/posts/gguf-vram-formula/
vram = (
(size_per_layer - 17.99552795246051 + 3.148552680382576e-05 * kv_cache_factor)
* (gpu_layers + max(0.9690636483914102, cache_type - (floor(50.77817218646521 * embedding_per_context) + 9.987899908205632)))
+ 1516.522943869404
)
return vram
except Exception as e:
print(f"Error in VRAM calculation: {e}")
raise
def estimate_vram_wrapper(model_metadata, gpu_layers, ctx_size, cache_type):
"""Wrapper function to estimate VRAM usage"""
if not model_metadata or 'model_name' not in model_metadata:
return "<div id=\"vram-info\">Estimated VRAM to load the model:</div>"
# Use cache_type directly (it's already a string from the radio button)
try:
result = estimate_vram(model_metadata, gpu_layers, ctx_size, cache_type)
conservative = result + 577
return f"""<div id="vram-info">
<div>Expected VRAM usage: <span class="value">{result:.0f} MiB</span></div>
<div>Safe estimate: <span class="value">{conservative:.0f} MiB</span> - 95% chance the VRAM is at most this.</div>
</div>"""
except Exception as e:
return f"<div id=\"vram-info\">Estimated VRAM to load the model: <span class=\"value\">Error: {str(e)}</span></div>"
def create_ui():
"""Create the simplified UI"""
# Custom CSS to limit max width and center the content
css = """
body {
max-width: 810px !important;
margin: 0 auto !important;
}
#vram-info {
padding: 10px;
border-radius: 4px;
background-color: var(--background-fill-secondary);
}
#vram-info .value {
font-weight: bold;
color: var(--primary-500);
}
"""
with gr.Blocks(css=css) as demo:
# State to hold model metadata
model_metadata = gr.State(value={})
gr.Markdown("# Accurate GGUF VRAM Calculator\n\nCalculate VRAM for GGUF models from GPU layers and context length using an accurate formula.\n\nFor an explanation about how this works, consult this blog post: https://oobabooga.github.io/blog/posts/gguf-vram-formula/")
with gr.Row():
with gr.Column():
# Model URL input
model_url = gr.Textbox(
label="GGUF Model URL",
placeholder="https://huggingface.co/unsloth/Qwen3-235B-A22B-GGUF/blob/main/UD-Q2_K_XL/Qwen3-235B-A22B-UD-Q2_K_XL-00001-of-00002.gguf",
value=""
)
# Load metadata button
load_metadata_btn = gr.Button("Load metadata", elem_classes='refresh-button')
# GPU layers slider
gpu_layers = gr.Slider(
label="GPU Layers",
minimum=0,
maximum=256,
value=256,
info='`--gpu-layers` in llama.cpp.'
)
# Context size slider
ctx_size = gr.Slider(
label='Context Length',
minimum=512,
maximum=131072,
step=256,
value=8192,
info='`--ctx-size` in llama.cpp.'
)
# Cache type checkbox group
cache_type = gr.Radio(
choices=['fp16', 'q8_0', 'q4_0'],
value='fp16',
label="Cache Type",
info='Cache quantization.'
)
# VRAM info display
vram_info = gr.HTML(
value="<div id=\"vram-info\">Estimated VRAM to load the model:</div>"
)
# Status display
status = gr.Textbox(
label="Status",
value="No model loaded",
interactive=False
)
# Event handlers
load_metadata_btn.click(
load_metadata,
inputs=[model_url, model_metadata],
outputs=[model_metadata, gpu_layers, status],
show_progress=True
).then(
estimate_vram_wrapper,
inputs=[model_metadata, gpu_layers, ctx_size, cache_type],
outputs=[vram_info],
show_progress=False
)
# Update VRAM estimate when any parameter changes
for component in [gpu_layers, ctx_size, cache_type]:
component.change(
estimate_vram_wrapper,
inputs=[model_metadata, gpu_layers, ctx_size, cache_type],
outputs=[vram_info],
show_progress=False
)
# Also update when model_metadata state changes
model_metadata.change(
estimate_vram_wrapper,
inputs=[model_metadata, gpu_layers, ctx_size, cache_type],
outputs=[vram_info],
show_progress=False
)
return demo
if __name__ == "__main__":
# Create and launch the app
demo = create_ui()
demo.launch()