Skip to content

Commit 5e2d57b

Browse files
BB-fatalexju
and
alexju
authored
metal : simplify kernel arguments using a struct (#3229) (#12194)
* metal : refactor im2col parameters into a struct * metal: Change im2col offset types from int32_t to uint64_t to support larger memory offsets * metal : refactor sum_rows parameters into a struct * metal : refactor soft_max parameters into a struct * metal : refactor diag_mask_inf parameters into a struct * metal : refactor ssm_conv parameters into a struct * metal : refactor ssm_scan parameters into a struct * metal : refactor get_rows parameters into a struct * metal : refactor group_norm parameters into a struct * metal : refactor conv_transpose_1d parameters into a struct * metal : refactor upscale parameters into a struct * metal : refactor pad parameters into a struct * metal : refactor pad_reflect_1d parameters into a struct * metal : refactor arange parameters into a struct * metal : refactor timestep_embedding parameters into a struct * metal : refactor argsort parameters into a struct * metal : refactor leaky_relu parameters into a struct * metal : refactor pool_2d parameters into a struct * metal : fix trailing whitespace --------- Co-authored-by: alexju <[email protected]>
1 parent f1648e9 commit 5e2d57b

File tree

3 files changed

+685
-643
lines changed

3 files changed

+685
-643
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

+235
Original file line numberDiff line numberDiff line change
@@ -285,4 +285,239 @@ typedef struct {
285285
float eps;
286286
} ggml_metal_kargs_rms_norm;
287287

288+
typedef struct {
289+
int64_t ne00;
290+
int64_t ne01;
291+
int64_t ne02;
292+
uint64_t nb00;
293+
uint64_t nb01;
294+
uint64_t nb02;
295+
int32_t n_groups;
296+
float eps;
297+
} ggml_metal_kargs_group_norm;
298+
299+
typedef struct {
300+
int32_t IC;
301+
int32_t IL;
302+
int32_t K;
303+
int32_t s0;
304+
uint64_t nb0;
305+
uint64_t nb1;
306+
} ggml_metal_kargs_conv_transpose_1d;
307+
308+
typedef struct {
309+
uint64_t ofs0;
310+
uint64_t ofs1;
311+
int32_t IW;
312+
int32_t IH;
313+
int32_t CHW;
314+
int32_t s0;
315+
int32_t s1;
316+
int32_t p0;
317+
int32_t p1;
318+
int32_t d0;
319+
int32_t d1;
320+
int32_t N;
321+
int32_t KH;
322+
int32_t KW;
323+
int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
324+
} ggml_metal_kargs_im2col;
325+
326+
typedef struct {
327+
int64_t ne00;
328+
int64_t ne01;
329+
int64_t ne02;
330+
int64_t ne03;
331+
uint64_t nb00;
332+
uint64_t nb01;
333+
uint64_t nb02;
334+
uint64_t nb03;
335+
int64_t ne10;
336+
int64_t ne11;
337+
int64_t ne12;
338+
int64_t ne13;
339+
uint64_t nb10;
340+
uint64_t nb11;
341+
uint64_t nb12;
342+
uint64_t nb13;
343+
int64_t ne0;
344+
int64_t ne1;
345+
int64_t ne2;
346+
int64_t ne3;
347+
uint64_t nb0;
348+
uint64_t nb1;
349+
uint64_t nb2;
350+
uint64_t nb3;
351+
} ggml_metal_kargs_sum_rows;
352+
353+
typedef struct {
354+
int64_t ne00;
355+
int64_t ne01;
356+
int64_t ne02;
357+
float scale;
358+
float max_bias;
359+
float m0;
360+
float m1;
361+
uint32_t n_head_log2;
362+
} ggml_metal_kargs_soft_max;
363+
364+
typedef struct {
365+
int64_t ne00;
366+
int64_t ne01;
367+
int n_past;
368+
} ggml_metal_kargs_diag_mask_inf;
369+
370+
typedef struct {
371+
int64_t ne00;
372+
int64_t ne01;
373+
int64_t ne02;
374+
uint64_t nb00;
375+
uint64_t nb01;
376+
uint64_t nb02;
377+
int64_t ne10;
378+
int64_t ne11;
379+
uint64_t nb10;
380+
uint64_t nb11;
381+
int64_t ne0;
382+
int64_t ne1;
383+
int64_t ne2;
384+
uint64_t nb0;
385+
uint64_t nb1;
386+
uint64_t nb2;
387+
} ggml_metal_kargs_ssm_conv;
388+
389+
typedef struct {
390+
int64_t d_state;
391+
int64_t d_inner;
392+
int64_t n_seq_tokens;
393+
int64_t n_seqs;
394+
uint64_t nb00;
395+
uint64_t nb01;
396+
uint64_t nb02;
397+
uint64_t nb10;
398+
uint64_t nb11;
399+
uint64_t nb12;
400+
uint64_t nb13;
401+
uint64_t nb20;
402+
uint64_t nb21;
403+
uint64_t nb22;
404+
uint64_t nb30;
405+
uint64_t nb31;
406+
uint64_t nb40;
407+
uint64_t nb41;
408+
uint64_t nb42;
409+
uint64_t nb50;
410+
uint64_t nb51;
411+
uint64_t nb52;
412+
} ggml_metal_kargs_ssm_scan;
413+
414+
typedef struct {
415+
int64_t ne00;
416+
uint64_t nb01;
417+
uint64_t nb02;
418+
int64_t ne10;
419+
uint64_t nb10;
420+
uint64_t nb11;
421+
uint64_t nb1;
422+
uint64_t nb2;
423+
} ggml_metal_kargs_get_rows;
424+
425+
typedef struct {
426+
int64_t ne00;
427+
int64_t ne01;
428+
int64_t ne02;
429+
int64_t ne03;
430+
uint64_t nb00;
431+
uint64_t nb01;
432+
uint64_t nb02;
433+
uint64_t nb03;
434+
int64_t ne0;
435+
int64_t ne1;
436+
int64_t ne2;
437+
int64_t ne3;
438+
uint64_t nb0;
439+
uint64_t nb1;
440+
uint64_t nb2;
441+
uint64_t nb3;
442+
float sf0;
443+
float sf1;
444+
float sf2;
445+
float sf3;
446+
} ggml_metal_kargs_upscale;
447+
448+
typedef struct {
449+
int64_t ne00;
450+
int64_t ne01;
451+
int64_t ne02;
452+
int64_t ne03;
453+
uint64_t nb00;
454+
uint64_t nb01;
455+
uint64_t nb02;
456+
uint64_t nb03;
457+
int64_t ne0;
458+
int64_t ne1;
459+
int64_t ne2;
460+
int64_t ne3;
461+
uint64_t nb0;
462+
uint64_t nb1;
463+
uint64_t nb2;
464+
uint64_t nb3;
465+
} ggml_metal_kargs_pad;
466+
467+
typedef struct {
468+
int64_t ne00;
469+
int64_t ne01;
470+
int64_t ne02;
471+
int64_t ne03;
472+
uint64_t nb00;
473+
uint64_t nb01;
474+
uint64_t nb02;
475+
uint64_t nb03;
476+
int64_t ne0;
477+
int64_t ne1;
478+
int64_t ne2;
479+
int64_t ne3;
480+
uint64_t nb0;
481+
uint64_t nb1;
482+
uint64_t nb2;
483+
uint64_t nb3;
484+
int32_t p0;
485+
int32_t p1;
486+
} ggml_metal_kargs_pad_reflect_1d;
487+
488+
typedef struct {
489+
uint64_t nb1;
490+
int dim;
491+
int max_period;
492+
} ggml_metal_kargs_timestep_embedding;
493+
494+
typedef struct {
495+
float slope;
496+
} ggml_metal_kargs_leaky_relu;
497+
498+
typedef struct {
499+
int64_t ncols;
500+
int64_t ncols_pad;
501+
} ggml_metal_kargs_argsort;
502+
503+
typedef struct {
504+
int64_t ne0;
505+
float start;
506+
float step;
507+
} ggml_metal_kargs_arange;
508+
509+
typedef struct {
510+
int32_t k0;
511+
int32_t k1;
512+
int32_t s0;
513+
int32_t s1;
514+
int32_t p0;
515+
int32_t p1;
516+
int64_t IH;
517+
int64_t IW;
518+
int64_t OH;
519+
int64_t OW;
520+
int64_t parallel_elements;
521+
} ggml_metal_kargs_pool_2d;
522+
288523
#endif // GGML_METAL_IMPL

0 commit comments

Comments
 (0)