|
import io |
|
import re |
|
import struct |
|
from enum import IntEnum |
|
from math import floor |
|
|
|
import requests |
|
|
|
import gradio as gr |
|
|
|
|
|
class GGUFValueType(IntEnum): |
|
UINT8 = 0 |
|
INT8 = 1 |
|
UINT16 = 2 |
|
INT16 = 3 |
|
UINT32 = 4 |
|
INT32 = 5 |
|
FLOAT32 = 6 |
|
BOOL = 7 |
|
STRING = 8 |
|
ARRAY = 9 |
|
UINT64 = 10 |
|
INT64 = 11 |
|
FLOAT64 = 12 |
|
|
|
|
|
_simple_value_packing = { |
|
GGUFValueType.UINT8: "<B", |
|
GGUFValueType.INT8: "<b", |
|
GGUFValueType.UINT16: "<H", |
|
GGUFValueType.INT16: "<h", |
|
GGUFValueType.UINT32: "<I", |
|
GGUFValueType.INT32: "<i", |
|
GGUFValueType.FLOAT32: "<f", |
|
GGUFValueType.UINT64: "<Q", |
|
GGUFValueType.INT64: "<q", |
|
GGUFValueType.FLOAT64: "<d", |
|
GGUFValueType.BOOL: "?", |
|
} |
|
|
|
value_type_info = { |
|
GGUFValueType.UINT8: 1, |
|
GGUFValueType.INT8: 1, |
|
GGUFValueType.UINT16: 2, |
|
GGUFValueType.INT16: 2, |
|
GGUFValueType.UINT32: 4, |
|
GGUFValueType.INT32: 4, |
|
GGUFValueType.FLOAT32: 4, |
|
GGUFValueType.UINT64: 8, |
|
GGUFValueType.INT64: 8, |
|
GGUFValueType.FLOAT64: 8, |
|
GGUFValueType.BOOL: 1, |
|
} |
|
|
|
|
|
def get_single(value_type, file): |
|
if value_type == GGUFValueType.STRING: |
|
value_length = struct.unpack("<Q", file.read(8))[0] |
|
value = file.read(value_length) |
|
try: |
|
value = value.decode('utf-8') |
|
except: |
|
pass |
|
else: |
|
type_str = _simple_value_packing.get(value_type) |
|
bytes_length = value_type_info.get(value_type) |
|
value = struct.unpack(type_str, file.read(bytes_length))[0] |
|
|
|
return value |
|
|
|
|
|
def load_metadata_from_file(file_obj): |
|
"""Load metadata from a file-like object""" |
|
metadata = {} |
|
|
|
GGUF_MAGIC = struct.unpack("<I", file_obj.read(4))[0] |
|
GGUF_VERSION = struct.unpack("<I", file_obj.read(4))[0] |
|
ti_data_count = struct.unpack("<Q", file_obj.read(8))[0] |
|
kv_data_count = struct.unpack("<Q", file_obj.read(8))[0] |
|
|
|
if GGUF_VERSION == 1: |
|
raise Exception('You are using an outdated GGUF, please download a new one.') |
|
|
|
for i in range(kv_data_count): |
|
key_length = struct.unpack("<Q", file_obj.read(8))[0] |
|
key = file_obj.read(key_length) |
|
|
|
value_type = GGUFValueType(struct.unpack("<I", file_obj.read(4))[0]) |
|
if value_type == GGUFValueType.ARRAY: |
|
ltype = GGUFValueType(struct.unpack("<I", file_obj.read(4))[0]) |
|
length = struct.unpack("<Q", file_obj.read(8))[0] |
|
|
|
arr = [get_single(ltype, file_obj) for _ in range(length)] |
|
metadata[key.decode()] = arr |
|
else: |
|
value = get_single(value_type, file_obj) |
|
metadata[key.decode()] = value |
|
|
|
|
|
extracted_fields = {} |
|
for key, value in metadata.items(): |
|
if key.endswith('.block_count'): |
|
extracted_fields['n_layers'] = value |
|
elif key.endswith('.attention.head_count_kv'): |
|
extracted_fields['n_kv_heads'] = max(value) if isinstance(value, list) else value |
|
elif key.endswith('.embedding_length'): |
|
extracted_fields['embedding_dim'] = value |
|
elif key.endswith('.context_length'): |
|
extracted_fields['context_length'] = value |
|
elif key.endswith('.feed_forward_length'): |
|
extracted_fields['feed_forward_dim'] = value |
|
|
|
|
|
metadata.update(extracted_fields) |
|
return metadata |
|
|
|
|
|
def download_gguf_partial(url, max_bytes=25 * 1024 * 1024): |
|
"""Download the first max_bytes from a GGUF URL""" |
|
try: |
|
|
|
headers = {'Range': f'bytes=0-{max_bytes-1}'} |
|
|
|
|
|
response = requests.get(url, headers=headers, stream=True) |
|
response.raise_for_status() |
|
|
|
|
|
content = response.content |
|
|
|
|
|
return io.BytesIO(content) |
|
|
|
except Exception as e: |
|
raise Exception(f"Failed to download GGUF file: {str(e)}") |
|
|
|
|
|
def load_metadata(model_url, current_metadata): |
|
"""Load metadata from model URL and return updated metadata dict""" |
|
if not model_url or model_url.strip() == "": |
|
return {}, "Please enter a model URL" |
|
|
|
try: |
|
|
|
model_size_mb = get_model_size_mb_from_url(model_url) |
|
|
|
|
|
normalized_url = normalize_huggingface_url(model_url) |
|
|
|
|
|
file_obj = download_gguf_partial(normalized_url) |
|
|
|
|
|
metadata = load_metadata_from_file(file_obj) |
|
|
|
|
|
gguf_filename = model_url.split('/')[-1].split('?')[0] |
|
|
|
|
|
model_name = model_url |
|
if "huggingface.co/" in model_url: |
|
try: |
|
|
|
parts = model_url.split("huggingface.co/")[1].split("/") |
|
if len(parts) >= 2: |
|
model_name = f"{parts[0]}/{parts[1]}" |
|
except: |
|
model_name = model_url |
|
|
|
|
|
metadata['url'] = model_url |
|
metadata['model_name'] = model_name |
|
metadata['model_size_mb'] = model_size_mb |
|
metadata['loaded'] = True |
|
|
|
return metadata, gr.update(value=metadata["n_layers"], maximum=metadata["n_layers"]), f"Metadata loaded successfully for: {gguf_filename}" |
|
|
|
except Exception as e: |
|
error_msg = f"Error loading metadata: {str(e)}" |
|
return {}, gr.update(), error_msg |
|
|
|
|
|
def normalize_huggingface_url(url: str) -> str: |
|
"""Normalize HuggingFace URL to resolve format for direct access""" |
|
if 'huggingface.co' not in url: |
|
return url |
|
|
|
|
|
base_url = url.split('?')[0] |
|
|
|
|
|
if '/blob/' in base_url: |
|
base_url = base_url.replace('/blob/', '/resolve/') |
|
|
|
return base_url |
|
|
|
|
|
def get_model_size_mb_from_url(model_url: str) -> float: |
|
"""Get model size in MB from URL without downloading, handling multi-part files""" |
|
try: |
|
|
|
normalized_url = normalize_huggingface_url(model_url) |
|
|
|
|
|
response = requests.head(normalized_url, allow_redirects=True) |
|
response.raise_for_status() |
|
main_file_size = int(response.headers.get('content-length', 0)) |
|
|
|
|
|
filename = normalized_url.split('/')[-1] |
|
|
|
|
|
match = re.match(r'(.+)-(\d+)-of-(\d+)\.gguf$', filename) |
|
|
|
if match: |
|
base_pattern = match.group(1) |
|
total_parts = int(match.group(3)) |
|
|
|
total_size = 0 |
|
base_url = '/'.join(normalized_url.split('/')[:-1]) + '/' |
|
|
|
|
|
for part_num in range(1, total_parts + 1): |
|
part_filename = f"{base_pattern}-{part_num:05d}-of-{total_parts:05d}.gguf" |
|
part_url = base_url + part_filename |
|
|
|
try: |
|
part_response = requests.head(part_url, allow_redirects=True) |
|
part_response.raise_for_status() |
|
part_size = int(part_response.headers.get('content-length', 0)) |
|
total_size += part_size |
|
except requests.RequestException as e: |
|
print(f"Warning: Could not get size of {part_filename}, estimating...") |
|
|
|
if total_size > 0: |
|
avg_size = total_size / (part_num - 1) |
|
remaining_parts = total_parts - (part_num - 1) |
|
total_size += avg_size * remaining_parts |
|
else: |
|
|
|
total_size = main_file_size * total_parts |
|
break |
|
|
|
return total_size / (1024 ** 2) |
|
else: |
|
|
|
return main_file_size / (1024 ** 2) |
|
|
|
except Exception as e: |
|
print(f"Error getting model size: {e}") |
|
return 0.0 |
|
|
|
|
|
def estimate_vram(metadata, gpu_layers, ctx_size, cache_type): |
|
"""Calculate VRAM usage using the actual formula""" |
|
try: |
|
|
|
n_layers = metadata.get('n_layers') |
|
n_kv_heads = metadata.get('n_kv_heads') |
|
embedding_dim = metadata.get('embedding_dim') |
|
context_length = metadata.get('context_length') |
|
feed_forward_dim = metadata.get('feed_forward_dim') |
|
size_in_mb = metadata.get('model_size_mb', 0) |
|
|
|
|
|
required_fields = [n_layers, n_kv_heads, embedding_dim, context_length, feed_forward_dim] |
|
if any(field is None for field in required_fields): |
|
missing = [name for name, field in zip( |
|
['n_layers', 'n_kv_heads', 'embedding_dim', 'context_length', 'feed_forward_dim'], |
|
required_fields) if field is None] |
|
raise ValueError(f"Missing required metadata fields: {missing}") |
|
|
|
|
|
if gpu_layers > n_layers: |
|
gpu_layers = n_layers |
|
|
|
|
|
if cache_type == 'q4_0': |
|
cache_type = 4 |
|
elif cache_type == 'q8_0': |
|
cache_type = 8 |
|
else: |
|
cache_type = 16 |
|
|
|
|
|
size_per_layer = size_in_mb / max(n_layers, 1e-6) |
|
kv_cache_factor = n_kv_heads * cache_type * ctx_size |
|
embedding_per_context = embedding_dim / ctx_size |
|
|
|
|
|
|
|
vram = ( |
|
(size_per_layer - 17.99552795246051 + 3.148552680382576e-05 * kv_cache_factor) |
|
* (gpu_layers + max(0.9690636483914102, cache_type - (floor(50.77817218646521 * embedding_per_context) + 9.987899908205632))) |
|
+ 1516.522943869404 |
|
) |
|
|
|
return vram |
|
|
|
except Exception as e: |
|
print(f"Error in VRAM calculation: {e}") |
|
raise |
|
|
|
|
|
def estimate_vram_wrapper(model_metadata, gpu_layers, ctx_size, cache_type): |
|
"""Wrapper function to estimate VRAM usage""" |
|
if not model_metadata or 'model_name' not in model_metadata: |
|
return "<div id=\"vram-info\">Estimated VRAM to load the model:</div>" |
|
|
|
|
|
try: |
|
result = estimate_vram(model_metadata, gpu_layers, ctx_size, cache_type) |
|
conservative = result + 577 |
|
return f"""<div id="vram-info"> |
|
<div>Expected VRAM usage: <span class="value">{result:.0f} MiB</span></div> |
|
<div>Safe estimate: <span class="value">{conservative:.0f} MiB</span> - 95% chance the VRAM is at most this.</div> |
|
</div>""" |
|
except Exception as e: |
|
return f"<div id=\"vram-info\">Estimated VRAM to load the model: <span class=\"value\">Error: {str(e)}</span></div>" |
|
|
|
|
|
def create_ui(): |
|
"""Create the simplified UI""" |
|
|
|
css = """ |
|
body { |
|
max-width: 810px !important; |
|
margin: 0 auto !important; |
|
} |
|
|
|
#vram-info { |
|
padding: 10px; |
|
border-radius: 4px; |
|
background-color: var(--background-fill-secondary); |
|
} |
|
|
|
#vram-info .value { |
|
font-weight: bold; |
|
color: var(--primary-500); |
|
} |
|
""" |
|
|
|
with gr.Blocks(css=css) as demo: |
|
|
|
model_metadata = gr.State(value={}) |
|
|
|
gr.Markdown("# Accurate GGUF VRAM Calculator\n\nCalculate VRAM for GGUF models from GPU layers and context length using an accurate formula.\n\nFor an explanation about how this works, consult this blog post: https://oobabooga.github.io/blog/posts/gguf-vram-formula/") |
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
model_url = gr.Textbox( |
|
label="GGUF Model URL", |
|
placeholder="https://huggingface.co/unsloth/Qwen3-235B-A22B-GGUF/blob/main/UD-Q2_K_XL/Qwen3-235B-A22B-UD-Q2_K_XL-00001-of-00002.gguf", |
|
value="" |
|
) |
|
|
|
|
|
load_metadata_btn = gr.Button("Load metadata", elem_classes='refresh-button') |
|
|
|
|
|
gpu_layers = gr.Slider( |
|
label="GPU Layers", |
|
minimum=0, |
|
maximum=256, |
|
value=256, |
|
info='`--gpu-layers` in llama.cpp.' |
|
) |
|
|
|
|
|
ctx_size = gr.Slider( |
|
label='Context Length', |
|
minimum=512, |
|
maximum=131072, |
|
step=256, |
|
value=8192, |
|
info='`--ctx-size` in llama.cpp.' |
|
) |
|
|
|
|
|
cache_type = gr.Radio( |
|
choices=['fp16', 'q8_0', 'q4_0'], |
|
value='fp16', |
|
label="Cache Type", |
|
info='Cache quantization.' |
|
) |
|
|
|
|
|
vram_info = gr.HTML( |
|
value="<div id=\"vram-info\">Estimated VRAM to load the model:</div>" |
|
) |
|
|
|
|
|
status = gr.Textbox( |
|
label="Status", |
|
value="No model loaded", |
|
interactive=False |
|
) |
|
|
|
|
|
load_metadata_btn.click( |
|
load_metadata, |
|
inputs=[model_url, model_metadata], |
|
outputs=[model_metadata, gpu_layers, status], |
|
show_progress=True |
|
).then( |
|
estimate_vram_wrapper, |
|
inputs=[model_metadata, gpu_layers, ctx_size, cache_type], |
|
outputs=[vram_info], |
|
show_progress=False |
|
) |
|
|
|
|
|
for component in [gpu_layers, ctx_size, cache_type]: |
|
component.change( |
|
estimate_vram_wrapper, |
|
inputs=[model_metadata, gpu_layers, ctx_size, cache_type], |
|
outputs=[vram_info], |
|
show_progress=False |
|
) |
|
|
|
|
|
model_metadata.change( |
|
estimate_vram_wrapper, |
|
inputs=[model_metadata, gpu_layers, ctx_size, cache_type], |
|
outputs=[vram_info], |
|
show_progress=False |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
demo = create_ui() |
|
demo.launch() |
|
|