Arrcttacsrks commited on
Commit
d70f8dd
·
verified ·
1 Parent(s): 5e1633b

Upload llama.cpp/ggml/src/ggml-cuda/fattn.cu with huggingface_hub

Browse files
Files changed (1) hide show
  1. llama.cpp/ggml/src/ggml-cuda/fattn.cu +345 -0
llama.cpp/ggml/src/ggml-cuda/fattn.cu ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "common.cuh"
2
+ #include "fattn-common.cuh"
3
+ #include "fattn-tile-f16.cuh"
4
+ #include "fattn-tile-f32.cuh"
5
+ #include "fattn-vec-f16.cuh"
6
+ #include "fattn-vec-f32.cuh"
7
+ #include "fattn-wmma-f16.cuh"
8
+ #include "fattn.cuh"
9
+
10
+ #include <cstdint>
11
+
12
+ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
13
+ const ggml_tensor * KQV = dst;
14
+ const ggml_tensor * Q = dst->src[0];
15
+
16
+ const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
17
+
18
+ if (prec != GGML_PREC_DEFAULT) {
19
+ if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
20
+ constexpr int cols_per_block = 16;
21
+ switch (Q->ne[0]) {
22
+ case 64:
23
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
24
+ break;
25
+ case 80:
26
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
27
+ break;
28
+ case 96:
29
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
30
+ break;
31
+ case 112:
32
+ ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
33
+ break;
34
+ case 128:
35
+ ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
36
+ break;
37
+ case 256:
38
+ ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
39
+ break;
40
+ default:
41
+ GGML_ABORT("fatal error");
42
+ break;
43
+ }
44
+ } else {
45
+ constexpr int cols_per_block = 32;
46
+ switch (Q->ne[0]) {
47
+ case 64:
48
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
49
+ break;
50
+ case 80:
51
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
52
+ break;
53
+ case 96:
54
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
55
+ break;
56
+ case 112:
57
+ ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
58
+ break;
59
+ case 128:
60
+ ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
61
+ break;
62
+ // case 256:
63
+ // ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
64
+ // break;
65
+ default:
66
+ GGML_ABORT("fatal error");
67
+ break;
68
+ }
69
+ }
70
+ return;
71
+ }
72
+
73
+ if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
74
+ constexpr int cols_per_block = 8;
75
+ switch (Q->ne[0]) {
76
+ case 64:
77
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
78
+ break;
79
+ case 96:
80
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
81
+ break;
82
+ case 128:
83
+ ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
84
+ break;
85
+ case 256:
86
+ ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
87
+ break;
88
+ default:
89
+ GGML_ABORT("fatal error");
90
+ break;
91
+ }
92
+ return;
93
+ }
94
+
95
+ if (Q->ne[1] <= 32) {
96
+ constexpr int cols_per_block = 16;
97
+ switch (Q->ne[0]) {
98
+ case 64:
99
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
100
+ break;
101
+ case 80:
102
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
103
+ break;
104
+ case 96:
105
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
106
+ break;
107
+ case 112:
108
+ ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
109
+ break;
110
+ case 128:
111
+ ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
112
+ break;
113
+ case 256:
114
+ ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
115
+ break;
116
+ default:
117
+ GGML_ABORT("fatal error");
118
+ break;
119
+ }
120
+ return;
121
+ }
122
+
123
+ constexpr int cols_per_block = 32;
124
+ switch (Q->ne[0]) {
125
+ case 64:
126
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
127
+ break;
128
+ case 80:
129
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
130
+ break;
131
+ case 96:
132
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
133
+ break;
134
+ case 112:
135
+ ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
136
+ break;
137
+ case 128:
138
+ ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
139
+ break;
140
+ case 256:
141
+ ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
142
+ break;
143
+ default:
144
+ GGML_ABORT("fatal error");
145
+ break;
146
+ }
147
+ }
148
+ #define FATTN_VEC_F16_CASE(D, type_K, type_V) \
149
+ if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
150
+ ggml_cuda_flash_attn_ext_vec_f16_case<D, type_K, type_V>(ctx, dst); \
151
+ return; \
152
+ } \
153
+
154
+ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
155
+ ggml_tensor * Q = dst->src[0];
156
+ ggml_tensor * K = dst->src[1];
157
+ ggml_tensor * V = dst->src[2];
158
+
159
+ #ifdef GGML_CUDA_FA_ALL_QUANTS
160
+ FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0)
161
+ FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1)
162
+ FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0)
163
+ FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1)
164
+ FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
165
+ FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16 )
166
+
167
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
168
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
169
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
170
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
171
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
172
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0)
173
+
174
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
175
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
176
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
177
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
178
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
179
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1)
180
+
181
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
182
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
183
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
184
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
185
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
186
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0)
187
+
188
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
189
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
190
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
191
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
192
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
193
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1)
194
+
195
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
196
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
197
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
198
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
199
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
200
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
201
+
202
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16)
203
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16)
204
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16)
205
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16)
206
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16)
207
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
208
+
209
+ FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
210
+ #else
211
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
212
+
213
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
214
+
215
+ FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
216
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
217
+ FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
218
+ #endif // GGML_CUDA_FA_ALL_QUANTS
219
+
220
+ on_no_fattn_vec_case(Q->ne[0]);
221
+ }
222
+
223
+ #define FATTN_VEC_F32_CASE(D, type_K, type_V) \
224
+ if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
225
+ ggml_cuda_flash_attn_ext_vec_f32_case<D, type_K, type_V>(ctx, dst); \
226
+ return; \
227
+ } \
228
+
229
+ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
230
+ ggml_tensor * Q = dst->src[0];
231
+ ggml_tensor * K = dst->src[1];
232
+ ggml_tensor * V = dst->src[2];
233
+
234
+ #ifdef GGML_CUDA_FA_ALL_QUANTS
235
+ FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0)
236
+ FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1)
237
+ FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0)
238
+ FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1)
239
+ FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
240
+ FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
241
+
242
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
243
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
244
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
245
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
246
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
247
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0)
248
+
249
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
250
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
251
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
252
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
253
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
254
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1)
255
+
256
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
257
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
258
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
259
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
260
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
261
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0)
262
+
263
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
264
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
265
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
266
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
267
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
268
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1)
269
+
270
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
271
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
272
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
273
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
274
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
275
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
276
+
277
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16)
278
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16)
279
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16)
280
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16)
281
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16)
282
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
283
+
284
+ FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
285
+ #else
286
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
287
+
288
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
289
+
290
+ FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
291
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
292
+ FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
293
+ #endif // GGML_CUDA_FA_ALL_QUANTS
294
+
295
+ on_no_fattn_vec_case(Q->ne[0]);
296
+ }
297
+
298
+ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
299
+ const ggml_tensor * KQV = dst;
300
+ const ggml_tensor * Q = dst->src[0];
301
+
302
+ ggml_cuda_set_device(ctx.device);
303
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
304
+ const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
305
+
306
+ // On AMD the tile kernels perform poorly, use the vec kernel instead:
307
+ if (cc >= CC_OFFSET_AMD) {
308
+ if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
309
+ ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
310
+ } else {
311
+ ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
312
+ }
313
+ return;
314
+ }
315
+
316
+ if (!fast_fp16_available(cc)) {
317
+ if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
318
+ ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
319
+ } else {
320
+ ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
321
+ }
322
+ return;
323
+ }
324
+
325
+ if (!fp16_mma_available(cc)) {
326
+ if (Q->ne[1] <= 8) {
327
+ ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
328
+ } else {
329
+ ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
330
+ }
331
+ return;
332
+ }
333
+
334
+ if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
335
+ if (prec == GGML_PREC_DEFAULT) {
336
+ ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
337
+ return;
338
+ } else if(Q->ne[0] <= 128) {
339
+ ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
340
+ return;
341
+ }
342
+ }
343
+
344
+ ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
345
+ }