Spaces:
Sleeping
Sleeping
import io | |
import re | |
import struct | |
from enum import IntEnum | |
from math import floor | |
import requests | |
import gradio as gr | |
class GGUFValueType(IntEnum): | |
UINT8 = 0 | |
INT8 = 1 | |
UINT16 = 2 | |
INT16 = 3 | |
UINT32 = 4 | |
INT32 = 5 | |
FLOAT32 = 6 | |
BOOL = 7 | |
STRING = 8 | |
ARRAY = 9 | |
UINT64 = 10 | |
INT64 = 11 | |
FLOAT64 = 12 | |
_simple_value_packing = { | |
GGUFValueType.UINT8: "<B", | |
GGUFValueType.INT8: "<b", | |
GGUFValueType.UINT16: "<H", | |
GGUFValueType.INT16: "<h", | |
GGUFValueType.UINT32: "<I", | |
GGUFValueType.INT32: "<i", | |
GGUFValueType.FLOAT32: "<f", | |
GGUFValueType.UINT64: "<Q", | |
GGUFValueType.INT64: "<q", | |
GGUFValueType.FLOAT64: "<d", | |
GGUFValueType.BOOL: "?", | |
} | |
value_type_info = { | |
GGUFValueType.UINT8: 1, | |
GGUFValueType.INT8: 1, | |
GGUFValueType.UINT16: 2, | |
GGUFValueType.INT16: 2, | |
GGUFValueType.UINT32: 4, | |
GGUFValueType.INT32: 4, | |
GGUFValueType.FLOAT32: 4, | |
GGUFValueType.UINT64: 8, | |
GGUFValueType.INT64: 8, | |
GGUFValueType.FLOAT64: 8, | |
GGUFValueType.BOOL: 1, | |
} | |
def get_single(value_type, file): | |
if value_type == GGUFValueType.STRING: | |
value_length = struct.unpack("<Q", file.read(8))[0] | |
value = file.read(value_length) | |
try: | |
value = value.decode('utf-8') | |
except: | |
pass | |
else: | |
type_str = _simple_value_packing.get(value_type) | |
bytes_length = value_type_info.get(value_type) | |
value = struct.unpack(type_str, file.read(bytes_length))[0] | |
return value | |
def load_metadata_from_file(file_obj): | |
"""Load metadata from a file-like object""" | |
metadata = {} | |
GGUF_MAGIC = struct.unpack("<I", file_obj.read(4))[0] | |
GGUF_VERSION = struct.unpack("<I", file_obj.read(4))[0] | |
ti_data_count = struct.unpack("<Q", file_obj.read(8))[0] | |
kv_data_count = struct.unpack("<Q", file_obj.read(8))[0] | |
if GGUF_VERSION == 1: | |
raise Exception('You are using an outdated GGUF, please download a new one.') | |
for i in range(kv_data_count): | |
key_length = struct.unpack("<Q", file_obj.read(8))[0] | |
key = file_obj.read(key_length) | |
value_type = GGUFValueType(struct.unpack("<I", file_obj.read(4))[0]) | |
if value_type == GGUFValueType.ARRAY: | |
ltype = GGUFValueType(struct.unpack("<I", file_obj.read(4))[0]) | |
length = struct.unpack("<Q", file_obj.read(8))[0] | |
arr = [get_single(ltype, file_obj) for _ in range(length)] | |
metadata[key.decode()] = arr | |
else: | |
value = get_single(value_type, file_obj) | |
metadata[key.decode()] = value | |
# Extract specific fields needed for VRAM calculation | |
extracted_fields = {} | |
for key, value in metadata.items(): | |
if key.endswith('.block_count'): | |
extracted_fields['n_layers'] = value | |
elif key.endswith('.attention.head_count_kv'): | |
extracted_fields['n_kv_heads'] = max(value) if isinstance(value, list) else value | |
elif key.endswith('.embedding_length'): | |
extracted_fields['embedding_dim'] = value | |
elif key.endswith('.context_length'): | |
extracted_fields['context_length'] = value | |
elif key.endswith('.feed_forward_length'): | |
extracted_fields['feed_forward_dim'] = value | |
# Add extracted fields to metadata for easy access | |
metadata.update(extracted_fields) | |
return metadata | |
def download_gguf_partial(url, max_bytes=25 * 1024 * 1024): | |
"""Download the first max_bytes from a GGUF URL""" | |
try: | |
# Set up headers for partial content request | |
headers = {'Range': f'bytes=0-{max_bytes-1}'} | |
# Make the request | |
response = requests.get(url, headers=headers, stream=True) | |
response.raise_for_status() | |
# Read the content | |
content = response.content | |
# Convert to BytesIO for file-like interface | |
return io.BytesIO(content) | |
except Exception as e: | |
raise Exception(f"Failed to download GGUF file: {str(e)}") | |
def load_metadata(model_url, current_metadata): | |
"""Load metadata from model URL and return updated metadata dict""" | |
if not model_url or model_url.strip() == "": | |
return {}, gr.update(), "Please enter a model URL" | |
try: | |
# Get model size first | |
model_size_mb = get_model_size_mb_from_url(model_url) | |
# Normalize URL for downloading | |
normalized_url = normalize_huggingface_url(model_url) | |
# Download the first 25MB of the file | |
file_obj = download_gguf_partial(normalized_url) | |
# Parse the metadata | |
metadata = load_metadata_from_file(file_obj) | |
# Extract filename from URL | |
gguf_filename = model_url.split('/')[-1].split('?')[0] # Remove query parameters if any | |
# Extract model name from URL if it's a Hugging Face URL | |
model_name = model_url | |
if "huggingface.co/" in model_url: | |
try: | |
# Extract model name from URL like https://huggingface.co/user/model | |
parts = model_url.split("huggingface.co/")[1].split("/") | |
if len(parts) >= 2: | |
model_name = f"{parts[0]}/{parts[1]}" | |
except: | |
model_name = model_url | |
# Add URL, model name, and size to metadata | |
metadata['url'] = model_url | |
metadata['model_name'] = model_name | |
metadata['model_size_mb'] = model_size_mb | |
metadata['loaded'] = True | |
return metadata, gr.update(value=metadata["n_layers"], maximum=metadata["n_layers"]), f"Metadata loaded successfully for: {gguf_filename}" | |
except Exception as e: | |
error_msg = f"Error loading metadata: {str(e)}" | |
return {}, gr.update(), error_msg | |
def normalize_huggingface_url(url: str) -> str: | |
"""Normalize HuggingFace URL to resolve format for direct access""" | |
if 'huggingface.co' not in url: | |
return url | |
# Remove query parameters first | |
base_url = url.split('?')[0] | |
# Convert blob URL to resolve URL | |
if '/blob/' in base_url: | |
base_url = base_url.replace('/blob/', '/resolve/') | |
return base_url | |
def get_model_size_mb_from_url(model_url: str) -> float: | |
"""Get model size in MB from URL without downloading, handling multi-part files""" | |
try: | |
# Normalize the URL for direct access | |
normalized_url = normalize_huggingface_url(model_url) | |
# Get size of the main file | |
response = requests.head(normalized_url, allow_redirects=True) | |
response.raise_for_status() | |
main_file_size = int(response.headers.get('content-length', 0)) | |
# Extract filename from original URL | |
filename = normalized_url.split('/')[-1] | |
# Check for multipart pattern (e.g., model-00001-of-00002.gguf) | |
match = re.match(r'(.+)-(\d+)-of-(\d+)\.gguf$', filename) | |
if match: | |
base_pattern = match.group(1) | |
total_parts = int(match.group(3)) | |
total_size = 0 | |
base_url = '/'.join(normalized_url.split('/')[:-1]) + '/' | |
# Get size of all parts | |
for part_num in range(1, total_parts + 1): | |
part_filename = f"{base_pattern}-{part_num:05d}-of-{total_parts:05d}.gguf" | |
part_url = base_url + part_filename | |
try: | |
part_response = requests.head(part_url, allow_redirects=True) | |
part_response.raise_for_status() | |
part_size = int(part_response.headers.get('content-length', 0)) | |
total_size += part_size | |
except requests.RequestException as e: | |
print(f"Warning: Could not get size of {part_filename}, estimating...") | |
# If we can't get some parts, estimate based on what we have | |
if total_size > 0: | |
avg_size = total_size / (part_num - 1) | |
remaining_parts = total_parts - (part_num - 1) | |
total_size += avg_size * remaining_parts | |
else: | |
# Fallback to main file size * total parts | |
total_size = main_file_size * total_parts | |
break | |
return total_size / (1024 ** 2) | |
else: | |
# Single part file | |
return main_file_size / (1024 ** 2) | |
except Exception as e: | |
print(f"Error getting model size: {e}") | |
return 0.0 | |
def estimate_vram(metadata, gpu_layers, ctx_size, cache_type): | |
"""Calculate VRAM usage using the actual formula""" | |
try: | |
# Extract required values from metadata | |
n_layers = metadata.get('n_layers') | |
n_kv_heads = metadata.get('n_kv_heads') | |
embedding_dim = metadata.get('embedding_dim') | |
context_length = metadata.get('context_length') | |
feed_forward_dim = metadata.get('feed_forward_dim') | |
size_in_mb = metadata.get('model_size_mb', 0) | |
# Check if we have all required fields | |
required_fields = [n_layers, n_kv_heads, embedding_dim, context_length, feed_forward_dim] | |
if any(field is None for field in required_fields): | |
missing = [name for name, field in zip( | |
['n_layers', 'n_kv_heads', 'embedding_dim', 'context_length', 'feed_forward_dim'], | |
required_fields) if field is None] | |
raise ValueError(f"Missing required metadata fields: {missing}") | |
# Ensure gpu_layers doesn't exceed total layers | |
if gpu_layers > n_layers: | |
gpu_layers = n_layers | |
# Convert cache_type to numeric | |
if cache_type == 'q4_0': | |
cache_type = 4 | |
elif cache_type == 'q8_0': | |
cache_type = 8 | |
else: | |
cache_type = 16 | |
# Derived features | |
size_per_layer = size_in_mb / max(n_layers, 1e-6) | |
kv_cache_factor = n_kv_heads * cache_type * ctx_size | |
embedding_per_context = embedding_dim / ctx_size | |
# Calculate VRAM using the model | |
# Details: https://oobabooga.github.io/blog/posts/gguf-vram-formula/ | |
vram = ( | |
(size_per_layer - 17.99552795246051 + 3.148552680382576e-05 * kv_cache_factor) | |
* (gpu_layers + max(0.9690636483914102, cache_type - (floor(50.77817218646521 * embedding_per_context) + 9.987899908205632))) | |
+ 1516.522943869404 | |
) | |
return vram | |
except Exception as e: | |
print(f"Error in VRAM calculation: {e}") | |
raise | |
def estimate_vram_wrapper(model_metadata, gpu_layers, ctx_size, cache_type): | |
"""Wrapper function to estimate VRAM usage""" | |
if not model_metadata or 'model_name' not in model_metadata: | |
return "<div id=\"vram-info\">Estimated VRAM to load the model:</div>" | |
# Use cache_type directly (it's already a string from the radio button) | |
try: | |
result = estimate_vram(model_metadata, gpu_layers, ctx_size, cache_type) | |
conservative = result + 577 | |
return f"""<div id="vram-info"> | |
<div>Expected VRAM usage: <span class="value">{result:.0f} MiB</span></div> | |
<div>Safe estimate: <span class="value">{conservative:.0f} MiB</span> - 95% chance the VRAM is at most this.</div> | |
</div>""" | |
except Exception as e: | |
return f"<div id=\"vram-info\">Estimated VRAM to load the model: <span class=\"value\">Error: {str(e)}</span></div>" | |
def create_ui(): | |
"""Create the simplified UI""" | |
# Custom CSS to limit max width and center the content | |
css = """ | |
body { | |
max-width: 810px !important; | |
margin: 0 auto !important; | |
} | |
#vram-info { | |
padding: 10px; | |
border-radius: 4px; | |
background-color: var(--background-fill-secondary); | |
} | |
#vram-info .value { | |
font-weight: bold; | |
color: var(--primary-500); | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
# State to hold model metadata | |
model_metadata = gr.State(value={}) | |
gr.Markdown("# Accurate GGUF VRAM Calculator\n\nCalculate VRAM for GGUF models from GPU layers and context length using an accurate formula.\n\nFor an explanation about how this works, consult this blog post: https://oobabooga.github.io/blog/posts/gguf-vram-formula/") | |
with gr.Row(): | |
with gr.Column(): | |
# Model URL input | |
model_url = gr.Textbox( | |
label="GGUF Model URL", | |
value="https://huggingface.co/unsloth/Qwen3-235B-A22B-GGUF/blob/main/UD-Q2_K_XL/Qwen3-235B-A22B-UD-Q2_K_XL-00001-of-00002.gguf" | |
) | |
# Load metadata button | |
load_metadata_btn = gr.Button("Load metadata", elem_classes='refresh-button') | |
# GPU layers slider | |
gpu_layers = gr.Slider( | |
label="GPU Layers", | |
minimum=0, | |
maximum=256, | |
value=256, | |
info='`--gpu-layers` in llama.cpp.' | |
) | |
# Context size slider | |
ctx_size = gr.Slider( | |
label='Context Length', | |
minimum=512, | |
maximum=131072, | |
step=256, | |
value=8192, | |
info='`--ctx-size` in llama.cpp.' | |
) | |
# Cache type checkbox group | |
cache_type = gr.Radio( | |
choices=['fp16', 'q8_0', 'q4_0'], | |
value='fp16', | |
label="Cache Type", | |
info='Cache quantization.' | |
) | |
# VRAM info display | |
vram_info = gr.HTML( | |
value="<div id=\"vram-info\">Estimated VRAM to load the model:</div>" | |
) | |
# Status display | |
status = gr.Textbox( | |
label="Status", | |
value="No model loaded", | |
interactive=False | |
) | |
# Event handlers | |
load_metadata_btn.click( | |
load_metadata, | |
inputs=[model_url, model_metadata], | |
outputs=[model_metadata, gpu_layers, status], | |
show_progress=True | |
).then( | |
estimate_vram_wrapper, | |
inputs=[model_metadata, gpu_layers, ctx_size, cache_type], | |
outputs=[vram_info], | |
show_progress=False | |
) | |
# Update VRAM estimate when any parameter changes | |
for component in [gpu_layers, ctx_size, cache_type]: | |
component.change( | |
estimate_vram_wrapper, | |
inputs=[model_metadata, gpu_layers, ctx_size, cache_type], | |
outputs=[vram_info], | |
show_progress=False | |
) | |
# Also update when model_metadata state changes | |
model_metadata.change( | |
estimate_vram_wrapper, | |
inputs=[model_metadata, gpu_layers, ctx_size, cache_type], | |
outputs=[vram_info], | |
show_progress=False | |
) | |
return demo | |
if __name__ == "__main__": | |
# Create and launch the app | |
demo = create_ui() | |
demo.launch() | |