;
+ `,_=E=>`
+ ${E.registerUniform("reduceSize","u32").declareVariables(p,u)}
+ ${g}
+ fn DIV_CEIL(a : u32, b : u32) -> u32 {
+ return ((a - 1u) / b + 1u);
+ }
+ ${E.mainStart(h)}
+
+ let outputIndex = global_idx / ${h};
+ let offset = outputIndex * uniforms.reduceSize;
+
+ var bestValue = f32(${Hf[s]});
+ let Length = uniforms.reduceSize;
+ for (var k = local_idx; k < Length; k = k + ${h}) {
+ let candidate = f32(${p.getByOffset("offset + k")});
+ bestValue = ${Gf[s]};
+ }
+ aBestValues[local_idx] = bestValue;
+ workgroupBarrier();
+
+ var reduceSize = min(Length, ${h}u);
+ for (var currentSize = reduceSize / 2u; reduceSize > 1u;
+ currentSize = reduceSize / 2u) {
+ let interval = DIV_CEIL(reduceSize, 2u);
+ if (local_idx < currentSize) {
+ let candidate = aBestValues[local_idx + interval];
+ bestValue = ${Kf[s]};
+ aBestValues[local_idx] = bestValue;
+ }
+ reduceSize = interval;
+ workgroupBarrier();
+ }
+
+ if (local_idx == 0u) {
+ ${u.setByOffset("outputIndex",`${s==="mean"?`${u.type.storage}(bestValue / f32(uniforms.reduceSize))`:`${u.type.storage}(${qf[s]})`}`)};
+ }
+ }`;return{name:e,shaderCache:{hint:`${r};${h}`,inputDependencies:["type"]},getShaderSource:_,getRunData:()=>({outputs:[{dims:n,dataType:o}],dispatchGroup:{x:l},programUniforms:[{type:12,data:c}]})}},ps=(e,r,t,s)=>{let o=e.inputs.length===1?t:qc(e.inputs,t),n=o.axes;n.length===0&&!o.noopWithEmptyAxes&&(n=e.inputs[0].dims.map((g,_)=>_));let i=xe.normalizeAxes(n,e.inputs[0].dims.length),a=i,l=e.inputs[0],c=Zf(a,e.inputs[0].dims.length);c.length>0&&(l=e.compute(Wr(e.inputs[0],c),{inputs:[0],outputs:[-1]})[0],a=Qf(a.length,l.dims.length));let[p,u]=Xf(l.dims,a),h=p;o.keepDims&&(h=Jf(p,i)),e.compute(e_(r,o.cacheKey,[l],s,e.inputs[0].dataType,h,u),{inputs:[l]})},zy=(e,r)=>{ps(e,"ReduceMeanShared",r,"mean")},By=(e,r)=>{ps(e,"ReduceL1Shared",r,"l1")},Ry=(e,r)=>{ps(e,"ReduceL2Shared",r,"l2")},Ny=(e,r)=>{ps(e,"ReduceLogSumExpShared",r,"logSumExp")},jy=(e,r)=>{ps(e,"ReduceMaxShared",r,"max")},Vy=(e,r)=>{ps(e,"ReduceMinShared",r,"min")},Uy=(e,r)=>{ps(e,"ReduceProdShared",r,"prod")},Wy=(e,r)=>{ps(e,"ReduceSumShared",r,"sum")},Gy=(e,r)=>{ps(e,"ReduceSumSquareShared",r,"sumSquare")},Ky=(e,r)=>{ps(e,"ReduceLogSumShared",r,"logSum")}}),hs,t_,id,qc,ms,r_,s_,n_,o_,i_,a_,l_,d_,c_,u_,fs,Hy,qy,Qy,Xy,Jy,Yy,Zy,eM,tM,rM,Mu=Ve(()=>{mt(),bt(),tr(),xt(),Qv(),hs=e=>{if(!e||e.length===0||e.length>2)throw new Error("Reduce op requires 1 or 2 inputs.");if(e.length===2&&e[1].dims.length!==1)throw new Error("Invalid axes input dims.")},t_=e=>["","",`var value = ${e.getByIndices("input_indices")};`,""],id=(e,r,t,s,o,n,i=!1,a=!1)=>{let l=[],c=t[0].dims,p=c.length,u=xe.normalizeAxes(o,p),h=!a&&u.length===0;c.forEach((E,I)=>{h||u.indexOf(I)>=0?i&&l.push(1):l.push(E)});let g=l.length,_=xe.size(l);return{name:e,shaderCache:r,getShaderSource:E=>{let I=[],M=$e("_A",t[0].dataType,p),y=tt("output",n,g),$=s(M,y,u),P=$[2];for(let b=0,w=0;b=0?(i&&w++,P=`for(var j${b}: u32 = 0; j${b} < ${c[b]}; j${b}++) {
+ ${$[2].includes("last_index")?`let last_index = j${b};`:""}
+ ${M.indicesSet("input_indices",b,`j${b}`)}
+ ${P}
+ }`):(I.push(`${M.indicesSet("input_indices",b,y.indicesGet("output_indices",w))};`),w++);return`
+
+ ${E.registerUniform("output_size","u32").declareVariables(M,y)}
+
+ ${E.mainStart()}
+ ${E.guardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")}
+ var input_indices: ${M.type.indices};
+ let output_indices = ${y.offsetToIndices("global_idx")};
+
+ ${I.join(`
+`)}
+ ${$[0]} // init ops for reduce max/min
+ ${$[1]}
+ ${P}
+ ${$[3]}
+ ${$.length===4?y.setByOffset("global_idx","value"):$.slice(4).join(`
+`)}
+ }`},getRunData:()=>({outputs:[{dims:l,dataType:n}],dispatchGroup:{x:Math.ceil(_/64)},programUniforms:[{type:12,data:_},...nt(c,l)]})}},qc=(e,r)=>{let t=[];return e[1].dims[0]>0&&e[1].getBigInt64Array().forEach(s=>t.push(Number(s))),Lt({axes:t,keepDims:r.keepDims,noopWithEmptyAxes:r.noopWithEmptyAxes})},ms=(e,r,t,s)=>{let o=e.inputs,n=o.length===1?t:qc(o,t);e.compute(id(r,{hint:n.cacheKey,inputDependencies:["rank"]},[o[0]],n.noopWithEmptyAxes&&n.axes.length===0?t_:s,n.axes,o[0].dataType,n.keepDims,n.noopWithEmptyAxes),{inputs:[0]})},r_=(e,r)=>{hs(e.inputs),ms(e,"ReduceLogSum",r,(t,s)=>[`var value = ${s.type.storage}(0);`,"",`value += ${t.getByIndices("input_indices")};`,"value = log(value);"])},s_=(e,r)=>{hs(e.inputs),ms(e,"ReduceL1",r,(t,s)=>[`var value = ${s.type.storage}(0);`,"",`value += abs(${t.getByIndices("input_indices")});`,""])},n_=(e,r)=>{hs(e.inputs),ms(e,"ReduceL2",r,(t,s)=>[`var t = ${s.type.value}(0); var value = ${s.type.value}(0);`,"",`t = ${t.getByIndices("input_indices")}; value += (t * t);`,"value = sqrt(value);"])},o_=(e,r)=>{hs(e.inputs),ms(e,"ReduceLogSumExp",r,(t,s)=>[`var value = ${s.type.storage}(0);`,"",`value += exp(${t.getByIndices("input_indices")});`,"value = log(value);"])},i_=(e,r)=>{hs(e.inputs),ms(e,"ReduceMax",r,(t,s,o)=>{let n=[];for(let i=0;i=0||o.length===0)&&n.push(t.indicesSet("input_indices",i,0));return[`${n.join(`
+`)}`,`var value = ${t.getByIndices("input_indices")};`,`value = max(value, ${t.getByIndices("input_indices")});`,""]})},a_=(e,r)=>{hs(e.inputs),ms(e,"ReduceMean",r,(t,s,o)=>{let n=1;for(let i=0;i=0||o.length===0)&&(n*=e.inputs[0].dims[i]);return["var sum = f32(0);","",`sum += f32(${t.getByIndices("input_indices")});`,`let value = ${s.type.value}(sum / ${n});`]})},l_=(e,r)=>{hs(e.inputs),ms(e,"ReduceMin",r,(t,s,o)=>{let n=[];for(let i=0;i=0||o.length===0)&&n.push(`input_indices[${i}] = 0;`);return[`${n.join(`
+`)}`,`var value = ${t.getByIndices("input_indices")};`,`value = min(value, ${t.getByIndices("input_indices")});`,""]})},d_=(e,r)=>{hs(e.inputs),ms(e,"ReduceProd",r,(t,s)=>[`var value = ${s.type.storage}(1);`,"",`value *= ${t.getByIndices("input_indices")};`,""])},c_=(e,r)=>{hs(e.inputs),ms(e,"ReduceSum",r,(t,s)=>[`var value = ${s.type.storage}(0);`,"",`value += ${t.getByIndices("input_indices")};`,""])},u_=(e,r)=>{hs(e.inputs),ms(e,"ReduceSumSquare",r,(t,s)=>[`var t = ${s.type.value}(0); var value = ${s.type.value}(0);`,"",`t = ${t.getByIndices("input_indices")}; value += t * t;`,""])},fs=(e,r,t)=>{if(r.length===0)return t;let s=1,o=1;for(let n=0;n1024},Hy=(e,r)=>{fs(e.inputs[0].dims,r.axes,r.noopWithEmptyAxes)?a_(e,r):zy(e,r)},qy=(e,r)=>{fs(e.inputs[0].dims,r.axes,r.noopWithEmptyAxes)?s_(e,r):By(e,r)},Qy=(e,r)=>{fs(e.inputs[0].dims,r.axes,r.noopWithEmptyAxes)?n_(e,r):Ry(e,r)},Xy=(e,r)=>{fs(e.inputs[0].dims,r.axes,r.noopWithEmptyAxes)?o_(e,r):Ny(e,r)},Jy=(e,r)=>{fs(e.inputs[0].dims,r.axes,r.noopWithEmptyAxes)?i_(e,r):jy(e,r)},Yy=(e,r)=>{fs(e.inputs[0].dims,r.axes,r.noopWithEmptyAxes)?l_(e,r):Vy(e,r)},Zy=(e,r)=>{fs(e.inputs[0].dims,r.axes,r.noopWithEmptyAxes)?d_(e,r):Uy(e,r)},eM=(e,r)=>{fs(e.inputs[0].dims,r.axes,r.noopWithEmptyAxes)?c_(e,r):Wy(e,r)},tM=(e,r)=>{fs(e.inputs[0].dims,r.axes,r.noopWithEmptyAxes)?u_(e,r):Gy(e,r)},rM=(e,r)=>{fs(e.inputs[0].dims,r.axes,r.noopWithEmptyAxes)?r_(e,r):Ky(e,r)}}),dc,sM,nM,Qc,Xv=Ve(()=>{mt(),tr(),Mu(),dc=e=>{if(!e||e.length===0||e.length>2)throw new Error("ArgMinMaxOp op requires 1 or 2 inputs.");if(e[0].dataType!==1)throw new Error("Invalid input type.")},sM=(e,r)=>{dc(e.inputs);let t=(s,o,n)=>{let i=[];for(let a=0;a=0||n.length===0)&&i.push(`input_indices[${a}] = 0;`);return[`${i.join(`
+`)}`,`var value = ${s.getByIndices("input_indices")};
+var best_index : i32 = 0;`,`if (${s.getByIndices("input_indices")} ${r.selectLastIndex>0?"<=":"<"} value) {
+ value = ${s.getByIndices("input_indices")};
+ best_index = i32(last_index);
+ }`,"",o.setByOffset("global_idx","best_index")]};e.compute(id("ArgMin",{hint:r.cacheKey,inputDependencies:["rank"]},[e.inputs[0]],t,[r.axis],7,r.keepDims),{inputs:[0]})},nM=(e,r)=>{dc(e.inputs);let t=(s,o,n)=>{let i=[];for(let a=0;a=0||n.length===0)&&i.push(`input_indices[${a}] = 0;`);return[`${i.join(`
+`)}`,`var value = ${s.getByIndices("input_indices")};
+var best_index : i32 = 0;`,`if (${s.getByIndices("input_indices")} ${r.selectLastIndex>0?">=":">"} value) {
+ value = ${s.getByIndices("input_indices")};
+ best_index = i32(last_index);
+ }`,"",o.setByOffset("global_idx","best_index")]};e.compute(id("argMax",{hint:r.cacheKey,inputDependencies:["rank"]},[e.inputs[0]],t,[r.axis],7,r.keepDims),{inputs:[0]})},Qc=e=>Lt(e)}),p_,Wl,h_,m_,f_,ha,__,oM,bu=Ve(()=>{mt(),bt(),wu(),xt(),p_=(e,r)=>{let t=e[0],s=e[1],o=e[2],n=e[3],i=e[4],a=e[5];if(i&&a)throw new Error("Attention cannot have both past and attention_bias");if(t.dims.length!==3)throw new Error('Input "input" must have 3 dimensions');let l=t.dims[0],c=t.dims[1],p=t.dims[2];if(o.dims.length!==1)throw new Error('Input "bias" is expected to have 1 dimensions');if(s.dims.length!==2)throw new Error('Input "weights" is expected to have 2 dimensions');if(s.dims[0]!==p)throw new Error("Input 1 dimension 0 should have same length as dimension 2 of input 0");if(o.dims[0]!==s.dims[1])throw new Error('Input "bias" dimension 0 should have same length as dimension 1 of input "weights"');let u=o.dims[0]/3,h=u,g=h;if(r.qkvHiddenSizes.length>0){if(r.qkvHiddenSizes.length!==3)throw new Error("qkv_hidden_sizes attribute should have 3 elements");for(let $ of r.qkvHiddenSizes)if($%r.numHeads!==0)throw new Error("qkv_hidden_sizes should be divisible by num_heads");u=r.qkvHiddenSizes[0],h=r.qkvHiddenSizes[1],g=r.qkvHiddenSizes[2]}let _=c;if(u!==h)throw new Error("qkv_hidden_sizes first element should be same as the second");if(o.dims[0]!==u+h+g)throw new Error('Input "bias" dimension 0 should have same length as sum of Q/K/V hidden sizes');let E=0;if(i){if(h!==g)throw new Error('Input "past" expect k_hidden_size == v_hidden_size');if(i.dims.length!==5)throw new Error('Input "past" must have 5 dimensions');if(i.dims[0]!==2)throw new Error('Input "past" first dimension must be 2');if(i.dims[1]!==l)throw new Error('Input "past" second dimension must be batch_size');if(i.dims[2]!==r.numHeads)throw new Error('Input "past" third dimension must be num_heads');if(i.dims[4]!==h/r.numHeads)throw new Error('Input "past" fifth dimension must be k_hidden_size / num_heads');r.pastPresentShareBuffer||(E=i.dims[3])}let I=_+E,M=-1,y=0;if(n)throw new Error("Mask not supported");if(i)throw new Error("past is not supported");if(a){if(a.dims.length!==4)throw new Error('Input "attention_bias" must have 4 dimensions');if(a.dims[0]!==l||a.dims[1]!==r.numHeads||a.dims[2]!==c||a.dims[3]!==I)throw new Error('Expect "attention_bias" shape (batch_size, num_heads, sequence_length, total_sequence_length)')}return{batchSize:l,sequenceLength:c,pastSequenceLength:E,kvSequenceLength:_,totalSequenceLength:I,maxSequenceLength:M,inputHiddenSize:p,hiddenSize:u,vHiddenSize:g,headSize:Math.floor(u/r.numHeads),vHeadSize:Math.floor(g/r.numHeads),numHeads:r.numHeads,isUnidirectional:!1,pastPresentShareBuffer:!1,maskFilterValue:r.maskFilterValue,maskType:y,scale:r.scale,broadcastResPosBias:!1,passPastInKv:!1,qkvFormat:1}},Wl=(e,r,t)=>r&&e?`
+ let total_sequence_length_input = u32(${r.getByOffset("0")});
+ let present_sequence_length = max(total_sequence_length_input, uniforms.past_sequence_length);
+ let is_subsequent_prompt: bool = sequence_length > 1 && sequence_length != total_sequence_length_input;
+ let is_first_prompt: bool = is_subsequent_prompt == false && sequence_length == total_sequence_length_input;
+ total_sequence_length = u32(${e==null?void 0:e.getByOffset("batchIdx")}) + 1;
+ var past_sequence_length: u32 = 0;
+ if (is_first_prompt == false) {
+ past_sequence_length = total_sequence_length - sequence_length;
+ }
+ `:`
+ ${t?"let past_sequence_length = uniforms.past_sequence_length":""};
+ let present_sequence_length = total_sequence_length;
+ `,h_=(e,r,t,s,o,n,i,a)=>{let l=Jt(i?1:n),c=64,p=n/l;p{let y=tt("x",e.dataType,e.dims,l),$=[y],P=i?$e("seq_lens",i.dataType,i.dims):void 0;P&&$.push(P);let b=a?$e("total_sequence_length_input",a.dataType,a.dims):void 0;b&&$.push(b);let w=Cr(e.dataType),T=[{name:"batch_size",type:"u32"},{name:"num_heads",type:"u32"},{name:"past_sequence_length",type:"u32"},{name:"sequence_length",type:"u32"},{name:"total_sequence_length",type:"u32"},{name:"elements_per_thread",type:"u32"}];return`
+ var thread_max: array;
+ var thread_sum: array;
+ ${M.registerUniforms(T).declareVariables(...$)}
+ ${M.mainStart([c,1,1])}
+ let batchIdx = workgroup_id.z / uniforms.num_heads;
+ let headIdx = workgroup_id.z % uniforms.num_heads;
+ let sequence_length = uniforms.sequence_length;
+ var total_sequence_length = uniforms.total_sequence_length;
+ ${Wl(P,b,!1)}
+ let local_offset = local_idx * uniforms.elements_per_thread;
+ let offset = (global_idx / ${c}) * uniforms.total_sequence_length + local_offset;
+ let seq_causal_length = ${i?"u32(past_sequence_length + workgroup_id.y + 1)":"total_sequence_length"};
+ var thread_max_vector = ${_}(-3.402823e+38f);
+ for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {
+ thread_max_vector = max(${_}(x[offset + i]), thread_max_vector);
+ }
+ thread_max[local_idx] = ${(()=>{switch(l){case 1:return"thread_max_vector";case 2:return"max(thread_max_vector.x, thread_max_vector.y)";case 4:return"max(max(thread_max_vector.x, thread_max_vector.y), max(thread_max_vector.z, thread_max_vector.w))";default:throw new Error(`Unsupported components: ${l}`)}})()};
+ workgroupBarrier();
+
+ var max_value = f32(-3.402823e+38f);
+ for (var i = 0u; i < ${c}; i++) {
+ max_value = max(thread_max[i], max_value);
+ }
+
+ var sum_vector = ${_}(0);
+ for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {
+ sum_vector += exp(${_}(x[offset + i]) - max_value);
+ }
+ thread_sum[local_idx] = ${(()=>{switch(l){case 1:return"sum_vector";case 2:return"sum_vector.x + sum_vector.y";case 4:return"sum_vector.x + sum_vector.y + sum_vector.z + sum_vector.w";default:throw new Error(`Unsupported components: ${l}`)}})()};
+ workgroupBarrier();
+
+ var sum: f32 = 0;
+ for (var i = 0u; i < ${c}; i++) {
+ sum += thread_sum[i];
+ }
+
+ if (sum == 0) {
+ for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {
+ x[offset + i] = ${y.type.value}(${w}(1.0) / ${w}(seq_causal_length));
+ }
+ } else {
+ for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {
+ var f32input = ${_}(x[offset + i]);
+ x[offset + i] = ${y.type.value}(exp(f32input - max_value) / sum);
+ }
+ }
+ ${i?`
+ for (var total_seq_id: u32 = seq_causal_length; total_seq_id + local_offset < uniforms.total_sequence_length; total_seq_id++) {
+ x[offset + total_seq_id] = ${y.type.value}(${w}(0));
+ }`:""};
+ }`};return{name:"AttentionProbsSoftmax",shaderCache:{hint:`${c};${g};${l}`,inputDependencies:E},getShaderSource:I,getRunData:()=>({outputs:[],dispatchGroup:{x:1,y:o,z:r*t},programUniforms:h})}},m_=(e,r,t,s,o,n,i,a,l)=>{let c=i+n.kvSequenceLength,p=[n.batchSize,n.numHeads,n.sequenceLength,c],u=e>1&&s,h=n.kvNumHeads?n.kvNumHeads:n.numHeads,g=u?[n.batchSize,h,c,n.headSize]:void 0,_=n.nReps?n.nReps:1,E=n.scale===0?1/Math.sqrt(n.headSize):n.scale,I=Jt(n.headSize),M=n.headSize/I,y=12,$={x:Math.ceil(c/y),y:Math.ceil(n.sequenceLength/y),z:n.batchSize*n.numHeads},P=[{type:12,data:n.sequenceLength},{type:12,data:M},{type:12,data:c},{type:12,data:n.numHeads},{type:12,data:n.headSize},{type:1,data:E},{type:12,data:i},{type:12,data:n.kvSequenceLength},{type:12,data:_}],b=u&&s&&xe.size(s.dims)>0,w=["type","type"];b&&w.push("type"),o&&w.push("type"),a&&w.push("type"),l&&w.push("type");let T=[{dims:p,dataType:r.dataType,gpuDataType:0}];u&&T.push({dims:g,dataType:r.dataType,gpuDataType:0});let k=z=>{let R=$e("q",r.dataType,r.dims,I),Q=$e("key",t.dataType,t.dims,I),q=[R,Q];if(b){let ce=$e("past_key",s.dataType,s.dims,I);q.push(ce)}o&&q.push($e("attention_bias",o.dataType,o.dims));let U=a?$e("seq_lens",a.dataType,a.dims):void 0;U&&q.push(U);let Z=l?$e("total_sequence_length_input",l.dataType,l.dims):void 0;Z&&q.push(Z);let H=tt("output",r.dataType,p),J=[H];u&&J.push(tt("present_key",r.dataType,g,I));let oe=Cr(1,I),ae=[{name:"M",type:"u32"},{name:"K",type:"u32"},{name:"N",type:"u32"},{name:"num_heads",type:"u32"},{name:"head_size",type:"u32"},{name:"alpha",type:"f32"},{name:"past_sequence_length",type:"u32"},{name:"kv_sequence_length",type:"u32"},{name:"n_reps",type:"u32"}];return`
+ const TILE_SIZE = ${y}u;
+
+ var tileQ: array<${R.type.storage}, ${y*y}>;
+ var tileK: array<${R.type.storage}, ${y*y}>;
+ ${z.registerUniforms(ae).declareVariables(...q,...J)}
+ ${z.mainStart([y,y,1])}
+ // x holds the N and y holds the M
+ let headIdx = workgroup_id.z % uniforms.num_heads;
+ let kvHeadIdx = ${_===1?"headIdx":"headIdx / uniforms.n_reps"};
+ let kv_num_heads = ${_===1?"uniforms.num_heads":"uniforms.num_heads / uniforms.n_reps"};
+ let batchIdx = workgroup_id.z / uniforms.num_heads;
+ let m = workgroup_id.y * TILE_SIZE;
+ let n = workgroup_id.x * TILE_SIZE;
+ let sequence_length = uniforms.M;
+ var total_sequence_length = uniforms.N;
+ ${Wl(U,Z,!0)}
+ let absKvHeadIdx = batchIdx * kv_num_heads + kvHeadIdx;
+ let qOffset = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;
+ ${b&&u?"let pastKeyOffset = absKvHeadIdx * uniforms.past_sequence_length * uniforms.K;":""};
+ let kOffset = absKvHeadIdx * uniforms.kv_sequence_length * uniforms.K;
+ ${u?"let presentKeyOffset = absKvHeadIdx * uniforms.N * uniforms.K;":""}
+ var value = ${oe}(0);
+ for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
+ if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) {
+ tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x];
+ }
+ if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {
+ var idx = TILE_SIZE * local_id.y + local_id.x;
+ ${b&&u?`
+ if (n + local_id.y < past_sequence_length) {
+ tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];
+ } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {
+ tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];
+ }`:`
+ if (n + local_id.y < uniforms.kv_sequence_length) {
+ tileK[idx] = key[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x];
+ }`}
+ ${u?`if (n + local_id.y < present_sequence_length) {
+ present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];
+ }`:""}
+ }
+ workgroupBarrier();
+
+ for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {
+ value += ${oe}(tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k]);
+ }
+
+ workgroupBarrier();
+ }
+
+ if (global_id.y < uniforms.M && global_id.x < total_sequence_length) {
+ let headOffset = workgroup_id.z * uniforms.M * uniforms.N;
+ let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x;
+ var sum: f32 = ${(()=>{switch(I){case 1:return"value";case 2:return"value.x + value.y";case 4:return"value.x + value.y + value.z + value.w";default:throw new Error(`Unsupported components: ${I}`)}})()};
+ output[outputIdx] = ${H.type.value} (sum * uniforms.alpha) + ${o?"attention_bias[outputIdx]":"0.0"};
+ }
+ }`};return{name:"AttentionProbs",shaderCache:{hint:`${I};${o!==void 0};${s!==void 0};${e}`,inputDependencies:w},getRunData:()=>({outputs:T,dispatchGroup:$,programUniforms:P}),getShaderSource:k}},f_=(e,r,t,s,o,n,i=void 0,a=void 0)=>{let l=n+o.kvSequenceLength,c=o.nReps?o.nReps:1,p=o.vHiddenSize*c,u=e>1&&s,h=o.kvNumHeads?o.kvNumHeads:o.numHeads,g=u?[o.batchSize,h,l,o.headSize]:void 0,_=[o.batchSize,o.sequenceLength,p],E=12,I={x:Math.ceil(o.vHeadSize/E),y:Math.ceil(o.sequenceLength/E),z:o.batchSize*o.numHeads},M=[{type:12,data:o.sequenceLength},{type:12,data:l},{type:12,data:o.vHeadSize},{type:12,data:o.numHeads},{type:12,data:o.headSize},{type:12,data:p},{type:12,data:n},{type:12,data:o.kvSequenceLength},{type:12,data:c}],y=u&&s&&xe.size(s.dims)>0,$=["type","type"];y&&$.push("type"),i&&$.push("type"),a&&$.push("type");let P=[{dims:_,dataType:r.dataType,gpuDataType:0}];u&&P.push({dims:g,dataType:r.dataType,gpuDataType:0});let b=w=>{let T=$e("probs",r.dataType,r.dims),k=$e("v",t.dataType,t.dims),z=[T,k];y&&z.push($e("past_value",s.dataType,s.dims));let R=i?$e("seq_lens",i.dataType,i.dims):void 0;i&&z.push(R);let Q=a?$e("total_sequence_length_input",a.dataType,a.dims):void 0;a&&z.push(Q);let q=[tt("output",r.dataType,_)];u&&q.push(tt("present_value",r.dataType,g));let U=[{name:"M",type:"u32"},{name:"K",type:"u32"},{name:"N",type:"u32"},{name:"num_heads",type:"u32"},{name:"head_size",type:"u32"},{name:"v_hidden_size",type:"u32"},{name:"past_sequence_length",type:"u32"},{name:"kv_sequence_length",type:"u32"},{name:"n_reps",type:"u32"}];return`
+ const TILE_SIZE = ${E}u;
+ var tileQ: array<${T.type.value}, ${E*E}>;
+ var tileV: array<${T.type.value}, ${E*E}>;
+ ${w.registerUniforms(U).declareVariables(...z,...q)}
+ ${w.mainStart([E,E,1])}
+ let headIdx = workgroup_id.z % uniforms.num_heads;
+ let batchIdx = workgroup_id.z / uniforms.num_heads;
+ let kvHeadIdx = ${c===1?"headIdx":"headIdx / uniforms.n_reps"};
+ let kv_num_heads = ${c===1?"uniforms.num_heads":"uniforms.num_heads / uniforms.n_reps"};
+ let m = global_id.y;
+ let n = global_id.x;
+ let sequence_length = uniforms.M;
+ var total_sequence_length = uniforms.K;
+ ${Wl(R,Q,!0)}
+ let offsetA = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;
+ let absKvHeadIdx = batchIdx * kv_num_heads + kvHeadIdx; // kvHeadIdx is relative to the batch
+ ${y&&u?"let pastValueOffset = absKvHeadIdx * uniforms.N * uniforms.past_sequence_length + n;":""};
+ let vOffset = absKvHeadIdx * uniforms.N * uniforms.kv_sequence_length + n;
+ ${u?"let presentValueOffset = absKvHeadIdx * uniforms.N * uniforms.K + n;":""}
+ var value = ${T.type.storage}(0);
+ for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
+ if (m < uniforms.M && w + local_id.x < uniforms.K) {
+ tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x];
+ }
+ if (n < uniforms.N && w + local_id.y < uniforms.K) {
+ var idx = TILE_SIZE * local_id.y + local_id.x;
+ ${y&&u?`
+ if (w + local_id.y < past_sequence_length) {
+ tileV[idx] = past_value[pastValueOffset + (w + local_id.y) * uniforms.N];
+ } else if (w + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {
+ tileV[idx] = v[vOffset + (w + local_id.y - past_sequence_length) * uniforms.N];
+ }
+ `:`
+ if (w + local_id.y < uniforms.kv_sequence_length) {
+ tileV[idx] = v[vOffset + (w + local_id.y) * uniforms.N];
+ }`}
+ ${u?`
+ if (w + local_id.y < present_sequence_length) {
+ present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileV[idx];
+ }`:""}
+ }
+ workgroupBarrier();
+ for (var k: u32 = 0u; k < TILE_SIZE && w+k < total_sequence_length; k++) {
+ value += tileQ[TILE_SIZE * local_id.y + k] * tileV[TILE_SIZE * k + local_id.x];
+ }
+ workgroupBarrier();
+ }
+
+ // we need to transpose output from BNSH_v to BSND_v
+ if (m < uniforms.M && n < uniforms.N) {
+ let outputIdx = batchIdx * uniforms.M * uniforms.v_hidden_size + m * uniforms.v_hidden_size
+ + headIdx * uniforms.N + n;
+ output[outputIdx] = value;
+ }
+ }`};return{name:"AttentionScore",shaderCache:{hint:`${s!==void 0};${e}`,inputDependencies:$},getRunData:()=>({outputs:P,dispatchGroup:I,programUniforms:M}),getShaderSource:b}},ha=(e,r,t,s,o,n,i,a,l,c,p=void 0,u=void 0)=>{let h=Math.min(e.outputCount,1+(i?1:0)+(a?1:0)),g=h>1?c.pastSequenceLength:0,_=g+c.kvSequenceLength,E=l&&xe.size(l.dims)>0?l:void 0,I=[r,t];h>1&&i&&xe.size(i.dims)>0&&I.push(i),E&&I.push(E),p&&I.push(p),u&&I.push(u);let M=e.compute(m_(h,r,t,i,E,c,g,p,u),{inputs:I,outputs:h>1?[-1,1]:[-1]})[0];e.compute(h_(M,c.batchSize,c.numHeads,g,c.sequenceLength,_,p,u),{inputs:p&&u?[M,p,u]:[M],outputs:[]});let y=[M,s];h>1&&a&&xe.size(a.dims)>0&&y.push(a),p&&y.push(p),u&&y.push(u),e.compute(f_(h,M,s,a,c,g,p,u),{inputs:y,outputs:h>1?[0,2]:[0]})},__=(e,r)=>{let t=[r.batchSize,r.numHeads,r.sequenceLength,r.headSize],s=r.sequenceLength,o=r.inputHiddenSize,n=r.headSize,i=12,a={x:Math.ceil(r.headSize/i),y:Math.ceil(r.sequenceLength/i),z:r.batchSize*r.numHeads},l=[e.inputs[0],e.inputs[1],e.inputs[2]],c=[{type:12,data:s},{type:12,data:o},{type:12,data:n},{type:12,data:r.numHeads},{type:12,data:r.headSize},{type:12,data:r.hiddenSize},{type:12,data:r.hiddenSize+r.hiddenSize+r.vHiddenSize}],p=u=>{let h=tt("output_q",l[0].dataType,t),g=tt("output_k",l[0].dataType,t),_=tt("output_v",l[0].dataType,t),E=$e("input",l[0].dataType,l[0].dims),I=$e("weight",l[1].dataType,l[1].dims),M=$e("bias",l[2].dataType,l[2].dims),y=E.type.storage,$=[{name:"M",type:"u32"},{name:"K",type:"u32"},{name:"N",type:"u32"},{name:"num_heads",type:"u32"},{name:"head_size",type:"u32"},{name:"hidden_size",type:"u32"},{name:"ldb",type:"u32"}];return`
+ const TILE_SIZE = ${i}u;
+ var tileInput: array<${y}, ${i*i}>;
+ var tileWeightQ: array<${y}, ${i*i}>;
+ var tileWeightK: array<${y}, ${i*i}>;
+ var tileWeightV: array<${y}, ${i*i}>;
+ ${u.registerUniforms($).declareVariables(E,I,M,h,g,_)}
+ ${u.mainStart([i,i,1])}
+ let batchIndex = workgroup_id.z / uniforms.num_heads;
+ let headNumber = workgroup_id.z % uniforms.num_heads;
+ let m = global_id.y;
+ let n = global_id.x;
+
+ let inputOffset = batchIndex * (uniforms.M * uniforms.K) + m * uniforms.K;
+ let biasOffsetQ = headNumber * uniforms.head_size;
+ let biasOffsetK = uniforms.hidden_size + biasOffsetQ;
+ let biasOffsetV = uniforms.hidden_size + biasOffsetK;
+
+ var valueQ = ${y}(0);
+ var valueK = ${y}(0);
+ var valueV = ${y}(0);
+ for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
+ if (m < uniforms.M && w + local_id.x < uniforms.K) {
+ tileInput[TILE_SIZE * local_id.y + local_id.x] = input[inputOffset + w + local_id.x];
+ }
+ if (n < uniforms.N && w + local_id.y < uniforms.K) {
+ let offset = n + (w + local_id.y) * uniforms.ldb;
+ tileWeightQ[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetQ + offset];
+ tileWeightK[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetK + offset];
+ tileWeightV[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetV + offset];
+ }
+ workgroupBarrier();
+ for (var k: u32 = 0u; k({outputs:[{dims:t,dataType:e.inputs[0].dataType,gpuDataType:0},{dims:t,dataType:e.inputs[0].dataType,gpuDataType:0},{dims:t,dataType:e.inputs[0].dataType,gpuDataType:0}],dispatchGroup:a,programUniforms:c}),getShaderSource:p},{inputs:l,outputs:[-1,-1,-1]})},oM=(e,r)=>{let t=p_(e.inputs,r),[s,o,n]=__(e,t);return ha(e,s,o,n,e.inputs[4],void 0,void 0,void 0,e.inputs[5],t)}}),g_,w_,y_,iM,Jv=Ve(()=>{Ms(),mt(),bt(),tr(),xt(),g_=(e,r)=>{if(!e||e.length!==5)throw new Error("BatchNormalization requires 5 inputs");let t=(s,o,n)=>{let i=o.length;if(i!==s.length)throw new Error(`${n}: num dimensions != ${i}`);o.forEach((a,l)=>{if(a!==s[l])throw new Error(`${n}: dim[${l}] do not match`)})};if(e[0].dims.length>1){let s=r.format==="NHWC"?r.spatial?e[0].dims.slice(-1):e[0].dims.slice(-1).concat(e[0].dims.slice(1,e[0].dims.length-1)):e[0].dims.slice(1,r.spatial?2:void 0);t(e[1].dims,s,"Invalid input scale"),t(e[2].dims,s,"Invalid input B"),t(e[3].dims,s,"Invalid input mean"),t(e[4].dims,s,"Invalid input var")}else t(e[1].dims,[1],"Invalid input scale"),t(e[2].dims,[1],"Invalid input B"),t(e[3].dims,[1],"Invalid input mean"),t(e[4].dims,[1],"Invalid input var")},w_=(e,r)=>{let{epsilon:t,spatial:s,format:o}=r,n=e[0].dims,i=s?Jt(n[n.length-1]):1,a=o==="NHWC"&&n.length>1?i:1,l=xe.size(n)/i,c=s,p=c?n.length:n,u=$e("x",e[0].dataType,e[0].dims,i),h=$e("scale",e[1].dataType,e[1].dims,a),g=$e("bias",e[2].dataType,e[2].dims,a),_=$e("inputMean",e[3].dataType,e[3].dims,a),E=$e("inputVar",e[4].dataType,e[4].dims,a),I=tt("y",e[0].dataType,p,i),M=()=>{let $="";if(s)$=`let cOffset = ${n.length===1?"0u":o==="NHWC"?`outputIndices[${n.length-1}] / ${i}`:"outputIndices[1]"};`;else if(o==="NCHW")$=`
+ ${I.indicesSet("outputIndices","0","0")}
+ let cOffset = ${I.indicesToOffset("outputIndices")};`;else{$=`var cIndices = ${h.type.indices}(0);
+ cIndices[0] = outputIndices[${n.length-1}];`;for(let P=1;P`
+ const epsilon = ${t};
+ ${$.registerUniform("outputSize","u32").declareVariables(u,h,g,_,E,I)}
+ ${$.mainStart()}
+ ${$.guardAgainstOutOfBoundsWorkgroupSizes("uniforms.outputSize")}
+ var outputIndices = ${I.offsetToIndices(`global_idx * ${i}`)};
+ ${M()}
+ let scale = ${h.getByOffset("cOffset")};
+ let bias = ${g.getByOffset("cOffset")};
+ let inputMean = ${_.getByOffset("cOffset")};
+ let inputVar = ${E.getByOffset("cOffset")};
+ let x = ${u.getByOffset("global_idx")};
+ let value = (x - inputMean) * inverseSqrt(inputVar + epsilon) * scale + bias;
+ ${I.setByOffset("global_idx","value")}
+ }`;return{name:"BatchNormalization",shaderCache:{hint:`${r.epsilon}_${r.format}_${s}_${i}`,inputDependencies:c?["rank","type","type","type","type"]:void 0},getShaderSource:y,getRunData:()=>({outputs:[{dims:e[0].dims,dataType:e[0].dataType}],dispatchGroup:{x:Math.ceil(l/64)},programUniforms:c?[{type:12,data:l},...nt(n)]:[{type:12,data:l}]})}},y_=e=>Lt(e),iM=(e,r)=>{let{inputs:t,outputCount:s}=e,o=y_({...r,outputCount:s});if(Kt.webgpu.validateInputContent&&g_(t,o),r.trainingMode)throw new Error("BatchNormalization trainingMode is not supported yet.");e.compute(w_(t,o))}}),M_,b_,aM,Yv=Ve(()=>{bt(),xt(),M_=e=>{if(e[0].dims.length!==3)throw new Error("input should have 3 dimensions");if(![320,640,1280].includes(e[0].dims[2]))throw new Error("number of channels should be 320, 640 or 1280");if(e[1].dims.length!==1)throw new Error("bias is expected to have 1 dimensions");if(e[0].dims[2]!==e[1].dims[0])throw new Error("last dimension of input and bias are not the same")},b_=e=>{let r=e[0].dims,t=e[0].dims[2],s=xe.size(r)/4,o=e[0].dataType,n=$e("input",o,r,4),i=$e("bias",o,[t],4),a=$e("residual",o,r,4),l=tt("output",o,r,4);return{name:"BiasAdd",getRunData:()=>({outputs:[{dims:r,dataType:e[0].dataType}],dispatchGroup:{x:Math.ceil(s/64)}}),getShaderSource:c=>`
+ const channels = ${t}u / 4;
+ ${c.declareVariables(n,i,a,l)}
+
+ ${c.mainStart()}
+ ${c.guardAgainstOutOfBoundsWorkgroupSizes(s)}
+ let value = ${n.getByOffset("global_idx")}
+ + ${i.getByOffset("global_idx % channels")} + ${a.getByOffset("global_idx")};
+ ${l.setByOffset("global_idx","value")}
+ }`}},aM=e=>{M_(e.inputs),e.compute(b_(e.inputs))}}),v_,It,lM,dM,cM,uM,pM,hM,mM,fM,_M,x_,gM,wM,yM,MM,da,bM,td,vM,xM,TM,EM,PM,CM,SM,$M,kM,IM,AM,FM,OM,DM,LM,zM,cc,BM,Xc,Jc,RM,NM,jM,T_,E_,VM,vu=Ve(()=>{mt(),bt(),tr(),xt(),v_=(e,r,t,s,o,n,i)=>{let a=Math.ceil(r/4),l="";typeof o=="string"?l=`${o}(a)`:l=o("a");let c=$e("inputData",t,[a],4),p=tt("outputData",s,[a],4),u=[{name:"vec_size",type:"u32"}];return i&&u.push(...i),`
+ ${e.registerUniforms(u).declareVariables(c,p)}
+
+ ${n??""}
+
+ ${e.mainStart()}
+ ${e.guardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size")}
+
+ let a = ${c.getByOffset("global_idx")};
+ ${p.setByOffset("global_idx",l)}
+ }`},It=(e,r,t,s,o,n=e.dataType,i,a)=>{let l=[{type:12,data:Math.ceil(xe.size(e.dims)/4)}];return i&&l.push(...i),{name:r,shaderCache:{hint:o,inputDependencies:["type"]},getShaderSource:c=>v_(c,xe.size(e.dims),e.dataType,n,t,s,a),getRunData:c=>({outputs:[{dims:e.dims,dataType:n}],dispatchGroup:{x:Math.ceil(xe.size(c[0].dims)/64/4)},programUniforms:l})}},lM=e=>{e.compute(It(e.inputs[0],"Abs","abs"))},dM=e=>{e.compute(It(e.inputs[0],"Acos","acos"))},cM=e=>{e.compute(It(e.inputs[0],"Acosh","acosh"))},uM=e=>{e.compute(It(e.inputs[0],"Asin","asin"))},pM=e=>{e.compute(It(e.inputs[0],"Asinh","asinh"))},hM=e=>{e.compute(It(e.inputs[0],"Atan","atan"))},mM=e=>{e.compute(It(e.inputs[0],"Atanh","atanh"))},fM=e=>Lt(e),_M=(e,r)=>{let t;switch(r.to){case 10:t="vec4";break;case 1:t="vec4";break;case 12:t="vec4";break;case 6:t="vec4";break;case 9:t="vec4";break;default:throw new RangeError(`not supported type (specified in attribute 'to' from 'Cast' operator): ${r.to}`)}e.compute(It(e.inputs[0],"Cast",t,void 0,r.cacheKey,r.to))},x_=e=>{let r,t,s=e.length>=2&&e[1].data!==0,o=e.length>=3&&e[2].data!==0;switch(e[0].dataType){case 1:r=s?e[1].getFloat32Array()[0]:-34028234663852886e22,t=o?e[2].getFloat32Array()[0]:34028234663852886e22;break;case 10:r=s?e[1].getUint16Array()[0]:64511,t=o?e[2].getUint16Array()[0]:31743;break;default:throw new Error("Unsupport data type")}return Lt({min:r,max:t})},gM=(e,r)=>{let t=r||x_(e.inputs),s=Cr(e.inputs[0].dataType);e.compute(It(e.inputs[0],"Clip",o=>`clamp(${o}, vec4<${s}>(uniforms.min), vec4<${s}>(uniforms.max))`,void 0,t.cacheKey,void 0,[{type:e.inputs[0].dataType,data:t.min},{type:e.inputs[0].dataType,data:t.max}],[{name:"min",type:s},{name:"max",type:s}]),{inputs:[0]})},wM=e=>{e.compute(It(e.inputs[0],"Ceil","ceil"))},yM=e=>{e.compute(It(e.inputs[0],"Cos","cos"))},MM=e=>{e.compute(It(e.inputs[0],"Cosh","cosh"))},da=e=>Lt(e),bM=(e,r)=>{let t=Cr(e.inputs[0].dataType);e.compute(It(e.inputs[0],"Elu",s=>`elu_vf32(${s})`,`
+ const elu_alpha_ = ${t}(${r.alpha});
+
+ fn elu_f32(a: ${t}) -> ${t} {
+ return select((exp(a) - 1.0) * elu_alpha_, a, a >= 0.0);
+ }
+
+ fn elu_vf32(v: vec4<${t}>) -> vec4<${t}> {
+ return vec4(elu_f32(v.x), elu_f32(v.y), elu_f32(v.z), elu_f32(v.w));
+ }`,r.cacheKey))},td=(e="f32")=>`
+const r0: ${e} = 0.3275911;
+const r1: ${e} = 0.254829592;
+const r2: ${e} = -0.284496736;
+const r3: ${e} = 1.421413741;
+const r4: ${e} = -1.453152027;
+const r5: ${e} = 1.061405429;
+
+fn erf_vf32(v: vec4<${e}>) -> vec4<${e}> {
+ let absv = abs(v);
+ let x = 1.0 / (1.0 + r0 * absv);
+ return sign(v) * (1.0 - ((((r5 * x + r4) * x + r3) * x + r2) * x + r1) * x * exp(-absv * absv));
+}`,vM=e=>{let r=Cr(e.inputs[0].dataType);e.compute(It(e.inputs[0],"Erf",t=>`erf_vf32(${t})`,td(r)))},xM=e=>{e.compute(It(e.inputs[0],"Exp","exp"))},TM=e=>{e.compute(It(e.inputs[0],"Floor","floor"))},EM=e=>{let r=Cr(e.inputs[0].dataType);e.compute(It(e.inputs[0],"Gelu",t=>`0.5 * ${t} * (1.0 + erf_vf32(${t} * 0.7071067811865475))`,td(r)))},PM=(e,r)=>{let t=Cr(e.inputs[0].dataType);e.compute(It(e.inputs[0],"LeakyRelu",s=>`select(leaky_relu_alpha_ * ${s}, ${s}, ${s} >= vec4<${t}>(0.0))`,`const leaky_relu_alpha_ = ${t}(${r.alpha});`,r.cacheKey))},CM=e=>{e.compute(It(e.inputs[0],"Not",r=>`!${r}`))},SM=e=>{e.compute(It(e.inputs[0],"Neg",r=>`-${r}`))},$M=e=>{e.compute(It(e.inputs[0],"Reciprocal",r=>`1.0/${r}`))},kM=e=>{let r=Cr(e.inputs[0].dataType);e.compute(It(e.inputs[0],"Relu",t=>`select(vec4<${r}>(0.0), ${t}, ${t} > vec4<${r}>(0.0))`))},IM=e=>{e.compute(It(e.inputs[0],"Sigmoid",r=>`(1.0 / (1.0 + exp(-${r})))`))},AM=e=>Lt(e),FM=(e,r)=>{let t=Cr(e.inputs[0].dataType);e.compute(It(e.inputs[0],"HardSigmoid",s=>`max(vec4<${t}>(0.0), min(vec4<${t}>(1.0), ${r.alpha} * ${s} + vec4<${t}>(${r.beta})))`,void 0,r.cacheKey))},OM=e=>{e.compute(It(e.inputs[0],"Sin","sin"))},DM=e=>{e.compute(It(e.inputs[0],"Sinh","sinh"))},LM=e=>{e.compute(It(e.inputs[0],"Sqrt","sqrt"))},zM=e=>{e.compute(It(e.inputs[0],"Tan","tan"))},cc=e=>`sign(${e}) * (1 - exp(-2 * abs(${e}))) / (1 + exp(-2 * abs(${e})))`,BM=e=>{e.compute(It(e.inputs[0],"Tanh",cc))},Xc=(e="f32")=>`
+const fast_gelu_a: ${e} = 0.5;
+const fast_gelu_b: ${e} = 0.7978845608028654;
+const fast_gelu_c: ${e} = 0.035677408136300125;
+
+fn tanh_v(v: vec4<${e}>) -> vec4<${e}> {
+ return ${cc("v")};
+}
+`,Jc=e=>`(fast_gelu_a + fast_gelu_a * tanh_v(${e} * (fast_gelu_c * ${e} * ${e} + fast_gelu_b))) * ${e}`,RM=e=>{let r=Cr(e.inputs[0].dataType);e.compute(It(e.inputs[0],"FastGelu",Jc,Xc(r),void 0,e.inputs[0].dataType))},NM=(e,r)=>{let t=Cr(e.inputs[0].dataType);return e.compute(It(e.inputs[0],"ThresholdedRelu",s=>`select(vec4<${t}>(0.0), ${s}, ${s} > thresholded_relu_alpha_)`,`const thresholded_relu_alpha_ = vec4<${t}>(${r.alpha});`,r.cacheKey)),0},jM=e=>{e.compute(It(e.inputs[0],"Log","log"))},T_=(e,r)=>`
+const alpha = vec4<${e}>(${r});
+const one = ${e}(1.0);
+const zero = ${e}(0.0);
+
+fn quick_gelu_impl(x: vec4<${e}>) -> vec4<${e}> {
+ let v = x *alpha;
+ var x1 : vec4<${e}>;
+ for (var i = 0; i < 4; i = i + 1) {
+ if (v[i] >= zero) {
+ x1[i] = one / (one + exp(-v[i]));
+ } else {
+ x1[i] = one - one / (one + exp(v[i]));
+ }
+ }
+ return x * x1;
+}
+`,E_=e=>`quick_gelu_impl(${e})`,VM=(e,r)=>{let t=Cr(e.inputs[0].dataType);e.compute(It(e.inputs[0],"QuickGelu",E_,T_(t,r.alpha),r.cacheKey,e.inputs[0].dataType))}}),P_,C_,UM,Zv=Ve(()=>{bt(),xt(),vu(),P_=e=>{if(e[0].dims.length!==3)throw new Error("input should have 3 dimensions");if(![2560,5120,10240].includes(e[0].dims[2]))throw new Error("hidden state should be 2560, 5120 or 10240");if(e[1].dims.length!==1)throw new Error("bias is expected to have 1 dimensions");if(e[0].dims[2]!==e[1].dims[0])throw new Error("last dimension of input and bias are not the same")},C_=e=>{let r=e[0].dims.slice();r[2]=r[2]/2;let t=$e("input",e[0].dataType,e[0].dims,4),s=$e("bias",e[0].dataType,[e[0].dims[2]],4),o=tt("output",e[0].dataType,r,4),n=xe.size(r)/4,i=pr(e[0].dataType);return{name:"BiasSplitGelu",getRunData:()=>({outputs:[{dims:r,dataType:e[0].dataType}],dispatchGroup:{x:Math.ceil(n/64)}}),getShaderSource:a=>`
+ const M_SQRT2 = sqrt(2.0);
+ const halfChannels = ${e[0].dims[2]/4/2}u;
+
+ ${a.declareVariables(t,s,o)}
+
+ ${td(i)}
+
+ ${a.mainStart()}
+ ${a.guardAgainstOutOfBoundsWorkgroupSizes(n)}
+ let biasIdx = global_idx % halfChannels;
+ let batchIndex = global_idx / halfChannels;
+ let inputOffset = biasIdx + batchIndex * halfChannels * 2;
+ let valueLeft = input[inputOffset] + bias[biasIdx];
+ let valueRight = input[inputOffset + halfChannels] + bias[biasIdx + halfChannels];
+ let geluRight = valueRight * 0.5 * (erf_vf32(valueRight / M_SQRT2) + 1);
+
+ ${o.setByOffset("global_idx","valueLeft * geluRight")}
+ }`}},UM=e=>{P_(e.inputs),e.compute(C_(e.inputs))}}),S_,$_,_s,WM,GM,KM,HM,qM,QM,XM,JM,YM,ZM,ex=Ve(()=>{mt(),bt(),xt(),S_=(e,r,t,s,o,n,i,a,l,c,p,u)=>{let h,g;typeof a=="string"?h=g=(y,$)=>`${a}((${y}),(${$}))`:typeof a=="function"?h=g=a:(h=a.scalar,g=a.vector);let _=tt("outputData",p,s.length,4),E=$e("aData",l,r.length,4),I=$e("bData",c,t.length,4),M;if(o)if(n){let y=xe.size(r)===1,$=xe.size(t)===1,P=r.length>0&&r[r.length-1]%4===0,b=t.length>0&&t[t.length-1]%4===0;y||$?M=_.setByOffset("global_idx",g(y?`${E.type.value}(${E.getByOffset("0")}.x)`:E.getByOffset("global_idx"),$?`${I.type.value}(${I.getByOffset("0")}.x)`:I.getByOffset("global_idx"))):M=`
+ let outputIndices = ${_.offsetToIndices("global_idx * 4u")};
+ let offsetA = ${E.broadcastedIndicesToOffset("outputIndices",_)};
+ let offsetB = ${I.broadcastedIndicesToOffset("outputIndices",_)};
+ ${_.setByOffset("global_idx",g(i||P?E.getByOffset("offsetA / 4u"):`${E.type.value}(${E.getByOffset("offsetA / 4u")}[offsetA % 4u])`,i||b?I.getByOffset("offsetB / 4u"):`${I.type.value}(${I.getByOffset("offsetB / 4u")}[offsetB % 4u])`))}
+ `}else M=_.setByOffset("global_idx",g(E.getByOffset("global_idx"),I.getByOffset("global_idx")));else{if(!n)throw new Error("no necessary to use scalar implementation for element-wise binary op implementation.");let y=($,P,b="")=>{let w=`aData[indexA${P}][componentA${P}]`,T=`bData[indexB${P}][componentB${P}]`;return`
+ let outputIndices${P} = ${_.offsetToIndices(`global_idx * 4u + ${P}u`)};
+ let offsetA${P} = ${E.broadcastedIndicesToOffset(`outputIndices${P}`,_)};
+ let offsetB${P} = ${I.broadcastedIndicesToOffset(`outputIndices${P}`,_)};
+ let indexA${P} = offsetA${P} / 4u;
+ let indexB${P} = offsetB${P} / 4u;
+ let componentA${P} = offsetA${P} % 4u;
+ let componentB${P} = offsetB${P} % 4u;
+ ${$}[${P}] = ${b}(${h(w,T)});
+ `};p===9?M=`
+ var data = vec4(0);
+ ${y("data",0,"u32")}
+ ${y("data",1,"u32")}
+ ${y("data",2,"u32")}
+ ${y("data",3,"u32")}
+ outputData[global_idx] = dot(vec4(0x1, 0x100, 0x10000, 0x1000000), vec4(data));`:M=`
+ ${y("outputData[global_idx]",0)}
+ ${y("outputData[global_idx]",1)}
+ ${y("outputData[global_idx]",2)}
+ ${y("outputData[global_idx]",3)}
+ `}return`
+ ${e.registerUniform("vec_size","u32").declareVariables(E,I,_)}
+
+ ${u??""}
+
+ ${e.mainStart()}
+ ${e.guardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size")}
+ ${M}
+ }`},$_=(e,r,t,s,o,n,i=t.dataType)=>{let a=t.dims.map(E=>Number(E)??1),l=s.dims.map(E=>Number(E)??1),c=!xe.areEqual(a,l),p=a,u=xe.size(a),h=!1,g=!1,_=[c];if(c){let E=So.calcShape(a,l,!1);if(!E)throw new Error("Can't perform binary op on the given tensors");p=E.slice(),u=xe.size(p);let I=xe.size(a)===1,M=xe.size(l)===1,y=a.length>0&&a[a.length-1]%4===0,$=l.length>0&&l[l.length-1]%4===0;_.push(I),_.push(M),_.push(y),_.push($);let P=1;for(let b=1;bE.toString()).join("_"),inputDependencies:["rank","rank"]},getShaderSource:E=>S_(E,a,l,p,h,c,g,o,t.dataType,s.dataType,i,n),getRunData:()=>({outputs:[{dims:p,dataType:i}],dispatchGroup:{x:Math.ceil(u/64/4)},programUniforms:[{type:12,data:Math.ceil(xe.size(p)/4)},...nt(a,l,p)]})}},_s=(e,r,t,s,o,n)=>{e.compute($_(r,o??"",e.inputs[0],e.inputs[1],t,s,n))},WM=e=>{_s(e,"Add",(r,t)=>`${r}+${t}`)},GM=e=>{_s(e,"Div",(r,t)=>`${r}/${t}`)},KM=e=>{_s(e,"Equal",{scalar:(r,t)=>`u32(${r}==${t})`,vector:(r,t)=>`vec4(${r}==${t})`},void 0,void 0,9)},HM=e=>{_s(e,"Mul",(r,t)=>`${r}*${t}`)},qM=e=>{let r=$e("input",e.inputs[0].dataType,e.inputs[0].dims).type.value;_s(e,"Pow",{scalar:(t,s)=>`pow_custom(${t},${s})`,vector:(t,s)=>`pow_vector_custom(${t},${s})`},`
+ fn pow_custom(a : ${r}, b : ${r}) -> ${r} {
+ if (b == ${r}(0.0)) {
+ return ${r}(1.0);
+ } else if (a < ${r}(0.0) && f32(b) != floor(f32(b))) {
+ return ${r}(pow(f32(a), f32(b))); // NaN
+ }
+ return select(sign(a), ${r}(1.0), round(f32(abs(b) % ${r}(2.0))) != 1.0) * ${r}(${r==="i32"?"round":""}(pow(f32(abs(a)), f32(b))));
+ }
+ fn pow_vector_custom(a : vec4<${r}>, b : vec4<${r}>) -> vec4<${r}> {
+ // TODO: implement vectorized pow
+ return vec4<${r}>(pow_custom(a.x, b.x), pow_custom(a.y, b.y), pow_custom(a.z, b.z), pow_custom(a.w, b.w));
+ }
+ `)},QM=e=>{_s(e,"Sub",(r,t)=>`${r}-${t}`)},XM=e=>{_s(e,"Greater",{scalar:(r,t)=>`u32(${r}>${t})`,vector:(r,t)=>`vec4(${r}>${t})`},void 0,void 0,9)},JM=e=>{_s(e,"Less",{scalar:(r,t)=>`u32(${r}<${t})`,vector:(r,t)=>`vec4(${r}<${t})`},void 0,void 0,9)},YM=e=>{_s(e,"GreaterOrEqual",{scalar:(r,t)=>`u32(${r}>=${t})`,vector:(r,t)=>`vec4(${r}>=${t})`},void 0,void 0,9)},ZM=e=>{_s(e,"LessOrEqual",{scalar:(r,t)=>`u32(${r}<=${t})`,vector:(r,t)=>`vec4(${r}<=${t})`},void 0,void 0,9)}}),k_,I_,A_,F_,e0,t0,tx=Ve(()=>{mt(),bt(),tr(),xt(),k_=(e,r)=>{if(!e||e.length<1)throw new Error("too few inputs");let t=0,s=e[t],o=s.dataType,n=s.dims.length;e.forEach((i,a)=>{if(a!==t){if(i.dataType!==o)throw new Error("input tensors should be one type");if(i.dims.length!==n)throw new Error("input tensors should have the same shape");i.dims.forEach((l,c)=>{if(c!==r&&l!==s.dims[c])throw new Error("non concat dimensions must match")})}})},I_=(e,r)=>`
+ fn calculateInputIndex(index: u32) -> u32 {
+ let sizeInConcatAxis = array(${r});
+ for (var i: u32 = 0u; i < ${e}; i += 1u ) {
+ if (index < sizeInConcatAxis[i]) {
+ return i;
+ }
+ }
+ return ${e}u;
+ }`,A_=(e,r)=>{let t=e.length,s=[];for(let o=0;o{let o=xe.size(t),n=new Array(e.length),i=new Array(e.length),a=0,l=[],c=[],p=[{type:12,data:o}];for(let E=0;E`uniforms.sizeInConcatAxis${E}`).join(","),_=E=>`
+
+ ${(()=>{E.registerUniform("outputSize","u32");for(let I=0;I(${g});
+ ${h} -= sizeInConcatAxis[inputIndex - 1u];
+ }
+
+ ${A_(i,u)}
+ }`;return{name:"Concat",shaderCache:{hint:`${r}`,inputDependencies:l},getRunData:()=>({outputs:[{dims:t,dataType:s}],dispatchGroup:{x:Math.ceil(o/64)},programUniforms:p}),getShaderSource:_}},e0=(e,r)=>{let t=e.inputs,s=t[0].dims,o=xe.normalizeAxis(r.axis,s.length);k_(t,o);let n=s.slice();n[o]=t.reduce((a,l)=>a+(l.dims.length>o?l.dims[o]:0),0);let i=t.filter(a=>xe.size(a.dims)>0);e.compute(F_(i,o,n,t[0].dataType),{inputs:i})},t0=e=>Lt({axis:e.axis})}),On,Dn,Ln,xu,Bn=Ve(()=>{mt(),bt(),On=(e,r,t="f32")=>{switch(e.activation){case"Relu":return`value = max(value, ${r}(0.0));`;case"Sigmoid":return`value = (${r}(1.0) / (${r}(1.0) + exp(-value)));`;case"Clip":return`value = clamp(value, ${r}(${t}(uniforms.clip_min)), ${r}(${t}(uniforms.clip_max)));`;case"HardSigmoid":return`value = max(${r}(0.0), min(${r}(1.0), ${t}(uniforms.alpha) * value + ${t}(uniforms.beta)));`;case"LeakyRelu":return`value = select(${t}(uniforms.alpha) * value, value, value >= ${r}(0.0));`;case"Tanh":return`let e2x = exp(-2.0 * abs(value));
+ value = sign(value) * (1.0 - e2x) / (1.0 + e2x);
+ `;case"":return"";default:throw new Error(`Unsupported activation ${e.activation}`)}},Dn=(e,r)=>{e.activation==="Clip"?r.push({type:1,data:e.clipMax},{type:1,data:e.clipMin}):e.activation==="HardSigmoid"?r.push({type:1,data:e.alpha},{type:1,data:e.beta}):e.activation==="LeakyRelu"&&r.push({type:1,data:e.alpha})},Ln=(e,r)=>{e.activation==="Clip"?r.push({name:"clip_max",type:"f32"},{name:"clip_min",type:"f32"}):e.activation==="HardSigmoid"?r.push({name:"alpha",type:"f32"},{name:"beta",type:"f32"}):e.activation==="LeakyRelu"&&r.push({name:"alpha",type:"f32"})},xu=e=>{let r=(e==null?void 0:e.activation)||"";if(r==="HardSigmoid"){let[t,s]=(e==null?void 0:e.activation_params)||[.2,.5];return{activation:r,alpha:t,beta:s}}else if(r==="Clip"){let[t,s]=(e==null?void 0:e.activation_params)||[Iy,Ay];return{activation:r,clipMax:s,clipMin:t}}else if(r==="LeakyRelu"){let[t]=(e==null?void 0:e.activation_params)||[.01];return{activation:r,alpha:t}}return{activation:r}}}),yr,r0,Tu=Ve(()=>{yr=(e,r)=>{switch(e){case 1:return r;case 2:return`vec2<${r}>`;case 3:return`vec3<${r}>`;case 4:return`vec4<${r}>`;default:throw new Error(`${e}-component is not supported.`)}},r0=e=>`
+ ${e?"value = value + getBiasByOutputCoords(coords);":""}
+ `}),s0,rx=Ve(()=>{s0=e=>`
+fn getIndexFromCoords4D(coords : vec4, shape : vec4) -> i32 {
+ return dot(coords, vec4(
+ shape.y * shape.z * shape.w, shape.z * shape.w, shape.w, 1));
+}
+fn getOutputIndexFromCoords(coords : vec4) -> i32 {
+ return dot(coords, vec4(
+ i32(${e}.x), i32(${e}.y), i32(${e}.z), 1));
+}
+`}),ua,Eu,Pu=Ve(()=>{mt(),bt(),xt(),Bn(),ua=(e,r,t,s,o)=>{let n=s-t;return`
+ ${Array.from({length:t}).map((i,a)=>`
+ if (${rt(r.shape,a,r.rank)} != 1) {
+ ${r.indicesSet(e,a,rt(o,a+n,s))}
+ } else {
+ ${r.indicesSet(e,a,0)}
+ }`).join("")}
+`},Eu=(e,r,t,s,o=!1,n)=>{let i=e[0].dims,a=e[1].dims,l=i[i.length-2],c=a[a.length-1],p=i[i.length-1],u=Jt(c),h=Jt(p),g=Jt(l),_=xe.size(t)/u/g,E=e.length>2,I=s?s.slice(0,-2):t.slice(0,-2),M=[xe.size(I),l,c],y=[{type:12,data:_},{type:12,data:l},{type:12,data:c},{type:12,data:p}];Dn(r,y),y.push(...nt(I,i,a)),E&&y.push(...nt(e[2].dims)),y.push(...nt(M));let $=P=>{let b=yu("batch_dims",e[0].dataType,I.length),w=$e("a",e[0].dataType,i.length,h),T=$e("b",e[1].dataType,a.length,u),k=tt("output",e[0].dataType,M.length,u),z=pr(k.type.tensor),R=On(r,k.type.value,z),Q=[w,T],q="";if(E){let H=o?u:1;Q.push($e("bias",e[2].dataType,e[2].dims.length,H)),q=`${o?`value += bias[col / ${H}];`:`value += ${k.type.value}(bias[row + i]);`}`}let U=[{name:"output_size",type:"u32"},{name:"M",type:"u32"},{name:"N",type:"u32"},{name:"K",type:"u32"}];Ln(r,U);let Z=()=>{let H=`var a_data: ${w.type.value};`;for(let J=0;J;
+ for (var k: u32 = 0u; k < uniforms.K; k = k + ${h}) {
+ ${Z()}
+ }
+ for (var i = 0u; i < ${g}u; i++) {
+ var value = values[i];
+ ${q}
+ ${R}
+ let cur_indices = ${k.type.indices}(batch, row + i, col);
+ let offset = ${k.indicesToOffset("cur_indices")};
+ ${k.setByOffset(`offset / ${u}`,"value")};
+ }
+ }
+ `};return{name:"MatMulNaive",shaderCache:{hint:`${r.activation};${u};${h};${g};${o}`,inputDependencies:E?["rank","rank","rank"]:["rank","rank"]},getRunData:()=>({outputs:[{dims:n?n(t):t,dataType:e[0].dataType}],dispatchGroup:{x:Math.ceil(_/64)},programUniforms:y}),getShaderSource:$}}}),O_,D_,Yc,uc,L_,Zc,z_,ad,Cu=Ve(()=>{mt(),bt(),xt(),Bn(),Pu(),Tu(),O_=(e,r)=>e?`
+ mm_Asub[inputRow][inputCol] = mm_readA(batch,
+ kStart + inputRow,
+ globalRowStart / innerElementSize + inputCol${r?", batchIndices":""});
+ `:`
+ mm_Asub[inputRow][inputCol] = mm_readA(batch,
+ globalRow + innerRow,
+ kStart / innerElementSize + inputCol${r?", batchIndices":""});
+ `,D_=(e,r)=>e?`
+ let ACached0 = mm_Asub[k * innerElementSize][localRow];
+ let ACached1 = mm_Asub[k * innerElementSize + 1][localRow];
+ let ACached2 = mm_Asub[k * innerElementSize + 2][localRow];
+ ${r===3?"":"let ACached3 = mm_Asub[k * innerElementSize + 3][localRow];"}
+ for (var i = 0; i < rowPerThread; i = i + 1) {
+ acc[i] = BCached0 * ACached0[i] + acc[i];
+ acc[i] = BCached1 * ACached1[i] + acc[i];
+ acc[i] = BCached2 * ACached2[i] + acc[i];
+ ${r===3?"":"acc[i] = BCached3 * ACached3[i] + acc[i];"}
+ }`:`
+ for (var i = 0; i < rowPerThread; i = i + 1) {
+ let ACached = mm_Asub[tileRow + i][k];
+ acc[i] = BCached0 * ACached.x + acc[i];
+ acc[i] = BCached1 * ACached.y + acc[i];
+ acc[i] = BCached2 * ACached.z + acc[i];
+ ${r===3?"":"acc[i] = BCached3 * ACached.w + acc[i];"}
+ }`,Yc=(e,r,t="f32",s,o=!1,n=32,i=!1,a=32)=>{let l=r[1]*e[1],c=r[0]*e[0],p=o?l:n,u=o?n:l,h=p/r[0],g=n/r[1];if(!((o&&h===4&&e[1]===4||!o&&(h===3||h===4))&&p%r[0]===0&&n%r[1]===0&&e[0]===4))throw new Error(`If transposeA ${o} is true, innerElementSize ${h} and workPerThread[1] ${e[1]} must be 4.
+ Otherwise, innerElementSize ${h} must be 3 or 4.
+ tileAWidth ${p} must be divisible by workgroupSize[0]${r[0]}. tileInner ${n} must be divisible by workgroupSize[1] ${r[1]}. colPerThread ${e[0]} must be 4.`);return`
+var mm_Asub: array, ${p/h}>, ${u}>;
+var mm_Bsub: array, ${c/e[0]}>, ${n}>;
+
+const rowPerThread = ${e[1]};
+const colPerThread = ${e[0]};
+const innerElementSize = ${h};
+const tileInner = ${n};
+
+@compute @workgroup_size(${r[0]}, ${r[1]}, ${r[2]})
+fn main(@builtin(local_invocation_id) localId : vec3,
+ @builtin(global_invocation_id) globalId : vec3,
+ @builtin(workgroup_id) workgroupId : vec3) {
+ let localRow = i32(localId.y);
+ let tileRow = localRow * rowPerThread;
+ let tileCol = i32(localId.x);
+
+ let globalRow =i32(globalId.y) * rowPerThread;
+ let globalCol = i32(globalId.x);
+ let batch = ${i?"0":"i32(globalId.z)"};
+ ${s?`let batchIndices = ${s.offsetToIndices("u32(batch)")};`:""}
+ let globalRowStart = i32(workgroupId.y) * ${l};
+
+ let num_tiles = ${i?`${Math.ceil(a/n)}`:"(uniforms.dim_inner - 1) / tileInner + 1"};
+ var kStart = ${i?`i32(globalId.z) * ${a}`:"0"};
+
+ var acc: array, rowPerThread>;
+
+ // Loop over shared dimension.
+ let tileRowB = localRow * ${g};
+ for (var t = 0; t < num_tiles; t = t + 1) {
+ // Load one tile of A into local memory.
+ for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) {
+ let inputRow = tileRow + innerRow;
+ let inputCol = tileCol;
+ ${O_(o,s)}
+ }
+
+ // Load one tile of B into local memory.
+ for (var innerRow = 0; innerRow < ${g}; innerRow = innerRow + 1) {
+ let inputRow = tileRowB + innerRow;
+ let inputCol = tileCol;
+ mm_Bsub[inputRow][inputCol] = mm_readB(batch, kStart + inputRow, globalCol${s?", batchIndices":""});
+ }
+ kStart = kStart + tileInner;
+ workgroupBarrier();
+
+ // Compute acc values for a single thread.
+ for (var k = 0; k < tileInner / innerElementSize; k = k + 1) {
+ let BCached0 = mm_Bsub[k * innerElementSize][tileCol];
+ let BCached1 = mm_Bsub[k * innerElementSize + 1][tileCol];
+ let BCached2 = mm_Bsub[k * innerElementSize + 2][tileCol];
+ ${h===3?"":"let BCached3 = mm_Bsub[k * innerElementSize + 3][tileCol];"}
+
+ ${D_(o,h)}
+ }
+
+ workgroupBarrier();
+ }
+
+ for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) {
+ mm_write(batch, globalRow + innerRow, globalCol, acc[innerRow]);
+ }
+}`},uc=(e,r)=>e?`
+ mm_Asub[inputRow][inputCol] = mm_readA(batch,
+ kStart + inputRow,
+ globalRowStart + inputCol${r?", batchIndices":""});
+ `:`
+ mm_Asub[inputRow][inputCol] = mm_readA(batch,
+ globalRowStart + inputRow,
+ kStart + inputCol${r?", batchIndices":""});
+ `,L_=e=>e?"let ACached = mm_Asub[k][tileRow + innerRow];":"let ACached = mm_Asub[tileRow + innerRow][k];",Zc=(e,r,t="f32",s,o=!1,n=32,i=!1,a=32,l=!1)=>{let c=e[1]*r[1],p=e[0]*r[0],u=o?c:n,h=o?n:c;if(!(h%r[1]===0&&u%r[0]===0&&n%r[1]===0))throw new Error(`tileAHight ${h} must be divisible by workgroupSize[1]${r[1]}, tileAWidth ${u} must be divisible by workgroupSize[0]${r[0]}, tileInner ${n} must be divisible by workgroupSize[1]${r[1]}`);let g=h/r[1],_=u/r[0],E=n/r[1],I=l?`
+ let localRow = i32(localId.y);
+ let localCol = i32(localId.x);
+ let globalRowStart = i32(workgroupId.y) * ${c};
+ let globalColStart = i32(workgroupId.x) * ${p};
+
+ // Loop over shared dimension.
+ for (var t = 0; t < num_tiles; t = t + 1) {
+ // Load one tile of A into local memory.
+ for (var inputRow = localRow; inputRow < ${h}; inputRow = inputRow + ${r[1]}) {
+ for (var inputCol = localCol; inputCol < ${u}; inputCol = inputCol + ${r[0]}) {
+ ${uc(o,s)}
+ }
+ }
+ // Load one tile of B into local memory.
+ for (var inputRow = localRow; inputRow < ${n}; inputRow = inputRow + ${r[1]}) {
+ for (var inputCol = localCol; inputCol < ${p}; inputCol = inputCol + ${r[0]}) {
+ mm_Bsub[inputRow][inputCol] = mm_readB(batch,
+ kStart + inputRow,
+ globalColStart + inputCol${s?", batchIndices":""});
+ }
+ }
+ kStart = kStart + tileInner;
+ workgroupBarrier();
+
+ // Compute acc values for a single thread.
+ var BCached : array<${t}, colPerThread>;
+ for (var k = 0; k < tileInner; k = k + 1) {
+ for (var inner = 0; inner < colPerThread; inner = inner + 1) {
+ BCached[inner] = mm_Bsub[k][localCol + inner * ${r[0]}];
+ }
+ for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) {
+ let ACached = ${o?`mm_Asub[k][localRow + innerRow * ${r[1]}];`:`mm_Asub[localRow + innerRow * ${r[1]}][k];`}
+ for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) {
+ acc[innerRow][innerCol] = acc[innerRow][innerCol] +
+ ACached * BCached[innerCol];
+ }
+ }
+ }
+ workgroupBarrier();
+ }
+ for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) {
+ let gRow = globalRowStart + localRow + innerRow * ${r[1]};
+ for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) {
+ let gCol = globalColStart + localCol + innerCol * ${r[0]};
+ mm_write(batch, gRow, gCol, acc[innerRow][innerCol]);
+ }
+ }
+ `:`
+let tileRow = i32(localId.y) * rowPerThread;
+let tileCol = i32(localId.x) * colPerThread;
+
+let globalRow = i32(globalId.y) * rowPerThread;
+let globalCol = i32(globalId.x) * colPerThread;
+let globalRowStart = i32(workgroupId.y) * ${c};
+
+let tileRowA = i32(localId.y) * ${g};
+let tileColA = i32(localId.x) * ${_};
+let tileRowB = i32(localId.y) * ${E};
+// Loop over shared dimension.
+for (var t = 0; t < num_tiles; t = t + 1) {
+ // Load one tile of A into local memory.
+ for (var innerRow = 0; innerRow < ${g}; innerRow = innerRow + 1) {
+ for (var innerCol = 0; innerCol < ${_}; innerCol = innerCol + 1) {
+ let inputRow = tileRowA + innerRow;
+ let inputCol = tileColA + innerCol;
+ ${uc(o,s)}
+ }
+ }
+
+ // Load one tile of B into local memory.
+ for (var innerRow = 0; innerRow < ${E}; innerRow = innerRow + 1) {
+ for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) {
+ let inputRow = tileRowB + innerRow;
+ let inputCol = tileCol + innerCol;
+ mm_Bsub[inputRow][inputCol] = mm_readB(batch,
+ kStart + inputRow,
+ globalCol + innerCol${s?", batchIndices":""});
+ }
+ }
+ kStart = kStart + tileInner;
+ workgroupBarrier();
+
+ // Compute acc values for a single thread.
+ var BCached : array<${t}, colPerThread>;
+ for (var k = 0; k < tileInner; k = k + 1) {
+ for (var inner = 0; inner < colPerThread; inner = inner + 1) {
+ BCached[inner] = mm_Bsub[k][tileCol + inner];
+ }
+
+ for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) {
+ ${L_(o)}
+ for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) {
+ acc[innerRow][innerCol] = acc[innerRow][innerCol] + ACached * BCached[innerCol];
+ }
+ }
+ }
+
+ workgroupBarrier();
+}
+
+for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) {
+ for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) {
+ mm_write(batch, globalRow + innerRow, globalCol + innerCol,
+ acc[innerRow][innerCol]);
+ }
+}
+`;return`
+ var mm_Asub : array, ${h}>;
+ var mm_Bsub : array, ${n}>;
+ const rowPerThread = ${e[1]};
+ const colPerThread = ${e[0]};
+ const tileInner = ${n};
+
+@compute @workgroup_size(${r[0]}, ${r[1]}, ${r[2]})
+fn main(@builtin(local_invocation_id) localId : vec3,
+ @builtin(global_invocation_id) globalId : vec3,
+ @builtin(workgroup_id) workgroupId : vec3) {
+ let batch = ${i?"0":"i32(globalId.z)"};
+ ${s?`let batchIndices = ${s.offsetToIndices("u32(batch)")};`:""}
+ let num_tiles = ${i?`${Math.ceil(a/n)}`:"(uniforms.dim_inner - 1) / tileInner + 1"};
+ var kStart = ${i?`i32(globalId.z) * ${a}`:"0"};
+
+ var acc : array, rowPerThread>;
+ ${I}
+ }
+`},z_=(e,r,t,s,o=!1)=>{let[n,i,a,l]=s,c=pr(s[0].type.tensor);return`
+ fn mm_readA(batch: i32, row: i32, colIn: i32, batchIndices: ${n.type.indices}) -> ${yr(e,c)} {
+ var value = ${yr(e,c)}(0.0);
+ let col = colIn * ${e};
+ if(row < uniforms.dim_a_outer && col < uniforms.dim_inner)
+ {
+ var aIndices: ${i.type.indices};
+ ${ua("aIndices",i,i.rank-2,n.rank,"batchIndices")}
+ ${i.indicesSet("aIndices",i.rank-2,"u32(row)")}
+ ${i.indicesSet("aIndices",i.rank-1,"u32(colIn)")}
+ value = ${i.getByIndices("aIndices")};
+ }
+ return value;
+ }
+
+ fn mm_readB(batch: i32, row: i32, colIn: i32, batchIndices: ${n.type.indices}) -> ${yr(e,c)} {
+ var value = ${yr(e,c)}(0.0);
+ let col = colIn * ${e};
+ if(row < uniforms.dim_inner && col < uniforms.dim_b_outer)
+ {
+ var bIndices: ${a.type.indices};
+ ${ua("bIndices",a,a.rank-2,n.rank,"batchIndices")}
+ ${a.indicesSet("bIndices",a.rank-2,"u32(row)")}
+ ${a.indicesSet("bIndices",a.rank-1,"u32(colIn)")}
+ value = ${a.getByIndices("bIndices")};
+ }
+ return value;
+ }
+
+ fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: ${yr(e,c)}) {
+ let col = colIn * ${e};
+ if (row < uniforms.dim_a_outer && col < uniforms.dim_b_outer) {
+ var value = valueIn;
+ let coords = vec3(batch, row, colIn);
+ ${r?`value = value + ${o?"bias[colIn]":`${yr(e,c)}(bias[row])`};`:""}
+ ${t}
+ ${l.setByIndices("vec3(coords)","value")}
+ }
+ }
+ `},ad=(e,r,t,s,o=!1,n)=>{let i=e[0].dims,a=e[1].dims,l=i.slice(0,-2),c=a.slice(0,-2),p=s?s.slice(0,-2):t.slice(0,-2),u=xe.size(p),h=i[i.length-2],g=i[i.length-1],_=a[a.length-1],E=g%4===0&&_%4===0,I=h<=8?[4,1,1]:[4,4,1],M=[8,8,1],y=[Math.ceil(_/M[0]/I[0]),Math.ceil(h/M[1]/I[1]),Math.ceil(u/M[2]/I[2])],$=E?4:1,P=[...l,h,g/$],b=P.length,w=[...c,g,_/$],T=w.length,k=[u,h,_/$],z=[{type:6,data:h},{type:6,data:_},{type:6,data:g}];Dn(r,z),z.push(...nt(p,P,w));let R=["rank","rank"],Q=e.length>2;Q&&(z.push(...nt(e[2].dims)),R.push("rank")),z.push(...nt(k));let q=U=>{let Z=p.length,H=yu("batchDims",e[0].dataType,Z,1),J=pr(e[0].dataType),oe=$e("a",e[0].dataType,b,$),ae=$e("b",e[1].dataType,T,$),ce=tt("result",e[0].dataType,k.length,$),he=[oe,ae];if(Q){let X=o?$:1;he.push($e("bias",e[2].dataType,e[2].dims.length,X))}let N=[{name:"dim_a_outer",type:"i32"},{name:"dim_b_outer",type:"i32"},{name:"dim_inner",type:"i32"}];Ln(r,N);let O=pr(ce.type.tensor),G=On(r,ce.type.value,O),se=z_($,Q,G,[H,oe,ae,ce],o);return`
+ ${U.registerUniforms(N).registerInternalVariables(H).declareVariables(...he,ce)}
+ ${se}
+ ${E?Yc(I,M,J,H):Zc(I,M,J,H)}
+ `};return{name:"MatMul",shaderCache:{hint:`${I};${r.activation};${E};${o}`,inputDependencies:R},getRunData:()=>({outputs:[{dims:n?n(t):t,dataType:e[0].dataType}],dispatchGroup:{x:y[0],y:y[1],z:y[2]},programUniforms:z}),getShaderSource:q}}}),B_,n0,sx=Ve(()=>{mt(),Us(),xt(),Bn(),Tu(),rx(),Cu(),B_=(e,r,t,s,o=!1,n,i=4,a=4,l=4,c="f32")=>{let p=z=>{switch(z){case 1:return"resData = x[xIndex];";case 3:return`resData = vec3<${c}>(x[xIndex], x[xIndex + 1], x[xIndex + 2]);`;case 4:return"resData = x[xIndex / 4];";default:throw new Error(`innerElementSize ${z} is not supported.`)}},u=z=>{switch(z){case 1:return"return w[row * i32(uniforms.w_shape[3]) + colIn];";case 4:return"return w[row * i32(uniforms.w_shape[3]) / 4 + colIn];";default:throw new Error(`innerElementSize ${z} is not supported.`)}},h=e?`
+ let coord = vec4(batch, xRow, xCol, xCh);
+ `:`
+ let coord = vec4(batch, xCh, xRow, xCol);
+ `,g=e?`
+ let coords = vec4(
+ batch,
+ row / outWidth,
+ row % outWidth,
+ col);
+ `:`
+ let coords = vec4(
+ batch,
+ row,
+ col / outWidth,
+ col % outWidth);
+ `,_=e?"i32(uniforms.x_shape[1])":"i32(uniforms.x_shape[2])",E=e?"i32(uniforms.x_shape[2])":"i32(uniforms.x_shape[3])",I=e?"row":"col",M=e?"col":"row",y=`
+ let inChannels = i32(uniforms.w_shape[2]);
+ let outWidth = ${e?"i32(uniforms.result_shape[2])":"i32(uniforms.result_shape[3])"};
+ let outRow = ${I} / outWidth;
+ let outCol = ${I} % outWidth;
+
+ let WRow = ${M} / (i32(uniforms.w_shape[1]) * inChannels);
+ let WCol = ${M} / inChannels % i32(uniforms.w_shape[1]);
+ let xRow = outRow * uniforms.stride[0] + uniforms.dilation[0] * WRow - uniforms.pad[0];
+ let xCol = outCol * uniforms.stride[1] + uniforms.dilation[1] * WCol - uniforms.pad[1];
+ let xCh = ${M} % inChannels;
+ var resData = ${yr(i,c)}(0.0);
+ // The bounds checking is always needed since we use it to pad zero for
+ // the 'same' padding type.
+ if (xRow >= 0 && xRow < ${_} && xCol >= 0 && xCol < ${E}) {
+ ${h}
+ let xIndex = getIndexFromCoords4D(coord, vec4(uniforms.x_shape));
+ ${p(i)}
+ }
+ return resData;`,$=e?r&&s?`
+ let col = colIn * ${i};
+ ${y}`:`
+ let col = colIn * ${i};
+ if (row < uniforms.dim_a_outer && col < uniforms.dim_inner) {
+ ${y}
+ }
+ return ${yr(i,c)}(0.0);`:s&&t?`
+ let col = colIn * ${i};
+ ${y}`:`
+ let col = colIn * ${i};
+ if (row < uniforms.dim_inner && col < uniforms.dim_b_outer) {
+ ${y}
+ }
+ return ${yr(i,c)}(0.0);`,P=e?s&&t?u(a):`
+ let col = colIn * ${a};
+ if (row < uniforms.dim_inner && col < uniforms.dim_b_outer) {
+ ${u(a)}
+ }
+ return ${yr(a,c)}(0.0);`:`
+ let col = colIn * ${a};
+ if (row < uniforms.dim_inner && col < uniforms.dim_a_outer) {
+ ${u(a)}
+ }
+ return ${yr(a,c)}(0.0);`,b=yr(l,c),w=yr(e?i:a,c),T=yr(e?a:i,c),k=On(n,b,c);return`
+ fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${w} {
+ ${e?$:P}
+ }
+
+ fn mm_readB(batch: i32, row : i32, colIn : i32) -> ${T} {
+ ${e?P:$}
+ }
+
+ fn mm_write(batch: i32, row : i32, colIn : i32, valueIn : ${b}) {
+ let col = colIn * ${l};
+ if (row < uniforms.dim_a_outer && col < uniforms.dim_b_outer)
+ {
+ var value = valueIn;
+ let outWidth = ${e?"i32(uniforms.result_shape[2])":"i32(uniforms.result_shape[3])"};
+ ${g}
+ ${r0(o)}
+ ${k}
+ setOutputAtCoords(coords[0], coords[1], coords[2], coords[3], value);
+ }
+ }`},n0=(e,r,t,s,o,n,i,a,l)=>{let c=r.format==="NHWC",p=c?e[0].dims[3]:e[0].dims[1],u=t[0],h=c?t[2]:t[3],g=c?t[1]:t[2],_=c?t[3]:t[1],E=c&&(p%4===0||p%3===0)&&_%4===0,I=c?_:h*g,M=c?h*g:_,y=[8,8,1],$=s<=8?[4,1,1]:[4,4,1],P=[Math.ceil(I/y[0]/$[0]),Math.ceil(M/y[1]/$[1]),Math.ceil(u/y[2]/$[2])];St("verbose",()=>`[conv2d_mm_webgpu] dispatch = ${P}`);let b=E?c&&p%4!==0?3:4:1,w=y[1]*$[1],T=y[0]*$[0],k=Math.max(y[0]*b,y[1]),z=s%w===0,R=o%T===0,Q=n%k===0,q=E?[b,4,4]:[1,1,1],U=[{type:6,data:s},{type:6,data:o},{type:6,data:n},{type:6,data:[r.pads[0],r.pads[1]]},{type:6,data:r.strides},{type:6,data:r.dilations}];Dn(r,U),U.push(...nt(e[0].dims,e[1].dims));let Z=["rank","rank"];i&&(U.push(...nt(e[2].dims)),Z.push("rank")),U.push(...nt(t));let H=J=>{let oe=[{name:"dim_a_outer",type:"i32"},{name:"dim_b_outer",type:"i32"},{name:"dim_inner",type:"i32"},{name:"pad",type:"i32",length:2},{name:"stride",type:"i32",length:2},{name:"dilation",type:"i32",length:2}];Ln(r,oe);let ae=E?4:1,ce=pr(e[0].dataType),he=`
+ fn setOutputAtIndex(flatIndex : i32, value : ${E?`vec4<${ce}>`:ce}) {
+ result[flatIndex] = ${E?`vec4<${ce}>`:ce}(value);
+ }
+ fn setOutputAtCoords(d0 : i32, d1 : i32, d2 : i32, d3 : i32, value : ${E?`vec4<${ce}>`:ce}) {
+ let flatIndex = getOutputIndexFromCoords(vec4(d0, d1, d2, d3));
+ setOutputAtIndex(flatIndex ${E?"/ 4":""}, value);
+ }`,N=$e("x",e[0].dataType,e[0].dims.length,b===3?1:b),O=$e("w",e[1].dataType,e[1].dims.length,ae),G=[N,O],se=tt("result",e[0].dataType,t.length,ae);if(i){let X=$e("bias",e[2].dataType,e[2].dims.length,ae);G.push(X),he+=`
+ fn getBiasByOutputCoords(coords : vec4) -> ${E?`vec4<${ce}>`:ce} {
+ return bias[coords.${c?"w":"y"}${E?"/ 4":""}];
+ }`}return`
+ ${s0("uniforms.result_strides")}
+ //struct Uniforms { xShape : vec4, wShape : vec4, outShape : vec4,
+ // outShapeStrides: vec3, filterDims : vec2, pad : vec2, stride : vec2,
+ // dilation : vec2, dimAOuter : i32, dimBOuter : i32, dimInner : i32 };
+ ${J.registerUniforms(oe).declareVariables(...G,se)}
+ ${he}
+ ${B_(c,z,R,Q,i,r,q[0],q[1],q[2],ce)}
+ ${E?Yc($,y,ce,void 0,!c,k):Zc($,y,ce,void 0,!c,k,!1,void 0,a)}`};return{name:"Conv2DMatMul",shaderCache:{hint:`${r.cacheKey};${b};${E};${z};${R};${Q};${w};${T};${k}`,inputDependencies:Z},getRunData:()=>({outputs:[{dims:l?l(t):t,dataType:e[0].dataType}],dispatchGroup:{x:P[0],y:P[1],z:P[2]},programUniforms:U}),getShaderSource:H}}}),R_,pc,ta,N_,hc,j_,o0,i0,nx=Ve(()=>{mt(),Us(),bt(),xt(),Bn(),Tu(),R_=e=>{let r=1;for(let t=0;ttypeof e=="number"?[e,e,e]:e,ta=(e,r)=>r<=1?e:e+(e-1)*(r-1),N_=(e,r,t,s=1)=>{let o=ta(r,s);return Math.floor((e[0]*(t-1)-t+o)/2)},hc=(e,r,t,s,o)=>{o==null&&(o=N_(e,r[0],s[0]));let n=[0,0,0,t];for(let i=0;i<3;i++)e[i]+2*o>=r[i]&&(n[i]=Math.trunc((e[i]-r[i]+2*o)/s[i]+1));return n},j_=(e,r,t,s,o,n,i,a,l,c)=>{let p,u,h,g;if(e==="VALID"&&(e=0),typeof e=="number"){p={top:e,bottom:e,left:e,right:e,front:e,back:e};let _=hc([r,t,s,1],[a,l,c],1,[o,n,i],e);u=_[0],h=_[1],g=_[2]}else if(Array.isArray(e)){if(!e.every((E,I,M)=>E===M[0]))throw Error(`Unsupported padding parameter: ${e}`);p={top:e[0],bottom:e[1],left:e[2],right:e[3],front:e[4],back:e[5]};let _=hc([r,t,s,1],[a,l,c],1,[o,n,i],e[0]);u=_[0],h=_[1],g=_[2]}else if(e==="SAME_UPPER"){u=Math.ceil(r/o),h=Math.ceil(t/n),g=Math.ceil(s/i);let _=(u-1)*o+a-r,E=(h-1)*n+l-t,I=(g-1)*i+c-s,M=Math.floor(_/2),y=_-M,$=Math.floor(E/2),P=E-$,b=Math.floor(I/2),w=I-b;p={top:$,bottom:P,left:b,right:w,front:M,back:y}}else throw Error(`Unknown padding parameter: ${e}`);return{padInfo:p,outDepth:u,outHeight:h,outWidth:g}},o0=(e,r,t,s,o,n=!1,i="channelsLast")=>{let a,l,c,p,u;if(i==="channelsLast")[a,l,c,p,u]=e;else if(i==="channelsFirst")[a,u,l,c,p]=e;else throw new Error(`Unknown dataFormat ${i}`);let[h,,g,_,E]=r,[I,M,y]=pc(t),[$,P,b]=pc(s),w=ta(g,$),T=ta(_,P),k=ta(E,b),{padInfo:z,outDepth:R,outHeight:Q,outWidth:q}=j_(o,l,c,p,I,M,y,w,T,k),U=n?h*u:h,Z=[0,0,0,0,0];return i==="channelsFirst"?Z=[a,U,R,Q,q]:i==="channelsLast"&&(Z=[a,R,Q,q,U]),{batchSize:a,dataFormat:i,inDepth:l,inHeight:c,inWidth:p,inChannels:u,outDepth:R,outHeight:Q,outWidth:q,outChannels:U,padInfo:z,strideDepth:I,strideHeight:M,strideWidth:y,filterDepth:g,filterHeight:_,filterWidth:E,effectiveFilterDepth:w,effectiveFilterHeight:T,effectiveFilterWidth:k,dilationDepth:$,dilationHeight:P,dilationWidth:b,inShape:e,outShape:Z,filterShape:r}},i0=(e,r,t,s,o,n)=>{let i=n==="channelsLast";i?e[0].dims[3]:e[0].dims[1];let a=[64,1,1],l={x:t.map((I,M)=>M)},c=[Math.ceil(R_(l.x.map(I=>t[I]))/a[0]),1,1];St("verbose",()=>`[conv3d_naive_webgpu] dispatch = ${c}`);let p=1,u=xe.size(t),h=[{type:12,data:u},{type:12,data:s},{type:12,data:o},{type:12,data:r.strides},{type:12,data:r.dilations}];Dn(r,h),h.push(...nt(e[0].dims,e[1].dims));let g=["rank","rank"],_=e.length===3;_&&(h.push(...nt(e[2].dims)),g.push("rank")),h.push(...nt(t));let E=I=>{let M=[{name:"output_size",type:"u32"},{name:"filter_dims",type:"u32",length:s.length},{name:"pads",type:"u32",length:o.length},{name:"strides",type:"u32",length:r.strides.length},{name:"dilations",type:"u32",length:r.dilations.length}];Ln(r,M);let y=1,$=pr(e[0].dataType),P=$e("x",e[0].dataType,e[0].dims.length,p),b=$e("W",e[1].dataType,e[1].dims.length,y),w=[P,b],T=tt("result",e[0].dataType,t.length,y),k="";if(_){let Q=$e("bias",e[2].dataType,e[2].dims.length,y);w.push(Q),k+=`
+ fn getBiasByOutputCoords(coords : array) -> ${$} {
+ return bias[${i?rt("coords",4,5):rt("coords",1,5)}];
+ }`}let z=yr(p,$),R=On(r,z,$);return`
+ ${k}
+ fn getX(d0 : u32, d1 : u32, d2 : u32, d3 : u32, d4 : u32) -> f32 {
+ let aIndices = array(d0, d1, d2, d3, d4);
+ return ${P.getByIndices("aIndices")};
+ }
+ fn getW(d0 : u32, d1 : u32, d2 : u32, d3 : u32, d4 : u32) -> f32 {
+ let aIndices = array(d0, d1, d2, d3, d4);
+ return ${b.getByIndices("aIndices")};
+ }
+ ${I.registerUniforms(M).declareVariables(...w,T)}
+ ${I.mainStart()}
+ ${I.guardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")}
+ let coords = ${T.offsetToIndices("global_idx")};
+ let batch = ${rt("coords",0,P.rank)};
+ let d2 = ${i?rt("coords",P.rank-1,P.rank):rt("coords",1,P.rank)};
+ let xFRCCorner = vec3(${i?rt("coords",1,P.rank):rt("coords",2,P.rank)},
+ ${i?rt("coords",2,P.rank):rt("coords",3,P.rank)},
+ ${i?rt("coords",3,P.rank):rt("coords",4,P.rank)}) * uniforms.strides - uniforms.pads;
+ let xFCorner = xFRCCorner.x;
+ let xRCorner = xFRCCorner.y;
+ let xCCorner = xFRCCorner.z;
+ let xShapeY = ${i?rt("uniforms.x_shape",1,P.rank):rt("uniforms.x_shape",2,P.rank)};
+ let xShapeZ = ${i?rt("uniforms.x_shape",2,P.rank):rt("uniforms.x_shape",3,P.rank)};
+ let xShapeW = ${i?rt("uniforms.x_shape",3,P.rank):rt("uniforms.x_shape",4,P.rank)};
+ let xShapeU = ${i?rt("uniforms.x_shape",4,P.rank):rt("uniforms.x_shape",1,P.rank)};
+ let inputDepthNearestVec4 = (xShapeU / 4) * 4;
+ let inputDepthVec4Remainder = xShapeU % 4;
+
+ var value = 0.0;
+ for (var wF = 0u; wF < uniforms.filter_dims[0]; wF++) {
+ let xF = xFCorner + wF * uniforms.dilations[0];
+ if (xF < 0 || xF >= xShapeY) {
+ continue;
+ }
+
+ for (var wR = 0u; wR < uniforms.filter_dims[1]; wR++) {
+ let xR = xRCorner + wR * uniforms.dilations[1];
+ if (xR < 0 || xR >= xShapeZ) {
+ continue;
+ }
+
+ for (var wC = 0u; wC < uniforms.filter_dims[2]; wC++) {
+ let xC = xCCorner + wC * uniforms.dilations[2];
+ if (xC < 0 || xC >= xShapeW) {
+ continue;
+ }
+
+ for (var d1 = 0u; d1 < inputDepthNearestVec4; d1 += 4) {
+ ${i?`let xValues = vec4(
+ getX(batch, xF, xR, xC, d1),
+ getX(batch, xF, xR, xC, d1 + 1),
+ getX(batch, xF, xR, xC, d1 + 2),
+ getX(batch, xF, xR, xC, d1 + 3));
+ `:`let xValues = vec4(
+ getX(batch, d1, xF, xR, xC),
+ getX(batch, d1 + 1, xF, xR, xC),
+ getX(batch, d1 + 2, xF, xR, xC),
+ getX(batch, d1 + 3, xF, xR, xC));
+ `}
+ let wValues = vec4(
+ getW(d2, d1, wF, wR, wC),
+ getW(d2, d1 + 1, wF, wR, wC),
+ getW(d2, d1 + 2, wF, wR, wC),
+ getW(d2, d1 + 3, wF, wR, wC));
+ value += dot(xValues, wValues);
+ }
+ if (inputDepthVec4Remainder == 1) {
+ ${i?`value += getX(batch, xF, xR, xC, inputDepthNearestVec4)
+ * getW(d2, inputDepthNearestVec4, wF, wR, wC);`:`value += getX(batch, inputDepthNearestVec4, xF, xR, xC)
+ * getW(d2, inputDepthNearestVec4, wF, wR, wC);`}
+ } else if (inputDepthVec4Remainder == 2) {
+ ${i?`let xValues = vec2(
+ getX(batch, xF, xR, xC, inputDepthNearestVec4),
+ getX(batch, xF, xR, xC, inputDepthNearestVec4 + 1));
+ `:`let xValues = vec2(
+ getX(batch, inputDepthNearestVec4, xF, xR, xC),
+ getX(batch, inputDepthNearestVec4 + 1, xF, xR, xC));
+ `}
+ let wValues = vec2(
+ getW(d2, inputDepthNearestVec4, wF, wR, wC),
+ getW(d2, inputDepthNearestVec4 + 1, wF, wR, wC));
+ value += dot(xValues, wValues);
+ } else if (inputDepthVec4Remainder == 3) {
+ ${i?`let xValues = vec3(
+ getX(batch, xF, xR, xC, inputDepthNearestVec4),
+ getX(batch, xF, xR, xC, inputDepthNearestVec4 + 1),
+ getX(batch, xF, xR, xC, inputDepthNearestVec4 + 2));
+ `:`let xValues = vec3(
+ getX(batch, inputDepthNearestVec4, xF, xR, xC),
+ getX(batch, inputDepthNearestVec4 + 1, xF, xR, xC),
+ getX(batch, inputDepthNearestVec4 + 2, xF, xR, xC));
+ `}
+ let wValues = vec3(
+ getW(d2, inputDepthNearestVec4, wF, wR, wC),
+ getW(d2, inputDepthNearestVec4 + 1, wF, wR, wC),
+ getW(d2, inputDepthNearestVec4 + 2, wF, wR, wC));
+ value += dot(xValues, wValues);
+ }
+ }
+ }
+ }
+ ${_?"value = value + getBiasByOutputCoords(coords)":""};
+ ${R}
+ result[global_idx] = f32(value);
+ }`};return{name:"Conv3DNaive",shaderCache:{hint:`${r.cacheKey};${i};${p};${_}`,inputDependencies:g},getRunData:()=>({outputs:[{dims:t,dataType:e[0].dataType}],dispatchGroup:{x:c[0],y:c[1],z:c[2]},programUniforms:h}),getShaderSource:E}}}),a0,l0,ox=Ve(()=>{mt(),bt(),xt(),Bn(),a0=(e,r,t,s)=>{let o=e.length>2,n=o?"value += b[output_channel];":"",i=e[0].dims,a=e[1].dims,l=r.format==="NHWC",c=l?t[3]:t[1],p=c/r.group,u=l&&p>=4?Jt(c):1,h=xe.size(t)/u,g=[{type:12,data:h},{type:12,data:r.dilations},{type:12,data:[r.strides[0],r.strides[1]]},{type:12,data:[r.pads[0],r.pads[1]]},{type:12,data:p}];Dn(r,g),g.push(...nt(i,[a[0],a[1],a[2],a[3]/u]));let _=o?["rank","rank","rank"]:["rank","rank"];g.push(...nt([t[0],t[1],t[2],t[3]/u]));let E=I=>{let M=tt("output",e[0].dataType,t.length,u),y=pr(M.type.tensor),$=On(r,M.type.value,y),P=$e("x",e[0].dataType,i.length),b=$e("w",e[1].dataType,a.length,u),w=[P,b];o&&w.push($e("b",e[2].dataType,e[2].dims,u));let T=[{name:"output_size",type:"u32"},{name:"dilations",type:"u32",length:r.dilations.length},{name:"strides",type:"u32",length:2},{name:"pads",type:"u32",length:2},{name:"output_channels_per_group",type:"u32"}];Ln(r,T);let k=l?`
+ for (var wHeight: u32 = 0u; wHeight < uniforms.w_shape[0]; wHeight++) {
+ let xHeight = xRCCorner.x + wHeight * uniforms.dilations[0];
+
+ if (xHeight < 0u || xHeight >= uniforms.x_shape[1]) {
+ continue;
+ }
+
+ for (var wWidth: u32 = 0u; wWidth < uniforms.w_shape[1]; wWidth++) {
+ let xWidth = xRCCorner.y + wWidth * uniforms.dilations[1];
+ if (xWidth < 0u || xWidth >= uniforms.x_shape[2]) {
+ continue;
+ }
+
+ for (var wInChannel: u32 = 0u; wInChannel < uniforms.w_shape[2]; wInChannel++) {
+ let input_channel = in_channel_offset + wInChannel;
+ let xVal = ${P.get("batch","xHeight","xWidth","input_channel")};
+ let wVal = ${b.get("wHeight","wWidth","wInChannel","output_channel")};
+ value += xVal * wVal;
+ }
+ }
+ }
+ `:`
+ for (var wInChannel: u32 = 0u; wInChannel < uniforms.w_shape[1]; wInChannel++) {
+ let input_channel = in_channel_offset + wInChannel;
+ for (var wHeight: u32 = 0u; wHeight < uniforms.w_shape[2]; wHeight++) {
+ let xHeight = xRCCorner.x + wHeight * uniforms.dilations[0];
+
+ if (xHeight < 0u || xHeight >= uniforms.x_shape[2]) {
+ continue;
+ }
+
+ for (var wWidth: u32 = 0u; wWidth < uniforms.w_shape[3]; wWidth++) {
+ let xWidth = xRCCorner.y + wWidth * uniforms.dilations[1];
+ if (xWidth < 0u || xWidth >= uniforms.x_shape[3]) {
+ continue;
+ }
+
+ let xVal = ${P.get("batch","input_channel","xHeight","xWidth")};
+ let wVal = ${b.get("output_channel","wInChannel","wHeight","wWidth")};
+ value += xVal * wVal;
+ }
+ }
+ }
+ `;return`
+ ${I.registerUniforms(T).declareVariables(...w,M)}
+
+ ${I.mainStart()}
+ ${I.guardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")}
+
+ let outputIndices = ${M.offsetToIndices("global_idx")};
+ let batch: u32 = outputIndices[0];
+ let output_channel: u32 = outputIndices[${l?3:1}];
+ let xRCCorner: vec2 = vec2(outputIndices[${l?1:2}], outputIndices[${l?2:3}]) * uniforms.strides - uniforms.pads;
+ let group_id: u32 = output_channel * ${u} / uniforms.output_channels_per_group;
+ var in_channel_offset = group_id * uniforms.w_shape[${l?2:1}];
+
+ var value: ${M.type.value} = ${M.type.value}(0);
+ ${k}
+ ${n}
+ ${$}
+ ${M.setByOffset("global_idx","value")}
+ }`};return{name:"GroupedConv",shaderCache:{hint:`${r.cacheKey}_${u}`,inputDependencies:_},getRunData:()=>({outputs:[{dims:s?s(t):t,dataType:e[0].dataType}],dispatchGroup:{x:Math.ceil(h/64)},programUniforms:g}),getShaderSource:E}},l0=(e,r,t,s)=>{let o=e.length>2,n=Jt(t[3]),i=Jt(t[2]),a=xe.size(t)/n/i,l=[e[0].dims[0],e[0].dims[1],e[0].dims[2],e[0].dims[3]/n],c=[e[1].dims[0],e[1].dims[1],e[1].dims[2],e[1].dims[3]/n],p=[t[0],t[1],t[2],t[3]/n],u=[{type:12,data:a},{type:6,data:[r.strides[0],r.strides[1]]},{type:6,data:[r.pads[0],r.pads[1]]}];Dn(r,u),u.push(...nt(l,c,p));let h=(i-1)*r.strides[1]+c[1],g=_=>{let E=tt("output",e[0].dataType,p.length,n),I=pr(E.type.tensor),M=On(r,E.type.value,I),y=$e("x",e[0].dataType,l.length,n),$=$e("w",e[1].dataType,c.length,n),P=[y,$];o&&P.push($e("b",e[2].dataType,e[2].dims,n));let b=o?"value += b[output_channel];":"",w=[{name:"output_size",type:"u32"},{name:"strides",type:"i32",length:2},{name:"pads",type:"i32",length:2}];return Ln(r,w),`
+ ${_.registerUniforms(w).declareVariables(...P,E)}
+ ${_.mainStart()}
+ ${_.guardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")}
+ let width0 = uniforms.output_shape[3];
+ let output_channel = global_idx % width0;
+ var index1 = global_idx / width0;
+ let width1 = uniforms.output_shape[2] / ${i}u;
+ let col = (index1 % width1) * ${i}u;
+ index1 = index1 / width1;
+ let row = index1 % uniforms.output_shape[1];
+ let batch = index1 / uniforms.output_shape[1];
+
+ let x_corner = vec2(i32(row), i32(col)) * uniforms.strides - uniforms.pads;
+
+ var x_vals: array<${y.type.value}, ${h}>;
+ var values: array<${E.type.value}, ${i}>;
+ let input_channel = output_channel;
+ // Use constant instead of uniform can give better performance for w's height/width.
+ for (var w_height: u32 = 0u; w_height < ${c[0]}; w_height++) {
+ let x_height = x_corner.x + i32(w_height);
+ if (x_height >= 0 && u32(x_height) < uniforms.x_shape[1]) {
+ for (var i = 0; i < ${h}; i++) {
+ let x_width = x_corner.y + i;
+ if (x_width >= 0 && u32(x_width) < uniforms.x_shape[2]) {
+ x_vals[i] = ${y.get("batch","u32(x_height)","u32(x_width)","input_channel")};
+ } else {
+ x_vals[i] = ${y.type.value}(0);
+ }
+ }
+ for (var w_width: u32 = 0u; w_width < ${c[1]}; w_width++) {
+ let w_val = ${$.get("w_height","w_width","0","output_channel")};
+ for (var i = 0u; i < ${i}u; i++) {
+ values[i] = fma(x_vals[i * u32(uniforms.strides[1]) + w_width], w_val, values[i]);
+ }
+ }
+ }
+ }
+
+ for (var i = 0u; i < ${i}u; i++) {
+ var value = values[i];
+ ${b}
+ ${M}
+ ${E.set("batch","row","col + i","output_channel","value")};
+ }
+ }`};return{name:"GroupedConv-Vectorize",shaderCache:{hint:`${r.cacheKey};${n};${i};${h};${c[0]};${c[1]}`,inputDependencies:o?["rank","rank","type"]:["rank","rank"]},getRunData:()=>({outputs:[{dims:s?s(t):t,dataType:e[0].dataType}],dispatchGroup:{x:Math.ceil(a/64)},programUniforms:u}),getShaderSource:g}}}),V_,Gl,U_,Kl,eu,mc,W_,G_,tu,ix=Ve(()=>{bt(),sx(),nx(),Cu(),ox(),Bn(),Pu(),cn(),V_=(e,r,t,s,o,n)=>{let i=e[0],a=e.slice(n?1:2,n?3:4),l=a.length,c=r[0],p=r.slice(2).map((h,g)=>h+(h-1)*(t[g]-1)),u=a.map((h,g)=>h+s[g]+s[g+l]).map((h,g)=>Math.floor((h-p[g]+o[g])/o[g]));return u.splice(0,0,i),u.splice(n?3:1,0,c),u},Gl=[2,3,1,0],U_=(e,r)=>{if(!e||e.length!==2&&e.length!==3)throw new Error("Conv requires 2 or 3 inputs");if(e[0].dims.length>5)throw new Error("greater than 5D is not supported");if(e[0].dims.length!==e[1].dims.length)throw new Error("filter does not have same dimension as input");let t=e[0].dims[r.format==="NHWC"?e[0].dims.length-1:1],s=e[1].dims[1]*r.group;if(t!==s)throw new Error("FILTER_IN_CHANNEL should be equal to DATA_CHANNEL");if(e.length===3&&(e[2].dims.length!==1||e[1].dims[0]!==e[2].dims[0]))throw new Error("invalid bias");let o=e[0].dims.length-2;if(r.dilations.length!==o)throw new Error(`dilations should be ${o}D`);if(r.strides.length!==o)throw new Error(`strides should be ${o}D`);if(r.pads.length!==o*2)throw new Error(`pads should be ${o*2}D`);if(r.kernelShape.length!==0&&r.kernelShape.length!==e[1].dims.length-2)throw new Error("invalid kernel shape")},Kl=(e,r)=>{let t=e.kernelShape.slice();t.length{let r=xu(e),t=e.format,s=["NOTSET","VALID","SAME_UPPER","SAME_LOWER"][e.auto_pad],o=e.dilations,n=e.group,i=e.kernel_shape,a=e.pads,l=e.strides,c=e.w_is_const();return{autoPad:s,format:t,dilations:o,group:n,kernelShape:i,pads:a,strides:l,wIsConst:c,...r,cacheKey:`${e.format};${r.activation};`}},mc=(e,r,t,s)=>{let o=t.format==="NHWC",n=V_(r[0].dims,r[1].dims,t.dilations,t.pads,t.strides,o);if(t.group!==1){let w=[r[0]];if(o){let T=e.kernelCustomData.wT??e.compute(Wr(r[1],Gl),{inputs:[1],outputs:[t.wIsConst?-2:-1]})[0];t.wIsConst&&!e.kernelCustomData.wT&&(e.kernelCustomData.wT=T),w.push(T)}else w.push(r[1]);r.length===3&&w.push(r[2]),!e.adapterInfo.isArchitecture("ampere")&&o&&r[1].dims[0]===t.group&&r[1].dims[1]===1&&t.dilations[0]===1&&t.dilations[1]===1?e.compute(l0(w,t,n,s),{inputs:w}):e.compute(a0(w,t,n,s),{inputs:w});return}let i=r.length===3,a=r[0].dims[o?1:2],l=r[0].dims[o?2:3],c=r[0].dims[o?3:1],p=r[1].dims[2],u=r[1].dims[3],h=n[o?1:2],g=n[o?2:3],_=n[o?3:1],E=o&&p===a&&u===l&&t.pads[0]===0&&t.pads[1]===0;if(E||p===1&&u===1&&t.dilations[0]===1&&t.dilations[1]===1&&t.strides[0]===1&&t.strides[1]===1&&t.pads[0]===0&&t.pads[1]===0){let w=n[0],T,k,z,R=[];if(o){let U=e.kernelCustomData.wT??e.compute(Wr(r[1],Gl),{inputs:[1],outputs:[t.wIsConst?-2:-1]})[0];if(t.wIsConst&&!e.kernelCustomData.wT&&(e.kernelCustomData.wT=U),E){let Z=a*l*c;T=r[0].reshape([1,w,Z]),k=U.reshape([1,Z,_]),z=[1,w,_]}else T=r[0].reshape([w,a*l,c]),k=U.reshape([1,c,_]),z=[w,h*g,_];R.push(T),R.push(k)}else T=r[0].reshape([w,c,a*l]),k=r[1].reshape([1,_,c]),z=[w,_,h*g],R.push(k),R.push(T);i&&R.push(r[2]);let Q=z[2],q=R[0].dims[R[0].dims.length-1];Q<8&&q<8?e.compute(Eu(R,t,n,z,o,s),{inputs:R}):e.compute(ad(R,t,n,z,o,s),{inputs:R});return}let I=!0,M=e.kernelCustomData.wT??e.compute(Wr(r[1],Gl),{inputs:[1],outputs:[t.wIsConst?-2:-1]})[0];t.wIsConst&&!e.kernelCustomData.wT&&(e.kernelCustomData.wT=M);let y=[r[0],M];i&&y.push(r[2]);let $=o?h*g:_,P=o?_:h*g,b=p*u*c;e.compute(n0(y,t,n,$,P,b,i,I,s),{inputs:y})},W_=(e,r)=>{let t=r.format==="NHWC",s=[e.inputs[0].reshape(t?[e.inputs[0].dims[0],1,e.inputs[0].dims[1],e.inputs[0].dims[2]]:[e.inputs[0].dims[0],e.inputs[0].dims[1],1,e.inputs[0].dims[2]]),e.inputs[1].reshape([e.inputs[1].dims[0],e.inputs[1].dims[1],1,e.inputs[1].dims[2]])];e.inputs.length===3&&s.push(e.inputs[2]);let o=[0,r.pads[0],0,r.pads[1]],n=[1].concat(r.strides),i=[1].concat(r.dilations),a=[1].concat(r.kernelShape),l=Kl({...r,pads:o,strides:n,dilations:i,kernelShape:a},s);mc(e,s,l,c=>t?[c[0],c[2],c[3]]:[c[0],c[1],c[3]])},G_=(e,r,t)=>{let s=t.format==="NHWC"?"channelsLast":"channelsFirst",o=Kl(t,r),n=t.autoPad==="NOTSET"?t.pads:t.autoPad,i=o0(r[0].dims,r[1].dims,t.strides,t.dilations,n,!1,s);e.compute(i0(r,o,i.outShape,[i.filterDepth,i.filterHeight,i.filterWidth],[i.padInfo.front,i.padInfo.top,i.padInfo.left],s))},tu=(e,r)=>{if(U_(e.inputs,r),e.inputs[0].dims.length===3)W_(e,r);else if(e.inputs[0].dims.length===5)G_(e,e.inputs,r);else{let t=Kl(r,e.inputs);mc(e,e.inputs,t)}}}),d0,ax=Ve(()=>{mt(),Us(),bt(),xt(),d0=(e,r,t)=>{let s=e.length>2,o=r.outputShape,n=r.format==="NHWC",i=r.group,a=e[1].dims,l=a[2]/i,c=a[3],p=n?Jt(l):1,u=n&&c===1&&l>=4,h=u?Math.floor(l/4)*4:Math.floor(l/p)*p,g=l-h,_=n?Jt(c):1,E=n?c===1?p:_:1,I=xe.size(o)/_,M=[Math.ceil(I/64),1,1];St("verbose",()=>`[conv2d_backprop_webgpu] dispatch = ${M}`);let y=["rank","rank"],$=[r.strides[0],r.strides[1]],P=[r.kernelShape[n?1:2],r.kernelShape[n?2:3]],b=[r.dilations[0],r.dilations[1]],w=[P[0]+(r.dilations[0]<=1?0:(r.kernelShape[n?1:2]-1)*(r.dilations[0]-1)),P[1]+(r.dilations[1]<=1?0:(r.kernelShape[n?2:3]-1)*(r.dilations[1]-1))],T=[w[0]-1-Math.floor((r.pads[0]+r.pads[2])/2),w[1]-1-Math.floor((r.pads[1]+r.pads[3])/2)],k=[{type:12,data:I},{type:12,data:$},{type:12,data:P},{type:12,data:b},{type:12,data:w},{type:6,data:T},{type:12,data:h},{type:12,data:l},{type:12,data:c},...nt(e[0].dims,e[1].dims)];s&&(k.push(...nt(e[2].dims)),y.push("rank")),k.push(...nt(o));let z=R=>{let Q=[{name:"output_size",type:"u32"},{name:"strides",type:"u32",length:$.length},{name:"filter_dims",type:"u32",length:P.length},{name:"dilations",type:"u32",length:P.length},{name:"effective_filter_dims",type:"u32",length:w.length},{name:"pads",type:"i32",length:T.length},{name:"input_channels_per_group_int",type:"u32"},{name:"input_channels_per_group",type:"u32"},{name:"output_channels_per_group",type:"u32"}],q=pr(e[0].dataType),U=n?1:2,Z=n?2:3,H=n?3:1,J=$e("W",e[1].dataType,e[1].dims.length,E),oe=$e("Dy",e[0].dataType,e[0].dims.length,p),ae=[oe,J];s&&ae.push($e("bias",e[2].dataType,[o[H]].length,_));let ce=tt("result",e[0].dataType,o.length,_),he=()=>{let G="";if(u)p===4?G+=`
+ let xValue = ${oe.getByOffset("x_offset")};
+ let wValue = ${J.getByOffset("w_offset")};
+ dotProd = dotProd + dot(xValue, wValue);
+ x_offset += 1u;
+ w_offset += 1u;`:p===2?G+=`
+ dotProd = dotProd + dot(vec4<${q}>(${oe.getByOffset("x_offset")}, ${oe.getByOffset("x_offset + 1u")}), vec4<${q}>(${J.getByOffset("w_offset")}, ${J.getByOffset("w_offset + 1u")}));
+ x_offset += 2u;
+ w_offset += 2u;`:p===1&&(G+=`
+ dotProd = dotProd + dot(vec4<${q}>(${oe.getByOffset("x_offset")}, ${oe.getByOffset("x_offset + 1u")}, ${oe.getByOffset("x_offset + 2u")}, ${oe.getByOffset("x_offset + 3u")}), vec4<${q}>(${J.getByOffset("w_offset")}, ${J.getByOffset("w_offset + 1u")}, ${J.getByOffset("w_offset + 2u")}, ${J.getByOffset("w_offset + 3u")}));
+ x_offset += 4u;
+ w_offset += 4u;`);else if(G+=`
+ let xValue = ${n?oe.getByOffset(`${oe.indicesToOffset(`${oe.type.indices}(batch, idyR, idyC, inputChannel)`)} / ${p}`):oe.get("batch","inputChannel","idyR","idyC")};
+ `,p===1)G+=`
+ let w_offset = ${J.indicesToOffset(`${J.type.indices}(u32(wRPerm), u32(wCPerm), inputChannel, wOutChannel)`)};
+ let wValue = ${J.getByOffset(`w_offset / ${E}`)};
+ dotProd = dotProd + xValue * wValue;`;else for(let se=0;se{if(g===0)return"";if(!u)throw new Error(`packInputAs4 ${u} is not true.`);let G="";if(p===1){G+="dotProd = dotProd";for(let se=0;se(i32(r), i32(c)) - uniforms.pads;
+ let dyRCorner = dyCorner.x;
+ let dyCCorner = dyCorner.y;
+ let groupId = d1 / uniforms.output_channels_per_group;
+ let wOutChannel = d1 - groupId * uniforms.output_channels_per_group;
+ // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1).
+ // ? = to be determined. : = across all values in that axis.
+ var dotProd = ${ce.type.value}(0.0);
+ var wR: u32 = 0;
+ if (uniforms.dilations.x == 1) {
+ // Minimum wR >= 0 that satisfies (dyRCorner + wR) % (uniforms.strides.x) == 0
+ wR = u32(((dyRCorner + i32(uniforms.strides.x) - 1) / i32(uniforms.strides.x)) * i32(uniforms.strides.x) - dyRCorner);
+ }
+ for (; wR < uniforms.effective_filter_dims.x; wR = wR + 1) {
+ if (wR % uniforms.dilations.x != 0) {
+ continue;
+ }
+ let dyR = (${q}(dyRCorner) + ${q}(wR)) / ${q}(uniforms.strides[0]);
+ let wRPerm = uniforms.filter_dims.x - 1 - wR / uniforms.dilations.x;
+ if (dyR < 0.0 || dyR >= ${q}(uniforms.Dy_shape[${U}]) || fract(dyR) > 0.0 ||
+ wRPerm < 0) {
+ continue;
+ }
+ let idyR: u32 = u32(dyR);
+ var wC: u32 = 0;
+ if (uniforms.dilations.y == 1) {
+ // Minimum wC >= 0 that satisfies (dyCCorner + wC) % (uniforms.strides.y) == 0
+ wC = u32(((dyCCorner + i32(uniforms.strides.y) - 1) / i32(uniforms.strides.y)) * i32(uniforms.strides.y) - dyCCorner);
+ }
+ for (; wC < uniforms.effective_filter_dims.y; wC = wC + 1) {
+ if (wC % uniforms.dilations.y != 0) {
+ continue;
+ }
+ let dyC = (${q}(dyCCorner) + ${q}(wC)) / ${q}(uniforms.strides.y);
+ let wCPerm = uniforms.filter_dims.y - 1 - wC / uniforms.dilations.y;
+ if (dyC < 0.0 || dyC >= ${q}(uniforms.Dy_shape[${Z}]) ||
+ fract(dyC) > 0.0 || wCPerm < 0) {
+ continue;
+ }
+ let idyC: u32 = u32(dyC);
+ var inputChannel = groupId * uniforms.input_channels_per_group;
+ ${u?`
+ var x_offset = ${oe.indicesToOffset(`${oe.type.indices}(batch, idyR, idyC, inputChannel)`)} / ${p};
+ var w_offset = ${J.indicesToOffset(`${J.type.indices}(wRPerm, wCPerm, inputChannel, wOutChannel)`)} / ${E};
+ `:""}
+ for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group_int; d2 = d2 + ${u?4:p}) {
+ ${he()}
+ inputChannel = inputChannel + ${u?4:p};
+ }
+ ${N()}
+ wC = wC + uniforms.strides.y - 1;
+ }
+ wR = wR + uniforms.strides[0] - 1;
+ }
+ let value = dotProd${s?` + bias[d1 / ${_}]`:""};
+ ${ce.setByOffset("global_idx","value")};
+ `;return`
+ ${R.registerUniforms(Q).declareVariables(...ae,ce)}
+ ${R.mainStart()}
+ ${R.guardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")};
+ ${O}}`};return{name:"ConvTranspose2D",shaderCache:{hint:`${r.cacheKey};${p}${E}${_}${u}${g}`,inputDependencies:y},getRunData:()=>({dispatchGroup:{x:M[0],y:M[1],z:M[2]},outputs:[{dims:t?t(o):o,dataType:e[0].dataType}],programUniforms:k}),getShaderSource:z}}}),K_,H_,q_,fc,c0,Q_,_c,X_,u0,lx=Ve(()=>{ax(),Bn(),cn(),K_=(e,r,t,s,o,n)=>(e-1)*r+t+(s-1)*o+1-n,H_=(e,r,t,s,o)=>{let n=Math.floor(e/2);r==="SAME_UPPER"?(t[s]=n,t[o]=e-n):r==="SAME_LOWER"&&(t[s]=e-n,t[o]=n)},q_=(e,r,t,s,o,n,i,a,l,c)=>{let p=e.length-2,u=c.length===0;l.length{let t=e.kernelShape.slice();if(e.kernelShape.length===0||e.kernelShape.reduce((u,h)=>u*h,1)===0){t.length=0;for(let u=2;uu+h,0)===0){let u=r[0].dims.length-2;l=new Array(u).fill(1)}let c=e.strides.slice();if(c.reduce((u,h)=>u+h,0)===0){let u=r[0].dims.length-2;c=new Array(u).fill(1)}q_(a,t,l,e.autoPad,e.group,o,c,s,i,n);let p=Object.assign({},e);return Object.assign(p,{kernelShape:t,pads:o,outputPadding:i,outputShape:n,dilations:l,strides:c}),p},c0=e=>{let r=xu(e),t=e.format,s=["NOTSET","VALID","SAME_UPPER","SAME_LOWER"][typeof e.autoPad>"u"?0:e.autoPad],o=e.dilations,n=e.group,i=e.kernelShape,a=e.pads,l=e.strides,c=e.wIsConst(),p=e.outputPadding,u=e.outputShape;return{autoPad:s,format:t,dilations:o,group:n,kernelShape:i,outputPadding:p,outputShape:u,pads:a,strides:l,wIsConst:c,...r,cacheKey:`${e.format};${r.activation};`}},Q_=(e,r)=>{if(!e||e.length!==2&&e.length!==3)throw new Error("Conv requires 2 or 3 inputs");if(e[0].dims.length!==4&&e[0].dims.length!==3)throw new Error("currently only support 2-dimensional conv");if(e[0].dims.length!==e[1].dims.length)throw new Error("filter does not have same dimension as input");let t=e[0].dims[r.format==="NHWC"?e[0].dims.length-1:1],s=e[1].dims[0];if(t!==s)throw new Error("FILTER_IN_CHANNEL should be equal to DATA_CHANNEL");let o=e[1].dims[1]*r.group;if(e.length===3&&(e[2].dims.length!==1||e[2].dims[0]!==o))throw new Error("invalid bias");let n=e[0].dims.length-2;if(r.dilations.reduce((i,a)=>i+a,0)>0&&r.dilations.length!==n)throw new Error(`dilations should be ${n}D`);if(r.strides.reduce((i,a)=>i+a,0)>0&&r.strides.length!==n)throw new Error(`strides should be ${n}D`);if(r.pads.reduce((i,a)=>i+a,0)>0&&r.pads.length!==n*2)throw new Error(`pads should be ${n*2}D`);if(r.outputPadding.length!==n&&r.outputPadding.length!==0)throw new Error(`output_padding should be ${n}D`);if(r.kernelShape.reduce((i,a)=>i+a,0)>0&&r.kernelShape.length!==0&&r.kernelShape.length!==e[1].dims.length-2)throw new Error("invalid kernel shape");if(r.outputShape.length!==0&&r.outputShape.length!==e[0].dims.length-2)throw new Error("invalid output shape")},_c=(e,r,t,s)=>{let o=e.kernelCustomData.wT??e.compute(Wr(r[1],[2,3,0,1]),{inputs:[1],outputs:[t.wIsConst?-2:-1]})[0];t.wIsConst&&!e.kernelCustomData.wT&&(e.kernelCustomData.wT=o);let n=[r[0],o];r.length===3&&n.push(r[2]),e.compute(d0(n,t,s),{inputs:n})},X_=(e,r)=>{let t=r.format==="NHWC",s=[e.inputs[0].reshape(t?[e.inputs[0].dims[0],1,e.inputs[0].dims[1],e.inputs[0].dims[2]]:[e.inputs[0].dims[0],e.inputs[0].dims[1],1,e.inputs[0].dims[2]]),e.inputs[1].reshape([e.inputs[1].dims[0],e.inputs[1].dims[1],1,e.inputs[1].dims[2]])];e.inputs.length===3&&s.push(e.inputs[2]);let o=r.kernelShape;(o.length===0||o[0]===0)&&(o=[e.inputs[1].dims[2]]);let n=r.dilations;(n.length===0||n[0]===0)&&(n=[1]);let i=r.strides;(i.length===0||i[0]===0)&&(i=[1]);let a=r.pads;a.length===0&&(a=[0,0]),a=[0,a[0],0,a[1]],i=[1].concat(i),n=[1].concat(n),o=[1].concat(o);let l=r.outputPadding;l=[0].concat(l);let c=fc({...r,pads:a,strides:i,dilations:n,kernelShape:o,outputPadding:l},s);_c(e,s,c,p=>t?[p[0],p[2],p[3]]:[p[0],p[1],p[3]])},u0=(e,r)=>{if(Q_(e.inputs,r),e.inputs[0].dims.length===3)X_(e,r);else{let t=fc(r,e.inputs);_c(e,e.inputs,t)}}}),J_,p0,h0,dx=Ve(()=>{mt(),bt(),tr(),xt(),J_=(e,r,t,s)=>{let o=xe.size(r),n=r.length,i=$e("input",e,n),a=tt("output",e,n),l=t.dataType===6?t.getInt32Array()[0]:Number(t.getBigInt64Array()[0]),c=xe.normalizeAxis(l,n),p=u=>{let h=` i32(${i.indicesGet("inputIndices","uniforms.axis")}) `,g=rt("uniforms.input_shape","uniforms.axis",n),_=s.reverse?h+(s.exclusive?" + 1":""):"0",E=s.reverse?g:h+(s.exclusive?"":" + 1");return`
+ ${u.registerUniform("outputSize","u32").registerUniform("axis","u32").declareVariables(i,a)}
+ ${u.mainStart()}
+ ${u.guardAgainstOutOfBoundsWorkgroupSizes("uniforms.outputSize")}
+ var inputIndices = ${a.offsetToIndices("global_idx")};
+ var sum = ${a.type.value}(0);
+ let first : i32 = ${_};
+ let last : i32 = ${E};
+ for (var i : i32 = first; i < last; i++) {
+ ${i.indicesSet("inputIndices","uniforms.axis","u32(i)")};
+ sum = sum + ${i.getByIndices("inputIndices")};
+ }
+ ${a.setByOffset("global_idx","sum")};
+ }`};return{name:"CumSum",shaderCache:{hint:s.cacheKey,inputDependencies:["rank"]},getRunData:()=>({outputs:[{dims:r,dataType:e}],dispatchGroup:{x:Math.ceil(o/64)},programUniforms:[{type:12,data:o},{type:12,data:c},...nt(r,r)]}),getShaderSource:p}},p0=(e,r)=>{let t=e.inputs[0].dims,s=e.inputs[0].dataType,o=e.inputs[1];e.compute(J_(s,t,o,r),{inputs:[0]})},h0=e=>{let r=e.exclusive===1,t=e.reverse===1;return Lt({exclusive:r,reverse:t})}}),Y_,Z_,eg,m0,f0,cx=Ve(()=>{mt(),bt(),tr(),xt(),Y_=e=>{if(!e||e.length!==1)throw new Error("DepthToSpace requires 1 input.");if(e[0].dims.length!==4)throw new Error("DepthToSpace requires 4D input.")},Z_=(e,r,t,s)=>{let o=[];o.push(`fn perm(i: ${s.type.indices}) -> ${t.type.indices} {
+ var a: ${t.type.indices};`);for(let n=0;n{let t,s,o,n,i,a,l=r.format==="NHWC",c=r.blocksize,p=r.mode==="DCR";l?([t,s,o,n]=e.dims,i=p?[t,s,o,c,c,n/c**2]:[t,s,o,n/c**2,c,c],a=p?[0,1,3,2,4,5]:[0,1,4,2,5,3]):([t,s,o,n]=[e.dims[0],e.dims[2],e.dims[3],e.dims[1]],i=p?[t,c,c,n/c**2,s,o]:[t,n/c**2,c,c,s,o],a=p?[0,3,4,1,5,2]:[0,1,4,2,5,3]);let u=e.reshape(i),h=u.dims.length,g=e.dataType,_=$e("a",g,h),E=tt("output",g,h),I=M=>`
+ ${M.registerUniform("output_size","u32").declareVariables(_,E)}
+
+ ${Z_(a,h,_,E)}
+
+ ${M.mainStart()}
+ ${M.guardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")}
+
+ let indices = ${E.offsetToIndices("global_idx")};
+ let aIndices = perm(indices);
+
+ ${E.setByOffset("global_idx",_.getByIndices("aIndices"))}
+ }`;return{name:"DepthToSpace",shaderCache:{hint:`${e.dims};${r.blocksize};${r.mode}`,inputDependencies:["rank"]},getRunData:M=>{let y=l?[t,s*c,o*c,n/c**2]:[t,n/c**2,s*c,o*c],$=xe.size(y),P=u.dims,b=xe.sortBasedOnPerm(P,a);return{outputs:[{dims:y,dataType:M[0].dataType}],dispatchGroup:{x:Math.ceil($/64)},programUniforms:[{type:12,data:$},...nt(P,b)]}},getShaderSource:I}},m0=(e,r)=>{Y_(e.inputs),e.compute(eg(e.inputs[0],r))},f0=e=>Lt({blocksize:e.blocksize,mode:e.mode,format:e.format})}),Hl,ra,gc,tg,rg,sg,ng,wc,og,_0,g0,ux=Ve(()=>{mt(),bt(),tr(),xt(),Hl="[a-zA-Z]|\\.\\.\\.",ra="("+Hl+")+",gc="^"+ra+"$",tg="("+ra+",)*"+ra,rg="^"+tg+"$",sg=class{constructor(e=-1){this.symbolToIndices=new Map,this.inputIndex=e}addSymbol(e,r){let t=this.symbolToIndices.get(e);t===void 0?t=[r]:t.push(r),this.symbolToIndices.set(e,t)}},ng=class{constructor(e,r){var o;this.equation=r,this.hasEllipsis=!1,this.symbolToInfo=new Map,this.lhs=new Array,this.outputDims=[];let[t,s]=r.includes("->")?r.split("->",2):[r,""];if(!t.match(RegExp(rg)))throw new Error("Invalid LHS term");if(t.split(",").forEach((n,i)=>{let a=e[i].dims.slice();if(!n.match(RegExp(gc)))throw new Error("Invalid LHS term");let l=this.processTerm(n,!0,a,i);this.lhs.push(l)}),s==="")s+=[...this.symbolToInfo.entries()].filter(([n,i])=>i.count===1||n==="...").map(([n])=>n).join("");else if(!s.match(RegExp(ra)))throw new Error("Invalid RHS");(o=s.match(RegExp(Hl,"g")))==null||o.forEach(n=>{if(n==="...")this.outputDims=this.outputDims.concat(this.ellipsisDims);else{let i=this.symbolToInfo.get(n);if(i===void 0)throw new Error("Invalid RHS symbol");this.outputDims.push(i.dimValue)}}),this.rhs=this.processTerm(s,!1,this.outputDims)}addSymbol(e,r,t){let s=this.symbolToInfo.get(e);if(s!==void 0){if(s.dimValue!==r&&s.count!==1)throw new Error("Dimension mismatch");s.count++,s.inputIndices.push(t)}else s={count:1,dimValue:r,inputIndices:[t]};this.symbolToInfo.set(e,s)}processTerm(e,r,t,s=-1){let o=t.length,n=!1,i=[],a=0;if(!e.match(RegExp(gc))&&!r&&e!=="")throw new Error("Invalid LHS term");let l=e.match(RegExp(Hl,"g")),c=new sg(s);return l==null||l.forEach((p,u)=>{if(p==="..."){if(n)throw new Error("Only one ellipsis is allowed per input term");n=!0;let h=o-l.length+1;if(h<0)throw new Error("Ellipsis out of bounds");if(i=t.slice(a,a+h),this.hasEllipsis){if(this.ellipsisDims.length!==i.length||this.ellipsisDims.toString()!==i.toString())throw new Error("Ellipsis dimensions mismatch")}else if(r)this.hasEllipsis=!0,this.ellipsisDims=i;else throw new Error("Ellipsis must be specified in the LHS");for(let g=0;ge+"_max",og=(e,r,t,s)=>{let o=e.map(c=>c.length).map((c,p)=>$e(`input${p}`,r,c)),n=xe.size(s),i=tt("output",r,s.length),a=[...t.symbolToInfo.keys()].filter(c=>!t.rhs.symbolToIndices.has(c)),l=c=>{let p=[],u="var prod = 1.0;",h="var sum = 0.0;",g="sum += prod;",_=[],E=[],I=[],M=[],y=t.symbolToInfo.size===t.rhs.symbolToIndices.size;t.symbolToInfo.forEach((P,b)=>{var w;if(t.rhs.symbolToIndices.has(b)){let T=(w=t.rhs.symbolToIndices.get(b))==null?void 0:w[0];T!==void 0&&t.lhs.forEach((k,z)=>{if(P.inputIndices.includes(z)){let R=k.symbolToIndices.get(b);if(R===void 0)throw new Error("Invalid symbol error");R.forEach(Q=>{p.push(`${o[z].indicesSet(`input${z}Indices`,Q,i.indicesGet("outputIndices",T))}`)})}})}else t.lhs.forEach((T,k)=>{if(P.inputIndices.includes(k)){let z=T.symbolToIndices.get(b);if(z===void 0)throw new Error("Invalid symbol error");z.forEach(R=>{_.push(`${o[k].indicesSet(`input${k}Indices`,R,`${b}`)}`)}),M.push(`prod *= ${o[k].getByIndices(`input${k}Indices`)};`)}}),E.push(`for(var ${b}: u32 = 0; ${b} < uniforms.${wc(b)}; ${b}++) {`),I.push("}")});let $=y?[...p,`let sum = ${o.map((P,b)=>P.getByIndices(`input${b}Indices`)).join(" * ")};`]:[...p,h,...E,..._,u,...M,g,...I];return`
+ ${c.registerUniforms(a.map(P=>({name:`${wc(P)}`,type:"u32"}))).registerUniform("outputSize","u32").declareVariables(...o,i)}
+
+ ${c.mainStart()}
+ ${c.guardAgainstOutOfBoundsWorkgroupSizes("uniforms.outputSize")}
+ var outputIndices = ${i.offsetToIndices("global_idx")};
+ ${o.map((P,b)=>`var input${b}Indices: ${o[b].type.indices};`).join(`
+`)}
+ ${$.join(`
+`)};
+ ${i.setByOffset("global_idx","sum")};
+ }`};return{name:"Einsum",shaderCache:{hint:t.equation,inputDependencies:e.map(()=>"rank")},getRunData:()=>{let c=a.filter(u=>t.symbolToInfo.has(u)).map(u=>{var h;return{type:12,data:((h=t.symbolToInfo.get(u))==null?void 0:h.dimValue)||0}});c.push({type:12,data:n});let p=e.map((u,h)=>[...nt(u)]).reduce((u,h)=>u.concat(h),c);return p.push(...nt(s)),{outputs:[{dims:s,dataType:r}],dispatchGroup:{x:Math.ceil(n/64)},programUniforms:p}},getShaderSource:l}},_0=(e,r)=>{let t=new ng(e.inputs,r.equation),s=t.outputDims,o=e.inputs.map((n,i)=>n.dims);e.compute(og(o,e.inputs[0].dataType,t,s))},g0=e=>{let r=e.equation.replace(/\s+/g,"");return Lt({equation:r})}}),ig,yc,ag,lg,w0,px=Ve(()=>{mt(),bt(),xt(),ig=e=>{if(!e||e.length!==2)throw new Error("Expand requires 2 input.");let r=e[0].dims,t=Array.from(e[1].getBigInt64Array(),Number),s=t.length{let t=e.length-r.length,s=[];for(let o=0;oe.length>r.length?yc(e,r):yc(r,e),lg=e=>{let r=e[0].dims,t=Array.from(e[1].getBigInt64Array(),Number),s=ag(r,t),o=e[0].dataType,n=o===9||xe.size(r)===1,i=o===9||r.length>0&&r[r.length-1]%4===0?4:1,a=n||s.length>0&&s[s.length-1]%4===0?4:1,l=Math.ceil(xe.size(s)/a),c=u=>{let h=$e("input",o,r.length,i),g=tt("output",o,s.length,a),_;if(o===9){let E=(I,M,y="")=>`
+ let outputIndices${M} = ${g.offsetToIndices(`outputOffset + ${M}u`)};
+ let offset${M} = ${h.broadcastedIndicesToOffset(`outputIndices${M}`,g)};
+ let index${M} = offset${M} / 4u;
+ let component${M} = offset${M} % 4u;
+ ${I}[${M}] = ${y}(${h.getByOffset(`index${M}`)}[component${M}]);
+ `;_=`
+ let outputOffset = global_idx * ${a};
+ var data = vec4(0);
+ ${E("data",0,"u32")}
+ ${E("data",1,"u32")}
+ ${E("data",2,"u32")}
+ ${E("data",3,"u32")}
+ ${g.setByOffset("global_idx","data")}
+ }`}else _=`
+ let outputIndices = ${g.offsetToIndices(`global_idx * ${a}`)};
+ let inputOffset = ${h.broadcastedIndicesToOffset("outputIndices",g)};
+ let data = ${g.type.value}(${h.getByOffset(`inputOffset / ${i}`)});
+ ${g.setByOffset("global_idx","data")}
+ }`;return`
+ ${u.registerUniform("vec_size","u32").declareVariables(h,g)}
+ ${u.mainStart()}
+ ${u.guardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size")}
+ ${_}`},p=[{type:12,data:l},...nt(r,s)];return{name:"Expand",shaderCache:{hint:`${s.length};${i}${a}`,inputDependencies:["rank"]},getShaderSource:c,getRunData:()=>({outputs:[{dims:s,dataType:e[0].dataType}],dispatchGroup:{x:Math.ceil(l/64)},programUniforms:p})}},w0=e=>{ig(e.inputs),e.compute(lg(e.inputs),{inputs:[0]})}}),dg,y0,hx=Ve(()=>{mt(),bt(),xt(),vu(),dg=e=>{let r=e[0].dataType,t=xe.size(e[0].dims),s=xe.size(e[1].dims),o=s%4===0,n=i=>{let a=$e("x",r,[1],4),l=$e("bias",r,[1],4),c=tt("y",r,[1],4),p=[{name:"output_vec_size",type:"u32"},{name:"bias_size",type:"u32"}],u=g=>`
+ let bias${g}_offset: u32 = (global_idx * 4 + ${g}) % uniforms.bias_size;
+ let bias${g} = ${l.getByOffset(`bias${g}_offset / 4`)}[bias${g}_offset % 4];`,h=o?`
+ let bias = ${l.getByOffset("global_idx % (uniforms.bias_size / 4)")};`:`${u(0)}${u(1)}${u(2)}${u(3)}
+ let bias = ${a.type.value}(bias0, bias1, bias2, bias3);`;return`${i.registerUniforms(p).declareVariables(a,l,c)}
+
+ ${Xc(Cr(r))}
+
+ ${i.mainStart($o)}
+ ${i.guardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_vec_size")}
+
+ let x = ${a.getByOffset("global_idx")};
+ ${h}
+ let x_in = x + bias;
+ ${c.setByOffset("global_idx",Jc("x_in"))}
+ }`};return{name:"FastGeluWithBias",shaderCache:{hint:`${o}`,inputDependencies:["type","type"]},getShaderSource:n,getRunData:i=>({outputs:[{dims:i[0].dims,dataType:i[0].dataType}],programUniforms:[{type:12,data:Math.ceil(t/4)},{type:12,data:s}],dispatchGroup:{x:Math.ceil(t/$o/4)}})}},y0=e=>{e.inputs.length<2||xe.size(e.inputs[1].dims)===0?RM(e):e.compute(dg(e.inputs))}}),cg,ug,M0,b0,mx=Ve(()=>{mt(),bt(),tr(),xt(),cg=e=>{if(!e||e.length!==2)throw new Error("Gather requires 2 inputs.")},ug=(e,r)=>{let t=e[0].dims,s=e[1].dims,o=t.length,n=xe.normalizeAxis(r.axis,o),i=t.slice(0);i.splice(n,1,...s);let a=t[n],l=e[0].dataType===9?4:1,c=Math.ceil(xe.size(i)/l),p=[{type:12,data:c},{type:6,data:a},{type:12,data:n},...nt(e[0].dims,e[1].dims,i)],u=h=>{let g=$e("data",e[0].dataType,e[0].dims.length,l),_=$e("inputIndices",e[1].dataType,e[1].dims.length),E=tt("output",e[0].dataType,i.length,l),I=y=>{let $=s.length,P=`var indicesIndices${y} = ${_.type.indices}(0);`;for(let b=0;b<$;b++)P+=`${$>1?`indicesIndices${y}[${b}]`:`indicesIndices${y}`} = ${i.length>1?`outputIndices${y}[uniforms.axis + ${b}]`:`outputIndices${y}`};`;P+=`
+ var idx${y} = ${_.getByIndices(`indicesIndices${y}`)};
+ if (idx${y} < 0) {
+ idx${y} = idx${y} + uniforms.axisDimLimit;
+ }
+ var dataIndices${y} : ${g.type.indices};
+ `;for(let b=0,w=0;b1?`dataIndices${y}[${b}]`:`dataIndices${y}`} = u32(idx${y});`,w+=$):(P+=`${o>1?`dataIndices${y}[${b}]`:`dataIndices${y}`} = ${i.length>1?`outputIndices${y}[${w}]`:`outputIndices${y}`};`,w++);return P},M;if(e[0].dataType===9){let y=($,P,b="")=>`
+ let outputIndices${P} = ${E.offsetToIndices(`outputOffset + ${P}u`)};
+ ${I(P)};
+ let offset${P} = ${g.indicesToOffset(`dataIndices${P}`)};
+ let index${P} = offset${P} / 4u;
+ let component${P} = offset${P} % 4u;
+ ${$}[${P}] = ${b}(${g.getByOffset(`index${P}`)}[component${P}]);
+ `;M=`
+ let outputOffset = global_idx * ${l};
+ var value = vec4(0);
+ ${y("value",0,"u32")}
+ ${y("value",1,"u32")}
+ ${y("value",2,"u32")}
+ ${y("value",3,"u32")}
+ ${E.setByOffset("global_idx","value")}
+ `}else M=`
+ let outputIndices = ${E.offsetToIndices("global_idx")};
+ ${I("")};
+ let value = ${g.getByIndices("dataIndices")};
+ ${E.setByOffset("global_idx","value")};
+ `;return`
+ ${h.registerUniform("outputSize","u32").registerUniform("axisDimLimit","i32").registerUniform("axis","u32").declareVariables(g,_,E)}
+ ${h.mainStart()}
+ ${h.guardAgainstOutOfBoundsWorkgroupSizes("uniforms.outputSize")}
+ ${M}
+ }`};return{name:"Gather",shaderCache:{hint:r.cacheKey,inputDependencies:["rank","rank"]},getRunData:()=>({outputs:[{dims:i,dataType:e[0].dataType}],dispatchGroup:{x:Math.ceil(c/64)},programUniforms:p}),getShaderSource:u}},M0=e=>Lt({axis:e.axis}),b0=(e,r)=>{let t=e.inputs;cg(t),e.compute(ug(e.inputs,r))}}),pg,v0,x0,fx=Ve(()=>{mt(),bt(),xt(),pg=(e,r,t,s,o,n,i,a,l)=>{let c=[{type:12,data:n},{type:12,data:s},{type:12,data:o},{type:12,data:t},{type:12,data:i},{type:12,data:a},{type:12,data:l}],p=[n];c.push(...nt(r.dims,p));let u=h=>{let g=$e("indices_data",r.dataType,r.dims.length),_=tt("input_slice_offsets_data",12,1,1),E=[g,_],I=[{name:"output_size",type:"u32"},{name:"batch_dims",type:"u32"},{name:"input_dims",type:"u32",length:o.length},{name:"sizes_from_slice_dims_data",type:"u32",length:t.length},{name:"num_slices_per_batch",type:"u32"},{name:"input_batch_stride",type:"u32"},{name:"num_slice_dims",type:"u32"}];return`
+ ${h.registerUniforms(I).declareVariables(...E)}
+ ${h.mainStart()}
+ ${h.guardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")}
+ let batch_idx = global_idx / uniforms.num_slices_per_batch;
+ let base_offset = batch_idx * uniforms.input_batch_stride;
+
+ let slice_indices_base_offset = global_idx * uniforms.num_slice_dims;
+ var relative_slice_offset = 0;
+ for (var dim_idx = 0u; dim_idx < uniforms.num_slice_dims; dim_idx ++) {
+ var index = i32(indices_data[dim_idx + slice_indices_base_offset].x);
+ let input_dim_idx = uniforms.batch_dims + dim_idx;
+ if (index < 0) {
+ ${o.length===1?"index += i32(uniforms.input_dims);":"index += i32(uniforms.input_dims[input_dim_idx]);"}
+ }
+ ${t.length===1?"relative_slice_offset += index * i32(uniforms.sizes_from_slice_dims_data);":"relative_slice_offset += index * i32(uniforms.sizes_from_slice_dims_data[dim_idx]);"}
+ }
+
+ input_slice_offsets_data[global_idx] = base_offset + u32(relative_slice_offset);
+ }`};return e.compute({name:"computeSliceOffsets",shaderCache:{hint:`${o.length}_${t.length}`,inputDependencies:["rank"]},getRunData:()=>({outputs:[{dims:p,dataType:e.inputs[1].dataType}],dispatchGroup:{x:Math.ceil(n/64)},programUniforms:c}),getShaderSource:u},{inputs:[r],outputs:[-1]})[0]},v0=(e,r)=>{let t=e.inputs,s=t[0].dims,o=t[0].dataType,n=t[1].dims,i=n[n.length-1],a=xe.sizeToDimension(n,n.length-1),l=xe.sizeFromDimension(s,r.batchDims+i),c=xe.sizeToDimension(s,r.batchDims),p=xe.sizeFromDimension(s,r.batchDims),u=a/c,h=new Array(i),g=l;for(let P=0;Ps.length)throw new Error("last dimension of indices must not be larger than rank of input tensor");let I=n.slice(0,-1).concat(s.slice(E)),M=xe.size(I),y=[{type:12,data:M},{type:12,data:l},...nt(t[0].dims,_.dims,I)],$=P=>{let b=$e("data",t[0].dataType,t[0].dims.length),w=$e("slice_offsets",12,_.dims.length),T=tt("output",t[0].dataType,I.length);return`
+ ${P.registerUniform("output_size","u32").registerUniform("slice_size","u32").declareVariables(b,w,T)}
+ ${P.mainStart()}
+ ${P.guardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")}
+ let slice_offset = slice_offsets[global_idx / uniforms.slice_size];
+ output[global_idx] = data[u32(slice_offset) + global_idx % uniforms.slice_size];
+ }`};e.compute({name:"GatherND",shaderCache:{hint:r.cacheKey,inputDependencies:["rank","rank"]},getRunData:()=>({outputs:[{dims:I,dataType:o}],dispatchGroup:{x:Math.ceil(M/64)},programUniforms:y}),getShaderSource:$},{inputs:[t[0],_]})},x0=e=>({batchDims:e.batch_dims,cacheKey:""})}),hg,mg,T0,E0,_x=Ve(()=>{mt(),bt(),tr(),xt(),hg=(e,r)=>{if(e.length<3||e.length>4)throw new Error("GatherBlockQuantized requires 3 or 4 inputs.");let t=xe.normalizeAxis(r.quantizeAxis,e[0].dims.length),s=r.blockSize,o=e[0],n=e[2],i=e.length===4?e[3]:void 0;if(n.dims.length!==o.dims.length||!o.dims.map((a,l)=>l===t?Math.ceil(a/s)===n.dims[l]:a===n.dims[l]).reduce((a,l)=>a&&l,!0))throw new Error("Scales must have the same rank as the input tensor and the dims should match except on gatherAxis.");if(i){if(i.dataType!==o.dataType)throw new Error("Zero point must have the same data type as the input tensor.");if(i.dims.length!==n.dims.length||!i.dims.map((a,l)=>a===n.dims[l]).reduce((a,l)=>a&&l,!0))throw new Error("Zero point must have the same rank as the input tensor and the dims should match except on quantizeAxis.")}},mg=(e,r)=>{let t=e[0].dims,s=e[1].dims,o=t.length,n=xe.normalizeAxis(r.gatherAxis,o),i=xe.normalizeAxis(r.quantizeAxis,o),a=t.slice(0);a.splice(n,1,...s);let l=xe.size(a),c=e[2].dataType,p=e[0].dataType===22,u=[{type:12,data:l},{type:12,data:i},{type:12,data:n},{type:12,data:r.blockSize},...nt(...e.map((g,_)=>g.dims),a)],h=g=>{let _=$e("data",e[0].dataType,e[0].dims.length),E=$e("inputIndices",e[1].dataType,e[1].dims.length),I=$e("scales",e[2].dataType,e[2].dims.length),M=e.length>3?$e("zeroPoint",e[3].dataType,e[3].dims.length):void 0,y=tt("output",c,a.length),$=[_,E,I];M&&$.push(M);let P=[{name:"output_size",type:"u32"},{name:"quantize_axis",type:"u32"},{name:"gather_axis",type:"u32"},{name:"block_size",type:"u32"}];return`
+ ${g.registerUniforms(P).declareVariables(...$,y)}
+ ${g.mainStart()}
+ let output_indices = ${y.offsetToIndices("global_idx")};
+ var indices_indices = ${E.type.indices}(0);
+ ${s.length>1?`
+ for (var i: u32 = 0; i < ${s.length}; i++) {
+ let index = ${y.indicesGet("output_indices","uniforms.gather_axis + i")};
+ ${E.indicesSet("indices_indices","i","index")};
+ }`:`indices_indices = ${y.indicesGet("output_indices","uniforms.gather_axis")};`};
+ var data_indices = ${_.type.indices}(0);
+ for (var i: u32 = 0; i < uniforms.gather_axis; i++) {
+ let index = ${y.indicesGet("output_indices","i")};
+ ${_.indicesSet("data_indices","i","index")};
+ }
+ var index_from_indices = ${E.getByIndices("indices_indices")};
+ if (index_from_indices < 0) {
+ index_from_indices += ${t[n]};
+ }
+ ${_.indicesSet("data_indices","uniforms.gather_axis","u32(index_from_indices)")};
+ for (var i = uniforms.gather_axis + 1; i < ${a.length}; i++) {
+ let index = ${y.indicesGet("output_indices",`i + ${s.length} - 1`)};
+ ${_.indicesSet("data_indices","i","index")};
+ }
+ let data_offset = ${_.indicesToOffset("data_indices")};
+ let data_index = data_offset % 8;
+ // Convert 4-bit packed data to 8-bit packed data.
+ let packed_4bit_quantized_data = ${_.getByOffset("data_offset / 8")};
+ let packed_8bit_quantized_data = (packed_4bit_quantized_data >> (4 * (data_index % 2))) & 0x0f0f0f0f;
+ let quantized_data_vec = ${p?"unpack4xI8":"unpack4xU8"}(u32(packed_8bit_quantized_data));
+ let quantized_data = quantized_data_vec[data_index / 2];
+ var scale_indices = data_indices;
+ let quantize_axis_index = ${I.indicesGet("data_indices","uniforms.quantize_axis")} / uniforms.block_size;
+ ${I.indicesSet("scale_indices","uniforms.quantize_axis","quantize_axis_index")};
+ var scale = ${I.getByIndices("scale_indices")};
+ ${M?`
+ let zero_point_indices = scale_indices;
+ let zero_point_offset = ${M.indicesToOffset("zero_point_indices")};
+ let zero_point_index = zero_point_offset % 8;
+ let packed_4bit_zero_points = ${M.getByOffset("zero_point_offset / 8")};
+ let packed_8bit_zero_points = (packed_4bit_zero_points >> (4 * (zero_point_index % 2))) & 0x0f0f0f0f;
+ let zero_point_vec = ${p?"unpack4xI8":"unpack4xU8"}(u32(packed_8bit_zero_points));
+ let zero_point = zero_point_vec[zero_point_index / 2];`:"var zero_point = 0"};
+ let dequantized_data = ${Cr(c)}(quantized_data - zero_point) * scale;
+ ${y.setByOffset("global_idx","dequantized_data")};
+ }`};return{name:"GatherBlockQuantized",shaderCache:{hint:`${r.cacheKey};${e.filter((g,_)=>_!==1).map(g=>g.dims.join("_")).join(";")}`,inputDependencies:Array.from({length:e.length},(g,_)=>"rank")},getRunData:()=>({outputs:[{dims:a,dataType:c}],dispatchGroup:{x:Math.ceil(l/64)},programUniforms:u}),getShaderSource:h}},T0=(e,r)=>{let t=e.inputs;hg(t,r),e.compute(mg(e.inputs,r))},E0=e=>Lt({blockSize:e.blockSize,gatherAxis:e.gatherAxis,quantizeAxis:e.quantizeAxis})}),fg,_g,P0,C0,gx=Ve(()=>{mt(),bt(),tr(),xt(),fg=e=>{if(!e||e.length!==2)throw new Error("GatherElements requires 2 inputs.");if(e[0].dims.length<1)throw new Error("GatherElements requires that the data input be rank >= 1.");if(e[0].dims.length!==e[1].dims.length)throw new Error(`GatherElements requires that the data input and
+ indices input tensors be of same rank.`)},_g=(e,r)=>{let t=e[0].dims,s=e[0].dataType,o=t.length,n=e[1].dims,i=e[1].dataType,a=xe.normalizeAxis(r.axis,o),l=t[a],c=n.slice(0),p=xe.size(c),u=$e("input",s,o),h=$e("indicesInput",i,n.length),g=tt("output",s,c.length),_=[{type:12,data:p},{type:6,data:l},{type:12,data:a}];return _.push(...nt(t,n,c)),{name:"GatherElements",shaderCache:{inputDependencies:["rank","rank"]},getRunData:()=>({outputs:[{dims:c,dataType:e[0].dataType}],dispatchGroup:{x:Math.ceil(p/64)},programUniforms:_}),getShaderSource:E=>`
+ ${E.registerUniform("outputSize","u32").registerUniform("axisDimLimit","i32").registerUniform("axis","u32").declareVariables(u,h,g)}
+ ${E.mainStart()}
+ ${E.guardAgainstOutOfBoundsWorkgroupSizes("uniforms.outputSize")}
+
+ let outputIndices = ${g.offsetToIndices("global_idx")};
+
+ var idx = ${h.getByOffset("global_idx")};
+ if (idx < 0) {
+ idx = idx + uniforms.axisDimLimit;
+ }
+ var inputIndices = ${u.type.indices}(outputIndices);
+ ${u.indicesSet("inputIndices","uniforms.axis","u32(idx)")};
+ let value = ${u.getByIndices("inputIndices")};
+
+ ${g.setByOffset("global_idx","value")};
+ }`}},P0=e=>Lt({axis:e.axis}),C0=(e,r)=>{let t=e.inputs;fg(t),e.compute(_g(e.inputs,r))}}),gg,wg,S0,$0,wx=Ve(()=>{mt(),bt(),xt(),gg=e=>{if(!e)throw new Error("Input is missing");if(e.length<2||e.length>3)throw new Error("Invaid input number.");if(e.length===3&&e[2].dims.length>2)throw new Error("Invalid input shape of C");if(e[0].dataType!==e[1].dataType||e.length===3&&e[0].dataType!==e[2].dataType)throw new Error("Input types are mismatched")},wg=(e,r)=>{let t=e[0].dims.slice(),s=e[1].dims.slice(),[o,n,i]=ky.getShapeOfGemmResult(t,r.transA,s,r.transB,e.length===3?e[2].dims:void 0),a=[o,n];if(!a)throw new Error("Can't use gemm on the given tensors");let l=16,c=Math.ceil(n/l),p=Math.ceil(o/l),u=!0,h=xe.size(a),g=[{type:12,data:u?c:h},{type:12,data:o},{type:12,data:n},{type:12,data:i},{type:1,data:r.alpha},{type:1,data:r.beta}],_=["type","type"];e.length===3&&(g.push(...nt(e[2].dims)),_.push("rank")),g.push(...nt(a));let E=M=>{let y="";r.transA&&r.transB?y="value += a[k * uniforms.M + m] * b[n * uniforms.K + k];":r.transA&&!r.transB?y="value += a[k * uniforms.M + m] * b[k * uniforms.N + n];":!r.transA&&r.transB?y="value += a[m * uniforms.K + k] * b[n * uniforms.K + k];":!r.transA&&!r.transB&&(y="value += a[m * uniforms.K + k] * b[k * uniforms.N + n];");let $=r.alpha===1?"":"value *= uniforms.alpha;",P=$e("a",e[0].dataType,e[0].dims),b=$e("b",e[1].dataType,e[1].dims),w=P.type.value,T=null,k=[P,b];e.length===3&&(T=$e("c",e[2].dataType,e[2].dims.length),k.push(T));let z=tt("output",e[0].dataType,a.length);k.push(z);let R=[{name:"output_size",type:"u32"},{name:"M",type:"u32"},{name:"N",type:"u32"},{name:"K",type:"u32"},{name:"alpha",type:"f32"},{name:"beta",type:"f32"}];return`
+ ${M.registerUniforms(R).declareVariables(...k)}
+
+ ${M.mainStart()}
+ ${M.guardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")}
+
+ let m = global_idx / uniforms.N;
+ let n = global_idx % uniforms.N;
+
+ var value = ${w}(0);
+ for (var k: u32 = 0u; k < uniforms.K; k++) {
+ ${y}
+ }
+
+ ${$}
+ ${T!=null?`let cOffset = ${T.broadcastedIndicesToOffset("vec2(m, n)",z)}; value += ${w}(uniforms.beta) * ${T.getByOffset("cOffset")};`:""}
+ output[global_idx] = value;
+ }`},I=M=>{let y=$e("a",e[0].dataType,e[0].dims),$=$e("b",e[1].dataType,e[1].dims),P=null,b=[y,$];e.length===3&&(P=$e("c",e[2].dataType,e[2].dims.length),b.push(P));let w=tt("output",e[0].dataType,a.length);b.push(w);let T=[{name:"num_tile_n",type:"u32"},{name:"M",type:"u32"},{name:"N",type:"u32"},{name:"K",type:"u32"},{name:"alpha",type:"f32"},{name:"beta",type:"f32"}],k="",z="";r.transA&&r.transB?(z=`
+ var col = tile_row_start + local_id.x;
+ var row = k_start + local_id.y;
+ if (col < uniforms.M && row < uniforms.K) {
+ tile_a[local_id.y][local_id.x] = a[row * uniforms.M + col];
+ } else {
+ tile_a[local_id.y][local_id.x] = ${y.type.value}(0);
+ }
+
+ col = k_start + local_id.x;
+ row = tile_col_start + local_id.y;
+ if (col < uniforms.K && row < uniforms.N) {
+ tile_b[local_id.y][local_id.x] = b[row * uniforms.K + col];
+ } else {
+ tile_b[local_id.y][local_id.x] = ${$.type.value}(0);
+ }
+ `,k="value += tile_a[k][local_id.y] * tile_b[local_id.x][k];"):r.transA&&!r.transB?(z=`
+ var col = tile_row_start + local_id.x;
+ var row = k_start + local_id.y;
+ if (col < uniforms.M && row < uniforms.K) {
+ tile_a[local_id.y][local_id.x] = a[row * uniforms.M + col];
+ } else {
+ tile_a[local_id.y][local_id.x] = ${y.type.value}(0);
+ }
+
+ col = tile_col_start + local_id.x;
+ row = k_start + local_id.y;
+ if (col < uniforms.N && row < uniforms.K) {
+ tile_b[local_id.y][local_id.x] = b[row * uniforms.N + col];
+ } else {
+ tile_b[local_id.y][local_id.x] = ${$.type.value}(0);
+ }
+ `,k="value += tile_a[k][local_id.y] * tile_b[k][local_id.x];"):!r.transA&&r.transB?(z=`
+ var col = k_start + local_id.x;
+ var row = tile_row_start + local_id.y;
+ if (col < uniforms.K && row < uniforms.M) {
+ tile_a[local_id.y][local_id.x] = a[row * uniforms.K + col];
+ } else {
+ tile_a[local_id.y][local_id.x] = ${y.type.value}(0);
+ }
+
+ col = k_start + local_id.x;
+ row = tile_col_start + local_id.y;
+ if (col < uniforms.K && row < uniforms.N) {
+ tile_b[local_id.y][local_id.x] = b[row * uniforms.K + col];
+ } else {
+ tile_b[local_id.y][local_id.x] = ${$.type.value}(0);
+ }
+ `,k="value += tile_a[local_id.y][k] * tile_b[local_id.x][k];"):!r.transA&&!r.transB&&(z=`
+ var col = k_start + local_id.x;
+ var row = tile_row_start + local_id.y;
+ if (col < uniforms.K && row < uniforms.M) {
+ tile_a[local_id.y][local_id.x] = a[row * uniforms.K + col];
+ } else {
+ tile_a[local_id.y][local_id.x] = ${y.type.value}(0);
+ }
+
+ col = tile_col_start + local_id.x;
+ row = k_start + local_id.y;
+ if (col < uniforms.N && row < uniforms.K) {
+ tile_b[local_id.y][local_id.x] = b[row * uniforms.N + col];
+ } else {
+ tile_b[local_id.y][local_id.x] = ${$.type.value}(0);
+ }
+ `,k="value += tile_a[local_id.y][k] * tile_b[k][local_id.x];");let R=r.alpha===1?"":"value *= uniforms.alpha;";return`
+ ${M.registerUniforms(T).declareVariables(...b)}
+ var tile_a: array, ${l}>;
+ var tile_b: array, ${l}>;
+ ${M.mainStart([l,l,1])}
+ let tile_col_start = (workgroup_index % uniforms.num_tile_n) * ${l};
+ let tile_row_start = (workgroup_index / uniforms.num_tile_n) * ${l};
+ let num_tiles = (uniforms.K - 1) / ${l} + 1;
+ var k_start = 0u;
+ var value = ${w.type.value}(0);
+ for (var t: u32 = 0u; t < num_tiles; t++) {
+ ${z}
+ k_start = k_start + ${l};
+ workgroupBarrier();
+
+ for (var k: u32 = 0u; k < ${l}; k++) {
+ ${k}
+ }
+ workgroupBarrier();
+ }
+
+ ${R}
+ let m = tile_row_start + local_id.y;
+ let n = tile_col_start + local_id.x;
+ ${P!=null?`let cOffset = ${P.broadcastedIndicesToOffset("vec2(m, n)",w)}; value += ${w.type.value}(uniforms.beta) * ${P.getByOffset("cOffset")};`:""}
+ if (m < uniforms.M && n < uniforms.N) {
+ output[m * uniforms.N + n] = value;
+ }
+ }`};return u?{name:"GemmShared",shaderCache:{hint:`${r.cacheKey}`,inputDependencies:_},getRunData:()=>({outputs:[{dims:a,dataType:e[0].dataType}],dispatchGroup:{x:c*p},programUniforms:g}),getShaderSource:I}:{name:"Gemm",shaderCache:{hint:`${r.cacheKey}`,inputDependencies:_},getRunData:()=>({outputs:[{dims:a,dataType:e[0].dataType}],dispatchGroup:{x:Math.ceil(h/64)},programUniforms:g}),getShaderSource:E}},S0=e=>{let r=e.transA,t=e.transB,s=e.alpha,o=e.beta;return{transA:r,transB:t,alpha:s,beta:o,cacheKey:`${e.transA};${e.transB};${e.alpha===1}`}},$0=(e,r)=>{gg(e.inputs),e.compute(wg(e.inputs,r))}}),$s,js,En,Pn,yg,Mg,bg,vg,xg,Tg,Eg,Pg,k0,I0,yx=Ve(()=>{mt(),bt(),tr(),xt(),[$s,js,En,Pn]=[0,1,2,3],yg=e=>{if(e[0].dims.length!==4)throw new Error("only 4-D tensor is supported.");if(e[0].dims.length!==e[1].dims.length)throw new Error("input dimensions must be equal to grid dimensions");if(e[0].dims.length-2!==e[1].dims[e[1].dims.length-1])throw new Error(`last dimension of grid must be equal to ${e[0].dims.length-2}`);if(e[0].dims[0]!==e[1].dims[0])throw new Error("grid batch size must match input batch size")},Mg=`
+ fn gs_get_cubic_coeffs(x: f32) -> vec4 {
+ let cubic_alpha = -0.75f;
+ let x_abs = abs(x);
+ var coeffs: vec4;
+ coeffs[0] = (((cubic_alpha * (x_abs + 1) - 5 * cubic_alpha) * (x_abs + 1) + 8 * cubic_alpha) * (x_abs + 1) - 4 * cubic_alpha);
+ coeffs[1] = (((cubic_alpha + 2) * x_abs - (cubic_alpha + 3)) * x_abs * x_abs + 1);
+ coeffs[2] = (((cubic_alpha + 2) * (1 - x_abs) - (cubic_alpha + 3)) * (1 - x_abs) * (1 - x_abs) + 1);
+ coeffs[3] = (((cubic_alpha * (2 - x_abs) - 5 * cubic_alpha) * (2 - x_abs) + 8 * cubic_alpha) * (2 - x_abs) - 4 * cubic_alpha);
+ return coeffs;
+ }
+`,bg=e=>`
+ fn gs_bicubic_interpolate(p: mat4x4<${e}>, x: f32, y: f32) -> ${e} {
+ var v: vec4;
+ var coeffs = gs_get_cubic_coeffs(x);
+ for (var i = 0; i < 4; i++) {
+ v[i] = coeffs[0] * p[i][0] + coeffs[1] * p[i][1] + coeffs[2] * p[i][2] + coeffs[3] * p[i][3];
+ }
+ coeffs = gs_get_cubic_coeffs(y);
+ let pixel = ${e}(coeffs[0] * v[0] + coeffs[1] * v[1] + coeffs[2] * v[2] + coeffs[3] * v[3]);
+ return pixel;
+ }
+`,vg=e=>`
+ fn gs_denormalize(n: f32, length: i32) -> f32 {
+ ${e.alignCorners===0?`
+ // alignCorners: false => [-1, 1] to [-0.5, length - 0.5]
+ return ((n + 1.0) * f32(length) - 1.0) / 2.0;
+ `:`
+ // alignCorners: true => [-1, 1] to [0, length - 1]
+ return (n + 1.0) / 2.0 * (f32(length - 1));
+ `}
+ }
+`,xg=e=>`
+ ${e.paddingMode==="reflection"?`
+ fn gs_reflect(x: i32, x_min: f32, x_max: f32) -> u32 {
+ var dx = 0.0;
+ var fx = f32(x);
+ let range = x_max - x_min;
+ if (fx < x_min) {
+ dx = x_min - fx;
+ let n = u32(dx / range);
+ let r = dx - f32(n) * range;
+ if (n % 2 == 0) {
+ fx = x_min + r;
+ } else {
+ fx = x_max - r;
+ }
+ } else if (fx > x_max) {
+ dx = fx - x_max;
+ let n = u32(dx / range);
+ let r = dx - f32(n) * range;
+ if (n % 2 == 0) {
+ fx = x_max - r;
+ } else {
+ fx = x_min + r;
+ }
+ }
+ return u32(fx);
+ }`:""}
+`,Tg=(e,r,t)=>`
+ fn pixel_at_grid(r: i32, c: i32, H: i32, W: i32, batch: u32, channel: u32, border: vec4) -> ${r} {
+ var pixel = ${r}(0);
+ var indices = vec4(0);
+ indices[${$s}] = batch;
+ indices[${js}] = channel;`+(()=>{switch(t.paddingMode){case"zeros":return`
+ if (r >= 0 && r < H && c >=0 && c < W) {
+ indices[${En}] = u32(r);
+ indices[${Pn}] = u32(c);
+ }
+ `;case"border":return`
+ indices[${En}] = u32(clamp(r, 0, H - 1));
+ indices[${Pn}] = u32(clamp(c, 0, W - 1));
+ `;case"reflection":return`
+ indices[${En}] = gs_reflect(r, border[1], border[3]);
+ indices[${Pn}] = gs_reflect(c, border[0], border[2]);
+ `;default:throw new Error(`padding mode ${t.paddingMode} is not supported`)}})()+`
+ return ${e.getByIndices("indices")};
+ }
+`,Eg=(e,r,t)=>(()=>{switch(t.mode){case"nearest":return`
+ let result = pixel_at_grid(i32(round(y)), i32(round(x)), H_in, W_in, indices[${$s}], indices[${js}], border);
+ `;case"bilinear":return`
+ let x1 = i32(floor(x));
+ let y1 = i32(floor(y));
+ let x2 = x1 + 1;
+ let y2 = y1 + 1;
+
+ let p11 = pixel_at_grid(y1, x1, H_in, W_in, indices[${$s}], indices[${js}], border);
+ let p12 = pixel_at_grid(y1, x2, H_in, W_in, indices[${$s}], indices[${js}], border);
+ let p21 = pixel_at_grid(y2, x1, H_in, W_in, indices[${$s}], indices[${js}], border);
+ let p22 = pixel_at_grid(y2, x2, H_in, W_in, indices[${$s}], indices[${js}], border);
+
+ let dx2 = ${r}(f32(x2) - x);
+ let dx1 = ${r}(x - f32(x1));
+ let dy2 = ${r}(f32(y2) - y);
+ let dy1 = ${r}(y - f32(y1));
+ let result = dy2 * (dx2 * p11 + dx1 * p12) + dy1 * (dx2 * p21 + dx1 * p22);
+ `;case"bicubic":return`
+ let x0 = i32(floor(x)) - 1;
+ let y0 = i32(floor(y)) - 1;
+ var p: mat4x4<${r}>;
+ for (var h = 0; h < 4; h++) {
+ for (var w = 0; w < 4; w++) {
+ p[h][w] = pixel_at_grid(h + y0, w + x0, H_in, W_in, indices[${$s}], indices[${js}], border);
+ }
+ }
+
+ let dx = x - f32(x0 + 1);
+ let dy = y - f32(y0 + 1);
+ let result = gs_bicubic_interpolate(p, dx, dy);
+ `;default:throw new Error(`mode ${t.mode} is not supported`)}})()+`${e.setByOffset("global_idx","result")}`,Pg=(e,r)=>{let t=$e("x",e[0].dataType,e[0].dims.length),s=[e[1].dims[0],e[1].dims[1],e[1].dims[2]],o=$e("grid",e[1].dataType,s.length,2),n=[e[0].dims[0],e[0].dims[1],e[1].dims[1],e[1].dims[2]];r.format==="NHWC"&&(n=[e[0].dims[0],e[1].dims[1],e[1].dims[2],e[0].dims[3]],[$s,js,En,Pn]=[0,3,1,2]);let i=tt("output",e[0].dataType,n.length),a=t.type.value,l=xe.size(n),c=[{type:12,data:l},...nt(e[0].dims,s,n)],p=u=>`
+ ${u.registerUniform("output_size","u32").declareVariables(t,o,i)}
+ ${Mg}
+ ${bg(a)}
+ ${vg(r)}
+ ${xg(r)}
+ ${Tg(t,a,r)}
+
+ ${u.mainStart()}
+ ${u.guardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")}
+ let H_in = i32(uniforms.x_shape[${En}]);
+ let W_in = i32(uniforms.x_shape[${Pn}]);
+
+ ${r.alignCorners===0?`
+ let x_min = -0.5;
+ let x_max = f32(W_in) - 0.5;
+ let y_min = -0.5;
+ let y_max = f32(H_in) - 0.5;
+ `:`
+ let x_min = 0.0;
+ let x_max = f32(W_in) - 1.0;
+ let y_min = 0.0;
+ let y_max = f32(H_in) - 1.0;
+ `};
+ let border = vec4(x_min, y_min, x_max, y_max);
+
+ let indices = ${i.offsetToIndices("global_idx")};
+ var grid_indices = vec3(indices[${$s}], indices[${En}], indices[${Pn}]);
+ let nxy = ${o.getByIndices("grid_indices")};
+ var x = gs_denormalize(f32(nxy[0]), W_in);
+ var y = gs_denormalize(f32(nxy[1]), H_in);
+
+ ${Eg(i,a,r)}
+ }`;return{name:"GridSample",shaderCache:{hint:`${r.cacheKey}`,inputDependencies:["type","type"]},getRunData:u=>{let h=xe.size(n);return{outputs:[{dims:n,dataType:u[0].dataType}],dispatchGroup:{x:Math.ceil(h/64)},programUniforms:c}},getShaderSource:p}},k0=(e,r)=>{yg(e.inputs),e.compute(Pg(e.inputs,r))},I0=e=>Lt({alignCorners:e.align_corners,mode:e.mode,paddingMode:e.padding_mode,format:e.format})}),Fr,Cg,A0,Mc,Sg,ca,F0,O0=Ve(()=>{mt(),bt(),tr(),wu(),bu(),xt(),cn(),Fr=(e,r)=>e.length>r&&e[r].dims.length>0?e[r]:void 0,Cg=(e,r)=>{let t=e[0],s=Fr(e,1),o=Fr(e,2),n=Fr(e,3),i=Fr(e,4),a=Fr(e,5),l=Fr(e,6),c=Fr(e,7);if(t.dims.length!==3&&t.dims.length!==5)throw new Error("Input query is expected to have 3 or 5 dimensions");let p=t.dims[0],u=t.dims[1],h=t.dims.length===3?t.dims[2]:r.numHeads*t.dims[4],g=u,_=0,E=0,I=Math.floor(h/r.numHeads);if(l&&c&&xe.size(l.dims)&&xe.size(c.dims)){if(l.dims.length!==4)throw new Error('Input "past_key" is expected to have 4 dimensions');if(l.dims[0]!==p||l.dims[1]!==r.numHeads||l.dims[3]!==I)throw new Error('Input "past_key" shape (batch_size, num_heads, past_sequence_length, head_size)');if(c.dims[0]!==p||c.dims[1]!==r.numHeads||c.dims[3]!==I)throw new Error('Input "past_value" shape (batch_size, num_heads, past_sequence_length, head_size)');if(l.dims[2]!==c.dims[2])throw new Error('Input "past_key" and "past_value" shall have same dim 2 (past_sequence_length)');if(c.dims.length!==4)throw new Error('Input "past_value" is expected to have 4 dimensions');_=l.dims[2],E=l.dims[2]}else if(l&&xe.size(l.dims)||c&&xe.size(c.dims))throw new Error('Input "past_key" and "past_value" shall be both present or both absent');let M;if(s&&xe.size(s.dims)>0){if(t.dims.length!==3)throw new Error('Input "query" is expected to have 3 dimensions when key is given');if(s.dims.length<3||s.dims.length>5)throw new Error('Input "key" is expected to have 3, 4, or 5 dimensions');if(t.dims[0]!==s.dims[0])throw new Error('Input "query" and "key" shall have same dim 0 (batch size)');if(s.dims.length===3){if(s.dims[2]!==t.dims[2])throw new Error('Input "query" and "key" shall have same dim 2 (hidden_size)');M=2,g=s.dims[1]}else if(s.dims.length===5){if(s.dims[2]!==r.numHeads||s.dims[3]!==2||s.dims[4]!==I)throw new Error('Expect "key" shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv');if(o)throw new Error('Expect "value" be none when "key" has packed kv format.');M=5,g=s.dims[1]}else{if(s.dims[1]!==r.numHeads||s.dims[3]!==I)throw new Error('Expect "key" shape (batch_size, num_heads, kv_sequence_length, head_size) for past_key');M=0,g=s.dims[2]}}else{if(t.dims.length!==5)throw new Error('Input "query" is expected to have 5 dimensions when key is empty');if(t.dims[2]!==r.numHeads||t.dims[3]!==3)throw new Error('Expect "query" shape (batch_size, kv_sequence_length, num_heads, 3, head_size) for packed kv');M=3}if(n&&xe.size(n.dims)>0){if(n.dims.length!==1)throw new Error('Input "bias" is expected to have 1 dimension');if(s&&s.dims.length===5&&s.dims[3]===2)throw new Error("bias is not allowed for packed kv.")}let y=_+g,$=0;if(i&&xe.size(i.dims)>0){$=8;let T=i.dims;throw T.length===1?T[0]===p?$=1:T[0]===3*p+2&&($=3):T.length===2&&T[0]===p&&T[1]===y&&($=5),$===8?new Error('Input "key_padding_mask" shape shall be (batch_size) or (batch_size, total_sequence_length)'):new Error("Mask not supported")}let P=!1,b=h;if(o&&xe.size(o.dims)>0){if(o.dims.length!==3&&o.dims.length!==4)throw new Error('Input "value" is expected to have 3 or 4 dimensions');if(t.dims[0]!==o.dims[0])throw new Error('Input "query" and "value" shall have same dim 0 (batch_size)');if(o.dims.length===3){if(g!==o.dims[1])throw new Error('Input "key" and "value" shall have the same dim 1 (kv_sequence_length)');b=o.dims[2]}else{if(g!==o.dims[2])throw new Error('Input "key" and "value" shall have the same dim 2 (kv_sequence_length)');b=o.dims[1]*o.dims[3],P=!0}}let w=!1;if(i&&xe.size(i.dims)>0)throw new Error("Key padding mask is not supported");if(a&&xe.size(a.dims)>0){if(a.dims.length!==4)throw new Error('Input "attention_bias" is expected to have 4 dimensions');if(a.dims[0]!==p||a.dims[1]!==r.numHeads||a.dims[2]!==u||a.dims[3]!==y)throw new Error('Expect "attention_bias" shape (batch_size, num_heads, sequence_length, total_sequence_length)')}return{batchSize:p,sequenceLength:u,pastSequenceLength:_,kvSequenceLength:g,totalSequenceLength:y,maxSequenceLength:E,inputHiddenSize:0,hiddenSize:h,vHiddenSize:b,headSize:I,vHeadSize:Math.floor(b/r.numHeads),numHeads:r.numHeads,isUnidirectional:!1,pastPresentShareBuffer:!1,maskFilterValue:r.maskFilterValue,maskType:$,scale:r.scale,broadcastResPosBias:w,passPastInKv:P,qkvFormat:M}},A0=e=>Lt({...e}),Mc=Lt({perm:[0,2,1,3]}),Sg=(e,r,t,s,o,n,i)=>{let a=[s,o,n],l=xe.size(a),c=[{type:12,data:l},{type:12,data:i},{type:12,data:n}],p=u=>{let h=tt("qkv_with_bias",r.dataType,a),g=$e("qkv",r.dataType,a),_=$e("bias",t.dataType,a),E=[{name:"output_size",type:"u32"},{name:"bias_offset",type:"u32"},{name:"hidden_size",type:"u32"}];return`
+ ${u.registerUniforms(E).declareVariables(g,_,h)}
+ ${u.mainStart()}
+ ${u.guardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")}
+ let bias_offset_idx = (global_idx % uniforms.hidden_size) + uniforms.bias_offset;
+
+ qkv_with_bias[global_idx] = qkv[global_idx] + bias[bias_offset_idx];
+ }`};return e.compute({name:"MultiHeadAttentionAddBias",shaderCache:{inputDependencies:["type","type"]},getRunData:()=>({outputs:[{dims:a,dataType:r.dataType,gpuDataType:0}],dispatchGroup:{x:Math.ceil(l/64)},programUniforms:c}),getShaderSource:p},{inputs:[r,t],outputs:[-1]})[0]},ca=(e,r,t,s,o,n,i,a)=>{let l=n;if(i&&xe.size(i.dims)>0){if(s===1)throw new Error("AddBiasReshape is not implemented. Please export your model with packed QKV or KV");return l=Sg(e,n,i,r,s,t*o,a),l=l.reshape([r,s,t,o]),t===1||s===1?l:e.compute(Wr(l,Mc.perm),{inputs:[l],outputs:[-1]})[0]}else return n.dims.length===3&&(l=n.reshape([r,s,t,o])),t===1||s===1?l:e.compute(Wr(l,Mc.perm),{inputs:[l],outputs:[-1]})[0]},F0=(e,r)=>{let t=Cg(e.inputs,r),s=e.inputs[0],o=Fr(e.inputs,1),n=Fr(e.inputs,2),i=Fr(e.inputs,3),a=Fr(e.inputs,4),l=Fr(e.inputs,5),c=Fr(e.inputs,6),p=Fr(e.inputs,7);if(s.dims.length===5)throw new Error("Packed QKV is not implemented");if((o==null?void 0:o.dims.length)===5)throw new Error("Packed KV is not implemented");let u=o&&n&&o.dims.length===4&&n.dims.length===4,h=ca(e,t.batchSize,t.numHeads,t.sequenceLength,t.headSize,s,i,0);if(u)return ha(e,h,o,n,a,void 0,c,p,l,t);if(!o||!n)throw new Error("key and value must be provided");let g=ca(e,t.batchSize,t.numHeads,t.kvSequenceLength,t.headSize,o,i,t.hiddenSize),_=ca(e,t.batchSize,t.numHeads,t.kvSequenceLength,t.vHeadSize,n,i,2*t.hiddenSize);ha(e,h,g,_,a,void 0,c,p,l,t)}}),$g,kg,Ig,Ag,ru,D0,L0,z0=Ve(()=>{mt(),bt(),tr(),xt(),$g=e=>{if(!e||e.length<1)throw new Error("too few inputs")},kg=(e,r)=>{let t=[],s=r.numOutputs;return e[1].dims[0]>0&&(e[1].getBigInt64Array().forEach(o=>t.push(Number(o))),s=t.length),Lt({numOutputs:s,axis:r.axis,splitSizes:t})},Ig=e=>`
+fn calculateOutputIndex(index: u32) -> u32 {
+ for (var i: u32 = 0u; i < ${e}u; i += 1u ) {
+ if (index < ${rt("uniforms.size_in_split_axis","i",e)}) {
+ return i;
+ }
+ }
+ return ${e}u;
+}`,Ag=e=>{let r=e.length,t=[];for(let s=0;s{let t=e[0].dims,s=xe.size(t),o=e[0].dataType,n=xe.normalizeAxis(r.axis,t.length),i=new Array(r.numOutputs),a=$e("input",o,t.length),l=new Array(r.numOutputs),c=[],p=[],u=0,h=[{type:12,data:s}];for(let _=0;_`
+ ${_.registerUniform("input_size","u32").registerUniform("size_in_split_axis","u32",l.length).declareVariables(a,...i)}
+ ${Ig(l.length)}
+ ${Ag(i)}
+
+ ${_.mainStart()}
+ ${_.guardAgainstOutOfBoundsWorkgroupSizes("uniforms.input_size")}
+
+ var indices = ${a.offsetToIndices("global_idx")};
+ var index = ${a.indicesGet("indices",n)};
+ let output_number = calculateOutputIndex(index);
+ if (output_number != 0) {
+ index -= ${rt("uniforms.size_in_split_axis","output_number - 1u",l.length)};
+ ${a.indicesSet("indices",n,"index")};
+ }
+ writeBufferData(output_number, indices, global_idx);
+ }`;return{name:"Split",shaderCache:{hint:r.cacheKey,inputDependencies:["rank"]},getShaderSource:g,getRunData:()=>({outputs:c,dispatchGroup:{x:Math.ceil(s/64)},programUniforms:h})}},D0=(e,r)=>{$g(e.inputs);let t=e.inputs.length===1?r:kg(e.inputs,r);e.compute(ru(e.inputs,t),{inputs:[0]})},L0=e=>{let r=e.axis,t=e.splitSizes,s=e.numOutputs<0?t.length:e.numOutputs;if(s!==t.length)throw new Error("numOutputs and splitSizes lengh must be equal");return Lt({axis:r,numOutputs:s,splitSizes:t})}}),Fg,Og,bc,B0,Mx=Ve(()=>{tr(),bu(),O0(),z0(),cn(),Fg=(e,r)=>{if(r.doRotary)throw new Error("GroupQuerryAttention do_rotary attribute is not supported");if(r.doRotary&&e.length<=7)throw new Error("cos_cache and sin_cache inputs are required if do_rotary is specified");let t=e[0],s=e[1],o=e[2],n=e[3],i=e[4];if(r.localWindowSize!==-1)throw new Error("Local attention is not supported");if(r.softcap!==0)throw new Error("Softcap is not supported");if(r.rotaryInterleaved!==0)throw new Error("Rotary interleaved is not supported");if(r.smoothSoftmax)throw new Error("Smooth softmax is not supported");if(t.dims.length!==3&&t.dims.length!==5)throw new Error("Input query is expected to have 3 or 5 dimensions");let a=!1,l=t.dims[0],c=t.dims[1],p=t.dims.length===3?a?t.dims[2]/3:t.dims[2]:r.numHeads*t.dims[4],u=c,h=0,g=!s||s.dims.length===0,_=Math.floor(g?p/(r.numHeads+2*r.kvNumHeads):p/r.numHeads);g&&(p=_*r.numHeads);let E=n&&n.dims.length!==0,I=i&&i.dims.length!==0;if(E&&n.dims.length===4&&n.dims[0]===l&&n.dims[1]!==r.kvNumHeads&&n.dims[2]===r.kvNumHeads&&n.dims[3]===_)throw new Error("BSNH pastKey/pastValue is not supported");if(E&&I){if(n.dims.length!==4)throw new Error('Input "past_key" is expected to have 4 dimensions');if(i.dims.length!==4)throw new Error('Input "past_value" is expected to have 4 dimensions');h=n.dims[2]}else if(E||I)throw new Error('Input "past_key" and "past_value" shall be both present or both absent');let M=1;if(s&&s.dims.length>0){if(t.dims.length!==3)throw new Error('Input "query" is expected to have 3 dimensions when key is given');if(s.dims.length<3||s.dims.length>5)throw new Error('Input "key" is expected to have 3, 4, or 5 dimensions');if(t.dims[0]!==s.dims[0])throw new Error('Input "query" and "key" shall have same dim 0 (batch size)');if(s.dims.length===3){if(t.dims[2]%s.dims[2]!==0)throw new Error('Dimension 2 of "query" should be a multiple of "key"');u=s.dims[1]}else if(s.dims.length===5){if(s.dims[2]!==r.numHeads||s.dims[3]!==2||s.dims[4]!==_)throw new Error('Expect "key" shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv');if(o)throw new Error('Expect "value" be none when "key" has packed kv format.');u=s.dims[1]}else{if(s.dims[1]!==r.numHeads||s.dims[3]!==_)throw new Error('Expect "key" shape (batch_size, num_heads, kv_sequence_length, head_size) for past_key');u=s.dims[2]}}else{if(t.dims.length!==3&&t.dims.length!==5)throw new Error('Input "query" is expected to have 3 or 5 dimensions when key is empty');if(t.dims.length===5&&(t.dims[2]!==r.numHeads||t.dims[3]!==3))throw new Error('Expect "query" shape (batch_size, kv_sequence_length, num_heads, 3, head_size) for packed kv');M=3}let y=0,$=!1,P=r.kvNumHeads?_*r.kvNumHeads:p;if(o&&o.dims.length>0){if(o.dims.length!==3&&o.dims.length!==4)throw new Error('Input "value" is expected to have 3 or 4 dimensions');if(t.dims[0]!==o.dims[0])throw new Error('Input "query" and "value" shall have same dim 0 (batch_size)');if(o.dims.length===3){if(u!==o.dims[1])throw new Error('Input "key" and "value" shall have the same dim 1 (kv_sequence_length)');P=o.dims[2]}else{if(u!==o.dims[2])throw new Error('Input "past_key" and "past_value" shall have the same dim 2 (kv_sequence_length)');P=o.dims[1]*o.dims[3],$=!0}}let b=e.length>4?e[5]:void 0;if(b&&b.dims.length!==1&&b.dims[0]!==l)throw new Error('Input "seqlens" is expected to have 1 dimension and the same dim 0 as batch_size');return{batchSize:l,sequenceLength:c,pastSequenceLength:h,kvSequenceLength:u,totalSequenceLength:-1,maxSequenceLength:-1,inputHiddenSize:0,hiddenSize:p,vHiddenSize:P,headSize:_,vHeadSize:Math.floor(P/r.kvNumHeads),numHeads:r.numHeads,kvNumHeads:r.kvNumHeads,nReps:r.numHeads/r.kvNumHeads,pastPresentShareBuffer:!1,maskType:y,scale:r.scale,broadcastResPosBias:!1,passPastInKv:$,qkvFormat:M}},Og=Lt({perm:[0,2,1,3]}),bc=(e,r,t)=>{let s=r,o=t.kvNumHeads;return r.dims.length===3&&t.kvSequenceLength!==0&&(s=r.reshape([t.batchSize,t.kvSequenceLength,o,t.headSize]),s=e.compute(Wr(s,Og.perm),{inputs:[s],outputs:[-1]})[0]),s},B0=(e,r)=>{var I;let t=Fg(e.inputs,r);if(e.inputs[0].dims.length===5)throw new Error("Packed QKV is not implemented");if(((I=e.inputs[1])==null?void 0:I.dims.length)===5)throw new Error("Packed KV is not implemented");let s=e.inputs[0],o=e.inputs[1]&&e.inputs[1].dims.length>0?e.inputs[1]:void 0,n=e.inputs[2]&&e.inputs[2].dims.length>0?e.inputs[2]:void 0,i=e.inputs[3]&&e.inputs[3].dims.length!==0?e.inputs[3]:void 0,a=e.inputs[4]&&e.inputs[4].dims.length!==0?e.inputs[4]:void 0,l=e.inputs.length>4?e.inputs[5]:void 0,c=e.inputs.length>5?e.inputs[6]:void 0,p=t.kvNumHeads?t.kvNumHeads:t.numHeads,u=Lt({axis:2,numOutputs:3,splitSizes:[t.numHeads*t.headSize,p*t.headSize,p*t.headSize]}),[h,g,_]=!o&&!n?e.compute(ru([s],u),{inputs:[s],outputs:[-1,-1,-1]}):[s,o,n],E=ca(e,t.batchSize,t.numHeads,t.sequenceLength,t.headSize,h,void 0,0);ha(e,E,bc(e,g,t),bc(e,_,t),void 0,void 0,i,a,void 0,t,l,c)}}),vc,Dg,Lg,R0,bx=Ve(()=>{mt(),bt(),cn(),xt(),vc=(e,r,t,s,o,n,i,a)=>{let l=Jt(n),c=l===1?"f32":`vec${l}f`,p=l===1?"vec2f":`mat2x${l}f`,u=o*i,h=64;u===1&&(h=256);let g=[o,i,n/l],_=[o,i,2],E=["rank","type","type"],I=[];I.push(...nt(g,_));let M=y=>{let $=$e("x",r.dataType,3,l),P=$e("scale",t.dataType,t.dims),b=$e("bias",s.dataType,s.dims),w=tt("output",1,3,2),T=[$,P,b,w];return`
+ var workgroup_shared : array<${p}, ${h}>;
+ const workgroup_size = ${h}u;
+ ${y.declareVariables(...T)}
+ ${y.mainStart(h)}
+ let batch = workgroup_index / uniforms.x_shape[1];
+ let channel = workgroup_index % uniforms.x_shape[1];
+ let hight = uniforms.x_shape[2];
+ // initialize workgroup memory
+ var sum = ${c}(0);
+ var squared_sum = ${c}(0);
+ for (var h = local_idx; h < hight; h += workgroup_size) {
+ let value = ${c}(${$.get("batch","channel","h")});
+ sum += value;
+ squared_sum += value * value;
+ }
+ workgroup_shared[local_idx] = ${p}(sum, squared_sum);
+ workgroupBarrier();
+
+ for (var currSize = workgroup_size >> 1; currSize > 0; currSize = currSize >> 1) {
+ if (local_idx < currSize) {
+ workgroup_shared[local_idx] = workgroup_shared[local_idx] + workgroup_shared[local_idx + currSize];
+ }
+ workgroupBarrier();
+ }
+ if (local_idx == 0) {
+ let sum_final = ${dn("workgroup_shared[0][0]",l)} / f32(hight * ${l});
+ let squared_sum_final = ${dn("workgroup_shared[0][1]",l)} / f32(hight * ${l});
+
+ let inv_std_dev = inverseSqrt(squared_sum_final - sum_final * sum_final + f32(${a}));
+ let channel_scale = inv_std_dev * f32(scale[channel]);
+ let channel_shift = f32(bias[channel]) - sum_final * channel_scale;
+ output[workgroup_index] = vec2f(channel_scale, channel_shift);
+ }
+ }`};return e.compute({name:"InstanceNormComputeChannelScaleShift",shaderCache:{hint:`${l};${a};${h}`,inputDependencies:E},getRunData:()=>({outputs:[{dims:_,dataType:1}],dispatchGroup:{x:u},programUniforms:I}),getShaderSource:M},{inputs:[r,t,s],outputs:[-1]})[0]},Dg=(e,r,t)=>{let s=r[0].dims,o=s,n=2,i=s[0],a=s[1],l=xe.sizeFromDimension(s,n),c=Jt(l),p=xe.size(o)/c,u=vc(e,r[0],r[1],r[2],i,l,a,t.epsilon),h=[i,a,l/c],g=[i,a],_=["type","none"],E=I=>{let M=$e("x",r[0].dataType,h.length,c),y=$e("scale_shift",1,g.length,2),$=tt("output",r[0].dataType,h.length,c),P=[M,y,$];return`
+ ${I.registerUniform("output_size","u32").declareVariables(...P)}
+ ${I.mainStart()}
+ ${I.guardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")}
+ let outputIndices = ${$.offsetToIndices("global_idx")};
+ let batch = outputIndices[0];
+ let channel = outputIndices[1];
+ let scale_shift = ${y.getByIndices("vec2(batch, channel)")};
+ let value = ${M.getByOffset("global_idx")} * ${$.type.value}(scale_shift.x) + ${$.type.value}(scale_shift.y);
+ ${$.setByOffset("global_idx","value")};
+ }`};e.compute({name:"InstanceNormalization",shaderCache:{hint:`${c}`,inputDependencies:_},getRunData:()=>({outputs:[{dims:o,dataType:r[0].dataType}],dispatchGroup:{x:Math.ceil(p/64)},programUniforms:[{type:12,data:p},...nt(h,g,h)]}),getShaderSource:E},{inputs:[r[0],u]})},Lg=(e,r,t)=>{let s=r[0].dims,o=s,n=s[0],i=s[s.length-1],a=xe.sizeFromDimension(s,1)/i,l=Jt(i),c=xe.size(o)/l,p=[{type:12,data:a},{type:12,data:Math.floor(i/l)}],u=["type","type"],h=!1,g=[0,s.length-1];for(let M=0;Ms[g[y]])),E=vc(e,_,r[1],r[2],n,a,i,t.epsilon),I=M=>{let y=pr(r[0].dataType),$=l===1?"vec2f":`mat${l}x2f`,P=T=>{let k=T===0?"x":"y",z=l===1?"f32":`vec${l}f`;switch(l){case 1:return`${y}(${z}(scale.${k}))`;case 2:return`vec2<${y}>(${z}(scale[0].${k}, scale[1].${k}))`;case 4:return`vec4<${y}>(${z}(scale[0].${k}, scale[1].${k}, scale[2].${k}, scale[3].${k}))`;default:throw new Error(`Not supported compoents ${l}`)}},b=$e("input",r[0].dataType,r[0].dims,l),w=tt("output",r[0].dataType,o,l);return`
+ @group(0) @binding(0) var input : array<${b.type.storage}>;
+ @group(0) @binding(1) var scale_input : array<${$}>;
+ @group(0) @binding(2) var output : array<${w.type.storage}>;
+ struct Uniforms {H: u32, C : u32};
+ @group(0) @binding(3) var uniforms: Uniforms;
+
+ ${M.mainStart()}
+ let current_image_number = global_idx / (uniforms.C * uniforms.H);
+ let current_channel_number = global_idx % uniforms.C;
+
+ let scale_offset = current_image_number * uniforms.C + current_channel_number;
+ let scale = scale_input[scale_offset];
+ output[global_idx] = fma(input[global_idx], ${P(0)}, ${P(1)});
+ }`};e.compute({name:"InstanceNormalizationNHWC",shaderCache:{hint:`${l}`,inputDependencies:u},getRunData:()=>({outputs:[{dims:o,dataType:r[0].dataType}],dispatchGroup:{x:Math.ceil(c/64)},programUniforms:p}),getShaderSource:I},{inputs:[r[0],E]})},R0=(e,r)=>{r.format==="NHWC"?Lg(e,e.inputs,r):Dg(e,e.inputs,r)}}),zg,Bg,N0,vx=Ve(()=>{mt(),bt(),xt(),zg=e=>{if(!e||e.length<2)throw new Error("layerNorm requires at least 2 inputs.")},Bg=(e,r,t)=>{let s=r.simplified,o=e[0].dims,n=e[1],i=!s&&e[2],a=o,l=xe.normalizeAxis(r.axis,o.length),c=xe.sizeToDimension(o,l),p=xe.sizeFromDimension(o,l),u=xe.size(n.dims),h=i?xe.size(i.dims):0;if(u!==p||i&&h!==p)throw new Error(`Size of X.shape()[axis:] == ${p}.
+ Size of scale and bias (if provided) must match this.
+ Got scale size of ${u} and bias size of ${h}`);let g=[];for(let b=0;b1,y=t>2,$=b=>{let w=pr(e[0].dataType),T=[$e("x",e[0].dataType,e[0].dims,_),$e("scale",n.dataType,n.dims,_)];i&&T.push($e("bias",i.dataType,i.dims,_)),T.push(tt("output",e[0].dataType,a,_)),M&&T.push(tt("mean_data_output",1,g)),y&&T.push(tt("inv_std_output",1,g));let k=[{name:"norm_count",type:"u32"},{name:"norm_size",type:"f32"},{name:"norm_size_vectorized",type:"u32"},{name:"epsilon",type:"f32"}];return`
+ ${b.registerUniforms(k).declareVariables(...T)}
+ ${b.mainStart()}
+ ${b.guardAgainstOutOfBoundsWorkgroupSizes("uniforms.norm_count")}
+ let offset = global_idx * uniforms.norm_size_vectorized;
+ var mean_vector = ${Hc("f32",_)};
+ var mean_square_vector = ${Hc("f32",_)};
+
+ for (var h: u32 = 0u; h < uniforms.norm_size_vectorized; h++) {
+ let value = ${Co(w,_,"x[h + offset]")};
+ mean_vector += value;
+ mean_square_vector += value * value;
+ }
+ let mean = ${dn("mean_vector",_)} / uniforms.norm_size;
+ let inv_std_dev = inverseSqrt(${dn("mean_square_vector",_)} / uniforms.norm_size ${s?"":"- mean * mean"} + uniforms.epsilon);
+
+ for (var j: u32 = 0; j < uniforms.norm_size_vectorized; j++) {
+ let f32input = ${Co(w,_,"x[j + offset]")};
+ let f32scale = ${Co(w,_,"scale[j]")};
+ output[j + offset] = ${T[0].type.value}((f32input ${s?"":"- mean"}) * inv_std_dev * f32scale
+ ${i?`+ ${Co(w,_,"bias[j]")}`:""}
+ );
+ }
+
+ ${M?"mean_data_output[global_idx] = mean":""};
+ ${y?"inv_std_output[global_idx] = inv_std_dev":""};
+ }`},P=[{dims:a,dataType:e[0].dataType}];return M&&P.push({dims:g,dataType:1}),y&&P.push({dims:g,dataType:1}),{name:"LayerNormalization",shaderCache:{hint:`${_};${t};${s}`,inputDependencies:E},getRunData:()=>({outputs:P,dispatchGroup:{x:Math.ceil(c/64)},programUniforms:I}),getShaderSource:$}},N0=(e,r)=>{zg(e.inputs),e.compute(Bg(e.inputs,r,e.outputCount))}}),Rg,j0,xx=Ve(()=>{bt(),Pu(),Cu(),Rg=e=>{if(!e||e.length!==2)throw new Error("MatMul requires 2 inputs.");if(e[0].dims[e[0].dims.length-1]!==e[1].dims[e[1].dims.length-2])throw new Error("shared dimension does not match.")},j0=e=>{Rg(e.inputs);let r=So.calcShape(e.inputs[0].dims,e.inputs[1].dims,!0);if(!r)throw new Error("Can't use matmul on the given tensors");let t=r[r.length-1],s=e.inputs[0].dims[e.inputs[0].dims.length-1];if(t<8&&s<8)e.compute(Eu(e.inputs,{activation:""},r));else{let o=r[r.length-2],n=xe.size(e.inputs[0].dims.slice(0,-2)),i=xe.size(e.inputs[1].dims.slice(0,-2));if(n!==1&&o===1&&i===1){let a=e.inputs[0].reshape([1,n,s]),l=e.inputs[1].reshape([1,s,t]),c=[1,n,t],p=[a,l];e.compute(ad(p,{activation:""},r,c),{inputs:p})}else e.compute(ad(e.inputs,{activation:""},r))}}}),Ng,jg,Vg,V0,U0,Tx=Ve(()=>{mt(),bt(),tr(),xt(),Ng=(e,r)=>{if(e.length<3||e.length>4)throw new Error("MatMulNBits requires 3 or 4 inputs");let t=e[0],s=t.dims.length;if(t.dims[s-1]!==r.k)throw new Error("The last dim of input shape does not match the k value");let o=Math.floor((r.k+r.blockSize-1)/r.blockSize),n=r.blockSize/8*r.bits,i=e[1];if(!xe.areEqual(i.dims,[r.n,o,n]))throw new Error("The second inputs must be 3D tensor with shape N X nBlocksPerCol X blobSize");let a=e[2].dims;if(xe.size(a)!==r.n*o)throw new Error("scales input size error.");if(e.length===4){let l=e[3].dims,c=r.bits>4?r.n*o:r.n*Math.floor((o+1)/2);if(xe.size(l)!==c)throw new Error("zeroPoints input size error.")}},jg=(e,r)=>{let t=e[0].dims,s=t.length,o=t[s-2],n=r.k,i=r.n,a=t.slice(0,s-2),l=xe.size(a),c=e[1].dims[2]/4,p=e[0].dataType,u=Jt(r.k),h=Jt(c),g=Jt(i),_=a.concat([o,i]),E=o>1&&i/g%2===0?2:1,I=xe.size(_)/g/E,M=64,y=[],$=[l,o,n/u],P=xe.convertShape(e[1].dims).slice();P.splice(-1,1,c/h),y.push(...nt($)),y.push(...nt(P)),y.push(...nt(e[2].dims)),e.length===4&&y.push(...nt(xe.convertShape(e[3].dims)));let b=[l,o,i/g];y.push(...nt(b));let w=T=>{let k=$.length,z=$e("a",e[0].dataType,k,u),R=$e("b",12,P.length,h),Q=$e("scales",e[2].dataType,e[2].dims.length),q=[z,R,Q],U=e.length===4?$e("zero_points",12,e[3].dims.length):void 0;U&&q.push(U);let Z=b.length,H=tt("output",e[0].dataType,Z,g),J=pr(e[0].dataType),oe=(()=>{switch(u){case 1:return`array<${J}, 8>`;case 2:return`mat4x2<${J}>`;case 4:return`mat2x4<${J}>`;default:throw new Error(`${u}-component is not supported.`)}})(),ae=()=>{let N=`
+ // reuse a data
+ var input_offset = ${z.indicesToOffset(`${z.type.indices}(batch, row, word_offset)`)};
+ var a_data: ${oe};
+ for (var j: u32 = 0; j < ${8/u}; j++) {
+ a_data[j] = ${z.getByOffset("input_offset")};
+ input_offset++;
+ }
+ `;for(let O=0;O> 4) & b_mask);
+ b_quantized_values = ${oe}(${Array.from({length:4},(G,se)=>`${J}(b_value_lower[${se}]), ${J}(b_value_upper[${se}])`).join(", ")});
+ b_dequantized_values = ${u===1?`${oe}(${Array.from({length:8},(G,se)=>`(b_quantized_values[${se}] - ${U?`zero_point${O}`:"zero_point"}) * scale${O}`).join(", ")});`:`(b_quantized_values - ${oe}(${Array(8).fill(`${U?`zero_point${O}`:"zero_point"}`).join(",")})) * scale${O};`};
+ workgroup_shared[local_id.x * ${E} + ${Math.floor(O/g)}]${g>1?`[${O%g}]`:""} += ${Array.from({length:8/u},(G,se)=>`${u===1?`a_data[${se}] * b_dequantized_values[${se}]`:`dot(a_data[${se}], b_dequantized_values[${se}])`}`).join(" + ")};
+ `;return N},ce=()=>{let N=`
+ var col_index = col * ${g};
+ ${U?`
+ let zero_point_bytes_per_col = (nBlocksPerCol + 1) / 2;
+ var zero_point_byte_count: u32;
+ var zero_point_word_index: u32;
+ var zero_point_byte_offset: u32;
+ let zero_point_nibble_offset: u32 = block & 0x1u;
+ var zero_point_bits_offset: u32;
+ var zero_point_word: u32;`:`
+ // The default zero point is 8 for unsigned 4-bit quantization.
+ let zero_point = ${J}(8);`}
+ `;for(let O=0;O> 0x1u);
+ zero_point_word_index = zero_point_byte_count >> 0x2u;
+ zero_point_byte_offset = zero_point_byte_count & 0x3u;
+ zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);
+ zero_point_word = ${U.getByOffset("zero_point_word_index")} >> zero_point_bits_offset;
+ let zero_point${O} = ${J}((zero_point_word) & 0xFu);`:""}
+ col_index += 1;`;return N},he=()=>{let N=`col_index = col * ${g};`;for(let O=0;O;
+ var b_value_upper: vec4;
+ var b_quantized_values: ${oe};
+ var b_dequantized_values: ${oe};`,N};return`
+ var workgroup_shared: array<${H.type.value}, ${E*M}>;
+ ${T.declareVariables(...q,H)}
+ ${T.mainStart([M,1,1])}
+ let output_indices = ${H.offsetToIndices(`(global_idx / ${M}) * ${E}`)};
+ let col = output_indices[2];
+ let row = output_indices[1];
+ let batch = output_indices[0];
+ let nBlocksPerCol = uniforms.b_shape[1];
+
+ for (var block = local_id.x; block < nBlocksPerCol; block += ${M}) {
+ //process one block
+ var word_offset: u32 = block * ${r.blockSize/u};
+ ${ce()}
+ for (var word: u32 = 0; word < ${c}; word += ${h}) {
+ ${he()}
+ for (var i: u32 = 0; i < ${h}; i++) {
+ ${ae()}
+ word_offset += ${8/u};
+ }
+ }
+ }
+ workgroupBarrier();
+
+ if (local_id.x < ${E}) {
+ var output_value: ${H.type.value} = ${H.type.value}(0);
+ var workgroup_shared_offset: u32 = local_id.x;
+ for (var b: u32 = 0u; b < ${M}u; b++) {
+ output_value += workgroup_shared[workgroup_shared_offset];
+ workgroup_shared_offset += ${E};
+ }
+ ${H.setByIndices(`${H.type.indices}(batch, row, col + local_id.x)`,"output_value")};
+ }
+ }`};return{name:"MatMulNBits",shaderCache:{hint:`${r.blockSize};${r.bits};${u};${h};${g};${E};${M}`,inputDependencies:Array(e.length).fill("rank")},getRunData:()=>({outputs:[{dims:_,dataType:p}],dispatchGroup:{x:I},programUniforms:y}),getShaderSource:w}},Vg=(e,r)=>{let t=e[0].dims,s=t.length,o=t[s-2],n=r.k,i=r.n,a=t.slice(0,s-2),l=xe.size(a),c=e[1].dims[2]/4,p=e[0].dataType,u=Jt(r.k),h=Jt(c),g=a.concat([o,i]),_=128,E=i%8===0?8:i%4===0?4:1,I=_/E,M=I*h*8,y=M/u,$=M/r.blockSize,P=xe.size(g)/E,b=[],w=[l,o,n/u],T=xe.convertShape(e[1].dims).slice();T.splice(-1,1,c/h),b.push(...nt(w)),b.push(...nt(T)),b.push(...nt(e[2].dims)),e.length===4&&b.push(...nt(xe.convertShape(e[3].dims)));let k=[l,o,i];b.push(...nt(k));let z=R=>{let Q=w.length,q=$e("a",e[0].dataType,Q,u),U=$e("b",12,T.length,h),Z=$e("scales",e[2].dataType,e[2].dims.length),H=[q,U,Z],J=e.length===4?$e("zero_points",12,e[3].dims.length):void 0;J&&H.push(J);let oe=k.length,ae=tt("output",e[0].dataType,oe),ce=pr(e[0].dataType),he=()=>{switch(u){case 1:return`
+ let a_data0 = vec4<${ce}>(sub_a[word_offset], sub_a[word_offset + 1], sub_a[word_offset + 2], sub_a[word_offset + 3]);
+ let a_data1 = vec4<${ce}>(sub_a[word_offset + 4], sub_a[word_offset + 5], sub_a[word_offset + 6], sub_a[word_offset + 7]);`;case 2:return`
+ let a_data0 = vec4<${ce}>(sub_a[word_offset], sub_a[word_offset + 1]);
+ let a_data1 = vec4<${ce}>(sub_a[word_offset + 2], sub_a[word_offset + 3]);`;case 4:return`
+ let a_data0 = sub_a[word_offset];
+ let a_data1 = sub_a[word_offset + 1];`;default:throw new Error(`${u}-component is not supported.`)}};return`
+ var sub_a: array<${q.type.value}, ${y}>;
+ var inter_results: array, ${E}>;
+ ${R.declareVariables(...H,ae)}
+ ${R.mainStart([I,E,1])}
+ let output_indices = ${ae.offsetToIndices(`workgroup_index * ${E}`)};
+ let col = output_indices[2];
+ let row = output_indices[1];
+ let batch = output_indices[0];
+ let n_blocks_per_col = uniforms.b_shape[1];
+ let num_tiles = (n_blocks_per_col - 1) / ${$} + 1;
+
+ // Loop over shared dimension.
+ for (var tile: u32 = 0; tile < num_tiles; tile += 1) {
+ let a_col_start = tile * ${y};
+ // load one tile A data into shared memory.
+ for (var a_offset = local_idx; a_offset < ${y}; a_offset += ${_})
+ {
+ let a_col = a_col_start + a_offset;
+ if (a_col < uniforms.a_shape[2])
+ {
+ sub_a[a_offset] = ${q.getByIndices(`${q.type.indices}(batch, row, a_col)`)};
+ } else {
+ sub_a[a_offset] = ${q.type.value}(0);
+ }
+ }
+ workgroupBarrier();
+
+ // each thread process one block
+ let b_row = col + local_id.y;
+ let block = tile * ${$} + local_id.x;
+ ${J?`
+ let zero_point_bytes_per_col = (n_blocks_per_col + 1) / 2;
+ let zero_point_byte_count = b_row * zero_point_bytes_per_col + (block >> 0x1u);
+ let zero_point_word_index = zero_point_byte_count >> 0x2u;
+ let zero_point_byte_offset = zero_point_byte_count & 0x3u;
+ let zero_point_nibble_offset: u32 = block & 0x1u;
+ let zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);
+ let zero_point_word = ${J.getByOffset("zero_point_word_index")} >> zero_point_bits_offset;
+ let zero_point = ${ce}((zero_point_word) & 0xFu);`:`
+ // The default zero point is 8 for unsigned 4-bit quantization.
+ let zero_point = ${ce}(8);`}
+ let scale = ${Z.getByOffset("b_row * n_blocks_per_col + block")};
+ let b_data = ${U.getByIndices(`${U.type.indices}(b_row, block, 0)`)};
+ var word_offset = local_id.x * ${r.blockSize/u};
+ for (var i: u32 = 0; i < ${h}; i++) {
+ ${he()}
+ let b_value = ${h===1?"b_data":"b_data[i]"};
+ let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);
+ let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu);
+ let b_quantized_values = mat2x4<${ce}>(${Array.from({length:4},(N,O)=>`${ce}(b_value_lower[${O}]), ${ce}(b_value_upper[${O}])`).join(", ")});
+ let b_dequantized_values = (b_quantized_values - mat2x4<${ce}>(${Array(8).fill("zero_point").join(",")})) * scale;
+ inter_results[local_id.y][local_id.x] += ${Array.from({length:2},(N,O)=>`${`dot(a_data${O}, b_dequantized_values[${O}])`}`).join(" + ")};
+ word_offset += ${8/u};
+ }
+ workgroupBarrier();
+ }
+
+ if (local_idx < ${E}) {
+ var output_value: ${ae.type.value} = ${ae.type.value}(0);
+ for (var b = 0u; b < ${I}; b++) {
+ output_value += inter_results[local_idx][b];
+ }
+ if (col + local_idx < uniforms.output_shape[2])
+ {
+ ${ae.setByIndices(`${ae.type.indices}(batch, row, col + local_idx)`,"output_value")}
+ }
+ }
+ }`};return{name:"BlockwiseMatMulNBits32",shaderCache:{hint:`${r.blockSize};${u};${h};${I};${E}`,inputDependencies:Array(e.length).fill("rank")},getRunData:()=>({outputs:[{dims:g,dataType:p}],dispatchGroup:{x:P},programUniforms:b}),getShaderSource:z}},V0=(e,r)=>{Ng(e.inputs,r),r.blockSize===32&&e.adapterInfo.isVendor("intel")&&e.adapterInfo.isArchitecture("gen-12lp")?e.compute(Vg(e.inputs,r)):e.compute(jg(e.inputs,r))},U0=e=>Lt(e)}),Ug,Wg,Gg,Kg,Hg,qg,Qg,Xg,W0,Ex=Ve(()=>{mt(),bt(),xt(),Ug=e=>{if(!e||e.length<1)throw new Error("Too few inputs");if(e[0].dataType!==1&&e[0].dataType!==10)throw new Error("Input type must be float or float16.");if(e.length>=2){let r=e[0].dims.length*2===e[1].dims[0];if(e.length===4&&(r=e[3].dims[0]*2===e[1].dims[0]),!r)throw new Error("The pads should be a 1D tensor of shape [2 * input_rank] or [2 * num_axes].")}},Wg=(e,r,t)=>{let s="";for(let o=r-1;o>=0;--o)s+=`
+ k = i32(${e.indicesGet("indices",o)}) - ${rt("uniforms.pads",o,t)};
+ if (k < 0) {
+ break;
+ }
+ if (k >= i32(${rt("uniforms.x_shape",o,r)})) {
+ break;
+ }
+ offset += k * i32(${rt("uniforms.x_strides",o,r)});
+ `;return`
+ value = ${e.type.value}(uniforms.constant_value);
+ for (var i = 0; i < 1; i++) {
+ var offset = 0;
+ var k = 0;
+ ${s}
+ value = x[offset];
+ }
+ `},Gg=(e,r,t)=>{let s="";for(let o=r-1;o>=0;--o)s+=`
+ k = i32(${e.indicesGet("indices",o)}) - ${rt("uniforms.pads",o,t)};
+ if (k < 0) {
+ k = -k;
+ }
+ {
+ let _2n_1 = 2 * (i32(${rt("uniforms.x_shape",o,r)}) - 1);
+ k = k % _2n_1;
+ if(k >= i32(${rt("uniforms.x_shape",o,r)})) {
+ k = _2n_1 - k;
+ }
+ }
+ offset += k * i32(${rt("uniforms.x_strides",o,r)});
+ `;return`
+ var offset = 0;
+ var k = 0;
+ ${s}
+ value = x[offset];
+ `},Kg=(e,r,t)=>{let s="";for(let o=r-1;o>=0;--o)s+=`
+ k = i32(${e.indicesGet("indices",o)}) - ${rt("uniforms.pads",o,t)};
+ if (k < 0) {
+ k = 0;
+ }
+ if (k >= i32(${rt("uniforms.x_shape",o,r)})) {
+ k = i32(${rt("uniforms.x_shape",o,r)}) - 1;
+ }
+ offset += k * i32(${rt("uniforms.x_strides",o,r)});
+ `;return`
+ var offset = 0;
+ var k = 0;
+ ${s}
+ value = x[offset];
+ `},Hg=(e,r,t)=>{let s="";for(let o=r-1;o>=0;--o)s+=`
+ k = i32(${e.indicesGet("indices",o)}) - ${rt("uniforms.pads",o,t)};
+ if (k < 0) {
+ k += i32(${rt("uniforms.x_shape",o,r)}]);
+ }
+ if (k >= i32(${rt("uniforms.x_shape",o,r)})) {
+ k -= i32(${rt("uniforms.x_shape",o,r)});
+ }
+ offset += k * i32(${rt("uniforms.x_strides",o,r)});
+ `;return`
+ var offset = 0;
+ var k = 0;
+ ${s}
+ value = x[offset];
+ `},qg=(e,r,t)=>{switch(t.mode){case 0:return Wg(e,r,t.pads.length);case 1:return Gg(e,r,t.pads.length);case 2:return Kg(e,r,t.pads.length);case 3:return Hg(e,r,t.pads.length);default:throw new Error("Invalid mode")}},Qg=(e,r)=>{let t=xe.padShape(e[0].dims.slice(),r.pads),s=e[0].dims,o=xe.size(t),n=[{type:12,data:o},{type:6,data:r.pads}],i=e.length>=3&&e[2].data;r.mode===0&&n.push({type:i?e[2].dataType:1,data:r.value}),n.push(...nt(e[0].dims,t));let a=["rank"],l=c=>{let p=tt("output",e[0].dataType,t.length),u=$e("x",e[0].dataType,s.length),h=u.type.value,g=qg(p,s.length,r),_=[{name:"output_size",type:"u32"},{name:"pads",type:"i32",length:r.pads.length}];return r.mode===0&&_.push({name:"constant_value",type:i?h:"f32"}),`
+ ${c.registerUniforms(_).declareVariables(u,p)}
+ ${c.mainStart()}
+ ${c.guardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")}
+
+ let indices = ${p.offsetToIndices("global_idx")};
+
+ var value = ${h}(0);
+ ${g}
+ output[global_idx] = value;
+ }`};return{name:"Pad",shaderCache:{hint:`${r.mode}${i}`,inputDependencies:a},getRunData:()=>({outputs:[{dims:t,dataType:e[0].dataType}],dispatchGroup:{x:Math.ceil(xe.size(t)/64)},programUniforms:n}),getShaderSource:l}},Xg=(e,r)=>{if(e.length>1){let t=e[1].getBigInt64Array(),s=e.length>=3&&e[2].data?e[2].dataType===10?e[2].getUint16Array()[0]:e[2].getFloat32Array()[0]:0,o=e[0].dims.length,n=new Int32Array(2*o).fill(0);if(e.length>=4){let a=e[3].getBigInt64Array();for(let l=0;ln[Number(l)]=Number(a));let i=[];return n.forEach(a=>i.push(a)),{mode:r.mode,value:s,pads:i}}else return r},W0=(e,r)=>{Ug(e.inputs);let t=Xg(e.inputs,r);e.compute(Qg(e.inputs,t),{inputs:[0]})}}),sa,xc,Tc,Ec,Pc,Jg,Yg,Cc,Sc,G0,K0,$c,H0,q0,kc,Q0,X0,J0,Y0,Px=Ve(()=>{Ms(),mt(),bt(),xt(),sa=e=>{if(Kt.webgpu.validateInputContent&&(!e||e.length!==1))throw new Error("Pool ops requires 1 input.")},xc=(e,r,t)=>{let s=r.format==="NHWC",o=e.dims.slice();s&&o.splice(1,0,o.pop());let n=Object.hasOwnProperty.call(r,"dilations"),i=r.kernelShape.slice(),a=r.strides.slice(),l=n?r.dilations.slice():[],c=r.pads.slice();od.adjustPoolAttributes(t,o,i,a,l,c);let p=od.computePoolOutputShape(t,o,a,l,i,c,r.autoPad),u=Object.assign({},r);n?Object.assign(u,{kernelShape:i,strides:a,pads:c,dilations:l,cacheKey:r.cacheKey}):Object.assign(u,{kernelShape:i,strides:a,pads:c,cacheKey:r.cacheKey});let h=p.slice();return h.push(h.splice(1,1)[0]),[u,s?h:p]},Tc=(e,r)=>{let t=r.format==="NHWC",s=xe.size(e),o=xe.size(r.kernelShape),n=[{type:12,data:s},{type:12,data:o}],i=[{name:"outputSize",type:"u32"},{name:"kernelSize",type:"u32"}];if(r.kernelShape.length<=2){let a=r.kernelShape[r.kernelShape.length-1],l=r.strides[r.strides.length-1],c=r.pads[r.pads.length/2-1],p=r.pads[r.pads.length-1],u=!!(c+p);n.push({type:12,data:a},{type:12,data:l},{type:12,data:c},{type:12,data:p}),i.push({name:"kw",type:"u32"},{name:"sw",type:"u32"},{name:"pwStart",type:"u32"},{name:"pwEnd",type:"u32"});let h=!1;if(r.kernelShape.length===2){let g=r.kernelShape[r.kernelShape.length-2],_=r.strides[r.strides.length-2],E=r.pads[r.pads.length/2-2],I=r.pads[r.pads.length-2];h=!!(E+I),n.push({type:12,data:g},{type:12,data:_},{type:12,data:E},{type:12,data:I}),i.push({name:"kh",type:"u32"},{name:"sh",type:"u32"},{name:"phStart",type:"u32"},{name:"phEnd",type:"u32"})}return[n,i,!0,u,h]}else{if(t)throw new Error("Pooling with kernelShape.length > 2 is not supported for NHWC format.");let a=xe.computeStrides(r.kernelShape);n.push({type:12,data:a},{type:12,data:r.pads},{type:12,data:r.strides}),i.push({name:"kernelStrides",type:"u32",length:a.length},{name:"pads",type:"u32",length:r.pads.length},{name:"strides",type:"u32",length:r.strides.length});let l=r.pads.reduce((c,p)=>c+p);return[n,i,!!l,!1,!1]}},Ec=(e,r,t,s,o,n,i,a,l,c,p,u)=>{let h=o.format==="NHWC",g=r.type.value,_=tt("output",r.type.tensor,s);if(o.kernelShape.length<=2){let E="",I="",M="",y=t-(h?2:1);if(p?E=`
+ for (var i: u32 = 0u; i < uniforms.kw; i++) {
+ xIndices[${y}] = indices[${y}] * uniforms.sw - uniforms.pwStart + i;
+ if (xIndices[${y}] < 0 || xIndices[${y}]
+ >= uniforms.x_shape[${y}]) {
+ pad++;
+ continue;
+ }
+ let x_val = x[${r.indicesToOffset("xIndices")}];
+ ${n}
+ }`:E=`
+ for (var i: u32 = 0u; i < uniforms.kw; i++) {
+ xIndices[${y}] = indices[${y}] * uniforms.sw - uniforms.pwStart + i;
+ let x_val = x[${r.indicesToOffset("xIndices")}];
+ ${n}
+ }`,o.kernelShape.length===2){let $=t-(h?3:2);u?I=`
+ for (var j: u32 = 0u; j < uniforms.kh; j++) {
+ xIndices[${$}] = indices[${$}] * uniforms.sh - uniforms.phStart + j;
+ if (xIndices[${$}] < 0 || xIndices[${$}] >= uniforms.x_shape[${$}]) {
+ pad += i32(uniforms.kw);
+ continue;
+ }
+ `:I=`
+ for (var j: u32 = 0u; j < uniforms.kh; j++) {
+ xIndices[${$}] = indices[${$}] * uniforms.sh - uniforms.phStart + j;
+ `,M=`
+ }
+ `}return`
+ ${e.registerUniforms(l).declareVariables(r,_)}
+
+ ${e.mainStart()}
+ ${e.guardAgainstOutOfBoundsWorkgroupSizes("uniforms.outputSize")}
+
+ let indices = ${_.offsetToIndices("global_idx")};
+ var xIndices = ${_.offsetToIndices("global_idx")};
+
+ var value = ${g}(${a});
+ var pad = 0;
+ ${I}
+ ${E}
+ ${M}
+ ${i}
+
+ output[global_idx] = value;
+ }`}else{if(h)throw new Error("Pooling with kernelShape.length > 2 is not supported for NHWC format.");let E=o.kernelShape.length,I=o.pads.length,M="";return c?M=`
+ if (xIndices[j] >= uniforms.x_shape[j]) {
+ pad++;
+ isPad = true;
+ break;
+ }
+ }
+ if (!isPad) {
+ let x_val = x[${r.indicesToOffset("xIndices")}];
+ ${n}
+ }`:M=`
+ }
+ let x_val = x[${r.indicesToOffset("xIndices")}];
+ ${n}
+ `,`
+ ${e.registerUniforms(l).declareVariables(r,_)}
+
+ ${e.mainStart()}
+ ${e.guardAgainstOutOfBoundsWorkgroupSizes("uniforms.outputSize")}
+ let indices = ${_.offsetToIndices("global_idx")};
+ var xIndices = ${_.offsetToIndices("global_idx")};
+
+ var offsets: array;
+
+ var value = ${g}(${a});
+ var pad = 0;
+ var isPad = false;
+
+ for (var i: u32 = 0u; i < uniforms.kernelSize; i++) {
+ var offset = i;
+ for (var j = 0u; j < ${E-1}u; j++) {
+ offsets[j] = offset / ${rt("uniforms.kernelStrides","j",E)};
+ offset -= offsets[j] * ${rt("uniforms.kernelStrides","j",E)};
+ }
+ offsets[${E-1}] = offset;
+
+ isPad = false;
+ for (var j = ${t-E}u; j < ${t}u; j++) {
+ xIndices[j] = indices[j] * ${rt("uniforms.strides",`j - ${t-E}u`,E)}
+ + offsets[j - ${t-E}u] - ${rt("uniforms.pads","j - 2u",I)};
+ ${M}
+ }
+ ${i}
+
+ output[global_idx] = value;
+ }`}},Pc=e=>`${e.format};${e.ceilMode};${e.autoPad};${e.kernelShape.length}`,Jg=e=>`${Pc(e)};${e.countIncludePad}`,Yg=e=>`${Pc(e)};${e.storageOrder};${e.dilations}`,Cc=e=>({format:e.format,autoPad:["NOTSET","VALID","SAME_UPPER","SAME_LOWER"][e.auto_pad],ceilMode:e.ceil_mode,kernelShape:e.kernel_shape,strides:e.strides,pads:e.pads}),Sc=(e,r,t,s)=>{let[o,n]=xc(r,s,t),i=$e("x",r.dataType,r.dims.length),a=i.type.value,l="value += x_val;",c="";o.countIncludePad?c+=`value /= ${a}(uniforms.kernelSize);`:c+=`value /= ${a}(i32(uniforms.kernelSize) - pad);`;let[p,u,h,g,_]=Tc(n,o);p.push(...nt(r.dims,n));let E=["rank"];return{name:e,shaderCache:{hint:`${s.cacheKey};${h};${g};${_}`,inputDependencies:E},getRunData:()=>({outputs:[{dims:n,dataType:r.dataType}],dispatchGroup:{x:Math.ceil(xe.size(n)/64)},programUniforms:p}),getShaderSource:I=>Ec(I,i,r.dims.length,n.length,o,l,c,0,u,h,g,_)}},G0=e=>{let r=e.count_include_pad!==0,t=Cc(e);if(t.ceilMode!==0)throw new Error("using ceil() in shape computation is not yet supported for AveragePool");let s={countIncludePad:r,...t,cacheKey:""};return{...s,cacheKey:Jg(s)}},K0=(e,r)=>{sa(e.inputs),e.compute(Sc("AveragePool",e.inputs[0],!1,r))},$c={autoPad:"",ceilMode:0,countIncludePad:!1,kernelShape:[],strides:[],pads:[],storageOrder:0,dilations:[]},H0=e=>{let r=e.format;return{format:r,...$c,cacheKey:r}},q0=(e,r)=>{sa(e.inputs),e.compute(Sc("GlobalAveragePool",e.inputs[0],!0,r))},kc=(e,r,t,s)=>{let[o,n]=xc(r,s,t),i=`
+ value = max(x_val, value);
+ `,a="",l=$e("x",r.dataType,r.dims.length),c=["rank"],[p,u,h,g,_]=Tc(n,o);return p.push(...nt(r.dims,n)),{name:e,shaderCache:{hint:`${s.cacheKey};${h};${g};${_}`,inputDependencies:c},getRunData:()=>({outputs:[{dims:n,dataType:r.dataType}],dispatchGroup:{x:Math.ceil(xe.size(n)/64)},programUniforms:p}),getShaderSource:E=>Ec(E,l,r.dims.length,n.length,o,i,a,r.dataType===10?-65504:-1e5,u,h,g,_)}},Q0=(e,r)=>{sa(e.inputs),e.compute(kc("MaxPool",e.inputs[0],!1,r))},X0=e=>{let r=e.storage_order,t=e.dilations,s=Cc(e);if(r!==0)throw new Error("column major storage order is not yet supported for MaxPool");if(s.ceilMode!==0)throw new Error("using ceil() in shape computation is not yet supported for MaxPool");let o={storageOrder:r,dilations:t,...s,cacheKey:""};return{...o,cacheKey:Yg(o)}},J0=e=>{let r=e.format;return{format:r,...$c,cacheKey:r}},Y0=(e,r)=>{sa(e.inputs),e.compute(kc("GlobalMaxPool",e.inputs[0],!0,r))}}),Zg,ew,Z0,eb,Cx=Ve(()=>{mt(),bt(),tr(),xt(),Zg=(e,r)=>{if(e.length<2||e.length>3)throw new Error("DequantizeLinear requires 2 or 3 inputs.");if(e.length===3&&e[1].dims===e[2].dims)throw new Error("x-scale and x-zero-point must have the same shape.");if(e.length===3&&e[0].dataType!==e[2].dataType)throw new Error("x and x-zero-point must have the same data type.");if(e[0].dataType===6&&e.length>2)throw new Error("In the case of dequantizing int32 there is no zero point.");if(e[1].dims.length!==0&&e[1].dims.length!==1&&e[1].dims.length!==e[0].dims.length)throw new Error("scale input must be a scalar, a 1D tensor, or have the same rank as the input tensor.");if(e.length>2){if(e[0].dataType!==e[2].dataType)throw new Error("x and x-zero-point must have the same data type.");if(e[1].dims.length!==e[2].dims.length)throw new Error("scale and zero-point inputs must have the same rank.");if(!e[1].dims.map((t,s)=>t===e[2].dims[s]).reduce((t,s)=>t&&s,!0))throw new Error("scale and zero-point inputs must have the same shape.")}if(r.blockSize>0){if(e[1].dims.length===0||e[1].dims.length===1&&e[1].dims[0]===1)throw new Error("blockSize must be set only for block quantization.");if(!e[1].dims.map((o,n)=>n===r.axis||o===e[0].dims[n]).reduce((o,n)=>o&&n,!0))throw new Error("For block qunatization, scale input shape to match the input shape except for the axis");if(e[1].dims.length!==e[0].dims.length)throw new Error("For block qunatization the scale input rank must be the same as the x rank.");let t=e[0].dims[r.axis],s=e[1].dims[r.axis];if(r.blockSizeMath.ceil(t/(s-1)-1))throw new Error("blockSize must be with in the range [ceil(dI / Si), ceil(dI / (Si - 1) - 1)].")}},ew=(e,r)=>{let t=xe.normalizeAxis(r.axis,e[0].dims.length),s=e[0].dataType,o=s===3,n=e[0].dims,i=e[1].dataType,a=xe.size(n),l=s===3||s===2,c=l?[Math.ceil(xe.size(e[0].dims)/4)]:e[0].dims,p=e[1].dims,u=e.length>2?e[2]:void 0,h=u?l?[Math.ceil(xe.size(u.dims)/4)]:u.dims:void 0,g=p.length===0||p.length===1&&p[0]===1,_=g===!1&&p.length===1,E=Jt(a),I=g&&(!l||E===4),M=I?E:1,y=I&&!l?E:1,$=$e("input",l?12:s,c.length,y),P=$e("scale",i,p.length),b=u?$e("zero_point",l?12:s,h.length):void 0,w=tt("output",i,n.length,M),T=[$,P];b&&T.push(b);let k=[c,p];u&&k.push(h);let z=[{type:12,data:a/M},{type:12,data:t},{type:12,data:r.blockSize},...nt(...k,n)],R=Q=>{let q=[{name:"output_size",type:"u32"},{name:"axis",type:"u32"},{name:"block_size",type:"u32"}];return`
+ ${Q.registerUniforms(q).declareVariables(...T,w)}
+ ${Q.mainStart()}
+ ${Q.guardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")}
+ let output_indices = ${w.offsetToIndices("global_idx")};
+
+ // Set input x
+ ${l?`
+ let input = ${$.getByOffset("global_idx / 4")};
+ let x_vec = ${o?"unpack4xI8(input)":"unpack4xU8(input)"};
+ let x_value = ${M===1?"x_vec[global_idx % 4]":"x_vec"};`:`let x_value = ${$.getByOffset("global_idx")};`};
+
+ // Set scale input
+ ${g?`let scale_value= ${P.getByOffset("0")}`:_?`
+ let scale_index = ${w.indicesGet("output_indices","uniforms.axis")};
+ let scale_value= ${P.getByOffset("scale_index")};`:`
+ var scale_indices: ${P.type.indices} = output_indices;
+ let index = ${P.indicesGet("scale_indices","uniforms.axis")} / uniforms.block_size;
+ ${P.indicesSet("scale_indices","uniforms.axis","index")};
+ let scale_value= ${P.getByIndices("scale_indices")};`};
+
+ // Set zero-point input
+ ${b?g?l?`
+ let zero_point_input = ${b.getByOffset("0")};
+ let zero_point_vec = ${o?"unpack4xI8(zero_point_input)":"unpack4xU8(zero_point_input)"};
+ let zero_point_value= zero_point_vec[0]`:`let zero_point_value = ${b.getByOffset("0")}`:_?l?`
+ let zero_point_index = ${w.indicesGet("output_indices","uniforms.axis")};
+ let zero_point_input = ${b.getByOffset("zero_point_index / 4")};
+ let zero_point_vec = ${o?"unpack4xI8(zero_point_input)":"unpack4xU8(zero_point_input)"};
+ let zero_point_value = zero_point_vec[zero_point_index % 4]`:`
+ let zero_point_index = ${w.indicesGet("output_indices","uniforms.axis")};
+ let zero_point_value = ${b.getByOffset("zero_point_index")};`:l?`
+ let zero_point_offset = ${P.indicesToOffset("scale_indices")};
+ let zero_point_input = ${b.getByOffset("zero_point_offset / 4")};
+ let zero_point_vec = ${o?"unpack4xI8(zero_point_input)":"unpack4xU8(zero_point_input)"};
+ let zero_point_value = zero_point_vec[zero_point_offset % 4];`:`let zero_point_value = ${b.getByIndices("scale_indices")};`:`let zero_point_value = ${l?o?"i32":"u32":$.type.value}(0);`};
+ // Compute and write output
+ ${w.setByOffset("global_idx",`${w.type.value}(x_value - zero_point_value) * scale_value`)};
+ }`};return{name:"DequantizeLinear",shaderCache:{hint:r.cacheKey,inputDependencies:b?["rank","rank","rank"]:["rank","rank"]},getShaderSource:R,getRunData:()=>({outputs:[{dims:n,dataType:i}],dispatchGroup:{x:Math.ceil(a/M/64),y:1,z:1},programUniforms:z})}},Z0=(e,r)=>{Zg(e.inputs,r),e.compute(ew(e.inputs,r))},eb=e=>Lt({axis:e.axis,blockSize:e.blockSize})}),tw,rw,tb,Sx=Ve(()=>{Ms(),mt(),xt(),tw=(e,r,t)=>{let s=e===r,o=er&&t>0;if(s||o||n)throw new Error("Range these inputs' contents are invalid.")},rw=(e,r,t,s)=>{let o=Math.abs(Math.ceil((r-e)/t)),n=[o],i=o,a=[{type:12,data:i},{type:s,data:e},{type:s,data:t},...nt(n)],l=c=>{let p=tt("output",s,n.length),u=p.type.value,h=[{name:"outputSize",type:"u32"},{name:"start",type:u},{name:"delta",type:u}];return`
+ ${c.registerUniforms(h).declareVariables(p)}
+ ${c.mainStart()}
+ ${c.guardAgainstOutOfBoundsWorkgroupSizes("uniforms.outputSize")}
+ output[global_idx] = uniforms.start + ${u}(global_idx) * uniforms.delta;
+ }`};return{name:"Range",shaderCache:{hint:`${s}`},getShaderSource:l,getRunData:()=>({outputs:[{dims:n,dataType:s}],dispatchGroup:{x:Math.ceil(i/64)},programUniforms:a})}},tb=e=>{let r=0,t=0,s=0;e.inputs[0].dataType===6?(r=e.inputs[0].getInt32Array()[0],t=e.inputs[1].getInt32Array()[0],s=e.inputs[2].getInt32Array()[0]):e.inputs[0].dataType===1&&(r=e.inputs[0].getFloat32Array()[0],t=e.inputs[1].getFloat32Array()[0],s=e.inputs[2].getFloat32Array()[0]),Kt.webgpu.validateInputContent&&tw(r,t,s),e.compute(rw(r,t,s,e.inputs[0].dataType),{inputs:[]})}}),sw,nw,rb,sb,$x=Ve(()=>{mt(),bt(),tr(),xt(),sw=(e,r,t,s)=>{if(e!=="none"&&s!=="i32"&&s!=="u32"&&s!=="f32")throw new Error(`Input ${s} is not supported with reduction ${e}.`);let o=`{
+ var oldValue = 0;
+ loop {
+ let newValueF32 =`,n=`;
+ let newValue = bitcast(newValueF32);
+ let res = atomicCompareExchangeWeak(&${r}, oldValue, newValue);
+ if res.exchanged {
+ break;
+ }
+ oldValue = res.old_value;
+ }
+ }`;switch(e){case"none":return`${r}=${t};`;case"add":return s==="i32"||s==="u32"?`atomicAdd(&${r}, bitcast<${s}>(${t}));`:`
+ ${o}bitcast<${s}>(oldValue) + (${t})${n}`;case"max":return s==="i32"||s==="u32"?`atomicMax(&${r}, bitcast<${s}>(${t}));`:`
+ ${o}max(bitcast(oldValue), (${t}))${n}`;case"min":return s==="i32"||s==="u32"?`atomicMin(&${r}, bitcast<${s}>(${t}));`:`${o}min(bitcast<${s}>(oldValue), (${t}))${n}`;case"mul":return`${o}(bitcast<${s}>(oldValue) * (${t}))${n}`;default:throw new Error(`Reduction ${e} is not supported.`)}},nw=(e,r)=>{let t=e[0].dims,s=e[1].dims,o=t,n=1,i=Math.ceil(xe.size(s)/n),a=s[s.length-1],l=xe.sizeFromDimension(t,a),c=[{type:12,data:i},{type:12,data:a},{type:12,data:l},...nt(e[1].dims,e[2].dims,o)],p=u=>{let h=$e("indices",e[1].dataType,e[1].dims.length),g=$e("updates",e[2].dataType,e[2].dims.length,n),_=r.reduction!=="none"&&r.reduction!==""?Fy("output",e[0].dataType,o.length):tt("output",e[0].dataType,o.length,n);return`
+ ${u.registerUniform("output_size","u32").registerUniform("last_index_dimension","u32").registerUniform("num_updates_elements","u32").declareVariables(h,g,_)}
+ ${u.mainStart()}
+ ${u.guardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")}
+ var hasDuplicates = false;
+ if (${r.reduction==="none"}) {
+ let n = ${xe.size(s)};
+ for (var i = 0; i < n; i = i + 1) {
+ for (var j = i + 1; j < n; j = j + 1) {
+ var index_i = i32(indices[i].x);
+ var index_j = i32(indices[j].x);
+ if (index_i == index_j) {
+ hasDuplicates = true;
+ break;
+ }
+ }
+ if (hasDuplicates) {
+ break;
+ }
+ }
+ }
+
+ var data_offset = 0u;
+ var indices_start = uniforms.last_index_dimension * global_idx;
+ if (${r.reduction==="none"} && hasDuplicates) {
+ if (global_idx != 0u) {
+ return;
+ }
+ indices_start = 0u;
+ }
+ let indices_end = indices_start + uniforms.last_index_dimension;
+ for (var i = indices_start; i < indices_end; i++) {
+ var index = i32(indices[i].x);
+ ${e[0].dims.length===1?`
+ let element_count_dim = uniforms.output_strides;
+ let dim_value = uniforms.output_shape;`:`
+ let element_count_dim = uniforms.output_strides[i - indices_start];
+ let dim_value = uniforms.output_shape[i - indices_start + uniforms.last_index_dimension];`}
+ if (index >= 0) {
+ if (index >= i32(dim_value)) {
+ index = i32(dim_value - 1);
+ }
+ } else {
+ if (index < -i32(dim_value)) {
+ index = 0;
+ } else {
+ index += i32(dim_value);
+ }
+ }
+ data_offset += u32((u32(index) * element_count_dim));
+ }
+
+ for (var i = 0u; i < uniforms.num_updates_elements; i++) {
+ let value = updates[uniforms.num_updates_elements * global_idx + i];
+ ${sw(r.reduction,"output[data_offset + i]","value",_.type.value)}
+ }
+
+ }`};return{name:"ScatterND",shaderCache:{hint:`${r.cacheKey}_${r.reduction}`,inputDependencies:["rank","rank"]},getRunData:()=>({outputs:[{dims:o,dataType:e[0].dataType}],dispatchGroup:{x:Math.ceil(i/64)},programUniforms:c}),getShaderSource:p}},rb=e=>Lt({reduction:e.reduction}),sb=(e,r)=>{e.compute(nw(e.inputs,r),{inputs:[e.inputs[1],e.inputs[2]],outputs:[]})}}),ow,iw,aw,Ic,lw,dw,cw,uw,pw,hw,mw,fw,Ac,_w,gw,ww,yw,Mw,nb,ob,kx=Ve(()=>{mt(),bt(),tr(),xt(),ow=(e,r)=>{if(e.every(t=>t>0||(()=>{throw new Error("Resize requires scales input values to be positive")})),e.length>0){if(r.mode==="linear"){if(!(e.length===2||e.length===3||e.length===4&&e[0]===1&&e[1]===1||e.length===4&&e[0]===1&&e[3]===1||e.length===5&&e[0]===1&&e[1]===1))throw new Error(`For linear mode, Resize requires scales to be 2D, 3D, 4D with either two outermost or one innermost and
+ one outermost scale values equal to 1, or 5D with two outermost scale values equal to 1`)}else if(r.mode==="cubic"&&!(e.length===2||e.length===4&&e[0]===1&&e[1]===1||e.length===4&&e[0]===1&&e[3]===1))throw new Error("Resize requires scales input size to be 2 or 4 for cubic mode")}},iw=(e,r,t)=>{r.every(o=>o>=0&&o{throw new Error("Resize requires axes input values to be positive and less than rank")}));let s=new Array(t).fill(1);return r.forEach((o,n)=>s[o]=e[n]),s},aw=(e,r,t,s,o,n)=>{let[i,a,l]=t>10?[1,2,3]:[-1,e.length>1?1:-1,-1],c=e[0].dims.length;if(i>0&&e.length>i&&e[i].dims.length>0)e[i].getFloat32Array().forEach(p=>n.push(p));else if(r.coordinateTransformMode==="tf_crop_and_resize")throw new Error("Resize requires RoI input to be specified when coordinateTransformMode is tfCropAndResize");if(a>0&&e.length>a&&e[a].dims.length===1&&e[a].dims[0]>0){if(e[a].getFloat32Array().forEach(p=>s.push(p)),s.length!==0&&s.length!==c&&t>=18&&s.length!==r.axes.length)throw new Error("Resize requires scales input size to be same as input rank or axes size for opset 18 and up");ow(s,r),r.axes.length>0&&iw(s,r.axes,c).forEach((p,u)=>s[u]=p)}if(l>0&&e.length>l&&e[l].dims.length===1&&e[l].dims[0]>0&&(e[l].getBigInt64Array().forEach(p=>o.push(Number(p))),o.length!==0&&o.length!==c&&t>=18&&o.length!==r.axes.length))throw new Error("Resize requires sizes input size to be same as input rank or axes size for opset 18 and up");if(r.axes.length>0){if(s.length!==0&&s.length!==r.axes.length)throw new Error('Resize requires "scales" input size to be of axes rank when axes attributes is specified');if(o.length!==0&&o.length!==r.axes.length)throw new Error('Resize requires "sizes" input size to be of rank axes rank when axes attributes is specified')}if(typeof s<"u"&&typeof o<"u"&&s.length>0&&o.length>c)throw new Error("Resize requires only of scales or sizes to be specified")},Ic=(e,r,t,s)=>`
+ // The whole part and the fractional part are calculated separately due to inaccuracy of floating
+ // point division. As an example, f32(21) / f32(7) may evaluate to 2.99... instead of 3, causing an
+ // offset-by-one error later in floor().
+ let big = (${e}) * (${r});
+ let whole = ${s}(big / (${t}));
+ let fract = ${s}(big % (${t})) / ${s}(${t});
+ return whole + fract;
+`,lw=(e,r)=>`fn getOriginalCoordinateFromResizedCoordinate(xResized: u32, xScale: f32, lengthResized: u32,
+ lengthOriginal: u32, roiStart: f32, roiEnd: f32) -> ${r} { `+(()=>{switch(e){case"asymmetric":return`
+ if (xScale < 1.0 || floor(xScale) != xScale) {
+ return ${r}(xResized) / ${r}(xScale);
+ } else {
+ ${Ic("xResized","lengthOriginal","lengthResized",r)}
+ }
+ `;case"pytorch_half_pixel":return`if (lengthResized > 1) {
+ return (${r}(xResized) + 0.5) / ${r}(xScale) - 0.5;
+ } else {
+ return 0.0;
+ }`;case"tf_half_pixel_for_nn":return`return (${r}(xResized) + 0.5) / ${r}(xScale);`;case"align_corners":return`if (lengthResized == 1) {
+ return 0.0;
+ } else {
+ ${Ic("xResized","lengthOriginal - 1","lengthResized - 1",r)}
+ }`;case"tf_crop_and_resize":return`if (lengthResized > 1) {
+ return ${r}(roiStart) * ${r}(lengthOriginal - 1) +
+ (${r}(xResized) * ${r}(roiEnd - roiStart) * ${r}(lengthOriginal - 1)) /
+ ${r}(lengthResized - 1);
+ } else {
+ return 0.5 * ${r}(roiStart + roiEnd) * ${r}(lengthOriginal - 1);
+ }`;case"half_pixel_symmetric":return`const outputWidth = ${r}xScale * ${r}(lengthResized);
+ const adjustment = ${r}(lengthResized) / outputWidth;
+ const center = ${r}(lengthOriginal) / 2;
+ const offset = center * (1 - adjustment);
+ return offset + ((${r}(xResized) + 0.5) / ${r}(xScale)) - 0.5;`;case"half_pixel":return`return ((${r}(xResized) + 0.5) / ${r}(xScale)) - 0.5;`;default:throw new Error(`Coordinate transform mode ${e} is not supported`)}})()+"}",dw=(e,r,t)=>`fn getNearestPixelFromOriginal(xOriginal: ${t}, isDownSample: bool) -> ${t} {`+(()=>{switch(e){case"round_prefer_ceil":return"if (fract(xOriginal) == 0.5) { return ceil(xOriginal); } else { return round(xOriginal); }";case"floor":return"return floor(xOriginal);";case"ceil":return"return ceil(xOriginal);";case"round_prefer_floor":return"if (fract(xOriginal) == 0.5) { return floor(xOriginal); } else { return round(xOriginal); }";case"simple":default:if(r<11)return"if (isDownSample) { return ceil(xOriginal); } else { return xOriginal; }";throw new Error(`Nearest mode ${e} is not supported`)}})()+"}",cw=(e,r,t)=>{let s=new Array(t).fill(0).concat(new Array(t).fill(1)),o=e.length===0?s:e.slice();return r.length>0?(r.forEach((n,i)=>{s[n]=o[i],s[i+t]=o[r.length+i]}),s):o},uw=(e,r,t,s)=>{let o=[];if(t.length>0)if(s.length>0){if(e.forEach(n=>o.push(n)),Math.max(...s)>e.length)throw new Error("axes is out of bound");s.forEach((n,i)=>o[n]=t[i])}else t.forEach(n=>o.push(n));else{if(r.length===0)throw new Error("Resize requires either scales or sizes.");o=e.map((n,i)=>Math.round(n*r[i]))}return o},pw=(e,r,t)=>{let s=(()=>{switch(t.keepAspectRatioPolicy){case"not_larger":return t.axes.length>0?Math.min(...t.axes.map(n=>r[n]),Number.MAX_VALUE):Math.min(...r,Number.MAX_VALUE);case"not_smaller":return t.axes.length>0?Math.max(...t.axes.map(n=>r[n]),Number.MIN_VALUE):Math.max(...r,Number.MIN_VALUE);default:throw new Error(`Keep aspect ratio policy ${t.keepAspectRatioPolicy} is not supported`)}})();r.fill(1,0,r.length);let o=e.slice();return t.axes.length>0?(t.axes.forEach(n=>r[n]=s),t.axes.forEach(n=>o[n]=Math.round(e[n]*r[n]))):(r.fill(s,0,r.length),o.forEach((n,i)=>o[i]=Math.round(n*r[i]))),o},hw=(e,r,t,s,o)=>`
+ fn calculateOriginalIndicesFromOutputIndices(output_indices: ${e.type.indices}) -> array<${e.type.value}, ${t.length}> {
+ var original_indices: array<${e.type.value}, ${t.length}>;
+ for (var i:u32 = 0; i < ${t.length}; i++) {
+ var output_index = ${e.indicesGet("output_indices","i")};
+ var scale = ${rt("uniforms.scales","i",s)};
+ var roi_low = ${rt("uniforms.roi","i",o)};
+ var roi_hi = ${rt("uniforms.roi",`i + ${r.length}`,o)};
+ if (scale == 1.0) {
+ original_indices[i] = ${e.type.value}(output_index);
+ } else {
+ var input_shape_i = ${rt("uniforms.input_shape","i",r.length)};
+ var output_shape_i = ${rt("uniforms.output_shape","i",t.length)};
+ original_indices[i] = getOriginalCoordinateFromResizedCoordinate(output_index, scale, output_shape_i,
+ input_shape_i, roi_low, roi_hi);
+ }
+ }
+ return original_indices;
+ }`,mw=(e,r,t,s,o,n,i)=>`
+ fn calculateInputIndicesFromOutputIndices(output_indices: ${r.type.indices}) -> ${e.type.indices} {
+ var input_indices: ${e.type.indices};
+ for (var i:u32 = 0; i < ${s.length}; i++) {
+ var output_index = ${r.indicesGet("output_indices","i")};
+ var input_index: u32;
+ var scale = ${rt("uniforms.scales","i",o)};
+ if (scale == 1.0) {
+ input_index = output_index;
+ } else {
+ var roi_low = ${rt("uniforms.roi","i",n)};
+ var roi_hi = ${rt("uniforms.roi",`i + ${t.length}`,n)};
+ var input_shape_i = ${rt("uniforms.input_shape","i",t.length)};
+ var output_shape_i = ${rt("uniforms.output_shape","i",s.length)};
+ var original_idx = getOriginalCoordinateFromResizedCoordinate(output_index, scale, output_shape_i,
+ input_shape_i, roi_low, roi_hi);
+ if (!${i} || (original_idx >= 0 && original_idx < ${r.type.value}(input_shape_i))) {
+ if (original_idx < 0) {
+ input_index = 0;
+ } else if (original_idx > ${r.type.value}(input_shape_i - 1)) {
+ input_index = input_shape_i - 1;
+ } else {
+ input_index = u32(getNearestPixelFromOriginal(original_idx, scale < 1));
+ }
+ } else {
+ input_index = u32(original_idx);
+ }
+ }
+ ${e.indicesSet("input_indices","i","input_index")}
+ }
+ return input_indices;
+ }`,fw=(e,r)=>`
+ fn checkInputIndices(input_indices: ${e.type.indices}) -> bool {
+ for (var i:u32 = 0; i < ${r.length}; i++) {
+ var input_index = ${e.indicesGet("input_indices","i")};
+ if (input_index < 0 || input_index >= ${rt("uniforms.input_shape","i",r.length)}) {
+ return false;
+ }
+ }
+ return true;
+ }`,Ac=(e,r,t,s)=>e.rank>s?`
+ ${e.indicesSet("input_indices",r,"channel")};
+ ${e.indicesSet("input_indices",t,"batch")};
+`:"",_w=(e,r,t,s,o)=>{let[n,i,a,l]=t.length===2?[-1,0,1,-1]:[0,2,3,1],c=e.type.value;return`
+ fn getInputValue(batch: u32, channel: u32, row: u32, col: u32) -> ${c} {
+ var input_indices: ${e.type.indices};
+ ${e.indicesSet("input_indices",i,`max(0, min(row, ${t[i]} - 1))`)};
+ ${e.indicesSet("input_indices",a,`max(0, min(col, ${t[a]} - 1))`)};
+ ${Ac(e,l,n,2)}
+ return ${e.getByIndices("input_indices")};
+ }
+
+ fn bilinearInterpolation(output_indices: ${r.type.indices}) -> ${c} {
+ var originalIndices = calculateOriginalIndicesFromOutputIndices(output_indices);
+ var row:${c} = originalIndices[${i}];
+ var col:${c} = originalIndices[${a}];
+ ${s?`if (row < 0 || row > (${t[i]} - 1) || col < 0 || col > (${t[a]} - 1)) {
+ return ${o};
+ }`:""};
+ row = max(0, min(row, ${t[i]} - 1));
+ col = max(0, min(col, ${t[a]} - 1));
+ var row1: u32 = u32(row);
+ var col1: u32 = u32(col);
+ var row2: u32 = u32(row + 1);
+ var col2: u32 = u32(col + 1);
+ var channel: u32 = ${t.length>2?`u32(originalIndices[${l}])`:"0"};
+ var batch: u32 = ${t.length>2?`u32(originalIndices[${n}])`:"0"};
+ var x11: ${c} = getInputValue(batch, channel, row1, col1);
+ var x12: ${c} = getInputValue(batch, channel, row1, col2);
+ var x21: ${c} = getInputValue(batch, channel, row2, col1);
+ var x22: ${c} = getInputValue(batch, channel, row2, col2);
+ var dx1: ${c} = abs(row - ${c}(row1));
+ var dx2: ${c} = abs(${c}(row2) - row);
+ var dy1: ${c} = abs(col - ${c}(col1));
+ var dy2: ${c} = abs(${c}(col2) - col);
+ if (row1 == row2) {
+ dx1 = 0.5;
+ dx2 = 0.5;
+ }
+ if (col1 == col2) {
+ dy1 = 0.5;
+ dy2 = 0.5;
+ }
+ return (x11 * dx2 * dy2 + x12 * dx2 * dy1 + x21 * dx1 * dy2 + x22 * dx1 * dy1);
+ }`},gw=(e,r,t,s,o,n,i,a,l,c)=>{let p=t.length===2,[u,h]=p?[0,1]:[2,3],g=e.type.value,_=E=>{let I=E===u?"row":"col";return`
+ fn ${I}CubicInterpolation(input_indices: ${e.type.indices}, output_indices: ${r.type.indices}) -> ${g} {
+ var output_index = ${r.indicesGet("output_indices",E)};
+ var originalIdx: ${g} = getOriginalCoordinateFromResizedCoordinate(output_index, ${o[E]},
+ ${s[E]}, ${t[E]}, ${n[E]}, ${n[E]} + ${t.length});
+ var fractOriginalIdx: ${g} = originalIdx - floor(originalIdx);
+ var coefs = getCubicInterpolationCoefs(fractOriginalIdx);
+
+ if (${a} && (originalIdx < 0 || originalIdx > (${t[E]} - 1))) {
+ return ${l};
+ }
+ var data: array<${g}, 4> = array<${g}, 4>(0.0, 0.0, 0.0, 0.0);
+ for (var i: i32 = -1; i < 3; i++) {
+ var ${I}: ${g} = originalIdx + ${g}(i);
+ if (${I} < 0 || ${I} >= ${t[E]}) {
+ ${c?`coefs[i + 1] = 0.0;
+ continue;`:a?`return ${l};`:`${I} = max(0, min(${I}, ${t[E]} - 1));`};
+ }
+ var input_indices_copy: ${e.type.indices} = input_indices;
+ ${e.indicesSet("input_indices_copy",E,`u32(${I})`)};
+ data[i + 1] = ${E===u?e.getByIndices("input_indices_copy"):"rowCubicInterpolation(input_indices_copy, output_indices)"};
+ }
+ return cubicInterpolation1D(data, coefs);
+ }`};return`
+ ${_(u)};
+ ${_(h)};
+ fn getCubicInterpolationCoefs(s: ${g}) -> array<${g}, 4> {
+ var absS = abs(s);
+ var coeffs: array<${g}, 4> = array<${g}, 4>(0.0, 0.0, 0.0, 0.0);
+ var oneMinusAbsS: ${g} = 1.0 - absS;
+ var twoMinusAbsS: ${g} = 2.0 - absS;
+ var onePlusAbsS: ${g} = 1.0 + absS;
+ coeffs[0] = ((${i} * onePlusAbsS - 5 * ${i}) * onePlusAbsS + 8 * ${i}) * onePlusAbsS - 4 * ${i};
+ coeffs[1] = ((${i} + 2) * absS - (${i} + 3)) * absS * absS + 1;
+ coeffs[2] = ((${i} + 2) * oneMinusAbsS - (${i} + 3)) * oneMinusAbsS * oneMinusAbsS + 1;
+ coeffs[3] = ((${i} * twoMinusAbsS - 5 * ${i}) * twoMinusAbsS + 8 * ${i}) * twoMinusAbsS - 4 * ${i};
+ return coeffs;
+ }
+
+ fn cubicInterpolation1D(x: array<${g}, 4>, coefs: array<${g}, 4>) -> ${g} {
+ var coefsSum: ${g} = coefs[0] + coefs[1] + coefs[2] + coefs[3];
+ return (x[0] * coefs[0] + x[1] * coefs[1]+ x[2] * coefs[2]+ x[3] * coefs[3]) / coefsSum;
+ }
+
+ fn bicubicInterpolation(output_indices: ${r.type.indices}) -> ${g} {
+ var input_indices: ${e.type.indices} = output_indices;
+ return colCubicInterpolation(input_indices, output_indices);
+ }
+ `},ww=(e,r,t,s,o)=>{let[n,i,a,l,c]=t.length===3?[-1,0,1,2,-1]:[0,2,3,4,1],p=e.type.value;return`
+ fn getInputValue(batch: u32, channel: u32, depth:u32, height: u32, width: u32) -> ${p} {
+ var input_indices: ${e.type.indices};
+ ${e.indicesSet("input_indices",i,`max(0, min(depth, ${t[i]} - 1))`)};
+ ${e.indicesSet("input_indices",a,`max(0, min(height, ${t[a]} - 1))`)};
+ ${e.indicesSet("input_indices",l,`max(0, min(width, ${t[l]} - 1))`)};
+ ${Ac(e,c,n,3)}
+ return ${e.getByIndices("input_indices")};
+ }
+
+ fn trilinearInterpolation(output_indices: ${r.type.indices}) -> ${p} {
+ var originalIndices = calculateOriginalIndicesFromOutputIndices(output_indices);
+ var depth:${p} = originalIndices[${i}];
+ var height:${p} = originalIndices[${a}];
+ var width:${p} = originalIndices[${l}];
+ ${s?`if (depth < 0 || depth > (${t[i]} - 1) || height < 0 || height > (${t[a]} - 1) || width < 0 || (width > ${t[l]} - 1)) {
+ return ${o};
+ }`:""};
+
+ depth = max(0, min(depth, ${t[i]} - 1));
+ height = max(0, min(height, ${t[a]} - 1));
+ width = max(0, min(width, ${t[l]} - 1));
+ var depth1: u32 = u32(depth);
+ var height1: u32 = u32(height);
+ var width1: u32 = u32(width);
+ var depth2: u32 = u32(depth + 1);
+ var height2: u32 = u32(height + 1);
+ var width2: u32 = u32(width + 1);
+ var channel: u32 = ${t.length>3?`u32(originalIndices[${c}])`:"0"};
+ var batch: u32 = ${t.length>3?`u32(originalIndices[${n}])`:"0"};
+
+ var x111: ${p} = getInputValue(batch, channel, depth1, height1, width1);
+ var x112: ${p} = getInputValue(batch, channel, depth1, height1, width2);
+ var x121: ${p} = getInputValue(batch, channel, depth1, height2, width1);
+ var x122: ${p} = getInputValue(batch, channel, depth1, height2, width2);
+ var x211: ${p} = getInputValue(batch, channel, depth2, height1, width1);
+ var x212: ${p} = getInputValue(batch, channel, depth2, height1, width2);
+ var x221: ${p} = getInputValue(batch, channel, depth2, height2, width1);
+ var x222: ${p} = getInputValue(batch, channel, depth2, height2, width2);
+ var dx1: ${p} = abs(depth - ${p}(depth1));
+ var dx2: ${p} = abs(${p}(depth2) - depth);
+ var dy1: ${p} = abs(height - ${p}(height1));
+ var dy2: ${p} = abs(${p}(height2) - height);
+ var dz1: ${p} = abs(width - ${p}(width1));
+ var dz2: ${p} = abs(${p}(width2) - width);
+ if (depth1 == depth2) {
+ dx1 = 0.5;
+ dx2 = 0.5;
+ }
+ if (height1 == height2) {
+ dy1 = 0.5;
+ dy2 = 0.5;
+ }
+ if (width1 == width2) {
+ dz1 = 0.5;
+ dz2 = 0.5;
+ }
+ return (x111 * dx2 * dy2 * dz2 + x112 * dx2 * dy2 * dz1 + x121 * dx2 * dy1 *dz2 + x122 * dx2 * dy1 * dz1 +
+ x211 * dx1 * dy2 * dz2 + x212 * dx1 * dy2 * dz1 + x221 * dx1 * dy1 *dz2 + x222 * dx1 * dy1 * dz1);
+ }`},yw=(e,r,t,s,o,n)=>{let i=e.dims,a=cw(n,r.axes,i.length),l=uw(i,s,o,r.axes),c=s.slice();s.length===0&&(c=i.map((y,$)=>y===0?1:l[$]/y),r.keepAspectRatioPolicy!=="stretch"&&(l=pw(i,c,r)));let p=tt("output",e.dataType,l.length),u=$e("input",e.dataType,i.length),h=xe.size(l),g=i.length===l.length&&i.every((y,$)=>y===l[$]),_=r.coordinateTransformMode==="tf_crop_and_resize",E=r.extrapolationValue,I=u.type.value,M=y=>`
+ ${g?"":`
+ ${lw(r.coordinateTransformMode,I)};
+ ${(()=>{switch(r.mode){case"nearest":return`
+ ${fw(u,i)};
+ ${dw(r.nearestMode,t,I)};
+ ${mw(u,p,i,l,c.length,a.length,_)};
+ `;case"linear":return`
+ ${hw(p,i,l,c.length,a.length)};
+ ${(()=>{if(i.length===2||i.length===4)return`${_w(u,p,i,_,E)}`;if(i.length===3||i.length===5)return`${ww(u,p,i,_,E)}`;throw Error("Linear mode only supports input dims 2, 3, 4 and 5 are supported in linear mode.")})()};
+ `;case"cubic":return`
+ ${(()=>{if(i.length===2||i.length===4)return`${gw(u,p,i,l,c,a,r.cubicCoeffA,_,r.extrapolationValue,r.excludeOutside)}`;throw Error("Cubic mode only supports input dims 2 and 4 are supported in linear mode.")})()};
+ `;default:throw Error("Invalid resize mode")}})()};
+ `}
+ ${y.registerUniform("output_size","u32").registerUniform("scales","f32",c.length).registerUniform("roi","f32",a.length).declareVariables(u,p)}
+ ${y.mainStart()}
+ ${y.guardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")}
+ ${g?"output[global_idx] = input[global_idx];":`
+ let output_indices = ${p.offsetToIndices("global_idx")};
+ var input_indices: ${u.type.indices};
+ ${(()=>{switch(r.mode){case"nearest":return`input_indices = calculateInputIndicesFromOutputIndices(output_indices);
+ if (checkInputIndices(input_indices)) {
+ output[global_idx] = ${u.getByIndices("input_indices")};
+ } else {
+ output[global_idx] = ${r.extrapolationValue};
+ }`;case"linear":return`output[global_idx] = ${i.length===2||i.length===4?"bilinearInterpolation":"trilinearInterpolation"}(output_indices);`;case"cubic":return"output[global_idx] = bicubicInterpolation(output_indices);";default:throw Error(`Unsupported resize mode: ${r.mode}`)}})()};
+`}
+ }`;return{name:"Resize",shaderCache:{hint:`${r.cacheKey}|${t}|${c.length>0?r.mode==="cubic"?c:c.length:""}|${o.length>0?o:""}|${a.length>0?a:""}|${g}|${r.mode==="nearest"?i.length:i}`,inputDependencies:["rank"]},getShaderSource:M,getRunData:()=>({outputs:[{dims:l,dataType:e.dataType}],dispatchGroup:{x:Math.ceil(h/64)},programUniforms:[{type:12,data:h},{type:1,data:c},{type:1,data:a},...nt(i,l)]})}},Mw=e=>{let r=e.customDataBuffer;return new Uint32Array(r,r.byteOffset,1)[0]},nb=(e,r)=>{let t=[],s=[],o=[],n=Mw(e);if(r.antialias!==0)throw Error("Only default value (0) for Antialias attribute is supported");aw(e.inputs,r,n,t,s,o),e.compute(yw(e.inputs[0],r,n,t,s,o),{inputs:[0]})},ob=e=>{let r=e.antialias,t=e.axes,s=e.coordinateTransformMode,o=e.cubicCoeffA,n=e.excludeOutside!==0,i=e.extrapolationValue,a=e.keepAspectRatioPolicy,l=e.mode,c=e.nearestMode===""?"simple":e.nearestMode;return Lt({antialias:r,axes:t,coordinateTransformMode:s,cubicCoeffA:o,excludeOutside:n,extrapolationValue:i,keepAspectRatioPolicy:a,mode:l,nearestMode:c})}}),bw,vw,ib,Ix=Ve(()=>{mt(),bt(),tr(),xt(),bw=(e,r)=>{let[t,s,o,n]=e,{numHeads:i,rotaryEmbeddingDim:a}=r;if(t.dims.length!==3&&t.dims.length!==4)throw new Error(`Input 'x' is expected to have 3 or 4 dimensions, got ${t.dims.length}`);if(!xe.areEqual(s.dims,[])&&!xe.areEqual(s.dims,[1])&&s.dims.length!==2)throw new Error(`Input 'position_ids' is expected to have 0, 1, or 2 dimensions, got ${s.dims.length}`);if(o.dims.length!==2)throw new Error(`Input 'cos_cache' is expected to have 2 dimensions, got ${o.dims.length}`);if(n.dims.length!==2)throw new Error(`Input 'sin_cache' is expected to have 2 dimensions, got ${n.dims.length}`);if(!xe.areEqual(o.dims,n.dims))throw new Error("Inputs 'cos_cache' and 'sin_cache' are expected to have the same shape");if(a>0&&i===0)throw new Error("num_heads must be provided if rotary_embedding_dim is specified");let l=t.dims[0],c=t.dims[t.dims.length-2],p=o.dims[0],u=xe.sizeFromDimension(t.dims,1)/c,h=a===0?o.dims[1]*2:u/i;if(a>h)throw new Error("rotary_embedding_dim must be less than or equal to head_size");if(s.dims.length===2){if(l!==s.dims[0])throw new Error(`Input 'position_ids' dimension 0 should be of size batch_size, got ${s.dims[0]}`);if(c!==s.dims[1])throw new Error(`Input 'position_ids' dimension 1 should be of size sequence_length, got ${s.dims[1]}`)}if(h/2!==o.dims[1]&&a/2!==o.dims[1])throw new Error(`Input 'cos_cache' dimension 1 should be same as head_size / 2 or rotary_embedding_dim / 2, got ${o.dims[1]}`);if(c>p)throw new Error("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported")},vw=(e,r)=>{let{interleaved:t,numHeads:s,rotaryEmbeddingDim:o,scale:n}=r,i=e[0].dims[0],a=xe.sizeFromDimension(e[0].dims,1),l=e[0].dims[e[0].dims.length-2],c=a/l,p=e[2].dims[1],u=o===0?p*2:c/s,h=new Array(i,l,c/u,u-p),g=xe.computeStrides(h),_=[{type:1,data:n},{type:12,data:h},{type:12,data:g},...e[0].dims.length===3?new Array({type:12,data:[a,c,u,1]}):[],...e[0].dims.length===4?new Array({type:12,data:[a,u,l*u,1]}):[],...nt(e[0].dims,e[1].dims,e[2].dims,e[3].dims,e[0].dims)],E=I=>{let M=$e("input",e[0].dataType,e[0].dims.length),y=$e("position_ids",e[1].dataType,e[1].dims.length),$=$e("cos_cache",e[2].dataType,e[2].dims.length),P=$e("sin_cache",e[3].dataType,e[3].dims.length),b=tt("output",e[0].dataType,e[0].dims.length);return I.registerUniforms([{name:"scale",type:"f32"},{name:"global_shape",type:"u32",length:h.length},{name:"global_strides",type:"u32",length:g.length},{name:"input_output_strides",type:"u32",length:g.length}]),`
+ ${I.declareVariables(M,y,$,P,b)}
+
+ ${I.mainStart($o)}
+ let half_rotary_emb_dim = uniforms.${$.name}_shape[1];
+ let bsnh = global_idx / uniforms.global_strides % uniforms.global_shape;
+ let size = uniforms.global_shape[0] * uniforms.global_strides[0];
+ ${I.guardAgainstOutOfBoundsWorkgroupSizes("size")}
+
+ if (bsnh[3] < half_rotary_emb_dim) {
+ let position_ids_idx =
+ ${y.broadcastedIndicesToOffset("bsnh.xy",tt("",y.type.tensor,2))};
+ let position_id =
+ u32(${y.getByOffset("position_ids_idx")}) + select(0, bsnh[1], position_ids_idx == 0);
+ let i = dot(bsnh, uniforms.input_output_strides) + select(0, bsnh[3], ${t});
+ let j = i + select(half_rotary_emb_dim, 1, ${t});
+ let re = ${M.getByOffset("i")} * ${$.get("position_id","bsnh[3]")} -
+ ${M.getByOffset("j")} * ${P.get("position_id","bsnh[3]")};
+ ${b.setByOffset("i","re")}
+ let im = ${M.getByOffset("i")} * ${P.get("position_id","bsnh[3]")} +
+ ${M.getByOffset("j")} * ${$.get("position_id","bsnh[3]")};
+ ${b.setByOffset("j","im")}
+ } else {
+ let k = dot(bsnh, uniforms.input_output_strides) + half_rotary_emb_dim;
+ ${b.setByOffset("k",M.getByOffset("k"))}
+ }
+ }`};return{name:"RotaryEmbedding",shaderCache:{hint:Lt({interleaved:t}).cacheKey,inputDependencies:["rank","rank","rank","rank"]},getShaderSource:E,getRunData:()=>({outputs:[{dims:e[0].dims,dataType:e[0].dataType}],dispatchGroup:{x:Math.ceil(xe.size(h)/$o)},programUniforms:_})}},ib=(e,r)=>{bw(e.inputs,r),e.compute(vw(e.inputs,r))}}),xw,Tw,ab,Ax=Ve(()=>{mt(),bt(),xt(),xw=e=>{if(!e||e.length<3)throw new Error("layerNorm requires at least 3 inputs.");let r=e[0],t=e[1],s=e[2];if(r.dataType!==t.dataType||r.dataType!==s.dataType)throw new Error("All inputs must have the same data type");if(r.dims.length!==3&&r.dims.length!==2)throw new Error("Input must be 2D or 3D");if(t.dims.length!==3&&t.dims.length!==2)throw new Error("Skip must be 2D or 3D");let o=r.dims[r.dims.length-1],n=r.dims[r.dims.length-2];if(t.dims[t.dims.length-1]!==o)throw new Error("Skip must have the same hidden size as input");if(t.dims[t.dims.length-2]!==n)throw new Error("Skip must have the same sequence length as input");if(s.dims.length!==1)throw new Error("Gamma must be 1D");if(s.dims[s.dims.length-1]!==o)throw new Error("Gamma must have the same hidden size as input");if(e.length>3){let i=e[3];if(i.dims.length!==1)throw new Error("Beta must be 1D");if(i.dims[i.dims.length-1]!==o)throw new Error("Beta must have the same hidden size as input")}if(e.length>4){let i=e[4];if(i.dims.length!==1)throw new Error("Bias must be 1D");if(i.dims[i.dims.length-1]!==o)throw new Error("Bias must have the same hidden size as input")}},Tw=(e,r,t,s)=>{let o=r.simplified,n=e[0].dims,i=xe.size(n),a=n,l=i,c=n.slice(-1)[0],p=s?n.slice(0,-1).concat(1):[],u=!o&&e.length>3,h=e.length>4,g=s&&t>1,_=s&&t>2,E=t>3,I=64,M=Jt(c),y=[{type:12,data:l},{type:12,data:M},{type:12,data:c},{type:1,data:r.epsilon}],$=b=>{let w=[{name:"output_size",type:"u32"},{name:"components",type:"u32"},{name:"hidden_size",type:"u32"},{name:"epsilon",type:"f32"}],T=[$e("x",e[0].dataType,e[0].dims,M),$e("skip",e[1].dataType,e[1].dims,M),$e("gamma",e[2].dataType,e[2].dims,M)];u&&T.push($e("beta",e[3].dataType,e[3].dims,M)),h&&T.push($e("bias",e[4].dataType,e[4].dims,M)),T.push(tt("output",e[0].dataType,a,M)),g&&T.push(tt("mean_output",1,p)),_&&T.push(tt("inv_std_output",1,p)),E&&T.push(tt("input_skip_bias_sum",e[0].dataType,a,M));let k=pr(e[0].dataType),z=pr(1,M);return`
+
+ ${b.registerUniforms(w).declareVariables(...T)}
+ var sum_shared : array<${z}, ${I}>;
+ var sum_squared_shared : array<${z}, ${I}>;
+
+ ${b.mainStart([I,1,1])}
+ let ix = local_id.x;
+ let iy = global_id.x / ${I};
+
+ let hidden_size_vectorized: u32 = uniforms.hidden_size / uniforms.components;
+ var stride = hidden_size_vectorized / ${I};
+ let offset = ix * stride + iy * hidden_size_vectorized;
+ let offset1d = stride * ix;
+ if (ix == ${I-1}) {
+ stride = hidden_size_vectorized - stride * ix;
+ }
+ for (var i: u32 = 0; i < stride; i++) {
+ let skip_value = skip[offset + i];
+ let bias_value = ${h?"bias[offset1d + i]":k+"(0.0)"};
+ let input_value = x[offset + i];
+ let value = input_value + skip_value + bias_value;
+ ${E?"input_skip_bias_sum[offset + i] = value;":""}
+ output[offset + i] = value;
+ let f32_value = ${Co(k,M,"value")};
+ sum_shared[ix] += f32_value;
+ sum_squared_shared[ix] += f32_value * f32_value;
+ }
+ workgroupBarrier();
+
+ var reduce_size : u32 = ${I};
+ for (var curr_size = reduce_size >> 1; curr_size > 0; curr_size = reduce_size >> 1) {
+ reduce_size = curr_size + (reduce_size & 1);
+ if (ix < curr_size) {
+ sum_shared[ix] += sum_shared[ix + reduce_size];
+ sum_squared_shared[ix] += sum_squared_shared[ix + reduce_size];
+ }
+ workgroupBarrier();
+ }
+
+ let sum = sum_shared[0];
+ let square_sum = sum_squared_shared[0];
+ let mean = ${dn("sum",M)} / f32(uniforms.hidden_size);
+ let inv_std_dev = inverseSqrt(${dn("square_sum",M)} / f32(uniforms.hidden_size) ${o?"":"- mean * mean"} + uniforms.epsilon);
+ ${g?"mean_output[global_idx] = mean;":""}
+ ${_?"inv_std_output[global_idx] = inv_std_dev;":""}
+
+ for (var i: u32 = 0; i < stride; i++) {
+ output[offset + i] = (output[offset + i] ${o?"":`- ${k}(mean)`}) *
+ ${k}(inv_std_dev) * gamma[offset1d + i]
+ ${u?"+ beta[offset1d + i]":""};
+ }
+ }`},P=[{dims:a,dataType:e[0].dataType}];return t>1&&P.push({dims:p,dataType:1}),t>2&&P.push({dims:p,dataType:1}),t>3&&P.push({dims:n,dataType:e[0].dataType}),{name:"SkipLayerNormalization",shaderCache:{hint:`${M};${g};${_};${E}`,inputDependencies:e.map((b,w)=>"type")},getShaderSource:$,getRunData:()=>({outputs:P,dispatchGroup:{x:Math.ceil(l/c)},programUniforms:y})}},ab=(e,r)=>{xw(e.inputs);let t=[0];e.outputCount>1&&t.push(-3),e.outputCount>2&&t.push(-3),e.outputCount>3&&t.push(3),e.compute(Tw(e.inputs,r,e.outputCount,!1),{outputs:t})}}),Ew,na,Pw,Fc,Cw,Sw,lb,db,Fx=Ve(()=>{mt(),bt(),tr(),xt(),Ew=(e,r)=>{if(!e||e.length<1)throw new Error("too few inputs");if(r.axes.length!==0){if(r.axes.length!==r.starts.length||r.axes.length!==r.ends.length)throw new Error("axes, starts and ends must have the same length")}else if(r.starts.length!==r.ends.length)throw new Error("starts and ends must have the same length");e.slice(1).forEach((t,s)=>{if(e[s+1].dataType!==6&&e[s+1].dataType!==7)throw new Error(`Input ${s} must be an array of int32 or int64`)})},na=(e,r)=>{let t=[];if(e.length>r)if(e[r].dataType===7)e[r].getBigInt64Array().forEach(s=>t.push(Number(s)));else if(e[r].dataType===6)e[r].getInt32Array().forEach(s=>t.push(Number(s)));else throw new Error(`Input ${r} must be an array of int32 or int64`);return t},Pw=(e,r)=>{if(e.length>1){let t=na(e,1),s=na(e,2),o=na(e,3);return o.length===0&&(o=[...Array(e[0].dims.length).keys()]),Lt({starts:t,ends:s,axes:o})}else return r},Fc=(e,r,t,s,o)=>{let n=e;return e<0&&(n+=t[s[r]]),o[r]<0?Math.max(0,Math.min(n,t[s[r]]-1)):Math.max(0,Math.min(n,t[s[r]]))},Cw=(e,r,t)=>`fn calculateInputIndices(output_indices: ${r.type.indices}) -> ${e.type.indices} {
+ var input_indices: ${e.type.indices};
+ var carry = 0u;
+ for (var i = ${t.length}; i >= 0; i--) {
+ let input_shape_i = ${rt("uniforms.input_shape","i",t.length)};
+ let steps_i = ${rt("uniforms.steps","i",t.length)};
+ let signs_i = ${rt("uniforms.signs","i",t.length)};
+ let starts_i = ${rt("uniforms.starts","i",t.length)};
+ var output_index = ${r.indicesGet("output_indices","i")};
+ var input_index = output_index * steps_i + starts_i + carry;
+ carry = input_index / input_shape_i;
+ input_index = input_index % input_shape_i;
+ if (signs_i < 0) {
+ input_index = input_shape_i - input_index - 1u + starts_i;
+ }
+ ${e.indicesSet("input_indices","i","input_index")};
+ }
+ return input_indices;
+ }`,Sw=(e,r)=>{let t=e[0].dims,s=xe.size(t),o=r.axes.length>0?xe.normalizeAxes(r.axes,t.length):[...Array(t.length).keys()],n=na(e,4);n.forEach(M=>M!==0||(()=>{throw new Error("step cannot be 0")})),n.length===0&&(n=Array(o.length).fill(1));let i=r.starts.map((M,y)=>Fc(M,y,t,o,n)),a=r.ends.map((M,y)=>Fc(M,y,t,o,n));if(o.length!==i.length||o.length!==a.length)throw new Error("start, ends and axes should have the same number of elements");if(o.length!==t.length)for(let M=0;MMath.sign(M));n.forEach((M,y,$)=>{if(M<0){let P=(a[y]-i[y])/M,b=i[y],w=b+P*n[y];i[y]=w,a[y]=b,$[y]=-M}});let c=t.slice(0);o.forEach((M,y)=>{c[M]=Math.ceil((a[M]-i[M])/n[M])});let p={dims:c,dataType:e[0].dataType},u=tt("output",e[0].dataType,c.length),h=$e("input",e[0].dataType,e[0].dims.length),g=xe.size(c),_=[{name:"outputSize",type:"u32"},{name:"starts",type:"u32",length:i.length},{name:"signs",type:"i32",length:l.length},{name:"steps",type:"u32",length:n.length}],E=[{type:12,data:g},{type:12,data:i},{type:6,data:l},{type:12,data:n},...nt(e[0].dims,c)],I=M=>`
+ ${M.registerUniforms(_).declareVariables(h,u)}
+ ${Cw(h,u,t)}
+ ${M.mainStart()}
+ ${M.guardAgainstOutOfBoundsWorkgroupSizes("uniforms.outputSize")}
+ let output_indices = ${u.offsetToIndices("global_idx")};
+ let input_indices = calculateInputIndices(output_indices);
+ ${u.setByOffset("global_idx",h.getByIndices("input_indices"))}
+ }`;return{name:"Slice",shaderCache:{hint:`${l.length}_${i.length}_${n.length}`,inputDependencies:["rank"]},getShaderSource:I,getRunData:()=>({outputs:[p],dispatchGroup:{x:Math.ceil(s/64)},programUniforms:E})}},lb=(e,r)=>{Ew(e.inputs,r);let t=Pw(e.inputs,r);e.compute(Sw(e.inputs,t),{inputs:[0]})},db=e=>{let r=e.starts,t=e.ends,s=e.axes;return Lt({starts:r,ends:t,axes:s})}}),$w,kw,cb,ub,Ox=Ve(()=>{mt(),bt(),tr(),cn(),xt(),$w=e=>{if(!e||e.length!==1)throw new Error("Softmax op requires 1 input.")},kw=(e,r)=>{let t=e.inputs[0],s=t.dims,o=xe.size(s),n=s.length,i=xe.normalizeAxis(r.axis,n),a=ik),c[i]=n-1,c[n-1]=i,l=e.compute(Wr(t,c),{inputs:[t],outputs:[-1]})[0]):l=t;let p=l.dims,u=p[n-1],h=o/u,g=Jt(u),_=u/g,E=64;h===1&&(E=256);let I=(T,k)=>k===4?`max(max(${T}.x, ${T}.y), max(${T}.z, ${T}.w))`:k===2?`max(${T}.x, ${T}.y)`:k===3?`max(max(${T}.x, ${T}.y), ${T}.z)`:T,M=$e("x",l.dataType,l.dims,g),y=tt("result",l.dataType,l.dims,g),$=M.type.value,P=pr(l.dataType)==="f32"?`var threadMax = ${$}(-3.402823e+38f);`:`var threadMax = ${$}(-65504.0h);`,b=T=>`
+ var rowMaxShared : ${$};
+ var rowSumShared : ${$};
+ var threadShared : array<${$}, ${E}>;
+
+ fn getValue(row: i32, col: i32, row_stride: i32) -> ${$} {
+ let index = row * row_stride + col;
+ return x[index];
+ }
+
+ fn setValue(row: i32, col: i32, row_stride: i32, value: ${$}) {
+ let index = row * row_stride + col;
+ result[index] = value;
+ }
+ ${T.registerUniform("packedCols","i32").declareVariables(M,y)}
+ ${T.mainStart(E)}
+ let gindex = i32(global_idx);
+ let lindex = i32(local_idx);
+ const wg = ${E};
+ let row = gindex / wg;
+ let cols = uniforms.packedCols;
+ let row_stride : i32 = uniforms.packedCols;
+
+ // find the rows max
+ ${P}
+ for (var col = lindex; col < cols; col += wg) {
+ let value = getValue(row, col, row_stride);
+ threadMax = max(threadMax, value);
+ }
+ if (lindex < cols) {
+ threadShared[lindex] = threadMax;
+ }
+ workgroupBarrier();
+
+ var reduceSize = min(cols, wg);
+ for (var currSize = reduceSize >> 1; currSize > 0; currSize = reduceSize >> 1) {
+ reduceSize = currSize + (reduceSize & 1);
+ if (lindex < currSize) {
+ threadShared[lindex] = max(threadShared[lindex], threadShared[lindex + reduceSize]);
+ }
+ workgroupBarrier();
+ }
+ if (lindex == 0) {
+ rowMaxShared = ${$}(${I("threadShared[0]",g)});
+ }
+ workgroupBarrier();
+
+ // find the rows sum
+ var threadSum = ${$}(0.0);
+ for (var col = lindex; col < cols; col += wg) {
+ let subExp = exp(getValue(row, col, row_stride) - rowMaxShared);
+ threadSum += subExp;
+ }
+ threadShared[lindex] = threadSum;
+ workgroupBarrier();
+
+ for (var currSize = wg >> 1; currSize > 0; currSize = currSize >> 1) {
+ if (lindex < currSize) {
+ threadShared[lindex] = threadShared[lindex] + threadShared[lindex + currSize];
+ }
+ workgroupBarrier();
+ }
+ if (lindex == 0) {
+ rowSumShared = ${$}(${dn("threadShared[0]",g)});
+ }
+ workgroupBarrier();
+
+ // calculate final value for each element in the row
+ for (var col = lindex; col < cols; col += wg) {
+ let value = exp(getValue(row, col, row_stride) - rowMaxShared) / rowSumShared;
+ setValue(row, col, row_stride, value);
+ }
+ }`,w=e.compute({name:"Softmax",shaderCache:{hint:`${g};${E}`,inputDependencies:["type"]},getRunData:()=>({outputs:[{dims:p,dataType:l.dataType}],dispatchGroup:{x:h},programUniforms:[{type:6,data:_}]}),getShaderSource:b},{inputs:[l],outputs:[a?-1:0]})[0];a&&e.compute(Wr(w,c),{inputs:[w]})},cb=(e,r)=>{$w(e.inputs),kw(e,r)},ub=e=>Lt({axis:e.axis})}),Oc,Iw,Aw,Fw,pb,Dx=Ve(()=>{mt(),bt(),xt(),Oc=e=>Array.from(e.getBigInt64Array(),Number),Iw=e=>{if(!e||e.length!==2)throw new Error("Tile requires 2 inputs.");if(e[0].dataType!==1&&e[0].dataType!==10&&e[0].dataType!==6&&e[0].dataType!==12)throw new Error("Tile only support float, float16, int32, and uint32 data types");if(e[1].dataType!==7)throw new Error("Tile `repeats` input should be of int64 data type");if(e[1].dims.length!==1)throw new Error("Tile `repeats` input should be 1-D");if(Oc(e[1]).length!==e[0].dims.length)throw new Error("Tile `repeats` input should have same number of elements as rank of input data tensor")},Aw=(e,r)=>{let t=[];for(let s=0;s{let t=e[0].dims,s=r??Oc(e[1]),o=Aw(t,s),n=xe.size(o),i=e[0].dataType,a=$e("input",i,t.length),l=tt("output",i,o.length),c=p=>`
+ const inputShape = ${a.indices(...t)};
+ ${p.registerUniform("output_size","u32").declareVariables(a,l)}
+ ${p.mainStart()}
+ ${p.guardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")}
+ let output_indices = ${l.offsetToIndices("global_idx")};
+ var input_indices: ${a.type.indices};
+ for (var i = 0; i < ${t.length}; i++) {
+ let input_dim_i = ${a.indicesGet("uniforms.input_shape","i")};
+ let input_dim_value = ${l.indicesGet("output_indices","i")} % input_dim_i;
+
+ ${a.indicesSet("input_indices","i","input_dim_value")}
+ }
+ ${l.setByOffset("global_idx",a.getByIndices("input_indices"))}
+ }`;return{name:"Tile",shaderCache:{hint:`${s}`,inputDependencies:["rank"]},getRunData:()=>({outputs:[{dims:o,dataType:e[0].dataType}],dispatchGroup:{x:Math.ceil(n/64)},programUniforms:[{type:12,data:n},...nt(e[0].dims,o)]}),getShaderSource:c}},pb=e=>{Iw(e.inputs),e.compute(Fw(e.inputs),{inputs:[0]})}}),Ow,Dw,hb,Lx=Ve(()=>{mt(),bt(),xt(),Ow=(e,r,t,s,o)=>{let n=tt("output_data",o,t.length,4),i=$e("a_data",r[1].dataType,r[1].dims.length,4),a=$e("b_data",r[2].dataType,r[2].dims.length,4),l=$e("c_data",r[0].dataType,r[0].dims.length,4),c,p=(u,h,g)=>`select(${h}, ${u}, ${g})`;if(!s)c=n.setByOffset("global_idx",p(i.getByOffset("global_idx"),a.getByOffset("global_idx"),l.getByOffset("global_idx")));else{let u=(h,g,_="")=>{let E=`a_data[index_a${g}][component_a${g}]`,I=`b_data[index_b${g}][component_b${g}]`,M=`bool(c_data[index_c${g}] & (0xffu << (component_c${g} * 8)))`;return`
+ let output_indices${g} = ${n.offsetToIndices(`global_idx * 4u + ${g}u`)};
+ let offset_a${g} = ${i.broadcastedIndicesToOffset(`output_indices${g}`,n)};
+ let offset_b${g} = ${a.broadcastedIndicesToOffset(`output_indices${g}`,n)};
+ let offset_c${g} = ${l.broadcastedIndicesToOffset(`output_indices${g}`,n)};
+ let index_a${g} = offset_a${g} / 4u;
+ let index_b${g} = offset_b${g} / 4u;
+ let index_c${g} = offset_c${g} / 4u;
+ let component_a${g} = offset_a${g} % 4u;
+ let component_b${g} = offset_b${g} % 4u;
+ let component_c${g} = offset_c${g} % 4u;
+ ${h}[${g}] = ${_}(${p(E,I,M)});
+ `};o===9?c=`
+ var data = vec4(0);
+ ${u("data",0,"u32")}
+ ${u("data",1,"u32")}
+ ${u("data",2,"u32")}
+ ${u("data",3,"u32")}
+ output_data[global_idx] = dot(vec4(0x1, 0x100, 0x10000, 0x1000000), vec4(data));`:c=`
+ ${u("output_data[global_idx]",0)}
+ ${u("output_data[global_idx]",1)}
+ ${u("output_data[global_idx]",2)}
+ ${u("output_data[global_idx]",3)}
+ `}return`
+ ${e.registerUniform("vec_size","u32").declareVariables(l,i,a,n)}
+ ${e.mainStart()}
+ ${e.guardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size")}
+ ${c}
+ }`},Dw=e=>{let r=e[1].dims,t=e[2].dims,s=e[0].dims,o=e[1].dataType,n=!(xe.areEqual(r,t)&&xe.areEqual(t,s)),i=r,a=xe.size(r);if(n){let c=So.calcShape(So.calcShape(r,t,!1),s,!1);if(!c)throw new Error("Can't perform where op on the given tensors");i=c,a=xe.size(i)}let l=Math.ceil(a/4);return{name:"Where",shaderCache:{inputDependencies:["rank","rank","rank"]},getShaderSource:c=>Ow(c,e,i,n,o),getRunData:()=>({outputs:[{dims:i,dataType:o}],dispatchGroup:{x:Math.ceil(a/64/4)},programUniforms:[{type:12,data:l},...nt(s,r,t,i)]})}},hb=e=>{e.compute(Dw(e.inputs))}}),mb,zx=Ve(()=>{Xv(),bu(),Jv(),Yv(),Zv(),ex(),tx(),ix(),lx(),dx(),cx(),ux(),px(),hx(),mx(),fx(),_x(),gx(),wx(),yx(),Mx(),bx(),vx(),xx(),Tx(),O0(),Ex(),Px(),Cx(),Sx(),$x(),Mu(),kx(),Ix(),Ax(),Fx(),Ox(),z0(),Dx(),cn(),vu(),Lx(),mb=new Map([["Abs",[lM]],["Acos",[dM]],["Acosh",[cM]],["Add",[WM]],["ArgMax",[nM,Qc]],["ArgMin",[sM,Qc]],["Asin",[uM]],["Asinh",[pM]],["Atan",[hM]],["Atanh",[mM]],["Attention",[oM]],["AveragePool",[K0,G0]],["BatchNormalization",[iM]],["BiasAdd",[aM]],["BiasSplitGelu",[UM]],["Cast",[_M,fM]],["Ceil",[wM]],["Clip",[gM]],["Concat",[e0,t0]],["Conv",[tu,eu]],["ConvTranspose",[u0,c0]],["Cos",[yM]],["Cosh",[MM]],["CumSum",[p0,h0]],["DepthToSpace",[m0,f0]],["DequantizeLinear",[Z0,eb]],["Div",[GM]],["Einsum",[_0,g0]],["Elu",[bM,da]],["Equal",[KM]],["Erf",[vM]],["Exp",[xM]],["Expand",[w0]],["FastGelu",[y0]],["Floor",[TM]],["FusedConv",[tu,eu]],["Gather",[b0,M0]],["GatherElements",[C0,P0]],["GatherBlockQuantized",[T0,E0]],["GatherND",[v0,x0]],["Gelu",[EM]],["Gemm",[$0,S0]],["GlobalAveragePool",[q0,H0]],["GlobalMaxPool",[Y0,J0]],["Greater",[XM]],["GreaterOrEqual",[YM]],["GridSample",[k0,I0]],["GroupQueryAttention",[B0]],["HardSigmoid",[FM,AM]],["InstanceNormalization",[R0]],["LayerNormalization",[N0]],["LeakyRelu",[PM,da]],["Less",[JM]],["LessOrEqual",[ZM]],["Log",[jM]],["MatMul",[j0]],["MatMulNBits",[V0,U0]],["MaxPool",[Q0,X0]],["Mul",[HM]],["MultiHeadAttention",[F0,A0]],["Neg",[SM]],["Not",[CM]],["Pad",[W0]],["Pow",[qM]],["QuickGelu",[VM,da]],["Range",[tb]],["Reciprocal",[$M]],["ReduceMin",[Yy]],["ReduceMean",[Hy]],["ReduceMax",[Jy]],["ReduceSum",[eM]],["ReduceProd",[Zy]],["ReduceL1",[qy]],["ReduceL2",[Qy]],["ReduceLogSum",[rM]],["ReduceLogSumExp",[Xy]],["ReduceSumSquare",[tM]],["Relu",[kM]],["Resize",[nb,ob]],["RotaryEmbedding",[ib]],["ScatterND",[sb,rb]],["Sigmoid",[IM]],["Sin",[OM]],["Sinh",[DM]],["Slice",[lb,db]],["SkipLayerNormalization",[ab]],["Split",[D0,L0]],["Sqrt",[LM]],["Softmax",[cb,ub]],["Sub",[QM]],["Tan",[zM]],["Tanh",[BM]],["ThresholdedRelu",[NM,da]],["Tile",[pb]],["Transpose",[Dy,Ly]],["Where",[hb]]])}),fb,Bx=Ve(()=>{Ms(),Us(),xt(),fb=class{constructor(e){this.backend=e,this.repo=new Map,this.attributesBound=!1}getArtifact(e){return this.repo.get(e)}setArtifact(e,r){this.repo.set(e,r)}run(e,r,t,s,o){ys(e.programInfo.name);let n=this.backend.device,i=this.backend.getComputePassEncoder();this.backend.writeTimestamp(this.backend.pendingDispatchNumber*2);let a=[];for(let c of r)a.push({binding:a.length,resource:{buffer:c.buffer}});for(let c of t)a.push({binding:a.length,resource:{buffer:c.buffer}});o&&a.push({binding:a.length,resource:o});let l=n.createBindGroup({layout:e.computePipeline.getBindGroupLayout(0),entries:a,label:e.programInfo.name});if(this.backend.sessionStatus==="capturing"){let c={kernelId:this.backend.currentKernelId,computePipeline:e.computePipeline,bindGroup:l,dispatchGroup:s};this.backend.capturedCommandList.get(this.backend.currentSessionId).push(c)}i.setPipeline(e.computePipeline),i.setBindGroup(0,l),i.dispatchWorkgroups(...s),this.backend.writeTimestamp(this.backend.pendingDispatchNumber*2+1),this.backend.pendingDispatchNumber++,(this.backend.pendingDispatchNumber>=this.backend.maxDispatchNumber||this.backend.queryType==="at-passes")&&this.backend.endComputePass(),this.backend.pendingDispatchNumber>=this.backend.maxDispatchNumber&&this.backend.flush(),Zr(e.programInfo.name)}dispose(){}build(e,r){ys(e.name);let t=this.backend.device,s=[];[{feature:"shader-f16",extension:"f16"},{feature:"subgroups",extension:"subgroups"}].forEach(c=>{t.features.has(c.feature)&&s.push(`enable ${c.extension};`)});let o=Oy(r,this.backend.device.limits),n=e.getShaderSource(o),i=`${s.join(`
+`)}
+${o.additionalImplementations}
+${n}`,a=t.createShaderModule({code:i,label:e.name});St("verbose",()=>`[WebGPU] ${e.name} shader code: ${i}`);let l=t.createComputePipeline({compute:{module:a,entryPoint:"main"},layout:"auto",label:e.name});return Zr(e.name),{programInfo:e,computePipeline:l,uniformVariablesInfo:o.variablesInfo}}normalizeDispatchGroupSize(e){let r=typeof e=="number"?e:e.x,t=typeof e=="number"?1:e.y||1,s=typeof e=="number"?1:e.z||1,o=this.backend.device.limits.maxComputeWorkgroupsPerDimension;if(r<=o&&t<=o&&s<=o)return[r,t,s];let n=r*t*s,i=Math.ceil(Math.sqrt(n));if(i>o){if(i=Math.ceil(Math.cbrt(n)),i>o)throw new Error("Total dispatch size exceeds WebGPU maximum.");return[i,i,i]}else return[i,i,1]}}}),Lw,zw,Bw,_b,Rx=Ve(()=>{Ms(),mt(),Us(),Sy(),qv(),zx(),Bx(),Lw=(e,r)=>{if(r.length!==e.length)throw new Error(`inputDependencies length ${r.length} is not equal to inputTensors length ${e.length}.`);let t=[];for(let s=0;s{var o,n;let s=e.name;return(o=e.shaderCache)!=null&&o.hint&&(s+="["+e.shaderCache.hint+"]"),s+=":"+t+`:${Lw(r,((n=e.shaderCache)==null?void 0:n.inputDependencies)??new Array(r.length).fill("dims"))}`,s},Bw=class{constructor(e){e&&(this.architecture=e.architecture,this.vendor=e.vendor)}isArchitecture(e){return this.architecture===e}isVendor(e){return this.vendor===e}},_b=class{constructor(){this.currentSessionId=null,this.currentKernelId=null,this.commandEncoder=null,this.computePassEncoder=null,this.maxDispatchNumber=16,this.pendingDispatchNumber=0,this.pendingKernels=[],this.pendingQueries=new Map,this.sessionStatus="default",this.capturedCommandList=new Map,this.capturedPendingKernels=new Map,this.sessionExternalDataMapping=new Map}get currentKernelCustomData(){if(this.currentKernelId===null)throw new Error("currentKernelCustomData(): currentKernelId is null. (should not happen)");let e=this.kernelCustomData.get(this.currentKernelId);return e||(e={},this.kernelCustomData.set(this.currentKernelId,e)),e}async initialize(e,r){this.env=e;let t=[],s={requiredLimits:{maxComputeWorkgroupStorageSize:r.limits.maxComputeWorkgroupStorageSize,maxComputeWorkgroupsPerDimension:r.limits.maxComputeWorkgroupsPerDimension,maxStorageBufferBindingSize:r.limits.maxStorageBufferBindingSize,maxBufferSize:r.limits.maxBufferSize,maxComputeInvocationsPerWorkgroup:r.limits.maxComputeInvocationsPerWorkgroup,maxComputeWorkgroupSizeX:r.limits.maxComputeWorkgroupSizeX,maxComputeWorkgroupSizeY:r.limits.maxComputeWorkgroupSizeY,maxComputeWorkgroupSizeZ:r.limits.maxComputeWorkgroupSizeZ},requiredFeatures:t},o=n=>r.features.has(n)&&t.push(n)&&!0;o("chromium-experimental-timestamp-query-inside-passes")||o("timestamp-query"),o("shader-f16"),o("subgroups"),this.device=await r.requestDevice(s),this.adapterInfo=new Bw(r.info||await r.requestAdapterInfo()),this.gpuDataManager=$y(this),this.programManager=new fb(this),this.kernels=new Map,this.kernelPersistentData=new Map,this.kernelCustomData=new Map,_u(e.logLevel,!!e.debug),this.device.onuncapturederror=n=>{n.error instanceof GPUValidationError&&console.error(`An uncaught WebGPU validation error was raised: ${n.error.message}`)},Object.defineProperty(this.env.webgpu,"device",{value:this.device,writable:!1,enumerable:!0,configurable:!1}),Object.defineProperty(this.env.webgpu,"adapter",{value:r,writable:!1,enumerable:!0,configurable:!1}),this.setQueryType()}dispose(){typeof this.querySet<"u"&&this.querySet.destroy(),this.gpuDataManager.dispose()}getCommandEncoder(){return this.commandEncoder||(this.commandEncoder=this.device.createCommandEncoder()),this.commandEncoder}getComputePassEncoder(){if(!this.computePassEncoder){let e=this.getCommandEncoder(),r={};this.queryType==="at-passes"&&(r.timestampWrites={querySet:this.querySet,beginningOfPassWriteIndex:this.pendingDispatchNumber*2,endOfPassWriteIndex:this.pendingDispatchNumber*2+1}),this.computePassEncoder=e.beginComputePass(r)}return this.computePassEncoder}endComputePass(){this.computePassEncoder&&(this.computePassEncoder.end(),this.computePassEncoder=null)}flush(){if(!this.commandEncoder)return;ys(),this.endComputePass();let e;this.queryType!=="none"&&(this.commandEncoder.resolveQuerySet(this.querySet,0,this.pendingDispatchNumber*2,this.queryResolveBuffer,0),e=this.device.createBuffer({size:this.pendingDispatchNumber*2*8,usage:GPUBufferUsage.MAP_READ|GPUBufferUsage.COPY_DST}),this.pendingQueries.set(e,this.pendingKernels),this.pendingKernels=[],this.commandEncoder.copyBufferToBuffer(this.queryResolveBuffer,0,e,0,this.pendingDispatchNumber*2*8)),this.device.queue.submit([this.commandEncoder.finish()]),this.gpuDataManager.refreshPendingBuffers(),this.commandEncoder=null,this.pendingDispatchNumber=0,this.queryType!=="none"&&e.mapAsync(GPUMapMode.READ).then(()=>{var s;let r=new BigUint64Array(e.getMappedRange()),t=this.pendingQueries.get(e);for(let o=0;o"u"&&(this.queryTimeBase=g);let E=Number(g-this.queryTimeBase),I=Number(_-this.queryTimeBase);if(!Number.isSafeInteger(E)||!Number.isSafeInteger(I))throw new RangeError("incorrect timestamp range");if((s=this.env.webgpu.profiling)!=null&&s.ondata)this.env.webgpu.profiling.ondata({version:1,inputsMetadata:u.map(M=>({dims:M.dims,dataType:In(M.dataType)})),outputsMetadata:h.map(M=>({dims:M.dims,dataType:In(M.dataType)})),kernelId:i,kernelType:l,kernelName:c,programName:p,startTime:E,endTime:I});else{let M="";u.forEach(($,P)=>{M+=`input[${P}]: [${$.dims}] | ${In($.dataType)}, `});let y="";h.forEach(($,P)=>{y+=`output[${P}]: [${$.dims}] | ${In($.dataType)}, `}),console.log(`[profiling] kernel "${i}|${l}|${c}|${p}" ${M}${y}execution time: ${I-E} ns`)}pa("GPU",`${p}::${g}::${_}`)}e.unmap(),this.pendingQueries.delete(e)}),Zr()}run(e,r,t,s,o,n){ys(e.name);let i=[];for(let y=0;y$):t;if(p.length!==a.length)throw new Error(`Output size ${p.length} must be equal to ${a.length}.`);let u=[],h=[];for(let y=0;y=n)throw new Error(`Invalid output index: ${p[y]}`);if(p[y]===-3)continue;let $=p[y]===-1,P=p[y]===-2,b=$||P?o(a[y].dataType,a[y].dims):s(p[y],a[y].dataType,a[y].dims);if(u.push(b),b.data===0)continue;let w=this.gpuDataManager.get(b.data);if(!w)throw new Error(`no GPU data for output: ${b.data}`);if($&&this.temporaryData.push(w),P){let T=this.kernelPersistentData.get(this.currentKernelId);T||(T=[],this.kernelPersistentData.set(this.currentKernelId,T)),T.push(w)}h.push(w)}if(i.length!==r.length||h.length!==u.length){if(h.length===0)return Zr(e.name),u;throw new Error(`Program ${e.name} has zero-sized tensor(s) in inputs or outputs. This is not supported now.`)}let g;if(c){let y=0,$=[];c.forEach(T=>{let k=typeof T.data=="number"?[T.data]:T.data;if(k.length===0)return;let z=T.type===10?2:4,R,Q;T.type===10?(Q=k.length>4?16:k.length>2?8:k.length*z,R=k.length>4?16:z*k.length):(Q=k.length<=2?k.length*z:16,R=16),y=Math.ceil(y/Q)*Q,$.push(y);let q=T.type===10?8:4;y+=k.length>4?Math.ceil(k.length/q)*R:k.length*z});let P=16;y=Math.ceil(y/P)*P;let b=new ArrayBuffer(y);c.forEach((T,k)=>{let z=$[k],R=typeof T.data=="number"?[T.data]:T.data;if(T.type===6)new Int32Array(b,z,R.length).set(R);else if(T.type===12)new Uint32Array(b,z,R.length).set(R);else if(T.type===10)new Uint16Array(b,z,R.length).set(R);else if(T.type===1)new Float32Array(b,z,R.length).set(R);else throw new Error(`Unsupported uniform type: ${In(T.type)}`)});let w=this.gpuDataManager.create(y,GPUBufferUsage.COPY_DST|GPUBufferUsage.UNIFORM);this.device.queue.writeBuffer(w.buffer,0,b,0,y),this.gpuDataManager.release(w.id),g={offset:0,size:y,buffer:w.buffer}}let _=this.programManager.normalizeDispatchGroupSize(l),E=_[1]===1&&_[2]===1,I=zw(e,r,E),M=this.programManager.getArtifact(I);if(M||(M=this.programManager.build(e,_),this.programManager.setArtifact(I,M),St("info",()=>`[artifact] key: ${I}, programName: ${e.name}`)),c&&M.uniformVariablesInfo){if(c.length!==M.uniformVariablesInfo.length)throw new Error(`Uniform variables count mismatch: expect ${M.uniformVariablesInfo.length}, got ${c.length} in program "${M.programInfo.name}".`);for(let y=0;y