Spaces:
Running
Running
| import React, { useState, useEffect, useRef } from 'react'; | |
| import * as d3 from 'd3'; | |
| import { useTheme } from '../context/themeContext'; | |
| import MODELS from '../utils/models'; | |
| import DEVICES from '../utils/devices'; | |
| type Precision = 'fp32' | 'fp16' | 'int8' | 'int4'; | |
| interface ModelSizeBarChartProps { | |
| modelSize: number; // in GB | |
| largestModelSize: number; // largest model in full precision (fp32) | |
| modelPrecision: Precision; | |
| deviceMemorySet: boolean; | |
| activationMemorySize?: number; // optional for standard calculator | |
| } | |
| interface InferenceRuntimeLineChartProps { | |
| availableMemory: AvailableMemory; // in GB | |
| memoryPerInput: number; // in GB | |
| } | |
| interface LineChartData { | |
| seqLength: number; | |
| batchSize: number; | |
| } | |
| interface AvailableMemory { | |
| int4: number; | |
| int8: number; | |
| fp16: number; | |
| fp32: number; | |
| } | |
| // Utility to determine color based on precision | |
| function chooseColor(precision: Precision) { | |
| const colors = { | |
| fp32: '#e45f5b', | |
| fp16: '#ffc068', | |
| int8: '#71cce9', | |
| int4: '#383d95', | |
| }; | |
| return colors[precision] || 'gray'; | |
| } | |
| // Utility function to calculate total memory with precision factor for prefill chunking | |
| function calculateTotalMemory( | |
| modelParams: number, | |
| hiddenSize: number, | |
| numLayers: number, | |
| intermediateSize: number, | |
| precision: Precision | |
| ) { | |
| const precisionFactor = { | |
| fp32: 4, | |
| fp16: 2, | |
| int8: 1, | |
| int4: 0.5, | |
| }; | |
| const memoryPerInput = (4 * hiddenSize * numLayers) / 1_000_000_000; // GB | |
| const modelMemorySize = modelParams * precisionFactor[precision]; // Adjusted by precision | |
| const activationMemorySize = Math.max(2 * intermediateSize, 4 * hiddenSize) / 1_000_000_000; // GB | |
| return memoryPerInput + modelMemorySize + activationMemorySize; | |
| } | |
| // Bar chart for model footprint (shared by both standard and prefill chunking calculators) | |
| function ModelSizeBarChart({ | |
| modelSize, | |
| largestModelSize, | |
| modelPrecision, | |
| deviceMemorySet, | |
| activationMemorySize = 0, // default to 0 for standard calculator | |
| }: ModelSizeBarChartProps) { | |
| const { theme } = useTheme(); | |
| const chartRef = useRef<SVGSVGElement>(null); | |
| const width = 600; | |
| const height = 50; | |
| useEffect(() => { | |
| if (modelSize > 0 && largestModelSize > 0) { | |
| d3.select(chartRef.current).selectAll('*').remove(); | |
| const svg = d3.select(chartRef.current).attr('width', width).attr('height', height); | |
| const xScale = d3.scaleLinear().domain([0, largestModelSize]).range([0, width]); | |
| if (modelSize + activationMemorySize > largestModelSize) { | |
| svg | |
| .append('rect') | |
| .attr('x', 0) | |
| .attr('y', 0) | |
| .attr('width', width) | |
| .attr('height', height) | |
| .attr('fill', 'transparent') | |
| .style('stroke', theme === 'dark' ? '#f9fafb' : '#181f26') | |
| .style('stroke-dasharray', '4, 4') | |
| .style('stroke-width', '2px'); | |
| svg | |
| .append('text') | |
| .attr('x', width / 2) | |
| .attr('y', height / 2) | |
| .attr('text-anchor', 'middle') | |
| .attr('alignment-baseline', 'middle') | |
| .attr('fill', theme === 'dark' ? '#f9fafb' : '#181f26') | |
| .text('Out of Memory'); | |
| } else { | |
| svg | |
| .append('rect') | |
| .attr('x', 0) | |
| .attr('y', 0) | |
| .attr('width', xScale(modelSize)) | |
| .attr('height', height) | |
| .attr('fill', chooseColor(modelPrecision)); | |
| if (activationMemorySize > 0) { | |
| svg | |
| .append('rect') | |
| .attr('x', xScale(modelSize)) | |
| .attr('y', 0) | |
| .attr('width', xScale(activationMemorySize)) | |
| .attr('height', height) | |
| .attr('fill', '#a4b8e0'); | |
| } | |
| if (deviceMemorySet) { | |
| svg | |
| .append('rect') | |
| .attr('x', xScale(modelSize + activationMemorySize)) | |
| .attr('y', 0) | |
| .attr('width', xScale(largestModelSize - (modelSize + activationMemorySize))) | |
| .attr('height', height) | |
| .attr('fill', 'transparent') | |
| .style('stroke', chooseColor(modelPrecision)) | |
| .style('stroke-width', '2px'); | |
| } | |
| } | |
| } | |
| }, [modelSize, largestModelSize, modelPrecision, deviceMemorySet, activationMemorySize, theme]); | |
| return <svg ref={chartRef}></svg>; | |
| } | |
| // Line chart for inference runtime (shared by both standard and prefill chunking calculators) | |
| function InferenceRuntimeLineChart({ | |
| availableMemory, | |
| memoryPerInput, | |
| }: InferenceRuntimeLineChartProps) { | |
| const { theme } = useTheme(); | |
| const chartRef = useRef(null); | |
| const maxSeqLength = 4096; | |
| const maxBatchSize = 128; | |
| useEffect(() => { | |
| if (memoryPerInput > 0 && Object.values(availableMemory).some((val) => val > 0)) { | |
| const margin = { top: 20, right: 20, bottom: 50, left: 50 }; | |
| const width = 600 - margin.left - margin.right; | |
| const height = 400 - margin.top - margin.bottom; | |
| const precisions = [ | |
| { name: 'FP32', color: '#e45f5b' }, | |
| { name: 'FP16', color: '#ffc068' }, | |
| { name: 'INT8', color: '#71cce9' }, | |
| { name: 'INT4', color: '#383d95' }, | |
| ]; | |
| const svg = d3.select(chartRef.current); | |
| svg.selectAll('*').remove(); | |
| const xScale = d3.scaleLinear().domain([0, maxSeqLength]).range([0, width]); | |
| const yScale = d3.scaleLinear().domain([0, maxBatchSize]).range([height, 0]); | |
| const xAxis = d3.axisBottom(xScale); | |
| const yAxis = d3.axisLeft(yScale); | |
| const zoom = d3 | |
| .zoom() | |
| .scaleExtent([0.5, 10]) | |
| .translateExtent([ | |
| [-width, -height], | |
| [2 * width, 2 * height], | |
| ]) | |
| .on('zoom', (event) => { | |
| const transform = event.transform; | |
| svg.select('.x-axis').call(xAxis.scale(transform.rescaleX(xScale))); | |
| svg.select('.y-axis').call(yAxis.scale(transform.rescaleY(yScale))); | |
| svg.selectAll('path').attr('transform', transform); | |
| }); | |
| svg | |
| .attr('width', width + margin.left + margin.right) | |
| .attr('height', height + margin.top + margin.bottom) | |
| .append('g') | |
| .attr('transform', `translate(${margin.left}, ${margin.top})`) | |
| .call(zoom); | |
| svg | |
| .append('g') | |
| .attr('class', 'x-axis') | |
| .attr('transform', `translate(${margin.left}, ${height + margin.top})`) | |
| .call(xAxis); | |
| svg | |
| .append('g') | |
| .attr('class', 'y-axis') | |
| .attr('transform', `translate(${margin.left}, ${margin.top})`) | |
| .call(yAxis); | |
| svg | |
| .append('text') | |
| .attr('transform', `translate(${width / 2 + margin.left}, ${height + margin.top + 40})`) | |
| .style('text-anchor', 'middle') | |
| .attr('fill', theme === 'dark' ? '#f9fafb' : '#181f26') | |
| .text('Sequence Length'); | |
| svg | |
| .append('text') | |
| .attr('transform', `rotate(-90)`) | |
| .attr('y', 0) | |
| .attr('x', 0 - height / 2 - margin.top) | |
| .attr('dy', '1em') | |
| .style('text-anchor', 'middle') | |
| .attr('fill', theme === 'dark' ? '#f9fafb' : '#181f26') | |
| .text('Batch Size'); | |
| // Adding legend for precisions | |
| const legend = svg | |
| .append('g') | |
| .attr('class', 'legend') | |
| .attr('transform', `translate(${width - 20}, 20)`); | |
| precisions.forEach((precision, index) => { | |
| const legendItem = legend.append('g').attr('transform', `translate(0, ${index * 30})`); | |
| legendItem | |
| .append('rect') | |
| .attr('x', 10) | |
| .attr('y', 10) | |
| .attr('width', 10) | |
| .attr('height', 10) | |
| .style('fill', precision.color); | |
| legendItem | |
| .append('text') | |
| .attr('x', 30) | |
| .attr('y', 16) | |
| .text(precision.name) | |
| .style('font-size', '16px') | |
| .style('fill', theme === 'dark' ? '#f9fafb' : '#181f26') | |
| .attr('alignment-baseline', 'middle'); | |
| }); | |
| legend | |
| .append('rect') | |
| .attr('class', 'legend-box') | |
| .attr('width', 80) | |
| .attr('height', precisions.length * 30) | |
| .style('fill', 'none') | |
| .style('stroke-width', '1px') | |
| .style('stroke', theme === 'dark' ? '#f9fafb' : '#181f26'); | |
| const tooltip = d3.select('#tooltip'); | |
| for (const [precision, memory] of Object.entries(availableMemory)) { | |
| const sequenceLengths = d3 | |
| .range(1, maxSeqLength, 1) | |
| .map((seqLength) => ({ | |
| seqLength, | |
| batchSize: memory / (seqLength * memoryPerInput), | |
| })) | |
| .filter((d) => d.batchSize <= maxBatchSize && d.batchSize > 1 && d.seqLength > 1); | |
| const lineGroup = svg | |
| .append('g') | |
| .attr('transform', `translate(${margin.left}, ${margin.top})`); | |
| const line = d3 | |
| .line<LineChartData>() | |
| .x((d) => xScale(d.seqLength)) | |
| .y((d) => yScale(d.batchSize)) | |
| .curve(d3.curveBasis); | |
| lineGroup | |
| .append('path') | |
| .datum(sequenceLengths) | |
| .attr('fill', 'none') | |
| .attr('stroke', chooseColor(precision as Precision)) | |
| .attr('stroke-width', 4) | |
| .attr('d', line) | |
| .on('mouseover', () => { | |
| tooltip.style('opacity', 1); | |
| tooltip.style('background-color', theme === 'dark' ? '#181f26' : '#f9fafb'); | |
| }) | |
| .on('mousemove', (event) => { | |
| tooltip.selectAll('text').remove(); | |
| const [x, y] = d3.pointer(event); | |
| const xValue = xScale.invert(x); | |
| const yValue = yScale.invert(y); | |
| tooltip | |
| .html(`Sequence Length: ${xValue.toFixed(0)}<br/>Batch Size: ${yValue.toFixed(0)}`) | |
| .attr('fill', theme === 'dark' ? '#f9fafb' : '#181f26') | |
| .style('left', event.pageX + 10 + 'px') | |
| .style('top', event.pageY + 10 + 'px'); | |
| }) | |
| .on('mouseout', () => { | |
| tooltip.style('opacity', 0); | |
| }); | |
| } | |
| } | |
| }, [availableMemory, memoryPerInput, theme]); | |
| return ( | |
| <> | |
| <div id='tooltip'></div> | |
| <svg ref={chartRef} width={600} height={400} /> | |
| </> | |
| ); | |
| } | |
| // Prefill Chunking Calculator with Updated Logic and Precision Adjustment | |
| function PrefillChunkingCalculator({ | |
| deviceMemory, | |
| modelParams, | |
| hiddenSize, | |
| numLayers, | |
| intermediateSize, | |
| }: { | |
| deviceMemory: number; | |
| modelParams: number; | |
| hiddenSize: number; | |
| numLayers: number; | |
| intermediateSize: number; | |
| }) { | |
| if (!deviceMemory || !modelParams || !hiddenSize || !numLayers || !intermediateSize) { | |
| return null; | |
| } | |
| return ( | |
| <> | |
| {/* Model Footprint with Prefill Chunking */} | |
| <div className='chart'> | |
| <div className='text-2xl text-center mb-4'>Model Footprint with Prefill Chunking</div> | |
| <div className='space-y-8'> | |
| {(['fp32', 'fp16', 'int8', 'int4'] as Precision[]).map((precision) => { | |
| const totalMemory = calculateTotalMemory( | |
| modelParams, | |
| hiddenSize, | |
| numLayers, | |
| intermediateSize, | |
| precision | |
| ); | |
| return ( | |
| <div key={precision} className='chart-row'> | |
| <div className='chart-row-title'>{precision.toUpperCase()}</div> | |
| <ModelSizeBarChart | |
| modelSize={totalMemory} | |
| largestModelSize={deviceMemory} | |
| modelPrecision={precision} | |
| deviceMemorySet={deviceMemory > 0} | |
| activationMemorySize={ | |
| Math.max(2 * intermediateSize, 4 * hiddenSize) / 1_000_000_000 | |
| } | |
| /> | |
| <div className='chart-row-size ml-8'> | |
| {totalMemory.toFixed(2)} / {deviceMemory} GB | |
| </div> | |
| </div> | |
| ); | |
| })} | |
| </div> | |
| </div> | |
| {/* Inference Runtime with Prefill Chunking */} | |
| <div className='chart'> | |
| <div className='text-2xl text-center mb-4'> | |
| Maximum Batch Size / Sequence Length with Prefill Chunking | |
| </div> | |
| <InferenceRuntimeLineChart | |
| availableMemory={{ | |
| int4: deviceMemory - calculateTotalMemory(modelParams, hiddenSize, numLayers, intermediateSize, 'int4'), | |
| int8: deviceMemory - calculateTotalMemory(modelParams, hiddenSize, numLayers, intermediateSize, 'int8'), | |
| fp16: deviceMemory - calculateTotalMemory(modelParams, hiddenSize, numLayers, intermediateSize, 'fp16'), | |
| fp32: deviceMemory - calculateTotalMemory(modelParams, hiddenSize, numLayers, intermediateSize, 'fp32'), | |
| }} | |
| memoryPerInput={(4 * hiddenSize * numLayers) / 1_000_000_000} | |
| /> | |
| </div> | |
| </> | |
| ); | |
| } | |
| // Standard Model Memory Calculator (unchanged) | |
| function StandardCalculator({ | |
| deviceMemory, | |
| modelParams, | |
| hiddenSize, | |
| numLayers, | |
| }: { | |
| deviceMemory: number; | |
| modelParams: number; | |
| hiddenSize: number; | |
| numLayers: number; | |
| }) { | |
| if (!deviceMemory || !modelParams || !hiddenSize || !numLayers) { | |
| return null; | |
| } | |
| function calculateMemory(params: number, precision: Precision) { | |
| const paramSize = { fp32: 4, fp16: 2, int8: 1, int4: 0.5 }; | |
| return params * paramSize[precision]; // in GB | |
| } | |
| function calculateMemoryPerInput(hiddenSize: number, numLayers: number) { | |
| const memoryPerInput = 4 * hiddenSize * numLayers; | |
| return memoryPerInput / 1_000_000_000; // in GB | |
| } | |
| function calculateMaxInputSize( | |
| deviceMemory: number, | |
| modelParams: number, | |
| hiddenSize: number, | |
| numLayers: number, | |
| precision: Precision, | |
| inputSize: number, | |
| ) { | |
| const memoryPerInput = calculateMemoryPerInput(hiddenSize, numLayers); | |
| const availableMemory = deviceMemory - calculateMemory(modelParams, precision); | |
| return Math.floor(availableMemory / (memoryPerInput * inputSize)); | |
| } | |
| function calculateMemoryValid( | |
| deviceMemory: number, | |
| modelParams: number, | |
| hiddenSize: number, | |
| numLayers: number, | |
| precision: Precision, | |
| batchSize: number, | |
| seqLength: number, | |
| ) { | |
| const memoryPerInput = calculateMemoryPerInput(hiddenSize, numLayers); | |
| const availableMemory = deviceMemory - calculateMemory(modelParams, precision); | |
| return availableMemory >= memoryPerInput * batchSize * seqLength; | |
| } | |
| return ( | |
| <> | |
| {/* Model Footprint */} | |
| <div className='chart mb-8'> | |
| <div className='text-2xl text-center mb-4'>Model Footprint</div> | |
| <div className='space-y-8'> | |
| {(['fp32', 'fp16', 'int8', 'int4'] as Precision[]).map((precision) => ( | |
| <div key={precision} className='chart-row'> | |
| <div className='chart-row-title'>{precision.toUpperCase()}</div> | |
| <ModelSizeBarChart | |
| modelSize={calculateMemory(modelParams, precision)} | |
| largestModelSize={deviceMemory} | |
| modelPrecision={precision} | |
| deviceMemorySet={deviceMemory > 0} | |
| /> | |
| <div className='chart-row-size ml-8'> | |
| {calculateMemory(modelParams, precision).toFixed(2)} / {deviceMemory} GB | |
| </div> | |
| </div> | |
| ))} | |
| </div> | |
| </div> | |
| {/* Maximum Batch Size / Sequence Length */} | |
| <div className='chart'> | |
| <div className='text-2xl text-center mb-4'> | |
| Maximum Batch Size / Sequence Length | |
| </div> | |
| <InferenceRuntimeLineChart | |
| availableMemory={{ | |
| int4: deviceMemory - calculateMemory(modelParams, 'int4'), | |
| int8: deviceMemory - calculateMemory(modelParams, 'int8'), | |
| fp16: deviceMemory - calculateMemory(modelParams, 'fp16'), | |
| fp32: deviceMemory - calculateMemory(modelParams, 'fp32'), | |
| }} | |
| memoryPerInput={calculateMemoryPerInput(hiddenSize, numLayers)} | |
| /> | |
| </div> | |
| </> | |
| ); | |
| } | |
| // Main Calculator Page | |
| const Calculator = () => { | |
| const [modelParams, setModelParams] = useState<number | null>(null); | |
| const [hiddenSize, setHiddenSize] = useState<number | null>(null); | |
| const [numLayers, setNumLayers] = useState<number | null>(null); | |
| const [intermediateSize, setIntermediateSize] = useState<number | null>(null); | |
| const [deviceMemory, setDeviceMemory] = useState<number | null>(null); | |
| const [isPrefillChunking, setIsPrefillChunking] = useState<boolean>(false); | |
| const [modelSelectionTab, setModelSelectionTab] = useState<boolean>(true); | |
| const [deviceSelectionTab, setDeviceSelectionTab] = useState<boolean>(true); | |
| return ( | |
| <div className='flex flex-col items-center justify-center min-h-screen px-4'> | |
| {/* Toggle Between Standard and Prefill Chunking */} | |
| <div className='mb-4 flex space-x-4'> | |
| <button | |
| className={`${!isPrefillChunking ? 'calculator-input-tab-active' : 'calculator-input-tab'}`} | |
| onClick={() => setIsPrefillChunking(false)} | |
| > | |
| Standard Calculator | |
| </button> | |
| <button | |
| className={`${isPrefillChunking ? 'calculator-input-tab-active' : 'calculator-input-tab'}`} | |
| onClick={() => setIsPrefillChunking(true)} | |
| > | |
| Calculator with Prefill Chunking | |
| </button> | |
| </div> | |
| {/* Model and Device Selection */} | |
| <div className='w-full max-w-4xl'> | |
| <div className='text-4xl mb-4 text-center'>Model Memory Calculator</div> | |
| <div className='mb-6 text-center'> | |
| Use our Model Memory Calculator to help you estimate the memory footprint of your model | |
| and the maximum batch size/sequence length combination you can run on your device. | |
| </div> | |
| <div className='grid grid-cols-1 sm:grid-cols-2 gap-4 mb-6'> | |
| {/* Model Selection */} | |
| <div className='calculator-input-box'> | |
| <div className='text-2xl calculator-input-title'>Model</div> | |
| <div className='calculator-input-content'> | |
| <div className='mb-2'> | |
| <button | |
| className={`${modelSelectionTab ? 'calculator-input-tab-active' : 'calculator-input-tab'}`} | |
| onClick={() => setModelSelectionTab(true)} | |
| > | |
| Model Selection | |
| </button> | |
| <button | |
| className={`${modelSelectionTab ? 'calculator-input-tab' : 'calculator-input-tab-active'}`} | |
| onClick={() => setModelSelectionTab(false)} | |
| > | |
| Custom Model | |
| </button> | |
| </div> | |
| <div> | |
| {modelSelectionTab ? ( | |
| <> | |
| <label htmlFor='model'>Select a Model</label> | |
| <select | |
| id='model' | |
| className='calculator-select' | |
| onChange={(e) => { | |
| const selectedModel = MODELS.find( | |
| (model) => model.params === Number(e.target.value), | |
| ); | |
| if (selectedModel) { | |
| setModelParams(selectedModel.params); | |
| setHiddenSize(selectedModel.hidden_size); | |
| setNumLayers(selectedModel.num_hidden_layers); | |
| setIntermediateSize(selectedModel.intermediate_size); | |
| } | |
| }} | |
| > | |
| <option value=''>None selected</option> | |
| {MODELS.map((model) => ( | |
| <option | |
| key={model.name} | |
| value={model.params} | |
| > | |
| {model.name} | |
| </option> | |
| ))} | |
| </select> | |
| </> | |
| ) : ( | |
| <> | |
| <label htmlFor='modelParams'>Model Parameters (in billions)</label> | |
| <input | |
| type='number' | |
| id='modelParams' | |
| className='calculator-input mb-2' | |
| placeholder='e.g. 7 (for LLaMA-7B)' | |
| value={modelParams || ''} | |
| min={0} | |
| onChange={(e) => setModelParams(Number(e.target.value))} | |
| /> | |
| <label htmlFor='hiddenSize'>Hidden Size</label> | |
| <input | |
| type='number' | |
| id='hiddenSize' | |
| className='calculator-input mb-2' | |
| placeholder='e.g. 4096 (for LLaMA-7B)' | |
| value={hiddenSize || ''} | |
| min={1} | |
| onChange={(e) => setHiddenSize(Number(e.target.value))} | |
| /> | |
| <label htmlFor='numLayers'>Number of Layers</label> | |
| <input | |
| type='number' | |
| id='numLayers' | |
| className='calculator-input' | |
| placeholder='e.g. 32 (for LLaMA-7B)' | |
| value={numLayers || ''} | |
| min={1} | |
| onChange={(e) => setNumLayers(Number(e.target.value))} | |
| /> | |
| {isPrefillChunking && ( | |
| <> | |
| <label htmlFor='intermediateSize'>Intermediate Size</label> | |
| <input | |
| type='number' | |
| id='intermediateSize' | |
| className='calculator-input' | |
| placeholder='e.g. 11008 (for LLaMA-7B)' | |
| value={intermediateSize || ''} | |
| min={1} | |
| onChange={(e) => setIntermediateSize(Number(e.target.value))} | |
| /> | |
| </> | |
| )} | |
| </> | |
| )} | |
| </div> | |
| </div> | |
| </div> | |
| {/* Device Selection */} | |
| <div className='calculator-input-box'> | |
| <div className='text-2xl calculator-input-title'>Device</div> | |
| <div className='calculator-input-content'> | |
| <div className='mb-2'> | |
| <button | |
| className={`${deviceSelectionTab ? 'calculator-input-tab-active' : 'calculator-input-tab'}`} | |
| onClick={() => { | |
| setDeviceSelectionTab(true); | |
| setDeviceMemory(null); | |
| }} | |
| > | |
| Device Selection | |
| </button> | |
| <button | |
| className={`${deviceSelectionTab ? 'calculator-input-tab' : 'calculator-input-tab-active'}`} | |
| onClick={() => { | |
| setDeviceSelectionTab(false); | |
| setDeviceMemory(null); | |
| }} | |
| > | |
| Custom Device | |
| </button> | |
| </div> | |
| <div> | |
| {deviceSelectionTab ? ( | |
| <> | |
| <label htmlFor='device'>Select a Device</label> | |
| <select | |
| id='device' | |
| className='calculator-select' | |
| onChange={(e) => setDeviceMemory(Number(e.target.value))} | |
| > | |
| <option value=''>None selected</option> | |
| {DEVICES.map((device) => ( | |
| <option key={device.name} value={device.size}> | |
| {device.name} | |
| </option> | |
| ))} | |
| </select> | |
| </> | |
| ) : ( | |
| <> | |
| <label htmlFor='deviceMemory'>Device RAM (in GB)</label> | |
| <input | |
| type='number' | |
| id='deviceMemory' | |
| className='calculator-input' | |
| placeholder='e.g. 24' | |
| value={deviceMemory || ''} | |
| min={0} | |
| onChange={(e) => setDeviceMemory(Number(e.target.value))} | |
| /> | |
| </> | |
| )} | |
| </div> | |
| </div> | |
| </div> | |
| </div> | |
| {/* Render Appropriate Calculator Based on Toggle */} | |
| {isPrefillChunking ? ( | |
| <PrefillChunkingCalculator | |
| deviceMemory={deviceMemory!} | |
| modelParams={modelParams!} | |
| hiddenSize={hiddenSize!} | |
| numLayers={numLayers!} | |
| intermediateSize={intermediateSize!} | |
| /> | |
| ) : ( | |
| <StandardCalculator | |
| deviceMemory={deviceMemory!} | |
| modelParams={modelParams!} | |
| hiddenSize={hiddenSize!} | |
| numLayers={numLayers!} | |
| /> | |
| )} | |
| </div> | |
| </div> | |
| ); | |
| }; | |
| export default Calculator; | |