mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-03-06 20:48:53 +01:00
metal : refactor soft_max parameters into a struct
This commit is contained in:
parent
cd3dcdba46
commit
dba23c7e8b
3 changed files with 53 additions and 54 deletions
|
@ -330,4 +330,15 @@ typedef struct {
|
|||
uint64_t nb3;
|
||||
} ggml_metal_kargs_sum_rows;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
int64_t ne02;
|
||||
float scale;
|
||||
float max_bias;
|
||||
float m0;
|
||||
float m1;
|
||||
uint32_t n_head_log2;
|
||||
} ggml_metal_kargs_soft_max;
|
||||
|
||||
#endif // GGML_METAL_IMPL
|
||||
|
|
|
@ -2024,8 +2024,17 @@ static void ggml_metal_encode_node(
|
|||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
// TODO: add ggml_metal_kargs struct
|
||||
// TODO: optimize (see https://github.com/ggml-org/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6)
|
||||
ggml_metal_kargs_soft_max args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.scale =*/ scale,
|
||||
/*.max_bias =*/ max_bias,
|
||||
/*.m0 =*/ m0,
|
||||
/*.m1 =*/ m1,
|
||||
/*.n_head_log2 =*/ n_head_log2,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
if (id_src1) {
|
||||
|
@ -2034,14 +2043,7 @@ static void ggml_metal_encode_node(
|
|||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
}
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
||||
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
|
||||
[encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7];
|
||||
[encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
|
||||
[encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
|
||||
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
||||
|
||||
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||
|
||||
|
|
|
@ -975,36 +975,29 @@ kernel void kernel_soft_max(
|
|||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant float & scale,
|
||||
constant float & max_bias,
|
||||
constant float & m0,
|
||||
constant float & m1,
|
||||
constant uint32_t & n_head_log2,
|
||||
constant ggml_metal_kargs_soft_max & args,
|
||||
threadgroup float * buf [[threadgroup(0)]],
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t i03 = (tgpig) / (ne02*ne01);
|
||||
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
||||
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
||||
const int64_t i03 = (tgpig) / (args.ne02*args.ne01);
|
||||
const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01;
|
||||
const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
|
||||
|
||||
device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
|
||||
device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
device const float * psrc0 = (device const float *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
|
||||
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00 : nullptr;
|
||||
device float * pdst = (device float *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
|
||||
|
||||
float slope = 1.0f;
|
||||
|
||||
// ALiBi
|
||||
if (max_bias > 0.0f) {
|
||||
if (args.max_bias > 0.0f) {
|
||||
const int64_t h = i02;
|
||||
|
||||
const float base = h < n_head_log2 ? m0 : m1;
|
||||
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
|
||||
const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
|
||||
|
||||
slope = pow(base, exp);
|
||||
}
|
||||
|
@ -1012,8 +1005,8 @@ kernel void kernel_soft_max(
|
|||
// parallel max
|
||||
float lmax = -INFINITY;
|
||||
|
||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
|
||||
for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
|
||||
lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
|
||||
}
|
||||
|
||||
// find the max value in the block
|
||||
|
@ -1037,8 +1030,8 @@ kernel void kernel_soft_max(
|
|||
|
||||
// parallel sum
|
||||
float lsum = 0.0f;
|
||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
|
||||
for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
|
||||
const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
|
||||
lsum += exp_psrc0;
|
||||
pdst[i00] = exp_psrc0;
|
||||
}
|
||||
|
@ -1068,7 +1061,7 @@ kernel void kernel_soft_max(
|
|||
|
||||
const float inv_sum = 1.0f/sum;
|
||||
|
||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||
for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
|
||||
pdst[i00] *= inv_sum;
|
||||
}
|
||||
}
|
||||
|
@ -1078,35 +1071,28 @@ kernel void kernel_soft_max_4(
|
|||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant float & scale,
|
||||
constant float & max_bias,
|
||||
constant float & m0,
|
||||
constant float & m1,
|
||||
constant uint32_t & n_head_log2,
|
||||
constant ggml_metal_kargs_soft_max & args,
|
||||
threadgroup float * buf [[threadgroup(0)]],
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t i03 = (tgpig) / (ne02*ne01);
|
||||
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
||||
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
||||
const int64_t i03 = (tgpig) / (args.ne02*args.ne01);
|
||||
const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01;
|
||||
const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
|
||||
|
||||
device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
||||
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
|
||||
device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
||||
device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
|
||||
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00/4 : nullptr;
|
||||
device float4 * pdst4 = (device float4 *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
|
||||
|
||||
float slope = 1.0f;
|
||||
|
||||
if (max_bias > 0.0f) {
|
||||
if (args.max_bias > 0.0f) {
|
||||
const int64_t h = i02;
|
||||
|
||||
const float base = h < n_head_log2 ? m0 : m1;
|
||||
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
|
||||
const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
|
||||
|
||||
slope = pow(base, exp);
|
||||
}
|
||||
|
@ -1114,8 +1100,8 @@ kernel void kernel_soft_max_4(
|
|||
// parallel max
|
||||
float4 lmax4 = -INFINITY;
|
||||
|
||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||
lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
|
||||
for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
|
||||
lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
|
||||
}
|
||||
|
||||
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
||||
|
@ -1140,8 +1126,8 @@ kernel void kernel_soft_max_4(
|
|||
|
||||
// parallel sum
|
||||
float4 lsum4 = 0.0f;
|
||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
|
||||
for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
|
||||
const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
|
||||
lsum4 += exp_psrc4;
|
||||
pdst4[i00] = exp_psrc4;
|
||||
}
|
||||
|
@ -1173,7 +1159,7 @@ kernel void kernel_soft_max_4(
|
|||
|
||||
const float inv_sum = 1.0f/sum;
|
||||
|
||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||
for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
|
||||
pdst4[i00] *= inv_sum;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue