From 93b94a2d41e880cb2abfb708535d5b04ad05b7a5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 26 Jun 2023 21:10:24 +0300 Subject: [PATCH 01/15] ggml : sync llama.cpp (NUMA + thread improvements + k-quants) --- include/ggml/ggml.h | 3 + src/ggml-cuda.cu | 369 ++++++++++++++++++++++++++----- src/ggml-metal.m | 66 +++--- src/ggml-metal.metal | 412 +++++++++++++++++++++++++++------- src/ggml.c | 510 ++++++++++++++++++++++++------------------- 5 files changed, 976 insertions(+), 384 deletions(-) diff --git a/include/ggml/ggml.h b/include/ggml/ggml.h index 5ebd9c46c..6b106b1c3 100644 --- a/include/ggml/ggml.h +++ b/include/ggml/ggml.h @@ -469,6 +469,9 @@ extern "C" { GGML_API int64_t ggml_cycles(void); GGML_API int64_t ggml_cycles_per_ms(void); + GGML_API void ggml_numa_init(void); // call once for better performance on NUMA systems + GGML_API bool ggml_is_numa(void); // true if init detected that system has >1 NUMA node + GGML_API void ggml_print_object (const struct ggml_object * obj); GGML_API void ggml_print_objects(const struct ggml_context * ctx); diff --git a/src/ggml-cuda.cu b/src/ggml-cuda.cu index 010682edb..c34e96abf 100644 --- a/src/ggml-cuda.cu +++ b/src/ggml-cuda.cu @@ -117,7 +117,13 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 blo //================================= k-quants +#ifdef GGML_QKK_64 +#define QK_K 64 +#define K_SCALE_SIZE 4 +#else #define QK_K 256 +#define K_SCALE_SIZE 12 +#endif typedef struct { uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits @@ -128,13 +134,25 @@ typedef struct { static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding"); typedef struct { - uint8_t hmask[QK_K/8]; - uint8_t qs[QK_K/4]; // nibbles / quants - uint8_t scales[3*QK_K/64]; - half d; + uint8_t hmask[QK_K/8]; // quants - high bit + uint8_t qs[QK_K/4]; // quants - low 2 bits +#ifdef GGML_QKK_64 + uint8_t scales[2]; // scales, quantized with 8 bits +#else + uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits +#endif + half d; // super-block scale } block_q3_K; -static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding"); +//static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding"); +#ifdef GGML_QKK_64 +typedef struct { + half d[2]; // super-block scales/mins + uint8_t scales[2]; // 4-bit block scales/mins + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_K; +static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding"); +#else typedef struct { half d; // super-block scale for quantized scales half dmin; // super-block scale for quantized mins @@ -142,15 +160,26 @@ typedef struct { uint8_t qs[QK_K/2]; // 4--bit quants } block_q4_K; static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding"); +#endif +#ifdef GGML_QKK_64 typedef struct { - half d; // super-block scale for quantized scales - half dmin; // super-block scale for quantized mins - uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits + half d; // super-block scale + int8_t scales[QK_K/16]; // block scales + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_K; +static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding"); +#else +typedef struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits uint8_t qh[QK_K/8]; // quants, high bit uint8_t qs[QK_K/2]; // quants, low 4 bits } block_q5_K; -static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); +static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); +#endif typedef struct { uint8_t ql[QK_K/2]; // quants, lower 4 bits @@ -349,13 +378,14 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in static __global__ void dequantize_block_q2_K(const void * vx, float * yy) { const int i = blockIdx.x; + const block_q2_K * x = (const block_q2_K *) vx; + const int tid = threadIdx.x; +#if QK_K == 256 const int n = tid/32; const int l = tid - 32*n; const int is = 8*n + l/16; - const block_q2_K * x = (const block_q2_K *) vx; - const uint8_t q = x[i].qs[32*n + l]; float * y = yy + i*QK_K + 128*n; @@ -365,21 +395,32 @@ static __global__ void dequantize_block_q2_K(const void * vx, float * yy) { y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4); y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4); y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4); +#else + const int is = tid/16; // 0 or 1 + const int il = tid%16; // 0...15 + const uint8_t q = x[i].qs[il] >> (2*is); + float * y = yy + i*QK_K + 16*is + il; + float dall = x[i].d; + float dmin = x[i].dmin; + y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4); + y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4); +#endif } static __global__ void dequantize_block_q3_K(const void * vx, float * yy) { - int r = threadIdx.x/4; - int i = blockIdx.x; - int tid = r/2; - int is0 = r%2; - int l0 = 16*is0 + 4*(threadIdx.x%4); - int n = tid / 4; - int j = tid - 4*n; - + const int i = blockIdx.x; const block_q3_K * x = (const block_q3_K *) vx; +#if QK_K == 256 + const int r = threadIdx.x/4; + const int tid = r/2; + const int is0 = r%2; + const int l0 = 16*is0 + 4*(threadIdx.x%4); + const int n = tid / 4; + const int j = tid - 4*n; + uint8_t m = 1 << (4*n + j); int is = 8*n + 2*j + is0; int shift = 2*j; @@ -396,9 +437,31 @@ static __global__ void dequantize_block_q3_K(const void * vx, float * yy) { const uint8_t * hm = x[i].hmask; for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)); +#else + const int tid = threadIdx.x; + const int is = tid/16; // 0 or 1 + const int il = tid%16; // 0...15 + const int im = il/8; // 0...1 + const int in = il%8; // 0...7 + + float * y = yy + i*QK_K + 16*is + il; + + const uint8_t q = x[i].qs[il] >> (2*is); + const uint8_t h = x[i].hmask[in] >> (2*is + im); + const float d = (float)x[i].d; + + if (is == 0) { + y[ 0] = d * ((x[i].scales[0] & 0xF) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4)); + y[32] = d * ((x[i].scales[1] & 0xF) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4)); + } else { + y[ 0] = d * ((x[i].scales[0] >> 4) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4)); + y[32] = d * ((x[i].scales[1] >> 4) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4)); + } +#endif } +#if QK_K == 256 static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) { if (j < 4) { d = q[j] & 63; m = q[j + 4] & 63; @@ -407,19 +470,14 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); } } +#endif static __global__ void dequantize_block_q4_K(const void * vx, float * yy) { const block_q4_K * x = (const block_q4_K *) vx; const int i = blockIdx.x; - //// assume 64 threads - this is very slightly better than the one below - //const int tid = threadIdx.x; - //const int il = tid/16; - //const int ir = tid%16; - //const int is = 2*il; - //const int n = 2; - +#if QK_K == 256 // assume 32 threads const int tid = threadIdx.x; const int il = tid/8; @@ -443,6 +501,15 @@ static __global__ void dequantize_block_q4_K(const void * vx, float * yy) { y[l + 0] = d1 * (q[l] & 0xF) - m1; y[l +32] = d2 * (q[l] >> 4) - m2; } +#else + const int tid = threadIdx.x; + const uint8_t * q = x[i].qs; + float * y = yy + i*QK_K; + const float d = (float)x[i].d[0]; + const float m = (float)x[i].d[1]; + y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4); + y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >> 4) - m * (x[i].scales[1] >> 4); +#endif } static __global__ void dequantize_block_q5_K(const void * vx, float * yy) { @@ -450,6 +517,7 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) { const int i = blockIdx.x; +#if QK_K == 256 // assume 64 threads - this is very slightly better than the one below const int tid = threadIdx.x; const int il = tid/16; // il is in 0...3 @@ -476,12 +544,25 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) { hm <<= 1; y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2; y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2; +#else + const int tid = threadIdx.x; + const uint8_t q = x[i].qs[tid]; + const int im = tid/8; // 0...3 + const int in = tid%8; // 0...7 + const int is = tid/16; // 0 or 1 + const uint8_t h = x[i].qh[in] >> im; + const float d = x[i].d; + float * y = yy + i*QK_K + tid; + y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16)); + y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16)); +#endif } static __global__ void dequantize_block_q6_K(const void * vx, float * yy) { const block_q6_K * x = (const block_q6_K *) vx; const int i = blockIdx.x; +#if QK_K == 256 // assume 64 threads - this is very slightly better than the one below const int tid = threadIdx.x; @@ -501,6 +582,24 @@ static __global__ void dequantize_block_q6_K(const void * vx, float * yy) { y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32); y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32); y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32); +#else + + // assume 32 threads + const int tid = threadIdx.x; + const int ip = tid/16; // 0 or 1 + const int il = tid - 16*ip; // 0...15 + + float * y = yy + i*QK_K + 16*ip + il; + + const float d = x[i].d; + + const uint8_t ql = x[i].ql[16*ip + il]; + const uint8_t qh = x[i].qh[il] >> (2*ip); + const int8_t * sc = x[i].scales; + + y[ 0] = d * sc[ip+0] * ((int8_t)((ql & 0xF) | (((qh >> 0) & 3) << 4)) - 32); + y[32] = d * sc[ip+2] * ((int8_t)((ql >> 4) | (((qh >> 4) & 3) << 4)) - 32); +#endif } static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) { @@ -515,6 +614,9 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float const block_q2_K * x = (const block_q2_K *)vx + ib0; + float tmp = 0; // partial sum for thread in warp + +#if QK_K == 256 const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15 const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 @@ -528,8 +630,6 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float const int s_offset = 8*im; const int y_offset = 128*im + l0; - float tmp = 0; // partial sum for thread in warp - uint32_t aux[4]; const uint8_t * d = (const uint8_t *)aux; const uint8_t * m = (const uint8_t *)(aux + 2); @@ -565,6 +665,39 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float tmp += dall * sum1 - dmin * sum2; } +#else + const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7 + const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3 + const int offset = tid * K_QUANTS_PER_ITERATION; + + uint32_t uaux[2]; + const uint8_t * d = (const uint8_t *)uaux; + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + offset; + const uint8_t * q = x[i].qs + offset; + const uint32_t * s = (const uint32_t *)x[i].scales; + + uaux[0] = s[0] & 0x0f0f0f0f; + uaux[1] = (s[0] >> 4) & 0x0f0f0f0f; + + const half2 * dh = (const half2 *)&x[i].d; + + const float2 dall = __half22float2(dh[0]); + + float sum1 = 0, sum2 = 0; + for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + const uint8_t ql = q[l]; + sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3) + + y[l+16] * d[1] * ((ql >> 2) & 3) + + y[l+32] * d[2] * ((ql >> 4) & 3) + + y[l+48] * d[3] * ((ql >> 6) & 3); + sum2 += y[l+0] * d[4] + y[l+16] * d[5] + y[l+32] * d[6] + y[l+48] * d[7]; + } + tmp += dall.x * sum1 - dall.y * sum2; + } +#endif // sum up partial sums and write back result __syncthreads(); @@ -573,16 +706,13 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); } - if (tid == 0) { + if (threadIdx.x == 0) { dst[row] = tmp; } } static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) { - const uint16_t kmask1 = 0x0303; - const uint16_t kmask2 = 0x0f0f; - const int row = blockIdx.y*blockDim.y + threadIdx.y; if (row > nrows) return; @@ -591,6 +721,13 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float const block_q3_K * x = (const block_q3_K *)vx + ib0; + float tmp = 0; // partial sum for thread in warp + +#if QK_K == 256 + + const uint16_t kmask1 = 0x0303; + const uint16_t kmask2 = 0x0f0f; + const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 @@ -610,8 +747,6 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float const uint16_t s_shift = 4*im; - float tmp = 0; // partial sum for thread in warp - for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { const float * y = yy + i * QK_K + y_offset; @@ -640,6 +775,34 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float tmp += d * sum; } +#else + + const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7 + const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3 + const int offset = tid * K_QUANTS_PER_ITERATION; // 0...15 or 0...14 + const int in = offset/8; // 0 or 1 + const int im = offset%8; // 0...7 + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + offset; + const uint8_t * q = x[i].qs + offset; + const uint8_t * s = x[i].scales; + + const float dall = (float)x[i].d; + + float sum = 0; + for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + const uint8_t hl = x[i].hmask[im+l] >> in; + const uint8_t ql = q[l]; + sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4)) + + y[l+16] * dall * ((s[0] >> 4) - 8) * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4)) + + y[l+32] * dall * ((s[1] & 0xF) - 8) * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4)) + + y[l+48] * dall * ((s[1] >> 4) - 8) * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4)); + } + tmp += sum; + } +#endif // sum up partial sums and write back result __syncthreads(); @@ -648,22 +811,25 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); } - if (tid == 0) { + if (threadIdx.x == 0) { dst[row] = tmp; } } static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) { - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; - const int row = blockIdx.y*blockDim.y + threadIdx.y; if (row > nrows) return; const int num_blocks_per_row = ncols / QK_K; const int ib0 = row*num_blocks_per_row; + const block_q4_K * x = (const block_q4_K *)vx + ib0; + +#if QK_K == 256 + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 @@ -683,8 +849,6 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float uint16_t aux[4]; const uint8_t * sc = (const uint8_t *)aux; - const block_q4_K * x = (const block_q4_K *)vx + ib0; - float tmp = 0; // partial sum for thread in warp for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { @@ -713,6 +877,36 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float tmp += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin; } +#else + const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 + const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); + + const int step = tid * K_QUANTS_PER_ITERATION; + + uint16_t aux16[2]; + const uint8_t * s = (const uint8_t *)aux16; + + float tmp = 0; + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + const uint8_t * q = x[i].qs + step; + const float * y = yy + i*QK_K + step; + const uint16_t * a = (const uint16_t *)x[i].scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + const float d = (float)x[i].d[0]; + const float m = (float)x[i].d[1]; + float sum = 0.f; + for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2]) + + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2]) + + y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3]) + + y[j+48] * (d * s[1] * (q[j+16] >> 4) - m * s[3]); + } + tmp += sum; + } + +#endif // sum up partial sums and write back result __syncthreads(); @@ -728,15 +922,19 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float * yy, float * dst, const int ncols) { - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; - - //const int row = blockIdx.x*blockDim.y + threadIdx.y; const int row = blockIdx.x; const int num_blocks_per_row = ncols / QK_K; const int ib0 = row*num_blocks_per_row; + const block_q5_K * x = (const block_q5_K *)vx + ib0; + + float tmp = 0; // partial sum for thread in warp + +#if QK_K == 256 + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + const int tid = threadIdx.x/2; // 0...15 const int ix = threadIdx.x%2; @@ -757,10 +955,6 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float uint16_t aux[4]; const uint8_t * sc = (const uint8_t *)aux; - const block_q5_K * x = (const block_q5_K *)vx + ib0; - - float tmp = 0; // partial sum for thread in warp - for (int i = ix; i < num_blocks_per_row; i += 2) { const uint8_t * ql1 = x[i].qs + q_offset; @@ -793,8 +987,31 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7]; } tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin; + } +#else + const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 + const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); + const int step = tid * K_QUANTS_PER_ITERATION; + const int im = step/8; + const int in = step%8; + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + const uint8_t * q = x[i].qs + step; + const int8_t * s = x[i].scales; + const float * y = yy + i*QK_K + step; + const float d = x[i].d; + float sum = 0.f; + for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + const uint8_t h = x[i].qh[in+j] >> im; + sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16)) + + y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16)) + + y[j+32] * d * s[2] * ((q[j+ 0] >> 4) - ((h >> 4) & 1 ? 0 : 16)) + + y[j+48] * d * s[3] * ((q[j+16] >> 4) - ((h >> 6) & 1 ? 0 : 16)); + } + tmp += sum; } +#endif // sum up partial sums and write back result __syncthreads(); @@ -803,7 +1020,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); } - if (tid == 0) { + if (threadIdx.x == 0) { dst[row] = tmp; } } @@ -820,6 +1037,8 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float const block_q6_K * x = (const block_q6_K *)vx + ib0; +#if QK_K == 256 + const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 @@ -874,6 +1093,37 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float } +#else + + const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...7 + const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0...3 + + const int step = tid * K_QUANTS_PER_ITERATION; + + float tmp = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + step; + const uint8_t * ql = x[i].ql + step; + const uint8_t * qh = x[i].qh + step; + const int8_t * s = x[i].scales; + + const float d = x[i+0].d; + + float sum = 0; + for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32) + + y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32) + + y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32) + + y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >> 4) | ((qh[j] & 0xc0) >> 2)) - 32); + } + tmp += sum; + + } + +#endif + // sum up partial sums and write back result __syncthreads(); #pragma unroll @@ -1252,12 +1502,20 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; +#if QK_K == 256 dequantize_block_q2_K<<>>(vx, y); +#else + dequantize_block_q2_K<<>>(vx, y); +#endif } static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; +#if QK_K == 256 dequantize_block_q3_K<<>>(vx, y); +#else + dequantize_block_q3_K<<>>(vx, y); +#endif } static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { @@ -1267,12 +1525,20 @@ static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cu static void dequantize_row_q5_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; +#if QK_K == 256 dequantize_block_q5_K<<>>(vx, y); +#else + dequantize_block_q5_K<<>>(vx, y); +#endif } static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; +#if QK_K == 256 dequantize_block_q6_K<<>>(vx, y); +#else + dequantize_block_q6_K<<>>(vx, y); +#endif } static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { @@ -2553,6 +2819,7 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) { tensor->backend = GGML_BACKEND_GPU; struct ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu; + memset(extra, 0, sizeof(*extra)); const bool inplace = (tensor->src0 != nullptr && tensor->src0->data == tensor->data) || tensor->op == GGML_OP_VIEW; diff --git a/src/ggml-metal.m b/src/ggml-metal.m index a7e104dc7..7551231b9 100644 --- a/src/ggml-metal.m +++ b/src/ggml-metal.m @@ -51,21 +51,21 @@ GGML_METAL_DECL_KERNEL(get_rows_f16); GGML_METAL_DECL_KERNEL(get_rows_q4_0); GGML_METAL_DECL_KERNEL(get_rows_q4_1); - GGML_METAL_DECL_KERNEL(get_rows_q2_k); - GGML_METAL_DECL_KERNEL(get_rows_q3_k); - GGML_METAL_DECL_KERNEL(get_rows_q4_k); - GGML_METAL_DECL_KERNEL(get_rows_q5_k); - GGML_METAL_DECL_KERNEL(get_rows_q6_k); + GGML_METAL_DECL_KERNEL(get_rows_q2_K); + GGML_METAL_DECL_KERNEL(get_rows_q3_K); + GGML_METAL_DECL_KERNEL(get_rows_q4_K); + GGML_METAL_DECL_KERNEL(get_rows_q5_K); + GGML_METAL_DECL_KERNEL(get_rows_q6_K); GGML_METAL_DECL_KERNEL(rms_norm); GGML_METAL_DECL_KERNEL(norm); GGML_METAL_DECL_KERNEL(mul_mat_f16_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q2_k_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q3_k_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q5_k_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32); GGML_METAL_DECL_KERNEL(rope); GGML_METAL_DECL_KERNEL(alibi_f32); GGML_METAL_DECL_KERNEL(cpy_f32_f16); @@ -132,7 +132,13 @@ @implementation GGMLMetalClass exit(1); } +#ifdef GGML_QKK_64 + MTLCompileOptions* options = [MTLCompileOptions new]; + options.preprocessorMacros = @{ @"QK_K" : @(64) }; + ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error]; +#else ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error]; +#endif if (error) { fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]); exit(1); @@ -159,21 +165,21 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(get_rows_f16); GGML_METAL_ADD_KERNEL(get_rows_q4_0); GGML_METAL_ADD_KERNEL(get_rows_q4_1); - GGML_METAL_ADD_KERNEL(get_rows_q2_k); - GGML_METAL_ADD_KERNEL(get_rows_q3_k); - GGML_METAL_ADD_KERNEL(get_rows_q4_k); - GGML_METAL_ADD_KERNEL(get_rows_q5_k); - GGML_METAL_ADD_KERNEL(get_rows_q6_k); + GGML_METAL_ADD_KERNEL(get_rows_q2_K); + GGML_METAL_ADD_KERNEL(get_rows_q3_K); + GGML_METAL_ADD_KERNEL(get_rows_q4_K); + GGML_METAL_ADD_KERNEL(get_rows_q5_K); + GGML_METAL_ADD_KERNEL(get_rows_q6_K); GGML_METAL_ADD_KERNEL(rms_norm); GGML_METAL_ADD_KERNEL(norm); GGML_METAL_ADD_KERNEL(mul_mat_f16_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q2_k_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q3_k_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q5_k_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32); GGML_METAL_ADD_KERNEL(rope); GGML_METAL_ADD_KERNEL(alibi_f32); GGML_METAL_ADD_KERNEL(cpy_f32_f16); @@ -662,7 +668,7 @@ void ggml_metal_graph_compute( nth0 = 4; nth1 = 16; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_k_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32]; } break; case GGML_TYPE_Q3_K: { @@ -671,7 +677,7 @@ void ggml_metal_graph_compute( nth0 = 4; nth1 = 16; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_k_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32]; } break; case GGML_TYPE_Q4_K: { @@ -680,7 +686,7 @@ void ggml_metal_graph_compute( nth0 = 4; nth1 = 16; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32]; } break; case GGML_TYPE_Q5_K: { @@ -689,7 +695,7 @@ void ggml_metal_graph_compute( nth0 = 4; nth1 = 16; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_k_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32]; } break; case GGML_TYPE_Q6_K: { @@ -698,7 +704,7 @@ void ggml_metal_graph_compute( nth0 = 4; nth1 = 16; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_k_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32]; } break; default: { @@ -750,11 +756,11 @@ void ggml_metal_graph_compute( case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break; - case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break; - case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_k]; break; - case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break; - case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_k]; break; - case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break; + case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break; + case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break; + case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break; + case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break; + case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break; default: GGML_ASSERT(false && "not implemented"); } diff --git a/src/ggml-metal.metal b/src/ggml-metal.metal index d1e49222d..e62fe6842 100644 --- a/src/ggml-metal.metal +++ b/src/ggml-metal.metal @@ -428,7 +428,7 @@ kernel void kernel_mul_mat_q4_0_f32( } threadgroup_barrier(mem_flags::mem_threadgroup); if (ith == 0) { - for (uint i = 16; i < nth; i += 16) sum[0] += sum[i]; + for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; dst[r1*ne0 + r0] = sum[0]; } } @@ -497,7 +497,7 @@ kernel void kernel_mul_mat_q4_1_f32( } threadgroup_barrier(mem_flags::mem_threadgroup); if (ith == 0) { - for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; + for (uint i = 16; i < nth; i += 16) sum[0] += sum[i]; dst[r1*ne0 + r0] = sum[0]; } } @@ -775,47 +775,76 @@ kernel void kernel_cpy_f32_f32( //============================================ k-quants ====================================================== +#ifndef QK_K #define QK_K 256 +#else +static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64"); +#endif + +#if QK_K == 256 +#define K_SCALE_SIZE 12 +#else +#define K_SCALE_SIZE 4 +#endif typedef struct { uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits uint8_t qs[QK_K/4]; // quants half d; // super-block scale for quantized scales half dmin; // super-block scale for quantized mins -} block_q2_k; +} block_q2_K; // 84 bytes / block typedef struct { uint8_t hmask[QK_K/8]; // quants - high bit uint8_t qs[QK_K/4]; // quants - low 2 bits - uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits - half d; // super-block scale -} block_q3_k; -// 110 bytes / block - +#if QK_K == 64 + uint8_t scales[2]; +#else + uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits +#endif + half d; // super-block scale +} block_q3_K; + +#if QK_K == 64 +typedef struct { + half d[2]; // super-block scales/mins + uint8_t scales[2]; + uint8_t qs[QK_K/2]; // 4-bit quants +} block_q4_K; +#else typedef struct { half d; // super-block scale for quantized scales half dmin; // super-block scale for quantized mins - uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits + uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits uint8_t qs[QK_K/2]; // 4--bit quants -} block_q4_k; -// 144 bytes / block +} block_q4_K; +#endif +#if QK_K == 64 +typedef struct { + half d; // super-block scales/mins + int8_t scales[QK_K/16]; // 8-bit block scales + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_K; +#else typedef struct { half d; // super-block scale for quantized scales half dmin; // super-block scale for quantized mins uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits uint8_t qh[QK_K/8]; // quants, high bit uint8_t qs[QK_K/2]; // quants, low 4 bits -} block_q5_k; +} block_q5_K; // 176 bytes / block +#endif typedef struct { uint8_t ql[QK_K/2]; // quants, lower 4 bits uint8_t qh[QK_K/4]; // quants, upper 2 bits int8_t scales[QK_K/16]; // scales, quantized with 8 bits half d; // super-block scale -} block_q6_k; +} block_q6_K; // 210 bytes / block static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) { @@ -836,7 +865,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) { //========================================== dequantization ============================= -static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, int k) { +static void dequantize_row_q2_K(device const block_q2_K * x, device float * y, int k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -847,6 +876,7 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i device const uint8_t * q = x[i].qs; +#if QK_K == 256 int is = 0; float dl, ml; for (int n = 0; n < QK_K; n += 128) { @@ -865,14 +895,29 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i } q += 32; } +#else + float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4); + float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4); + float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4); + float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4); + for (int l = 0; l < 16; ++l) { + y[l+ 0] = dl1 * ((q[l] >> 0) & 3) - ml1; + y[l+16] = dl2 * ((q[l] >> 2) & 3) - ml2; + y[l+32] = dl3 * ((q[l] >> 4) & 3) - ml3; + y[l+48] = dl4 * ((q[l] >> 6) & 3) - ml4; + } + y += QK_K; +#endif } } -static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, int k) { +static void dequantize_row_q3_K(device const block_q3_K * x, device float * y, int k) { assert(k % QK_K == 0); const int nb = k / QK_K; +#if QK_K == 256 + const uint16_t kmask1 = 0x0303; const uint16_t kmask2 = 0x0f0f; @@ -918,22 +963,49 @@ static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, i } q += 32; } + } +#else + for (int i = 0; i < nb; i++) { + const float d_all = (float)(x[i].d); + + device const uint8_t * q = x[i].qs; + device const uint8_t * hm = x[i].hmask; + + const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8); + const float d2 = d_all * ((x[i].scales[0] >> 4) - 8); + const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8); + const float d4 = d_all * ((x[i].scales[1] >> 4) - 8); + + for (int l = 0; l < 8; ++l) { + uint8_t h = hm[l]; + y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4)); + y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4)); + y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4)); + y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4)); + y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4)); + y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4)); + y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4)); + y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4)); + } + y += QK_K; } +#endif } -static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) { +static void dequantize_row_q4_K(device const block_q4_K * x, device float * y, int k) { assert(k % QK_K == 0); const int nb = k / QK_K; - for (int i = 0; i < nb; i++) { + device const uint8_t * q = x[i].qs; + +#if QK_K == 256 const float d = x[i].d; const float min = x[i].dmin; - device const uint8_t * q = x[i].qs; device const uint8_t * scales = x[i].scales; int is = 0; @@ -945,14 +1017,29 @@ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, i for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2; q += 32; is += 2; } +#else + device const uint8_t * s = x[i].scales; + device const half2 * dh = (device const half2 *)x[i].d; + const float2 d = (float2)dh[0]; + const float d1 = d[0] * (s[0] & 0xF); + const float d2 = d[0] * (s[1] & 0xF); + const float m1 = d[1] * (s[0] >> 4); + const float m2 = d[1] * (s[1] >> 4); + for (int l = 0; l < 32; ++l) { + y[l+ 0] = d1 * (q[l] & 0xF) - m1; + y[l+32] = d2 * (q[l] >> 4) - m2; + } + y += QK_K; +#endif } } -static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, int k) { +static void dequantize_row_q5_K(device const block_q5_K * x, device float * y, int k) { assert(k % QK_K == 0); const int nb = k / QK_K; +#if QK_K == 256 for (int i = 0; i < nb; i++) { const float d = (float)(x[i].d); @@ -973,10 +1060,32 @@ static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, i u1 <<= 2; u2 <<= 2; } } +#else + for (int i = 0; i < nb; i++) { + + const float d = (float)x[i].d; + + device const uint8_t * ql = x[i].qs; + device const uint8_t * qh = x[i].qh; + device const int8_t * sc = x[i].scales; + + for (int l = 0; l < 8; ++l) { + y[l+ 0] = d * sc[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16)); + y[l+ 8] = d * sc[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16)); + y[l+16] = d * sc[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16)); + y[l+24] = d * sc[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16)); + y[l+32] = d * sc[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16)); + y[l+40] = d * sc[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16)); + y[l+48] = d * sc[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16)); + y[l+56] = d * sc[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16)); + } + y += QK_K; + } +#endif } -static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, int k) { +static void dequantize_row_q6_K(device const block_q6_K * x, device float * y, int k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -988,6 +1097,7 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i const float d = x[i].d; +#if QK_K == 256 for (int n = 0; n < QK_K; n += 128) { for (int l = 0; l < 32; ++l) { int is = l/16; @@ -1005,10 +1115,23 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i qh += 32; sc += 8; } +#else + for (int l = 0; l < 16; ++l) { + const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + y[l+ 0] = d * sc[0] * q1; + y[l+16] = d * sc[1] * q2; + y[l+32] = d * sc[2] * q3; + y[l+48] = d * sc[3] * q4; + } + y += 64; +#endif } } -kernel void kernel_get_rows_q2_k( +kernel void kernel_get_rows_q2_K( device const void * src0, device const int * src1, device float * dst, @@ -1019,12 +1142,12 @@ kernel void kernel_get_rows_q2_k( const int i = tpig; const int r = ((device int32_t *) src1)[i]; - dequantize_row_q2_k( - (device const block_q2_k *) ((device char *) src0 + r*nb01), + dequantize_row_q2_K( + (device const block_q2_K *) ((device char *) src0 + r*nb01), (device float *) ((device char *) dst + i*nb1), ne00); } -kernel void kernel_get_rows_q3_k( +kernel void kernel_get_rows_q3_K( device const void * src0, device const int * src1, device float * dst, @@ -1035,12 +1158,12 @@ kernel void kernel_get_rows_q3_k( const int i = tpig; const int r = ((device int32_t *) src1)[i]; - dequantize_row_q3_k( - (device const block_q3_k *) ((device char *) src0 + r*nb01), + dequantize_row_q3_K( + (device const block_q3_K *) ((device char *) src0 + r*nb01), (device float *) ((device char *) dst + i*nb1), ne00); } -kernel void kernel_get_rows_q4_k( +kernel void kernel_get_rows_q4_K( device const void * src0, device const int * src1, device float * dst, @@ -1051,12 +1174,12 @@ kernel void kernel_get_rows_q4_k( const int i = tpig; const int r = ((device int32_t *) src1)[i]; - dequantize_row_q4_k( - (device const block_q4_k *) ((device char *) src0 + r*nb01), + dequantize_row_q4_K( + (device const block_q4_K *) ((device char *) src0 + r*nb01), (device float *) ((device char *) dst + i*nb1), ne00); } -kernel void kernel_get_rows_q5_k( +kernel void kernel_get_rows_q5_K( device const void * src0, device const int * src1, device float * dst, @@ -1067,12 +1190,12 @@ kernel void kernel_get_rows_q5_k( const int i = tpig; const int r = ((device int32_t *) src1)[i]; - dequantize_row_q5_k( - (device const block_q5_k *) ((device char *) src0 + r*nb01), + dequantize_row_q5_K( + (device const block_q5_K *) ((device char *) src0 + r*nb01), (device float *) ((device char *) dst + i*nb1), ne00); } -kernel void kernel_get_rows_q6_k( +kernel void kernel_get_rows_q6_K( device const void * src0, device const int * src1, device float * dst, @@ -1083,14 +1206,14 @@ kernel void kernel_get_rows_q6_k( const int i = tpig; const int r = ((device int32_t *) src1)[i]; - dequantize_row_q6_k( - (device const block_q6_k *) ((device char *) src0 + r*nb01), + dequantize_row_q6_K( + (device const block_q6_K *) ((device char *) src0 + r*nb01), (device float *) ((device char *) dst + i*nb1), ne00); } //====================================== dot products ========================= -kernel void kernel_mul_mat_q2_k_f32( +kernel void kernel_mul_mat_q2_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -1107,12 +1230,15 @@ kernel void kernel_mul_mat_q2_k_f32( const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; - device const block_q2_k * x = (device const block_q2_k *) src0 + r0*nb; + device const block_q2_K * x = (device const block_q2_K *) src0 + r0*nb; device const float * yy = (device const float *) src1 + r1*ne10; const int nth = tptg.x*tptg.y; const int ith = tptg.y*tpitg.x + tpitg.y; + float sumf = 0; + +#if QK_K == 256 const int tid = tpitg.y; // 0...16 const int il = tid/4; // 0...3 const int ir = tid%4; // 0...3 @@ -1125,9 +1251,6 @@ kernel void kernel_mul_mat_q2_k_f32( const int y_offset = 64*il + n*ir; const int q_offset = 32*ip + n*ir; - sum[ith] = 0.0f; - - float sumf = 0; for (int i = tpitg.x; i < nb; i += tptg.x) { device const uint8_t * q = x[i].qs + q_offset; @@ -1140,7 +1263,6 @@ kernel void kernel_mul_mat_q2_k_f32( device const float * y = yy + i*QK_K + y_offset; - //float4 s = {0.f, 0.f, 0.f, 0.f}; float2 s = {0.f, 0.f}; float smin = 0; for (int l = 0; l < n; ++l) { @@ -1155,25 +1277,38 @@ kernel void kernel_mul_mat_q2_k_f32( sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin; } - sum[ith] = sumf; +#else + const int il = 4 * tpitg.x; - //int mask1 = (ith%4 == 0); - //int mask2 = (ith%16 == 0); + uint32_t aux[2]; + thread const uint8_t * d = (thread const uint8_t *)aux; + thread const uint8_t * m = (thread const uint8_t *)aux + 4; - //threadgroup_barrier(mem_flags::mem_threadgroup); - //for (int i = 1; i < 4; ++i) sum[ith] += mask1 * sum[ith + i]; - //threadgroup_barrier(mem_flags::mem_threadgroup); - //for (int i = 4; i < 16; i += 4) sum[ith] += mask2 * sum[ith + i]; - //threadgroup_barrier(mem_flags::mem_threadgroup); - //if (ith == 0) { - // for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; - // dst[r1*ne0 + r0] = sum[0]; - //} + for (int i = tpitg.y; i < nb; i += tptg.y) { + + device const uint8_t * q = x[i].qs + il; + device const float * y = yy + i*QK_K + il; + + const float dall = (float)x[i].d; + const float dmin = (float)x[i].dmin; + + device const uint32_t * a = (device const uint32_t *)x[i].scales; + aux[0] = a[0] & 0x0f0f0f0f; + aux[1] = (a[0] >> 4) & 0x0f0f0f0f; + + for (int l = 0; l < 4; ++l) { + sumf += y[l+ 0] * (dall * d[0] * ((q[l] >> 0) & 3) - dmin * m[0]) + + y[l+16] * (dall * d[1] * ((q[l] >> 2) & 3) - dmin * m[1]) + + y[l+32] * (dall * d[2] * ((q[l] >> 4) & 3) - dmin * m[2]) + + y[l+48] * (dall * d[3] * ((q[l] >> 6) & 3) - dmin * m[3]); + } + } +#endif + + sum[ith] = sumf; // // Accumulate the sum from all threads in the threadgroup - // This version is slightly faster than the commented out one below, - // which I copy-pasted from ggerganov's q4_0 dot product for metal. // threadgroup_barrier(mem_flags::mem_threadgroup); if (ith%4 == 0) { @@ -1190,7 +1325,7 @@ kernel void kernel_mul_mat_q2_k_f32( } } -kernel void kernel_mul_mat_q3_k_f32( +kernel void kernel_mul_mat_q3_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -1203,23 +1338,25 @@ kernel void kernel_mul_mat_q3_k_f32( uint2 tpitg[[thread_position_in_threadgroup]], uint2 tptg[[threads_per_threadgroup]]) { - const uint16_t kmask1 = 0x0303; - const uint16_t kmask2 = 0x0f0f; - - const uint8_t m3 = 3; - const int8_t m4 = 4; - const int nb = ne00/QK_K; const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; - device const block_q3_k * x = (device const block_q3_k *) src0 + r0*nb; + device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb; device const float * yy = (device const float *) src1 + r1*ne10; const int nth = tptg.x*tptg.y; const int ith = tptg.y*tpitg.x + tpitg.y; +#if QK_K == 256 + + const uint8_t m3 = 3; + const int8_t m4 = 4; + + const uint16_t kmask1 = 0x0303; + const uint16_t kmask2 = 0x0f0f; + const int tid = tpitg.y; // expecting 16 const int ip = tid/8; // 0 or 1 const int il = tid/2 - 4*ip; // 0...3 @@ -1273,6 +1410,39 @@ kernel void kernel_mul_mat_q3_k_f32( //sum[ith] = sumf; sum[ith] = sumf1 - 32.f*sumf2; +#else + const int il = 4 * tpitg.x; // 0, 4, 8, 12 + const int im = il/8; // 0, 0, 1, 1 + const int in = il%8; // 0, 4, 0, 4 + + float sumf = 0; + + for (int i = tpitg.y; i < nb; i += tptg.y) { + + const float d_all = (float)(x[i].d); + + device const uint8_t * q = x[i].qs + il; + device const uint8_t * h = x[i].hmask + in; + device const float * y = yy + i * QK_K + il; + + const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8); + const float d2 = d_all * ((x[i].scales[0] >> 4) - 8); + const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8); + const float d4 = d_all * ((x[i].scales[1] >> 4) - 8); + + for (int l = 0; l < 4; ++l) { + const uint8_t hm = h[l] >> im; + sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4)) + + y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4)) + + y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4)) + + y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4)); + } + + } + + sum[ith] = sumf; + +#endif // // Accumulate the sum from all threads in the threadgroup @@ -1293,7 +1463,7 @@ kernel void kernel_mul_mat_q3_k_f32( } -kernel void kernel_mul_mat_q4_k_f32( +kernel void kernel_mul_mat_q4_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -1305,21 +1475,25 @@ kernel void kernel_mul_mat_q4_k_f32( uint2 tpitg[[thread_position_in_threadgroup]], uint2 tptg[[threads_per_threadgroup]]) { - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; - const int nb = ne00/QK_K; const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; - device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb; - device const float * yy = (device const float *) src1 + r1*ne10; - const int nth = tptg.x*tptg.y; const int ith = tptg.y*tpitg.x + tpitg.y; + device const block_q4_K * x = (device const block_q4_K *) src0 + r0*nb; + device const float * yy = (device const float *) src1 + r1*ne10; + + float sumf = 0; + +#if QK_K == 256 + + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + const int tid = tpitg.y; // 0...16 const int il = tid/4; // 0...3 const int ir = tid - 4*il;// 0...3 @@ -1332,11 +1506,8 @@ kernel void kernel_mul_mat_q4_k_f32( const int q_offset = 32*im + l0; const int y_offset = 64*im + l0; - sum[ith] = 0.0f; - uchar2 sc1, sc2, sc3, sc4; - float sumf = 0; for (int i = tpitg.x; i < nb; i += tptg.x) { device const uint8_t * q1 = (x + i)->qs + q_offset; @@ -1365,6 +1536,30 @@ kernel void kernel_mul_mat_q4_k_f32( sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin; } +#else + uint16_t aux16[2]; + thread const uint8_t * scales = (thread const uint8_t *)aux16; + + const int il = 4*tpitg.x; + + for (int i = tpitg.y; i < nb; i += tptg.y) { + + device const uint8_t * q = x[i].qs + il; + device const float * y = yy + i * QK_K + il; + + const float d = (float)x[i].d[0]; + const float m = (float)x[i].d[1]; + + device const uint16_t * a = (device const uint16_t *)x[i].scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + + for (int l = 0; l < 4; ++l) { + sumf += d * scales[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - m * scales[2] * (y[l+ 0] + y[l+16]) + + d * scales[1] * (y[l+32] * (q[l] >> 4) + y[l+48] * (q[l+16] >> 4)) - m * scales[3] * (y[l+32] + y[l+48]); + } + } +#endif sum[ith] = sumf; @@ -1401,7 +1596,7 @@ kernel void kernel_mul_mat_q4_k_f32( //} } -kernel void kernel_mul_mat_q5_k_f32( +kernel void kernel_mul_mat_q5_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -1413,21 +1608,25 @@ kernel void kernel_mul_mat_q5_k_f32( uint2 tpitg[[thread_position_in_threadgroup]], uint2 tptg[[threads_per_threadgroup]]) { - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; - const int nb = ne00/QK_K; const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; - device const block_q5_k * x = (device const block_q5_k *) src0 + r0*nb; + device const block_q5_K * x = (device const block_q5_K *) src0 + r0*nb; device const float * yy = (device const float *) src1 + r1*ne10; const int nth = tptg.x*tptg.y; const int ith = tptg.y*tpitg.x + tpitg.y; + float sumf = 0; + +#if QK_K == 256 + + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + const int tid = tpitg.y; // 0...16 const int il = tid/4; // 0...3 const int ir = tid - 4*il;// 0...3 @@ -1447,7 +1646,6 @@ kernel void kernel_mul_mat_q5_k_f32( uchar2 sc1, sc2, sc3, sc4; - float sumf = 0; for (int i = tpitg.x; i < nb; i += tptg.x) { device const uint8_t * q1 = (x + i)->qs + q_offset; @@ -1479,6 +1677,28 @@ kernel void kernel_mul_mat_q5_k_f32( sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin; } +#else + const int il = 4 * tpitg.x; // 0, 4, 8, 12 + const int im = il/8; // 0, 0, 1, 1 + const int in = il%8; // 0, 4, 0, 4 + + for (int i = tpitg.y; i < nb; i += tptg.y) { + + const float d = (float)x[i].d; + device const uint8_t * q = x[i].qs + il; + device const uint8_t * h = x[i].qh + in; + device const int8_t * s = x[i].scales; + device const float * y = yy + i*QK_K + il; + + for (int l = 0; l < 4; ++l) { + const uint8_t hl = h[l] >> im; + sumf += y[l+ 0] * d * s[0] * ((q[l+ 0] & 0xF) - (hl & 0x01 ? 0 : 16)) + + y[l+16] * d * s[1] * ((q[l+16] & 0xF) - (hl & 0x04 ? 0 : 16)) + + y[l+32] * d * s[2] * ((q[l+ 0] >> 4) - (hl & 0x10 ? 0 : 16)) + + y[l+48] * d * s[3] * ((q[l+16] >> 4) - (hl & 0x40 ? 0 : 16)); + } + } +#endif sum[ith] = sumf; // @@ -1500,7 +1720,7 @@ kernel void kernel_mul_mat_q5_k_f32( } -kernel void kernel_mul_mat_q6_k_f32( +kernel void kernel_mul_mat_q6_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -1522,12 +1742,15 @@ kernel void kernel_mul_mat_q6_k_f32( const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; - device const block_q6_k * x = (device const block_q6_k *) src0 + r0*nb; + device const block_q6_K * x = (device const block_q6_K *) src0 + r0*nb; device const float * yy = (device const float *) src1 + r1*ne10; const int nth = tptg.x*tptg.y; const int ith = tptg.y*tpitg.x + tpitg.y; + float sumf = 0; + +#if QK_K == 256 // Note: we absolutely assume that tptg.y = 16 and QK_K = 256! const int iqs = 16 * tpitg.y; const int ip = iqs / 128; // 0 or 1 @@ -1540,7 +1763,6 @@ kernel void kernel_mul_mat_q6_k_f32( const int q_offset_l = 64*ip + l0; const int q_offset_h = 32*ip + l0; - float sumf = 0; for (int i = tpitg.x; i < nb; i += tptg.x) { device const uint8_t * ql = x[i].ql + q_offset_l; @@ -1562,6 +1784,28 @@ kernel void kernel_mul_mat_q6_k_f32( sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); } +#else + const int il = 4*tpitg.x; // 0, 4, 8, 12 + + for (int i = tpitg.y; i < nb; i += tptg.y) { + device const float * y = yy + i * QK_K + il; + device const uint8_t * ql = x[i].ql + il; + device const uint8_t * qh = x[i].qh + il; + device const int8_t * s = x[i].scales; + + const float d = x[i].d; + + float4 sums = {0.f, 0.f, 0.f, 0.f}; + for (int l = 0; l < 4; ++l) { + sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); + sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); + sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32); + sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); + } + sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]); + } + +#endif sum[ith] = sumf; diff --git a/src/ggml.c b/src/ggml.c index 1a441eb98..c179bee93 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -1,3 +1,4 @@ +#define _GNU_SOURCE // Defines CLOCK_MONOTONIC on Linux #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows #include "ggml.h" @@ -90,6 +91,11 @@ static int sched_yield (void) { #include typedef void* thread_ret_t; + +#include +#include +#include + #endif // __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512 @@ -118,6 +124,30 @@ typedef void* thread_ret_t; #define GGML_SOFT_MAX_UNROLL 4 #define GGML_VEC_DOT_UNROLL 2 +// +// logging +// + +#if (GGML_DEBUG >= 1) +#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__) +#else +#define GGML_PRINT_DEBUG(...) +#endif + +#if (GGML_DEBUG >= 5) +#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__) +#else +#define GGML_PRINT_DEBUG_5(...) +#endif + +#if (GGML_DEBUG >= 10) +#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__) +#else +#define GGML_PRINT_DEBUG_10(...) +#endif + +#define GGML_PRINT(...) printf(__VA_ARGS__) + #ifdef GGML_USE_ACCELERATE // uncomment to use vDSP for soft max computation // note: not sure if it is actually faster @@ -458,7 +488,6 @@ void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, size_t n) { } } - // // timing // @@ -521,6 +550,7 @@ int64_t ggml_cycles_per_ms(void) { #define ggml_perf_cycles_per_ms() 0 #endif + // // cache line // @@ -3842,12 +3872,31 @@ struct ggml_context_container { struct ggml_context context; }; +// +// NUMA support +// + +#define GGML_NUMA_MAX_NODES 8 +#define GGML_NUMA_MAX_CPUS 512 + +struct ggml_numa_node { + uint32_t cpus[GGML_NUMA_MAX_CPUS]; // hardware threads on this node + uint32_t n_cpus; +}; + +struct ggml_numa_nodes { + struct ggml_numa_node nodes[GGML_NUMA_MAX_NODES]; + uint32_t n_nodes; + uint32_t total_cpus; // hardware threads on system +}; + // // ggml state // struct ggml_state { struct ggml_context_container contexts[GGML_MAX_CONTEXTS]; + struct ggml_numa_nodes numa; }; // global state @@ -3872,6 +3921,75 @@ inline static void ggml_critical_section_end(void) { atomic_fetch_sub(&g_state_barrier, 1); } +void ggml_numa_init(void) { + if (g_state.numa.n_nodes > 0) { + fprintf(stderr, "ggml_numa_init: NUMA already initialized\n"); + + return; + } + +#ifdef __linux__ + struct stat st; + char path[256]; + int rv; + + // enumerate nodes + while (g_state.numa.n_nodes < GGML_NUMA_MAX_NODES) { + rv = snprintf(path, sizeof(path), "/sys/devices/system/node/node%u", g_state.numa.n_nodes); + GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path)); + if (stat(path, &st) != 0) { break; } + ++g_state.numa.n_nodes; + } + + // enumerate CPUs + while (g_state.numa.total_cpus < GGML_NUMA_MAX_CPUS) { + rv = snprintf(path, sizeof(path), "/sys/devices/system/cpu/cpu%u", g_state.numa.total_cpus); + GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path)); + if (stat(path, &st) != 0) { break; } + ++g_state.numa.total_cpus; + } + + GGML_PRINT_DEBUG("found %u numa nodes, %u CPUs\n", g_state.numa.n_nodes, g_state.numa.total_cpus); + + if (g_state.numa.n_nodes < 1 || g_state.numa.total_cpus < 1) { + g_state.numa.n_nodes = 0; + return; + } + + for (uint32_t n = 0; n < g_state.numa.n_nodes; ++n) { + struct ggml_numa_node * node = &g_state.numa.nodes[n]; + GGML_PRINT_DEBUG("CPUs on node %u:", n); + node->n_cpus = 0; + for (uint32_t c = 0; c < g_state.numa.total_cpus; ++c) { + rv = snprintf(path, sizeof(path), "/sys/devices/system/node/node%u/cpu%u", n, c); + GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path)); + if (stat(path, &st) == 0) { + node->cpus[node->n_cpus++] = c; + GGML_PRINT_DEBUG(" %u", c); + } + } + GGML_PRINT_DEBUG("\n"); + } + + if (ggml_is_numa()) { + FILE *fptr = fopen("/proc/sys/kernel/numa_balancing", "r"); + if (fptr != NULL) { + char buf[42]; + if (fgets(buf, sizeof(buf), fptr) && strncmp(buf, "0\n", sizeof(buf)) != 0) { + GGML_PRINT("WARNING: /proc/sys/kernel/numa_balancing is enabled, this has been observed to impair performance\n"); + } + fclose(fptr); + } + } +#else + // TODO +#endif +} + +bool ggml_is_numa(void) { + return g_state.numa.n_nodes > 1; +} + //////////////////////////////////////////////////////////////////////////////// void ggml_print_object(const struct ggml_object * obj) { @@ -4128,6 +4246,10 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { g_state = (struct ggml_state) { /*.contexts =*/ { { 0 } }, + /*.numa =*/ { + .n_nodes = 0, + .total_cpus = 0, + }, }; for (int i = 0; i < GGML_MAX_CONTEXTS; ++i) { @@ -16502,68 +16624,172 @@ typedef pthread_t ggml_thread_t; #endif +#ifdef __linux__ +void set_numa_thread_affinity(int thread_n, int n_threads) { + if (!ggml_is_numa()) { + return; + } + + // run thread on node_num thread_n / (threads per node) + const int node_num = thread_n / ((n_threads + g_state.numa.n_nodes - 1) / g_state.numa.n_nodes); + struct ggml_numa_node * node = &g_state.numa.nodes[node_num]; + size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus); + + cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus); + CPU_ZERO_S(setsize, cpus); + for (size_t i = 0; i < node->n_cpus; ++i) { + CPU_SET_S(node->cpus[i], setsize, cpus); + } + + int rv = pthread_setaffinity_np(pthread_self(), setsize, cpus); + if (rv) { + fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n", + strerror(rv)); + } + + CPU_FREE(cpus); +} + +void clear_numa_thread_affinity(void) { + if (!ggml_is_numa()) { + return; + } + + size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus); + + cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus); + CPU_ZERO_S(setsize, cpus); + for (unsigned i = 0; i < g_state.numa.total_cpus; ++i) { + CPU_SET_S(i, setsize, cpus); + } + + int rv = pthread_setaffinity_np(pthread_self(), setsize, cpus); + if (rv) { + fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n", + strerror(rv)); + } + + CPU_FREE(cpus); +} +#else +// TODO: Windows etc. +// (the linux implementation may also work on BSD, someone should test) +void set_numa_thread_affinity(int thread_n, int n_threads) { UNUSED(thread_n); UNUSED(n_threads); } +void clear_numa_thread_affinity(void) {} +#endif + struct ggml_compute_state_shared { - ggml_lock_t spin; + struct ggml_cgraph * cgraph; + + int64_t perf_node_start_cycles; + int64_t perf_node_start_time_us; int n_threads; // synchronization primitives - atomic_int n_ready; - atomic_bool has_work; - atomic_bool stop; // stop all threads + atomic_int n_active; // num active threads + atomic_int node_n; // active graph node }; struct ggml_compute_state { ggml_thread_t thrd; - - struct ggml_compute_params params; - struct ggml_tensor * node; - + int ith; struct ggml_compute_state_shared * shared; }; +static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const struct ggml_compute_state_shared * st) { + int64_t cycles_cur = ggml_perf_cycles() - st->perf_node_start_cycles; + int64_t time_us_cur = ggml_perf_time_us() - st->perf_node_start_time_us; + + node->perf_runs++; + node->perf_cycles += cycles_cur; + node->perf_time_us += time_us_cur; +} + static thread_ret_t ggml_graph_compute_thread(void * data) { struct ggml_compute_state * state = (struct ggml_compute_state *) data; + struct ggml_cgraph * cgraph = state->shared->cgraph; const int n_threads = state->shared->n_threads; + set_numa_thread_affinity(state->ith, n_threads); + + int node_n = -1; while (true) { - if (atomic_fetch_add(&state->shared->n_ready, 1) == n_threads - 1) { - atomic_store(&state->shared->has_work, false); - } else { - while (atomic_load(&state->shared->has_work)) { - if (atomic_load(&state->shared->stop)) { - return 0; - } - ggml_lock_lock (&state->shared->spin); - ggml_lock_unlock(&state->shared->spin); + if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) { + // all other threads are finished and spinning + // do finalize and init here so we don't have synchronize again + struct ggml_compute_params params = { + /*.type =*/ GGML_TASK_FINALIZE, + /*.ith =*/ 0, + /*.nth =*/ 0, + /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0, + /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL, + }; + + if (node_n != -1) { + /* FINALIZE */ + struct ggml_tensor * node = state->shared->cgraph->nodes[node_n]; + params.nth = node->n_tasks; + ggml_compute_forward(¶ms, node); + ggml_graph_compute_perf_stats_node(node, state->shared); } - } - atomic_fetch_sub(&state->shared->n_ready, 1); + // distribute new work or execute it direct if 1T + while (++node_n < cgraph->n_nodes) { + GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, node_n, cgraph->n_nodes); + + struct ggml_tensor * node = cgraph->nodes[node_n]; + + state->shared->perf_node_start_cycles = ggml_perf_cycles(); + state->shared->perf_node_start_time_us = ggml_perf_time_us(); + + /* INIT */ + params.type = GGML_TASK_INIT; + params.nth = node->n_tasks; + ggml_compute_forward(¶ms, node); - // wait for work - while (!atomic_load(&state->shared->has_work)) { - if (atomic_load(&state->shared->stop)) { - return 0; + if (node->n_tasks == 1) { + // TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1, + // they do something more efficient than spinning (?) + params.type = GGML_TASK_COMPUTE; + ggml_compute_forward(¶ms, node); + + params.type = GGML_TASK_FINALIZE; + ggml_compute_forward(¶ms, node); + ggml_graph_compute_perf_stats_node(node, state->shared); + } else { + break; + } } - ggml_lock_lock (&state->shared->spin); - ggml_lock_unlock(&state->shared->spin); + + atomic_store(&state->shared->n_active, n_threads); + atomic_store(&state->shared->node_n, node_n); + } else { + // wait for other threads to finish + const int last = node_n; + do { + sched_yield(); + node_n = atomic_load(&state->shared->node_n); + } while (node_n == last); } // check if we should stop - if (atomic_load(&state->shared->stop)) { - break; - } + if (node_n >= cgraph->n_nodes) break; - if (state->node) { - if (state->params.ith < state->params.nth) { - ggml_compute_forward(&state->params, state->node); - } + /* COMPUTE */ + struct ggml_tensor * node = cgraph->nodes[node_n]; - state->node = NULL; - } else { - break; + struct ggml_compute_params params = { + /*.type =*/ GGML_TASK_COMPUTE, + /*.ith =*/ state->ith, + /*.nth =*/ node->n_tasks, + /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0, + /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL, + }; + + if (state->ith < node->n_tasks) { + ggml_compute_forward(¶ms, node); } } @@ -16574,39 +16800,14 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) const int n_threads = cgraph->n_threads; struct ggml_compute_state_shared state_shared = { - /*.spin =*/ GGML_LOCK_INITIALIZER, - /*.n_threads =*/ n_threads, - /*.n_ready =*/ 0, - /*.has_work =*/ false, - /*.stop =*/ false, + /*.cgraph =*/ cgraph, + /*.perf_node_start_cycles =*/ 0, + /*.perf_node_start_time_us =*/ 0, + /*.n_threads =*/ n_threads, + /*.n_active =*/ n_threads, + /*.node_n =*/ -1, }; - struct ggml_compute_state * workers = n_threads > 1 ? alloca(sizeof(struct ggml_compute_state)*(n_threads - 1)) : NULL; - - // create thread pool - if (n_threads > 1) { - ggml_lock_init(&state_shared.spin); - - atomic_store(&state_shared.has_work, true); - - for (int j = 0; j < n_threads - 1; j++) { - workers[j] = (struct ggml_compute_state) { - .thrd = 0, - .params = { - .type = GGML_TASK_COMPUTE, - .ith = j + 1, - .nth = n_threads, - .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0, - .wdata = cgraph->work ? cgraph->work->data : NULL, - }, - .node = NULL, - .shared = &state_shared, - }; - - int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]); - GGML_ASSERT(rc == 0); - UNUSED(rc); - } - } + struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads); // initialize tasks + work buffer { @@ -16750,7 +16951,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) } break; case GGML_OP_SCALE: { - node->n_tasks = n_threads; + node->n_tasks = 1; } break; case GGML_OP_SET: case GGML_OP_CONT: @@ -16954,166 +17155,37 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) } } - const int64_t perf_start_cycles = ggml_perf_cycles(); - const int64_t perf_start_time_us = ggml_perf_time_us(); - - for (int i = 0; i < cgraph->n_nodes; i++) { - GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, i, cgraph->n_nodes); - - struct ggml_tensor * node = cgraph->nodes[i]; - - // TODO: this could be used to avoid unnecessary computations, but it needs to be improved - //if (node->grad == NULL && node->perf_runs > 0) { - // continue; - //} - - const int64_t perf_node_start_cycles = ggml_perf_cycles(); - const int64_t perf_node_start_time_us = ggml_perf_time_us(); - - // INIT - struct ggml_compute_params params = { - /*.type =*/ GGML_TASK_INIT, - /*.ith =*/ 0, - /*.nth =*/ node->n_tasks, - /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0, - /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL, - }; - - ggml_compute_forward(¶ms, node); - - // COMPUTE - if (node->n_tasks > 1) { - if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) { - atomic_store(&state_shared.has_work, false); - } - - while (atomic_load(&state_shared.has_work)) { - ggml_lock_lock (&state_shared.spin); - ggml_lock_unlock(&state_shared.spin); - } - - // launch thread pool - for (int j = 0; j < n_threads - 1; j++) { - workers[j].params = (struct ggml_compute_params) { - .type = GGML_TASK_COMPUTE, - .ith = j + 1, - .nth = node->n_tasks, - .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0, - .wdata = cgraph->work ? cgraph->work->data : NULL, - }; - workers[j].node = node; - } - - atomic_fetch_sub(&state_shared.n_ready, 1); - - while (atomic_load(&state_shared.n_ready) > 0) { - ggml_lock_lock (&state_shared.spin); - ggml_lock_unlock(&state_shared.spin); - } - - atomic_store(&state_shared.has_work, true); - } - - params.type = GGML_TASK_COMPUTE; - ggml_compute_forward(¶ms, node); - - // wait for thread pool - if (node->n_tasks > 1) { - if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) { - atomic_store(&state_shared.has_work, false); - } - - while (atomic_load(&state_shared.has_work)) { - ggml_lock_lock (&state_shared.spin); - ggml_lock_unlock(&state_shared.spin); - } - - atomic_fetch_sub(&state_shared.n_ready, 1); - - while (atomic_load(&state_shared.n_ready) != 0) { - ggml_lock_lock (&state_shared.spin); - ggml_lock_unlock(&state_shared.spin); - } - } - - // FINALIZE - if (node->n_tasks > 1) { - if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) { - atomic_store(&state_shared.has_work, false); - } - - while (atomic_load(&state_shared.has_work)) { - ggml_lock_lock (&state_shared.spin); - ggml_lock_unlock(&state_shared.spin); - } - - // launch thread pool - for (int j = 0; j < n_threads - 1; j++) { - workers[j].params = (struct ggml_compute_params) { - .type = GGML_TASK_FINALIZE, - .ith = j + 1, - .nth = node->n_tasks, - .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0, - .wdata = cgraph->work ? cgraph->work->data : NULL, - }; - workers[j].node = node; - } - - atomic_fetch_sub(&state_shared.n_ready, 1); - - while (atomic_load(&state_shared.n_ready) > 0) { - ggml_lock_lock (&state_shared.spin); - ggml_lock_unlock(&state_shared.spin); - } + // create thread pool + if (n_threads > 1) { + for (int j = 1; j < n_threads; ++j) { + workers[j] = (struct ggml_compute_state) { + .thrd = 0, + .ith = j, + .shared = &state_shared, + }; - atomic_store(&state_shared.has_work, true); + const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]); + GGML_ASSERT(rc == 0); } + } + workers[0].ith = 0; + workers[0].shared = &state_shared; - params.type = GGML_TASK_FINALIZE; - ggml_compute_forward(¶ms, node); - - // wait for thread pool - if (node->n_tasks > 1) { - if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) { - atomic_store(&state_shared.has_work, false); - } - - while (atomic_load(&state_shared.has_work)) { - ggml_lock_lock (&state_shared.spin); - ggml_lock_unlock(&state_shared.spin); - } - - atomic_fetch_sub(&state_shared.n_ready, 1); - - while (atomic_load(&state_shared.n_ready) != 0) { - ggml_lock_lock (&state_shared.spin); - ggml_lock_unlock(&state_shared.spin); - } - } + const int64_t perf_start_cycles = ggml_perf_cycles(); + const int64_t perf_start_time_us = ggml_perf_time_us(); - // performance stats (node) - { - int64_t perf_cycles_cur = ggml_perf_cycles() - perf_node_start_cycles; - int64_t perf_time_us_cur = ggml_perf_time_us() - perf_node_start_time_us; + // this is a work thread too + ggml_graph_compute_thread(&workers[0]); - node->perf_runs++; - node->perf_cycles += perf_cycles_cur; - node->perf_time_us += perf_time_us_cur; - } - } + // don't leave affinity set on the main thread + clear_numa_thread_affinity(); // join thread pool if (n_threads > 1) { - atomic_store(&state_shared.stop, true); - atomic_store(&state_shared.has_work, true); - - for (int j = 0; j < n_threads - 1; j++) { - int rc = ggml_thread_join(workers[j].thrd, NULL); + for (int j = 1; j < n_threads; j++) { + const int rc = ggml_thread_join(workers[j].thrd, NULL); GGML_ASSERT(rc == 0); - UNUSED(rc); } - - ggml_lock_destroy(&state_shared.spin); } // performance stats (graph) From c0b546bfd6ccf9900212ca5ede6f6f7f60862155 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 26 Jun 2023 23:26:37 +0300 Subject: [PATCH 02/15] ggml : increase max name size to 48 --- include/ggml/ggml.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ggml/ggml.h b/include/ggml/ggml.h index 6b106b1c3..08025e57a 100644 --- a/include/ggml/ggml.h +++ b/include/ggml/ggml.h @@ -198,7 +198,7 @@ #define GGML_MAX_PARAMS 256 #define GGML_MAX_CONTEXTS 64 #define GGML_MAX_OPT 4 -#define GGML_MAX_NAME 32 +#define GGML_MAX_NAME 48 #define GGML_DEFAULT_N_THREADS 4 #define GGML_ASSERT(x) \ From 2d95223a3f983c3aaad73aef4cd743ad4af4eff7 Mon Sep 17 00:00:00 2001 From: Jiahao Li Date: Tue, 27 Jun 2023 04:47:31 +0800 Subject: [PATCH 03/15] ggml : support ChatGLM-style RoPE (#305) --- examples/dolly-v2/main.cpp | 4 +- examples/gpt-j/main.cpp | 4 +- examples/gpt-neox/main.cpp | 4 +- include/ggml/ggml.h | 7 +++- src/ggml.c | 82 +++++++++++++++++++++++++++++++++----- tests/test-grad0.c | 4 +- 6 files changed, 84 insertions(+), 21 deletions(-) diff --git a/examples/dolly-v2/main.cpp b/examples/dolly-v2/main.cpp index 85faa7076..ce7d3d4fa 100644 --- a/examples/dolly-v2/main.cpp +++ b/examples/dolly-v2/main.cpp @@ -523,8 +523,8 @@ bool dollyv2_eval( struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd/n_head, n_head, N, cur->nb[1]/n_head, cur->nb[1], 2*sizeof(float)*n_embd/n_head)); // using mode = 2 for GPT-NeoX mode - Qcur = ggml_rope_inplace(ctx0, Qcur, n_past, n_rot, 2); - Kcur = ggml_rope_inplace(ctx0, Kcur, n_past, n_rot, 2); + Qcur = ggml_rope_inplace(ctx0, Qcur, n_past, n_rot, 2, 0); + Kcur = ggml_rope_inplace(ctx0, Kcur, n_past, n_rot, 2, 0); // store key and value to memory { diff --git a/examples/gpt-j/main.cpp b/examples/gpt-j/main.cpp index 3d956ffe8..516e0998a 100644 --- a/examples/gpt-j/main.cpp +++ b/examples/gpt-j/main.cpp @@ -452,8 +452,8 @@ bool gptj_eval( // self-attention { - struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].c_attn_q_proj_w, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0); - struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].c_attn_k_proj_w, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0); + struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].c_attn_q_proj_w, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0); + struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].c_attn_k_proj_w, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0); // store key and value to memory { diff --git a/examples/gpt-neox/main.cpp b/examples/gpt-neox/main.cpp index 1cd64a227..649b51c0e 100644 --- a/examples/gpt-neox/main.cpp +++ b/examples/gpt-neox/main.cpp @@ -517,8 +517,8 @@ bool gpt_neox_eval( struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd/n_head, n_head, N, cur->nb[1]/n_head, cur->nb[1], 2*sizeof(float)*n_embd/n_head)); // using mode = 2 for GPT-NeoX mode - Qcur = ggml_rope_inplace(ctx0, Qcur, n_past, n_rot, 2); - Kcur = ggml_rope_inplace(ctx0, Kcur, n_past, n_rot, 2); + Qcur = ggml_rope_inplace(ctx0, Qcur, n_past, n_rot, 2, 0); + Kcur = ggml_rope_inplace(ctx0, Kcur, n_past, n_rot, 2, 0); // store key and value to memory { diff --git a/include/ggml/ggml.h b/include/ggml/ggml.h index 08025e57a..459913222 100644 --- a/include/ggml/ggml.h +++ b/include/ggml/ggml.h @@ -1036,13 +1036,15 @@ extern "C" { // rotary position embedding // if mode & 1 == 1, skip n_past elements // if mode & 2 == 1, GPT-NeoX style + // if mode & 4 == 1, ChatGLM style // TODO: avoid creating a new tensor every time GGML_API struct ggml_tensor * ggml_rope( struct ggml_context * ctx, struct ggml_tensor * a, int n_past, int n_dims, - int mode); + int mode, + int n_ctx); // in-place, returns view(a) GGML_API struct ggml_tensor * ggml_rope_inplace( @@ -1050,7 +1052,8 @@ extern "C" { struct ggml_tensor * a, int n_past, int n_dims, - int mode); + int mode, + int n_ctx); // rotary position embedding backward, i.e compute dx from dy // a - dy diff --git a/src/ggml.c b/src/ggml.c index c179bee93..92faf03f7 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -6778,6 +6778,7 @@ struct ggml_tensor * ggml_rope_impl( int n_past, int n_dims, int mode, + int n_ctx, bool inplace) { GGML_ASSERT(n_past >= 0); bool is_node = false; @@ -6790,11 +6791,12 @@ struct ggml_tensor * ggml_rope_impl( ggml_scratch_save(ctx); - struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3); + struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4); ((int32_t *) b->data)[0] = n_past; ((int32_t *) b->data)[1] = n_dims; ((int32_t *) b->data)[2] = mode; + ((int32_t *) b->data)[3] = n_ctx; ggml_scratch_load(ctx); @@ -6811,8 +6813,9 @@ struct ggml_tensor * ggml_rope( struct ggml_tensor * a, int n_past, int n_dims, - int mode) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, false); + int mode, + int n_ctx) { + return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, false); } struct ggml_tensor * ggml_rope_inplace( @@ -6820,8 +6823,9 @@ struct ggml_tensor * ggml_rope_inplace( struct ggml_tensor * a, int n_past, int n_dims, - int mode) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, true); + int mode, + int n_ctx) { + return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, true); } // ggml_rope_back @@ -12440,7 +12444,7 @@ static void ggml_compute_forward_rope_f32( const struct ggml_tensor * src1, struct ggml_tensor * dst) { GGML_ASSERT(src1->type == GGML_TYPE_I32); - GGML_ASSERT(ggml_nelements(src1) == 3); + GGML_ASSERT(ggml_nelements(src1) == 4); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; @@ -12449,6 +12453,7 @@ static void ggml_compute_forward_rope_f32( const int n_past = ((int32_t *) src1->data)[0]; const int n_dims = ((int32_t *) src1->data)[1]; const int mode = ((int32_t *) src1->data)[2]; + const int n_ctx = ((int32_t *) src1->data)[3]; assert(n_past >= 0); @@ -12493,6 +12498,7 @@ static void ggml_compute_forward_rope_f32( const float theta_scale = powf(10000.0, -2.0f/n_dims); const bool is_neox = mode & 2; + const bool is_glm = mode & 4; for (int64_t i3 = 0; i3 < ne3; i3++) { for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) { @@ -12503,7 +12509,32 @@ static void ggml_compute_forward_rope_f32( float theta = (float)p; - if (!is_neox) { + if (is_glm) { + theta = MIN(p, n_ctx - 2); + float block_theta = MAX(p - (n_ctx - 2), 0); + for (int64_t i0 = 0; i0 < ne0 / 4; i0++) { + const float cos_theta = cosf(theta); + const float sin_theta = sinf(theta); + const float cos_block_theta = cosf(block_theta); + const float sin_block_theta = sinf(block_theta); + + theta *= theta_scale; + block_theta *= theta_scale; + + const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + const float x2 = src[n_dims]; + const float x3 = src[n_dims/2*3]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + dst_data[n_dims] = x2*cos_block_theta - x3*sin_block_theta; + dst_data[n_dims/2*3] = x2*sin_block_theta + x3*cos_block_theta; + } + } else if (!is_neox) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) { const float cos_theta = cosf(theta); const float sin_theta = sinf(theta); @@ -12553,7 +12584,7 @@ static void ggml_compute_forward_rope_f16( const struct ggml_tensor * src1, struct ggml_tensor * dst) { GGML_ASSERT(src1->type == GGML_TYPE_I32); - GGML_ASSERT(ggml_nelements(src1) == 3); + GGML_ASSERT(ggml_nelements(src1) == 4); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; @@ -12562,6 +12593,7 @@ static void ggml_compute_forward_rope_f16( const int n_past = ((int32_t *) src1->data)[0]; const int n_dims = ((int32_t *) src1->data)[1]; const int mode = ((int32_t *) src1->data)[2]; + const int n_ctx = ((int32_t *) src1->data)[3]; assert(n_past >= 0); @@ -12606,6 +12638,7 @@ static void ggml_compute_forward_rope_f16( const float theta_scale = powf(10000.0, -2.0f/n_dims); const bool is_neox = mode & 2; + const bool is_glm = mode & 4; for (int64_t i3 = 0; i3 < ne3; i3++) { for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) { @@ -12616,7 +12649,32 @@ static void ggml_compute_forward_rope_f16( float theta = (float)p; - if (!is_neox) { + if (is_glm) { + theta = MIN(p, n_ctx - 2); + float block_theta = MAX(p - (n_ctx - 2), 0); + for (int64_t i0 = 0; i0 < ne0 / 4; i0++) { + const float cos_theta = cosf(theta); + const float sin_theta = sinf(theta); + const float cos_block_theta = cosf(block_theta); + const float sin_block_theta = sinf(block_theta); + + theta *= theta_scale; + block_theta *= theta_scale; + + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = GGML_FP16_TO_FP32(src[0]); + const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]); + const float x2 = GGML_FP16_TO_FP32(src[n_dims]); + const float x3 = GGML_FP16_TO_FP32(src[n_dims/2*3]); + + dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); + dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); + dst_data[n_dims] = GGML_FP32_TO_FP16(x2*cos_block_theta - x3*sin_block_theta); + dst_data[n_dims/2*3] = GGML_FP32_TO_FP16(x2*sin_block_theta + x3*cos_block_theta); + } + } if (!is_neox) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) { const float cos_theta = cosf(theta); const float sin_theta = sinf(theta); @@ -16189,17 +16247,19 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { if (src0->grad) { assert(src1->type == GGML_TYPE_I32); - assert(ggml_nelements(src1) == 3); + assert(ggml_nelements(src1) == 4); const int n_past = ((int32_t *) src1->data)[0]; const int n_dims = ((int32_t *) src1->data)[1]; const int mode = ((int32_t *) src1->data)[2]; + const int n_ctx = ((int32_t *) src1->data)[3]; src0->grad = ggml_add_impl(ctx, src0->grad, ggml_rope(ctx, tensor->grad, n_past, n_dims, - mode), + mode, + n_ctx), inplace); } if (src1->grad) { diff --git a/tests/test-grad0.c b/tests/test-grad0.c index b5a499c1d..ce5dc32bf 100644 --- a/tests/test-grad0.c +++ b/tests/test-grad0.c @@ -1139,7 +1139,7 @@ int main(int argc, const char ** argv) { int n_rot = ne2[0]; for (int ndims = 3; ndims <= 4; ++ndims) { - for (int mode = 0; mode < 4; ++mode) { + for (int mode = 0; mode < 8; ++mode) { for (int n_past = 1; n_past < ne2[2]; ++n_past) { x[0] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f); @@ -1154,7 +1154,7 @@ int main(int argc, const char ** argv) { continue; } - struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], n_past, n_rot, mode)); + struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], n_past, n_rot, mode, 0)); GGML_PRINT_DEBUG("rope: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode); check_gradient("rope", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY); From ee1b3727e60403012dd2b57d35b60558f4db66d8 Mon Sep 17 00:00:00 2001 From: sjinzh Date: Tue, 27 Jun 2023 04:48:31 +0800 Subject: [PATCH 04/15] zig : add tests by zig (#307) * update build.zig * zig : add tests by zig --- .gitignore | 4 +- build.zig | 61 +++++-- tests/test0.zig | 43 +++++ tests/test1.zig | 459 ++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 548 insertions(+), 19 deletions(-) create mode 100644 tests/test0.zig create mode 100644 tests/test1.zig diff --git a/.gitignore b/.gitignore index 093031907..698b04a98 100644 --- a/.gitignore +++ b/.gitignore @@ -18,4 +18,6 @@ src/arm_neon.h tests/arm_neon.h zig-out/ -zig-cache/ \ No newline at end of file +zig-cache/ + +*.dot \ No newline at end of file diff --git a/build.zig b/build.zig index 34582cec6..87151a885 100644 --- a/build.zig +++ b/build.zig @@ -2,24 +2,26 @@ const std = @import("std"); // Zig Version: 0.11.0-dev.3798+a5e15eced // Zig Build Command: zig build -// Zig Run Command: -// zig build run_dolly-v2 -// zig build run_gpt-2 -// zig build run_gpt-j -// zig build run_gpt-neox -// zig build run_mnist -// zig build run_mpt -// zig build run_replit -// zig build run_starcoder -// zig build run_test-grad0 -// zig build run_test-mul-mat0 -// zig build run_test-mul-mat2 -// zig build run_test-opt -// zig build run_test-vec1 -// zig build run_test0 -// zig build run_test1 -// zig build run_test2 -// zig build run_test3 +// Zig Run Command: zig build -h +// zig build run_dolly-v2 +// zig build run_gpt-2 +// zig build run_gpt-j +// zig build run_gpt-neox +// zig build run_mnist +// zig build run_mpt +// zig build run_replit +// zig build run_starcoder +// zig build run_test-grad0 +// zig build run_test-mul-mat0 +// zig build run_test-mul-mat2 +// zig build run_test-opt +// zig build run_test-vec1 +// zig build run_test0 +// zig build run_test1 +// zig build run_test2 +// zig build run_test3 +// zig build run_zig_test0 +// zig build run_zig_test1 pub fn build(b: *std.build.Builder) void { const target = b.standardTargetOptions(.{}); const optimize = b.standardOptimizeOption(.{}); @@ -110,4 +112,27 @@ pub fn build(b: *std.build.Builder) void { const run_step = b.step("run_" ++ name, "Run tests"); run_step.dependOn(&run_cmd.step); } + + // zig_tests + const zig_tests = .{ + "test0", + "test1", + }; + inline for (zig_tests) |name| { + const exe = b.addExecutable(.{ + .name = name, + .root_source_file = .{ .path = std.fmt.comptimePrint("tests/{s}.zig", .{name}) }, + .target = target, + .optimize = optimize, + }); + exe.addIncludePath("./include"); + exe.addIncludePath("./include/ggml"); + exe.linkLibrary(lib); + b.installArtifact(exe); + const run_cmd = b.addRunArtifact(exe); + run_cmd.step.dependOn(b.getInstallStep()); + if (b.args) |args| run_cmd.addArgs(args); + const run_step = b.step("run_zig_" ++ name, "Run zig_tests"); + run_step.dependOn(&run_cmd.step); + } } \ No newline at end of file diff --git a/tests/test0.zig b/tests/test0.zig new file mode 100644 index 000000000..8246a978b --- /dev/null +++ b/tests/test0.zig @@ -0,0 +1,43 @@ +const std = @import("std"); +const c = @cImport({ + @cInclude("stdio.h"); + @cInclude("stdlib.h"); + @cInclude("ggml/ggml.h"); +}); + +pub fn main() !void { + const params = .{ + .mem_size = 128*1024*1024, + .mem_buffer = null, + .no_alloc = false, + }; + + const ctx0 = c.ggml_init(params); + defer c.ggml_free(ctx0); + + const t1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 10); + const t2 = c.ggml_new_tensor_2d(ctx0, c.GGML_TYPE_I16, 10, 20); + const t3 = c.ggml_new_tensor_3d(ctx0, c.GGML_TYPE_I32, 10, 20, 30); + + try std.testing.expect(t1.*.n_dims == 1); + try std.testing.expect(t1.*.ne[0] == 10); + try std.testing.expect(t1.*.nb[1] == 10*@sizeOf(f32)); + + try std.testing.expect(t2.*.n_dims == 2); + try std.testing.expect(t2.*.ne[0] == 10); + try std.testing.expect(t2.*.ne[1] == 20); + try std.testing.expect(t2.*.nb[1] == 10*@sizeOf(i16)); + try std.testing.expect(t2.*.nb[2] == 10*20*@sizeOf(i16)); + + try std.testing.expect(t3.*.n_dims == 3); + try std.testing.expect(t3.*.ne[0] == 10); + try std.testing.expect(t3.*.ne[1] == 20); + try std.testing.expect(t3.*.ne[2] == 30); + try std.testing.expect(t3.*.nb[1] == 10*@sizeOf(i32)); + try std.testing.expect(t3.*.nb[2] == 10*20*@sizeOf(i32)); + try std.testing.expect(t3.*.nb[3] == 10*20*30*@sizeOf(i32)); + + c.ggml_print_objects(ctx0); + + _ = try std.io.getStdIn().reader().readByte(); +} diff --git a/tests/test1.zig b/tests/test1.zig new file mode 100644 index 000000000..e40c37e81 --- /dev/null +++ b/tests/test1.zig @@ -0,0 +1,459 @@ +const std = @import("std"); +const c = @cImport({ + @cInclude("stdio.h"); + @cInclude("stdlib.h"); + @cInclude("ggml/ggml.h"); +}); + +pub fn main() !void { + const params = .{ + .mem_size = 128*1024*1024, + .mem_buffer = null, + .no_alloc = false, + }; + + const ctx0 = c.ggml_init(params); + defer c.ggml_free(ctx0); + + { + const x = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1); + + c.ggml_set_param(ctx0, x); + + const a = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1); + const b = c.ggml_mul(ctx0, x, x); + const f = c.ggml_mul(ctx0, b, a); + + // a*x^2 + // 2*a*x + + c.ggml_print_objects(ctx0); + + const gf = c.ggml_build_forward(f); + const gb = c.ggml_build_backward(ctx0, @constCast(&gf), false); + + _ = c.ggml_set_f32(x, 2.0); + _ = c.ggml_set_f32(a, 3.0); + + c.ggml_graph_reset(@constCast(&gf)); + _ = c.ggml_set_f32(f.*.grad, 1.0); + + c.ggml_graph_compute(ctx0, @constCast(&gb)); + + std.debug.print("f = {d:.6}\n", .{c.ggml_get_f32_1d(f, 0)}); + std.debug.print("df/dx = {d:.6}\n", .{c.ggml_get_f32_1d(x.*.grad, 0)}); + + try std.testing.expect(c.ggml_get_f32_1d(f, 0) == 12.0); + try std.testing.expect(c.ggml_get_f32_1d(x.*.grad, 0) == 12.0); + + _ = c.ggml_set_f32(x, 3.0); + + c.ggml_graph_reset(@constCast(&gf)); + _ = c.ggml_set_f32(f.*.grad, 1.0); + + c.ggml_graph_compute(ctx0, @constCast(&gb)); + + std.debug.print("f = {d:.6}\n", .{c.ggml_get_f32_1d(f, 0)}); + std.debug.print("df/dx = {d:.6}\n", .{c.ggml_get_f32_1d(x.*.grad, 0)}); + + try std.testing.expect(c.ggml_get_f32_1d(f, 0) == 27.0); + try std.testing.expect(c.ggml_get_f32_1d(x.*.grad, 0) == 18.0); + + c.ggml_graph_dump_dot(&gf, null, "test1-1-forward.dot"); + c.ggml_graph_dump_dot(&gb, &gf, "test1-1-backward.dot"); + } + + ///////////////////////////////////////////////////////////// + + { + const x1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1); + const x2 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1); + const x3 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1); + + _ = c.ggml_set_f32(x1, 3.0); + _ = c.ggml_set_f32(x2, 1.0); + _ = c.ggml_set_f32(x3, 0.0); + + c.ggml_set_param(ctx0, x1); + c.ggml_set_param(ctx0, x2); + + const y = c.ggml_add(ctx0, c.ggml_mul(ctx0, x1, x1), c.ggml_mul(ctx0, x1, x2)); + + const gf = c.ggml_build_forward(y); + const gb = c.ggml_build_backward(ctx0, @constCast(&gf), false); + + c.ggml_graph_reset(@constCast(&gf)); + _ = c.ggml_set_f32(y.*.grad, 1.0); + + c.ggml_graph_compute(ctx0, @constCast(&gb)); + + std.debug.print("y = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)}); + std.debug.print("df/dx1 = {d:.6}\n", .{c.ggml_get_f32_1d(x1.*.grad, 0)}); + std.debug.print("df/dx2 = {d:.6}\n", .{c.ggml_get_f32_1d(x2.*.grad, 0)}); + + try std.testing.expect(c.ggml_get_f32_1d(y, 0) == 12.0); + try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 7.0); + try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 3.0); + + const g1 = x1.*.grad; + const g2 = x2.*.grad; + + const gbb = c.ggml_build_backward(ctx0, @constCast(&gb), true); + + c.ggml_graph_reset(@constCast(&gb)); + _ = c.ggml_set_f32(g1.*.grad, 1.0); + _ = c.ggml_set_f32(g2.*.grad, 1.0); + + c.ggml_graph_compute(ctx0, @constCast(&gbb)); + + std.debug.print("H * [1, 1] = [ {d:.6} {d:.6} ]\n", .{c.ggml_get_f32_1d(x1.*.grad, 0), c.ggml_get_f32_1d(x2.*.grad, 0)}); + + try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 3.0); + try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 1.0); + + c.ggml_graph_dump_dot(&gf, null, "test1-2-forward.dot"); + c.ggml_graph_dump_dot(&gb, &gf, "test1-2-backward.dot"); + } + + /////////////////////////////////////////////////////////////// + + { + const x1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1); + const x2 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1); + + c.ggml_set_param(ctx0, x1); + c.ggml_set_param(ctx0, x2); + + const y = c.ggml_mul(ctx0, c.ggml_add(ctx0, c.ggml_mul(ctx0, x1, x1), c.ggml_mul(ctx0, x1, x2)), x1); + + const gf = c.ggml_build_forward(y); + const gb = c.ggml_build_backward(ctx0, @constCast(&gf), false); + + _ = c.ggml_set_f32(x1, 3.0); + _ = c.ggml_set_f32(x2, 4.0); + + c.ggml_graph_reset(@constCast(&gf)); + _ = c.ggml_set_f32(y.*.grad, 1.0); + + c.ggml_graph_compute(ctx0, @constCast(&gb)); + + std.debug.print("y = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)}); + std.debug.print("df/dx1 = {d:.6}\n", .{c.ggml_get_f32_1d(x1.*.grad, 0)}); + std.debug.print("df/dx2 = {d:.6}\n", .{c.ggml_get_f32_1d(x2.*.grad, 0)}); + + try std.testing.expect(c.ggml_get_f32_1d(y, 0) == 63.0); + try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 51.0); + try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 9.0); + + c.ggml_graph_dump_dot(&gf, null, "test1-3-forward.dot"); + c.ggml_graph_dump_dot(&gb, &gf, "test1-3-backward.dot"); + } + + /////////////////////////////////////////////////////////////// + + { + const x1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1); + const x2 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1); + const x3 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1); + + c.ggml_set_param(ctx0, x1); + c.ggml_set_param(ctx0, x2); + c.ggml_set_param(ctx0, x3); + + const y = c.ggml_mul(ctx0, c.ggml_mul(ctx0, c.ggml_mul(ctx0, x1, x1), c.ggml_mul(ctx0, x2, x2)), x3); + + const gf = c.ggml_build_forward(y); + const gb = c.ggml_build_backward(ctx0, @constCast(&gf), false); + + _ = c.ggml_set_f32(x1, 1.0); + _ = c.ggml_set_f32(x2, 2.0); + _ = c.ggml_set_f32(x3, 3.0); + + c.ggml_graph_reset(@constCast(&gf)); + _ = c.ggml_set_f32(y.*.grad, 1.0); + + c.ggml_graph_compute(ctx0, @constCast(&gb)); + + std.debug.print("y = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)}); + std.debug.print("df/dx1 = {d:.6}\n", .{c.ggml_get_f32_1d(x1.*.grad, 0)}); + std.debug.print("df/dx2 = {d:.6}\n", .{c.ggml_get_f32_1d(x2.*.grad, 0)}); + std.debug.print("df/dx3 = {d:.6}\n", .{c.ggml_get_f32_1d(x3.*.grad, 0)}); + + try std.testing.expect(c.ggml_get_f32_1d(y, 0) == 12.0); + try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 24.0); + try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 12.0); + try std.testing.expect(c.ggml_get_f32_1d(x3.*.grad, 0) == 4.0); + + const g1 = x1.*.grad; + const g2 = x2.*.grad; + const g3 = x3.*.grad; + + const gbb = c.ggml_build_backward(ctx0, @constCast(&gb), true); + + c.ggml_graph_reset(@constCast(&gb)); + _ = c.ggml_set_f32(g1.*.grad, 1.0); + _ = c.ggml_set_f32(g2.*.grad, 1.0); + _ = c.ggml_set_f32(g3.*.grad, 1.0); + + c.ggml_graph_compute(ctx0, @constCast(&gbb)); + + std.debug.print("H * [1, 1, 1] = [ {d:.6} {d:.6} {d:.6}]\n", + .{ + c.ggml_get_f32_1d(x1.*.grad, 0), + c.ggml_get_f32_1d(x2.*.grad, 0), + c.ggml_get_f32_1d(x3.*.grad, 0), + }); + + try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 56.0); + try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 34.0); + try std.testing.expect(c.ggml_get_f32_1d(x3.*.grad, 0) == 12.0); + + c.ggml_graph_dump_dot(&gf, null, "test1-4-forward.dot"); + c.ggml_graph_dump_dot(&gb, &gf, "test1-4-backward.dot"); + } + + /////////////////////////////////////////////////////////////// + + { + const x1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3); + const x2 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3); + + c.ggml_set_param(ctx0, x1); + c.ggml_set_param(ctx0, x2); + + const y = c.ggml_sum(ctx0, c.ggml_mul(ctx0, x1, x2)); + + const gf = c.ggml_build_forward(y); + const gb = c.ggml_build_backward(ctx0, @constCast(&gf), false); + + _ = c.ggml_set_f32(x1, 3.0); + _ = c.ggml_set_f32(x2, 5.0); + + c.ggml_graph_reset(@constCast(&gf)); + _ = c.ggml_set_f32(y.*.grad, 1.0); + + c.ggml_graph_compute(ctx0, @constCast(&gb)); + + std.debug.print("y = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)}); + std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n", + .{ + c.ggml_get_f32_1d(x1.*.grad, 0), + c.ggml_get_f32_1d(x1.*.grad, 1), + c.ggml_get_f32_1d(x1.*.grad, 2), + }); + std.debug.print("df/dx2 = {d:.6} {d:.6} {d:.6}\n", + .{ + c.ggml_get_f32_1d(x2.*.grad, 0), + c.ggml_get_f32_1d(x2.*.grad, 1), + c.ggml_get_f32_1d(x2.*.grad, 2), + }); + + try std.testing.expect(c.ggml_get_f32_1d(y, 0) == 45.0); + try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 5.0); + try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 3.0); + try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 1) == 5.0); + try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 1) == 3.0); + try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 2) == 5.0); + try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 2) == 3.0); + + c.ggml_graph_dump_dot(&gf, null, "test1-5-forward.dot"); + c.ggml_graph_dump_dot(&gb, &gf, "test1-5-backward.dot"); + } + + /////////////////////////////////////////////////////////////// + + { + const x1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3); + const x2 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3); + + c.ggml_set_param(ctx0, x1); + c.ggml_set_param(ctx0, x2); + + const y = + c.ggml_sum(ctx0, + c.ggml_add(ctx0, + c.ggml_mul(ctx0, x1, x2), + c.ggml_mul(ctx0, + c.ggml_repeat(ctx0, c.ggml_new_f32(ctx0, -2.0), x1), + c.ggml_mul(ctx0, x1, x1) + ) + ) + ); + + const gf = c.ggml_build_forward(y); + const gb = c.ggml_build_backward(ctx0, @constCast(&gf), false); + + _ = c.ggml_set_f32(x1, 3.0); + _ = c.ggml_set_f32(x2, 5.0); + + c.ggml_graph_reset(@constCast(&gf)); + _ = c.ggml_set_f32(y.*.grad, 1.0); + + c.ggml_graph_compute(ctx0, @constCast(&gb)); + + std.debug.print("y = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)}); + std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n", + .{ + c.ggml_get_f32_1d(x1.*.grad, 0), + c.ggml_get_f32_1d(x1.*.grad, 1), + c.ggml_get_f32_1d(x1.*.grad, 2), + }); + std.debug.print("df/dx2 = {d:.6} {d:.6} {d:.6}\n", + .{ + c.ggml_get_f32_1d(x2.*.grad, 0), + c.ggml_get_f32_1d(x2.*.grad, 1), + c.ggml_get_f32_1d(x2.*.grad, 2), + }); + + try std.testing.expect(c.ggml_get_f32_1d(y, 0) == -9.0); + try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == -7.0); + try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 1) == -7.0); + try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 2) == -7.0); + try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 3.0); + try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 1) == 3.0); + try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 2) == 3.0); + + c.ggml_graph_dump_dot(&gf, null, "test1-6-forward.dot"); + c.ggml_graph_dump_dot(&gb, &gf, "test1-6-backward.dot"); + } + + /////////////////////////////////////////////////////////////// + + { + const x1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3); + const x2 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3); + + c.ggml_set_param(ctx0, x1); + c.ggml_set_param(ctx0, x2); + + const y = + c.ggml_sum(ctx0, + c.ggml_sub(ctx0, + c.ggml_mul(ctx0, x1, x2), + c.ggml_mul(ctx0, + c.ggml_mul(ctx0, x1, x1), + c.ggml_repeat(ctx0, c.ggml_new_f32(ctx0, -2.0), x1) + ) + ) + ); + + const gf = c.ggml_build_forward(y); + const gb = c.ggml_build_backward(ctx0, @constCast(&gf), false); + + _ = c.ggml_set_f32(x1, 3.0); + _ = c.ggml_set_f32(x2, 5.0); + + c.ggml_graph_reset(@constCast(&gf)); + _ = c.ggml_set_f32(y.*.grad, 1.0); + + c.ggml_graph_compute(ctx0, @constCast(&gb)); + + std.debug.print("y = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)}); + std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n", + .{ + c.ggml_get_f32_1d(x1.*.grad, 0), + c.ggml_get_f32_1d(x1.*.grad, 1), + c.ggml_get_f32_1d(x1.*.grad, 2), + }); + std.debug.print("df/dx2 = {d:.6} {d:.6} {d:.6}\n", + .{ + c.ggml_get_f32_1d(x2.*.grad, 0), + c.ggml_get_f32_1d(x2.*.grad, 1), + c.ggml_get_f32_1d(x2.*.grad, 2), + }); + + try std.testing.expect(c.ggml_get_f32_1d(y, 0) == 99.0); + try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 17.0); + try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 1) == 17.0); + try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 2) == 17.0); + try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 3.0); + try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 1) == 3.0); + try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 2) == 3.0); + + c.ggml_graph_dump_dot(&gf, null, "test1-7-forward.dot"); + c.ggml_graph_dump_dot(&gb, &gf, "test1-7-backward.dot"); + } + + /////////////////////////////////////////////////////////////// + + { + const x1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3); + const x2 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3); + + c.ggml_set_param(ctx0, x1); + c.ggml_set_param(ctx0, x2); + + const y = + c.ggml_abs(ctx0, + c.ggml_sub(ctx0, x1, x2) + ); + + const gf = c.ggml_build_forward(y); + const gb = c.ggml_build_backward(ctx0, @constCast(&gf), false); + + _ = c.ggml_set_f32(x1, 3.0); + _ = c.ggml_set_f32(x2, 5.0); + + c.ggml_graph_reset(@constCast(&gf)); + _ = c.ggml_set_f32(y.*.grad, 1.0); + + c.ggml_graph_compute(ctx0, @constCast(&gb)); + + std.debug.print("y = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)}); + std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n", + .{ + c.ggml_get_f32_1d(x1.*.grad, 0), + c.ggml_get_f32_1d(x1.*.grad, 1), + c.ggml_get_f32_1d(x1.*.grad, 2), + }); + std.debug.print("df/dx2 = {d:.6} {d:.6} {d:.6}\n", + .{ + c.ggml_get_f32_1d(x2.*.grad, 0), + c.ggml_get_f32_1d(x2.*.grad, 1), + c.ggml_get_f32_1d(x2.*.grad, 2), + }); + + try std.testing.expect(c.ggml_get_f32_1d(y, 0) == 2.0); + try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == -1.0); + try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 1) == -1.0); + try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 2) == -1.0); + try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 1.0); + try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 1) == 1.0); + try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 2) == 1.0); + + _ = c.ggml_set_f32(x1, 7.0); + _ = c.ggml_set_f32(x2, 5.0); + + c.ggml_graph_reset(@constCast(&gf)); + _ = c.ggml_set_f32(y.*.grad, 1.0); + + c.ggml_graph_compute(ctx0, @constCast(&gb)); + + std.debug.print("y = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)}); + std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n", + .{ + c.ggml_get_f32_1d(x1.*.grad, 0), + c.ggml_get_f32_1d(x1.*.grad, 1), + c.ggml_get_f32_1d(x1.*.grad, 2), + }); + std.debug.print("df/dx2 = {d:.6} {d:.6} {d:.6}\n", + .{ + c.ggml_get_f32_1d(x2.*.grad, 0), + c.ggml_get_f32_1d(x2.*.grad, 1), + c.ggml_get_f32_1d(x2.*.grad, 2), + }); + + try std.testing.expect(c.ggml_get_f32_1d(y, 0) == 2.0); + try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 1.0); + try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 1) == 1.0); + try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 2) == 1.0); + try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == -1.0); + try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 1) == -1.0); + try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 2) == -1.0); + + c.ggml_graph_dump_dot(&gf, null, "test1-8-forward.dot"); + c.ggml_graph_dump_dot(&gb, &gf, "test1-8-backward.dot"); + } + + _ = try std.io.getStdIn().reader().readByte(); +} \ No newline at end of file From cad56f5e6b04e0d5939e2198f181af03dd02f6fc Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 2 Jul 2023 17:33:57 +0300 Subject: [PATCH 05/15] ggml : sync latest llama.cpp (ggml_task_type changes + GPU backends) --- include/ggml/ggml.h | 3 + src/ggml-cuda.cu | 128 ++++++++--- src/ggml-cuda.h | 5 +- src/ggml-metal.m | 4 +- src/ggml-opencl.cpp | 527 +++++++++++++++++++++++++++++--------------- src/ggml.c | 65 +++++- 6 files changed, 508 insertions(+), 224 deletions(-) diff --git a/include/ggml/ggml.h b/include/ggml/ggml.h index 459913222..11b51f8bd 100644 --- a/include/ggml/ggml.h +++ b/include/ggml/ggml.h @@ -444,6 +444,9 @@ extern "C" { // compute types + + // NOTE: the INIT or FINALIZE pass is not scheduled unless explicitly enabled. + // This behavior was changed since https://github.com/ggerganov/llama.cpp/pull/1995. enum ggml_task_type { GGML_TASK_INIT = 0, GGML_TASK_COMPUTE, diff --git a/src/ggml-cuda.cu b/src/ggml-cuda.cu index c34e96abf..50df20edd 100644 --- a/src/ggml-cuda.cu +++ b/src/ggml-cuda.cu @@ -214,6 +214,11 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2"); #endif +struct ggml_tensor_extra_gpu { + void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors + cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs +}; + static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -223,6 +228,15 @@ static __global__ void add_f32(const float * x, const float * y, float * dst, co dst[i] = x[i] + y[i]; } +static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + dst[i] = __hadd(x[i], __float2half(y[i])); +} + static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -1235,7 +1249,7 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y, } static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x) { - const half * x = (half *) vx; + const half * x = (const half *) vx; const int row_x = blockDim.y*blockIdx.y + threadIdx.y; const int channel = blockDim.z*blockIdx.z + threadIdx.z; @@ -1283,9 +1297,9 @@ static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, fl static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, - const int row_stride_x, const int nchannels_x, const int channel_stride_x) { + const int row_stride_x, const int channel_stride_x) { - const half * x = (half *) vx; + const half * x = (const half *) vx; const int row_x = blockDim.y*blockIdx.y + threadIdx.y; const int channel = blockDim.z*blockIdx.z + threadIdx.z; @@ -1328,14 +1342,14 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous } static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) { - const float * xi = (float *) cxi; + const float * xi = (const float *) cxi; float * dsti = (float *) cdsti; *dsti = *xi; } static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) { - const float * xi = (float *) cxi; + const float * xi = (const float *) cxi; half * dsti = (half *) cdsti; *dsti = __float2half(*xi); @@ -1459,6 +1473,11 @@ static void add_f32_cuda(const float * x, const float * y, float * dst, const in add_f32<<>>(x, y, dst, k); } +static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE; + add_f16_f32_f16<<>>(x, y, dst, k); +} + static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) { const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE; mul_f32<<>>(x, y, dst, kx, ky); @@ -1684,7 +1703,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_cuda( const dim3 block_nums(1, nrows_x, nchannels_x); const dim3 block_dims(WARP_SIZE, 1, 1); mul_mat_vec_nc_f16_f32<<>> - (vx, y, dst, ncols_x, nrows_x, row_stride_x, nchannels_x, channel_stride_x); + (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x); } static void ggml_cpy_f32_f32_cuda( @@ -1941,7 +1960,7 @@ inline void ggml_cuda_op_add( float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, cudaStream_t & cudaStream_main){ - GGML_ASSERT(src0_ddf_i != nullptr); + GGML_ASSERT(src0_ddq_i != nullptr || src0_ddf_i != nullptr); GGML_ASSERT(src1_ddf_i != nullptr); GGML_ASSERT(dst_ddf_i != nullptr); @@ -1949,8 +1968,13 @@ inline void ggml_cuda_op_add( const int64_t i01_diff = i01_high - i01_low; // compute - add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne0*i01_diff, cudaStream_main); - CUDA_CHECK(cudaGetLastError()); + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne0*i01_diff, cudaStream_main); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne0*i01_diff, cudaStream_main); + } else { + GGML_ASSERT(false); + } (void) src1; (void) dst; @@ -1982,7 +2006,6 @@ inline void ggml_cuda_op_mul( // compute mul_f32_cuda(src0_ddf_i01, src1_ddf_i01, dst_ddf_i01, ne00, ne10, cudaStream_main); - CUDA_CHECK(cudaGetLastError()); } (void) dst; @@ -2003,7 +2026,6 @@ inline void ggml_cuda_op_silu( // compute silu_f32_cuda(src0_ddf_i, dst_ddf_i, ne00*i01_diff, cudaStream_main); - CUDA_CHECK(cudaGetLastError()); (void) src1; (void) dst; @@ -2026,7 +2048,6 @@ inline void ggml_cuda_op_rms_norm( // compute rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main); - CUDA_CHECK(cudaGetLastError()); (void) src1; (void) dst; @@ -2105,7 +2126,6 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec( GGML_ASSERT(false); break; } - CUDA_CHECK(cudaGetLastError()); #ifdef GGML_CUDA_DMMV_F16 if (src1_convert_f16) { @@ -2182,7 +2202,6 @@ inline void ggml_cuda_op_rope( // compute rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main); - CUDA_CHECK(cudaGetLastError()); (void) dst; (void) src0_ddq_i; @@ -2206,7 +2225,6 @@ inline void ggml_cuda_op_diag_mask_inf( // compute diag_mask_inf_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_past, cudaStream_main); - CUDA_CHECK(cudaGetLastError()); (void) dst; (void) src0_ddq_i; @@ -2228,7 +2246,6 @@ inline void ggml_cuda_op_soft_max( // compute soft_max_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main); - CUDA_CHECK(cudaGetLastError()); (void) src1; (void) dst; @@ -2324,10 +2341,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0}; size_t dst_asf[GGML_CUDA_MAX_DEVICES] = {0}; - // if multiple GPUs are used they need to wait for the main GPU to finish + // if multiple devices are used they need to wait for the main device + // here an event is recorded that signifies that the main device has finished calculating the input data if (split && g_device_count > 1) { CUDA_CHECK(cudaSetDevice(g_main_device)); - CUDA_CHECK(cudaDeviceSynchronize()); + CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device], g_cudaStreams_main[g_main_device])); } for (int id = 0; id < g_device_count; ++id) { @@ -2353,6 +2371,12 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm int64_t row_diff = row_high - row_low; cudaSetDevice(id); + cudaStream_t cudaStream_main = g_cudaStreams_main[id]; + + // wait for main GPU data if necessary + if (split && id != g_main_device) { + CUDA_CHECK(cudaStreamWaitEvent(cudaStream_main, src0_extra->events[g_main_device])); + } if (src0_on_device && src0_is_contiguous) { if (src0_is_f32) { @@ -2428,8 +2452,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm } const int64_t i11 = i13*ne12 + i12; - cudaStream_t cudaStream_main = g_cudaStreams_main[id]; - // for split tensors the data begins at i0 == i0_offset_low char * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs; float * src0_ddf_i = src0_ddf[id] + (i0 - i0_offset_low)*src0_stride; @@ -2489,6 +2511,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm // do the computation op(src0, src1, dst, src0_ddq_i, src0_ddf_i, src1_ddf_i, dst_ddf_i, i02, i01_low, i01_high, i11, cudaStream_main); + CUDA_CHECK(cudaGetLastError()); // copy dst to host or other device if necessary if (!dst_on_device) { @@ -2518,6 +2541,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i, dst_stride*sizeof(float), kind, cudaStream_main)); } } + + // signify to main device that other device is done + if (split && g_device_count > 1 && id != g_main_device) { + CUDA_CHECK(cudaEventRecord(src0_extra->events[id], cudaStream_main)); + } } } } @@ -2529,7 +2557,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm } CUDA_CHECK(cudaSetDevice(id)); - CUDA_CHECK(cudaDeviceSynchronize()); if (src0_asq[id] > 0) { ggml_cuda_pool_free(src0_ddq[id], src0_asq[id]); @@ -2544,11 +2571,32 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm ggml_cuda_pool_free(dst_ddf[id], dst_asf[id]); } } + + // main device waits for all other devices to be finished + if (split && g_device_count > 1) { + CUDA_CHECK(cudaSetDevice(g_main_device)); + for (int id = 0; id < g_device_count; ++id) { + if (id != g_main_device) { + CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams_main[g_main_device], src0_extra->events[id])); + } + } + } + + if (dst->backend == GGML_BACKEND_CPU) { + CUDA_CHECK(cudaSetDevice(g_main_device)); + CUDA_CHECK(cudaDeviceSynchronize()); + } } void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, true, true); + // ggml_cuda_add permits f16 dst even though this could in theory cause problems with the pointer arithmetic in ggml_cuda_op. + // Due to flatten_rows == true this does in practice not make a difference however. + // Better solution would be nice but right now that would require disproportionate changes. + GGML_ASSERT( + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) && + src1->type == GGML_TYPE_F32 && + (dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16)); + ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, false, true); } void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -2777,6 +2825,10 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) { cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice); extra->data_device[id] = buf; + + if (backend == GGML_BACKEND_GPU_SPLIT) { + CUDA_CHECK(cudaEventCreateWithFlags(&extra->events[id], cudaEventDisableTiming)); + } } tensor->extra = extra; @@ -2790,18 +2842,21 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) { ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra; for (int id = 0; id < g_device_count; ++id) { - if (extra->data_device[id] == nullptr) { - continue; + if (extra->data_device[id] != nullptr) { + CUDA_CHECK(cudaSetDevice(id)); + CUDA_CHECK(cudaFree(extra->data_device[id])); } - CUDA_CHECK(cudaSetDevice(id)); - CUDA_CHECK(cudaFree(extra->data_device[id])); + if (extra->events[id] != nullptr) { + CUDA_CHECK(cudaSetDevice(id)); + CUDA_CHECK(cudaEventDestroy(extra->events[id])); + } } delete extra; } -void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) { +void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace) { if (scratch && g_scratch_size == 0) { return; } @@ -2810,11 +2865,11 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) { if (tensor->src0 != nullptr && tensor->src0->backend == GGML_BACKEND_CPU) { const ggml_op src0_op = tensor->src0->op; if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW) { - ggml_cuda_assign_buffers_impl(tensor->src0, scratch); + ggml_cuda_assign_buffers_impl(tensor->src0, scratch, force_inplace); } } if (tensor->op == GGML_OP_CPY && tensor->src1->backend == GGML_BACKEND_CPU) { - ggml_cuda_assign_buffers_impl(tensor->src1, scratch); + ggml_cuda_assign_buffers_impl(tensor->src1, scratch, force_inplace); } tensor->backend = GGML_BACKEND_GPU; @@ -2822,11 +2877,12 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) { memset(extra, 0, sizeof(*extra)); const bool inplace = (tensor->src0 != nullptr && tensor->src0->data == tensor->data) || - tensor->op == GGML_OP_VIEW; + tensor->op == GGML_OP_VIEW || + force_inplace; const size_t size = ggml_nbytes(tensor); CUDA_CHECK(cudaSetDevice(g_main_device)); - if (inplace && tensor->src0->backend == GGML_BACKEND_GPU) { + if (inplace && (tensor->src0->backend == GGML_BACKEND_GPU || tensor->src0->backend == GGML_BACKEND_GPU_SPLIT)) { struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src0->extra; char * src0_ddc = (char *) src0_extra->data_device[g_main_device]; size_t offset = 0; @@ -2865,11 +2921,15 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) { } void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) { - ggml_cuda_assign_buffers_impl(tensor, true); + ggml_cuda_assign_buffers_impl(tensor, true, false); } void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) { - ggml_cuda_assign_buffers_impl(tensor, false); + ggml_cuda_assign_buffers_impl(tensor, false, false); +} + +void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor) { + ggml_cuda_assign_buffers_impl(tensor, false, true); } void ggml_cuda_set_main_device(int main_device) { diff --git a/src/ggml-cuda.h b/src/ggml-cuda.h index d32b44842..3c1e8deb6 100644 --- a/src/ggml-cuda.h +++ b/src/ggml-cuda.h @@ -8,10 +8,6 @@ extern "C" { #define GGML_CUDA_MAX_DEVICES 16 -struct ggml_tensor_extra_gpu { - void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors -}; - void ggml_init_cublas(void); void ggml_cuda_set_tensor_split(const float * tensor_split); @@ -29,6 +25,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor); void ggml_cuda_free_data(struct ggml_tensor * tensor); void ggml_cuda_assign_buffers(struct ggml_tensor * tensor); void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor); +void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor); void ggml_cuda_set_main_device(int main_device); void ggml_cuda_set_scratch_size(size_t scratch_size); void ggml_cuda_free_scratch(void); diff --git a/src/ggml-metal.m b/src/ggml-metal.m index 7551231b9..fd69c41fe 100644 --- a/src/ggml-metal.m +++ b/src/ggml-metal.m @@ -202,7 +202,9 @@ @implementation GGMLMetalClass void ggml_metal_free(struct ggml_metal_context * ctx) { fprintf(stderr, "%s: deallocating\n", __func__); - + for (int i = 0; i < ctx->n_buffers; ++i) { + [ctx->buffers[i].metal release]; + } free(ctx); } diff --git a/src/ggml-opencl.cpp b/src/ggml-opencl.cpp index 95f4cec6d..fed4ffb0c 100644 --- a/src/ggml-opencl.cpp +++ b/src/ggml-opencl.cpp @@ -21,11 +21,19 @@ #define CL_DMMV_BLOCK_SIZE 32 +#ifndef K_QUANTS_PER_ITERATION +#define K_QUANTS_PER_ITERATION 1 +#else +static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2"); +#endif + #define MULTILINE_QUOTE(...) #__VA_ARGS__ static std::string program_source = MULTILINE_QUOTE( typedef char int8_t; typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; typedef int int32_t; typedef uint uint32_t; @@ -175,7 +183,9 @@ void convert_f16(__global half* x, const int ib, const int iqs, float* v0, float *v0 = vload_half(0, &x[ib + 0]); *v1 = vload_half(0, &x[ib + 1]); } +); +static std::string k_quants_source = MULTILINE_QUOTE( inline void get_scale_min_k4(int j, const __global uint8_t *q, uint8_t *d, uint8_t *m) { if (j < 4) @@ -199,7 +209,7 @@ __kernel void dequantize_block_q2_K(__global const struct block_q2_K *x, __globa const int is = 8 * n + l / 16; const uint8_t q = x[i].qs[32 * n + l]; - __global float *y = yy + i * 256 + 128 * n; + __global float *y = yy + i * QK_K + 128 * n; const float dall = vload_half(0, &x[i].d); const float dmin = vload_half(0, &x[i].dmin); @@ -231,7 +241,7 @@ __kernel void dequantize_block_q3_K(__global const struct block_q3_K *x, __globa float d_all = vload_half(0, &x[i].d); float dl = d_all * (us - 32); - __global float *y = yy + i * 256 + 128 * n + 32 * j; + __global float *y = yy + i * QK_K + 128 * n + 32 * j; const __global uint8_t *q = x[i].qs + 32 * n; const __global uint8_t *hm = x[i].hmask; @@ -248,7 +258,7 @@ __kernel void dequantize_block_q4_K(__global const struct block_q4_K *x, __globa const int is = 2 * il; const int n = 4; - __global float *y = yy + i * 256 + 64 * il + n * ir; + __global float *y = yy + i * QK_K + 64 * il + n * ir; const float dall = vload_half(0, &x[i].d); const float dmin = vload_half(0, &x[i].dmin); @@ -277,7 +287,7 @@ __kernel void dequantize_block_q5_K(__global const struct block_q5_K *x, __globa const int ir = tid % 16; const int is = 2 * il; - __global float *y = yy + i * 256 + 64 * il + 2 * ir; + __global float *y = yy + i * QK_K + 64 * il + 2 * ir; const float dall = vload_half(0, &x[i].d); const float dmin = vload_half(0, &x[i].dmin); @@ -309,7 +319,7 @@ __kernel void dequantize_block_q6_K(__global const struct block_q6_K *x, __globa const int il = tid - 32 * ip; const int is = 8 * ip + il / 16; - __global float *y = yy + i * 256 + 128 * ip + il; + __global float *y = yy + i * QK_K + 128 * ip + il; const float d = vload_half(0, &x[i].d); @@ -323,161 +333,383 @@ __kernel void dequantize_block_q6_K(__global const struct block_q6_K *x, __globa y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32); } +__kernel void dequantize_mul_mat_vec_q2_K(__global const struct block_q2_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) { -void vec_dot_q2_K(__global const struct block_q2_K* x, const int ib, const int iqs, const __global float *yy, float *result) { + const int row = get_group_id(0); - int n = iqs / 128; - int r = iqs - 128 * n; - int l = r / 8; + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; - __global const float *y = yy + 128 * n + l; - __global const uint8_t *q = x[ib].qs + 32 * n + l; - __global const uint8_t *s = x[ib].scales + 8 * n; + __global const struct block_q2_K * x = xx + ib0; - const float dall = vload_half(0, &x[ib].d); - const float dmin = vload_half(0, &x[ib].dmin); + const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION; // 0...31 or 0...15 + const int ix = get_local_id(0)%K_QUANTS_PER_ITERATION; // 0 or 0,1 - float sum = y[ 0] * (dall * ((s[0] & 0xF) * ((q[ 0] >> 0) & 3)) - dmin * (s[0] >> 4)) - + y[ 32] * (dall * ((s[2] & 0xF) * ((q[ 0] >> 2) & 3)) - dmin * (s[2] >> 4)) - + y[ 64] * (dall * ((s[4] & 0xF) * ((q[ 0] >> 4) & 3)) - dmin * (s[4] >> 4)) - + y[ 96] * (dall * ((s[6] & 0xF) * ((q[ 0] >> 6) & 3)) - dmin * (s[6] >> 4)) - + y[ 16] * (dall * ((s[1] & 0xF) * ((q[16] >> 0) & 3)) - dmin * (s[1] >> 4)) - + y[ 48] * (dall * ((s[3] & 0xF) * ((q[16] >> 2) & 3)) - dmin * (s[3] >> 4)) - + y[ 80] * (dall * ((s[5] & 0xF) * ((q[16] >> 4) & 3)) - dmin * (s[5] >> 4)) - + y[112] * (dall * ((s[7] & 0xF) * ((q[16] >> 6) & 3)) - dmin * (s[7] >> 4)); + const int step = 16/K_QUANTS_PER_ITERATION; - *result = sum; -} + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0...15 or 0...7 -void vec_dot_q3_K(__global const struct block_q3_K* x, const int ib, const int iqs, const __global float *yy, float *result) { + const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2 + const int q_offset = 32*im + l0; + const int s_offset = 8*im; + const int y_offset = 128*im + l0; - const uint32_t kmask1 = 0x03030303; - const uint32_t kmask2 = 0x0f0f0f0f; + tmp[16 * ix + tid] = 0; - uint32_t aux[3]; - uint32_t utmp[4]; + uint32_t aux[4]; + const uint8_t * d = (const uint8_t *)aux; + const uint8_t * m = (const uint8_t *)(aux + 2); - int n = iqs/128; - int r = iqs - 128*n; - int l = r/8; + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { - __global const float * y = yy + 128*n + l; - __global const uint8_t * q = x[ib].qs + 32*n + l; - __global const uint8_t * hm = x[ib].hmask + l; - const int8_t * s = (const int8_t *)utmp + 8*n; + __global const float * y = yy + i * QK_K + y_offset; + __global const uint8_t * q = x[i].qs + q_offset; - aux[0] = x[ib].scales[0] | x[ib].scales[1] << 8 | x[ib].scales[2] << 16 | x[ib].scales[3] << 24; - aux[1] = x[ib].scales[4] | x[ib].scales[5] << 8 | x[ib].scales[6] << 16 | x[ib].scales[7] << 24; - aux[2] = x[ib].scales[8] | x[ib].scales[9] << 8 | x[ib].scales[10] << 16 | x[ib].scales[11] << 24; + const float dall = vload_half(0, &x[i].d); + const float dmin = vload_half(0, &x[i].dmin); - utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); - utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); - utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); - utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + __global const uint32_t * a = (__global const uint32_t *)(x[i].scales + s_offset); + aux[0] = a[0] & 0x0f0f0f0f; + aux[1] = a[1] & 0x0f0f0f0f; + aux[2] = (a[0] >> 4) & 0x0f0f0f0f; + aux[3] = (a[1] >> 4) & 0x0f0f0f0f; - const float dall = vload_half(0, &x[ib].d); - const uint8_t m = 1 << (4*n); + float sum1 = 0, sum2 = 0; + for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3) + + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3) + + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3) + + y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3) + + y[l+16] * d[1] * ((q[l+16] >> 0) & 3) + + y[l+48] * d[3] * ((q[l+16] >> 2) & 3) + + y[l+80] * d[5] * ((q[l+16] >> 4) & 3) + +y[l+112] * d[7] * ((q[l+16] >> 6) & 3); + sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6] + + y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7]; - float sum = y[ 0] * (s[0] - 32) * (((q[ 0] >> 0) & 3) - (hm[ 0] & (m << 0) ? 0 : 4)) - + y[ 32] * (s[2] - 32) * (((q[ 0] >> 2) & 3) - (hm[ 0] & (m << 1) ? 0 : 4)) - + y[ 64] * (s[4] - 32) * (((q[ 0] >> 4) & 3) - (hm[ 0] & (m << 2) ? 0 : 4)) - + y[ 96] * (s[6] - 32) * (((q[ 0] >> 6) & 3) - (hm[ 0] & (m << 3) ? 0 : 4)) - + y[ 16] * (s[1] - 32) * (((q[16] >> 0) & 3) - (hm[16] & (m << 0) ? 0 : 4)) - + y[ 48] * (s[3] - 32) * (((q[16] >> 2) & 3) - (hm[16] & (m << 1) ? 0 : 4)) - + y[ 80] * (s[5] - 32) * (((q[16] >> 4) & 3) - (hm[16] & (m << 2) ? 0 : 4)) - + y[112] * (s[7] - 32) * (((q[16] >> 6) & 3) - (hm[16] & (m << 3) ? 0 : 4)); + } + tmp[16 * ix + tid] += dall * sum1 - dmin * sum2; - *result = sum * dall; + } + // sum up partial sums and write back result + barrier(CLK_LOCAL_MEM_FENCE); + for (int s=16; s>0; s>>=1) { + if (tid < s) { + tmp[tid] += tmp[tid + s]; + } + barrier(CLK_LOCAL_MEM_FENCE); + } + if (tid == 0) { + dst[row] = tmp[0]; + } } -void vec_dot_q4_K(__global const struct block_q4_K* x, const int ib, const int iqs, const __global float *yy, float *result) { +__kernel void dequantize_mul_mat_vec_q3_K(__global const struct block_q3_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) { + const uint16_t kmask1 = 0x0303; + const uint16_t kmask2 = 0x0f0f; + + const int row = get_group_id(0); + + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; - const int j = iqs / 64; // j is in 0...3 - const int ir = (iqs - 64*j)/2; // ir is in 0...28 in steps of 4 - const int is = 2*j; // is is in 0...6 in steps of 2 + __global const struct block_q3_K * x = xx + ib0; - __global const float * y = yy + 64*j + ir; - __global const uint8_t * q = x[ib].qs + 32*j + ir; + const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const int ix = get_local_id(0)%K_QUANTS_PER_ITERATION; // 0 or 0,1 - const float dall = vload_half(0, &x[ib].d); - const float dmin = vload_half(0, &x[ib].dmin); + const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop + const int step = 16/K_QUANTS_PER_ITERATION; + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0....15 or 0...7 - uint8_t sc, m; - get_scale_min_k4(is + 0, x[ib].scales, &sc, &m); - const float d1 = dall * sc; - const float m1 = dmin * m; - get_scale_min_k4(is + 1, x[ib].scales, &sc, &m); - const float d2 = dall * sc; - const float m2 = dmin * m; + const uint8_t m = 1 << (4*im); + + const int l0 = n*in; // 0...15 or 0...14 in steps of 2 + const int q_offset = 32*im + l0; + const int y_offset = 128*im + l0; + + uint16_t utmp[4]; + const int8_t * s = (const int8_t *)utmp; + + const uint16_t s_shift = 4*im; + + tmp[16 * ix + tid] = 0; + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + + __global const float * y = yy + i * QK_K + y_offset; + __global const uint8_t * q = x[i].qs + q_offset; + __global const uint8_t * h = x[i].hmask + l0; + + __global const uint16_t * a = (__global const uint16_t *)x[i].scales; + utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4); + utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4); + utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4); + utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4); + + const float d = vload_half(0, &x[i].d); + + float sum = 0; + for (int l = 0; l < n; ++l) { + sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4)) + + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4)) + + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4)) + + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4)); + sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4)) + + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4)) + + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4)) + + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4)); + } + tmp[16 * ix + tid] += d * sum; - float sum = 0; - for (int k = 0; k < 4; ++k) { - sum += y[k + 0] * (d1 * (q[k] & 0xF) - m1); - sum += y[k + 32] * (d2 * (q[k] >> 4) - m2); } - *result = sum; + // sum up partial sums and write back result + barrier(CLK_LOCAL_MEM_FENCE); + for (int s=16; s>0; s>>=1) { + if (tid < s) { + tmp[tid] += tmp[tid + s]; + } + barrier(CLK_LOCAL_MEM_FENCE); + } + if (tid == 0) { + dst[row] = tmp[0]; + } } -void vec_dot_q5_K(__global const struct block_q5_K* x, const int ib, const int iqs, const __global float *yy, float *result) { +__kernel void dequantize_mul_mat_vec_q4_K(__global const struct block_q4_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) { - const int j = iqs / 64; - const int ir = (iqs - 64*j)/2; - const int is = 2*j; + //to rename it later, just to test now + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; - __global const float * y = yy + 64*j + ir; - __global const uint8_t * ql = x[ib].qs + 32*j + ir; - __global const uint8_t * qh = x[ib].qh + ir; + const int row = get_group_id(0); + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; - const float dall = vload_half(0, &x[ib].d); - const float dmin = vload_half(0, &x[ib].dmin); + const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION; // 0...15 + const int ix = get_local_id(0)%K_QUANTS_PER_ITERATION; - uint8_t sc, m; - get_scale_min_k4(is + 0, x[ib].scales, &sc, &m); - const float d1 = dall * sc; - const float m1 = dmin * m; - get_scale_min_k4(is + 1, x[ib].scales, &sc, &m); - const float d2 = dall * sc; - const float m2 = dmin * m; + const int step = 8/K_QUANTS_PER_ITERATION; + + const int il = tid/step; // 0...3 + const int ir = tid - step*il;// 0...3 + const int n = 2*K_QUANTS_PER_ITERATION; + + const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 + const int in = il%2; + + const int l0 = n*(2*ir + in); + const int q_offset = 32*im + l0; + const int y_offset = 64*im + l0; + + uint16_t aux[4]; + const uint8_t * sc = (const uint8_t *)aux; + + __global const struct block_q4_K * x = xx + ib0; + + tmp[16 * ix + tid] = 0; + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + + __global const uint8_t * q1 = x[i].qs + q_offset; + __global const uint8_t * q2 = q1 + 64; + __global const float * y1 = yy + i*QK_K + y_offset; + __global const float * y2 = y1 + 128; + + const float dall = vload_half(0, &x[i].d); + const float dmin = vload_half(0, &x[i].dmin); + + __global const uint16_t * a = (__global const uint16_t *)x[i].scales; + aux[0] = a[im+0] & kmask1; + aux[1] = a[im+2] & kmask1; + aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); + aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); + + float4 s = (float4)(0.f); + float smin = 0; + for (int l = 0; l < n; ++l) { + s.x += y1[l] * (q1[l] & 0xF); s.y += y1[l+32] * (q1[l] >> 4); + s.z += y2[l] * (q2[l] & 0xF); s.w += y2[l+32] * (q2[l] >> 4); + smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; + } + tmp[16 * ix + tid] += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin; - uint8_t hm = 1 << is; - float sum = 0; - for (int k = 0; k < 4; ++k) { - sum += y[k + 0] * (d1 * ((ql[k] & 0xF) + (qh[k] & hm ? 16 : 0)) - m1); } - hm <<= 1; - for (int k = 0; k < 4; ++k) { - sum += y[k + 32] * (d2 * ((ql[k] >> 4) + (qh[k] & hm ? 16 : 0)) - m2); + + // sum up partial sums and write back result + barrier(CLK_LOCAL_MEM_FENCE); + for (int s=16; s>0; s>>=1) { + if (tid < s) { + tmp[tid] += tmp[tid + s]; + } + barrier(CLK_LOCAL_MEM_FENCE); + } + if (tid == 0) { + dst[row] = tmp[0]; + } +} + +__kernel void dequantize_mul_mat_vec_q5_K(__global const struct block_q5_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) { + + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + + const int row = get_group_id(0); + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + const int tid = get_local_id(0)/2; // 0...15 + const int ix = get_local_id(0)%2; + + const int il = tid/4; // 0...3 + const int ir = tid - 4*il;// 0...3 + const int n = 2; + + const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 + const int in = il%2; + + const int l0 = n*(2*ir + in); + const int q_offset = 32*im + l0; + const int y_offset = 64*im + l0; + + const uint8_t hm1 = 1 << (2*im); + const uint8_t hm2 = hm1 << 4; + + uint16_t aux[4]; + const uint8_t * sc = (const uint8_t *)aux; + + __global const struct block_q5_K * x = xx + ib0; + + tmp[16 * ix + tid] = 0; + + for (int i = ix; i < num_blocks_per_row; i += 2) { + + __global const uint8_t * ql1 = x[i].qs + q_offset; + __global const uint8_t * ql2 = ql1 + 64; + __global const uint8_t * qh = x[i].qh + l0; + __global const float * y1 = yy + i*QK_K + y_offset; + __global const float * y2 = y1 + 128; + + const float dall = vload_half(0, &x[i].d); + const float dmin = vload_half(0, &x[i].dmin); + + __global const uint16_t * a = (__global const uint16_t *)x[i].scales; + aux[0] = a[im+0] & kmask1; + aux[1] = a[im+2] & kmask1; + aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); + aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); + + float4 sum = (float4)(0.f); + float smin = 0; + for (int l = 0; l < n; ++l) { + sum.x += y1[l+ 0] * ((ql1[l+ 0] & 0xF) + (qh[l+ 0] & (hm1 << 0) ? 16 : 0)) + + y1[l+16] * ((ql1[l+16] & 0xF) + (qh[l+16] & (hm1 << 0) ? 16 : 0)); + sum.y += y1[l+32] * ((ql1[l+ 0] >> 4) + (qh[l+ 0] & (hm1 << 1) ? 16 : 0)) + + y1[l+48] * ((ql1[l+16] >> 4) + (qh[l+16] & (hm1 << 1) ? 16 : 0)); + sum.z += y2[l+ 0] * ((ql2[l+ 0] & 0xF) + (qh[l+ 0] & (hm2 << 0) ? 16 : 0)) + + y2[l+16] * ((ql2[l+16] & 0xF) + (qh[l+16] & (hm2 << 0) ? 16 : 0)); + sum.w += y2[l+32] * ((ql2[l+ 0] >> 4) + (qh[l+ 0] & (hm2 << 1) ? 16 : 0)) + + y2[l+48] * ((ql2[l+16] >> 4) + (qh[l+16] & (hm2 << 1) ? 16 : 0)); + smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3] + + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7]; + } + tmp[16 * ix + tid] += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin; + } - *result = sum; + // sum up partial sums and write back result + barrier(CLK_LOCAL_MEM_FENCE); + for (int s=16; s>0; s>>=1) { + if (tid < s) { + tmp[tid] += tmp[tid + s]; + } + barrier(CLK_LOCAL_MEM_FENCE); + } + if (tid == 0) { + dst[row] = tmp[0]; + } } -void vec_dot_q6_K(__global const struct block_q6_K* x, const int ib, const int iqs, const __global float *yy, float *result) { +__kernel void dequantize_mul_mat_vec_q6_K(__global const struct block_q6_K * xx, __local float* tmp, __global const float * yy, __global float * dst, const int ncols) { + + const int row = get_group_id(0); + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; - const int ip = iqs / 128; // 0 or 1 - const int il = (iqs - 128*ip)/8; // 0...15 - const int is = 8*ip; + __global const struct block_q6_K * x = xx + ib0; - __global const float * y = yy + 128*ip + il; + const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const int ix = get_local_id(0)%K_QUANTS_PER_ITERATION; // 0 or 0, 1 - const float d = vload_half(0, &x[ib].d); + const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 - __global const uint8_t * ql = x[ib].ql + 64*ip + il; - __global const uint8_t * qh = x[ib].qh + 32*ip + il; - __global const int8_t * sc = x[ib].scales + is; + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0...15 or 0...7 + +#if K_QUANTS_PER_ITERATION == 1 + const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 + const int is = 0; +#else + const int l0 = 4 * in; // 0, 4, 8, ..., 28 + const int is = in / 4; +#endif + const int ql_offset = 64*im + l0; + const int qh_offset = 32*im + l0; + const int s_offset = 8*im + is; + const int y_offset = 128*im + l0; + + tmp[16 * ix + tid] = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + + __global const float * y = yy + i * QK_K + y_offset; + __global const uint8_t * ql = x[i].ql + ql_offset; + __global const uint8_t * qh = x[i].qh + qh_offset; + __global const int8_t * s = x[i].scales + s_offset; + + const float d = vload_half(0, &x[i].d); + +#if K_QUANTS_PER_ITERATION == 1 + float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32) + + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32) + + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32) + + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32) + + y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32) + + y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32) + + y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32) + +y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32); + tmp[16 * ix + tid] += sum; +#else + float sum = 0; + for (int l = 0; l < 4; ++l) { + sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32) + + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32) + + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32) + + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32); + } + tmp[16 * ix + tid] += sum; +#endif - *result = y[ 0] * d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh[ 0] >> 0) & 3) << 4)) - 32) - + y[ 32] * d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh[ 0] >> 2) & 3) << 4)) - 32) - + y[ 64] * d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh[ 0] >> 4) & 3) << 4)) - 32) - + y[ 96] * d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh[ 0] >> 6) & 3) << 4)) - 32) - + y[ 16] * d * sc[1] * ((int8_t)((ql[16] & 0xF) | (((qh[16] >> 0) & 3) << 4)) - 32) - + y[ 48] * d * sc[3] * ((int8_t)((ql[48] & 0xF) | (((qh[16] >> 2) & 3) << 4)) - 32) - + y[ 80] * d * sc[5] * ((int8_t)((ql[16] >> 4) | (((qh[16] >> 4) & 3) << 4)) - 32) - + y[112] * d * sc[7] * ((int8_t)((ql[48] >> 4) | (((qh[16] >> 6) & 3) << 4)) - 32); + } + // sum up partial sums and write back result + barrier(CLK_LOCAL_MEM_FENCE); + for (int s=16; s>0; s>>=1) { + if (tid < s) { + tmp[tid] += tmp[tid + s]; + } + barrier(CLK_LOCAL_MEM_FENCE); + } + if (tid == 0) { + dst[row] = tmp[0]; + } } ); @@ -549,44 +781,6 @@ __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float } ); -std::string dequant_mul_mat_vec_k_template = MULTILINE_QUOTE( -__kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) { - const int block_size = get_local_size(0); - const int row = get_group_id(0); - const int tid = get_local_id(0); - - const int iter_stride = 256; - const int vals_per_iter = iter_stride / block_size; - const int num_blocks_per_row = ncols / 256; - const int ib0 = row*num_blocks_per_row; - - tmp[tid] = 0; - - for (int i = 0; i < ncols; i += iter_stride) { - const int col = i + vals_per_iter*tid; - const int ib = ib0 + col/256; // x block index - const int iqs = col%256; // x quant index - const int iybs = col - col%256; // y block start index - - // dequantize - float v; - DOT_KERNEL(x, ib, iqs, y + iybs, &v); - tmp[tid] += v; - } - - // sum up partial sums and write back result - barrier(CLK_LOCAL_MEM_FENCE); - for (int s=block_size/2; s>0; s>>=1) { - if (tid < s) { - tmp[tid] += tmp[tid + s]; - } - barrier(CLK_LOCAL_MEM_FENCE); - } - if (tid == 0) { - dst[row] = tmp[0]; - } -} -); std::string mul_template = MULTILINE_QUOTE( __kernel void KERNEL_NAME(__global TYPE* x, const int x_offset, __global TYPE* y, const int y_offset, __global TYPE* dst, const int dst_offset, const int ky) { @@ -649,18 +843,6 @@ std::array mul_str_values = { "mul_f32", "float" }; -std::array dmmv_k_str_keys = { - "KERNEL_NAME", "X_TYPE", "DOT_KERNEL" -}; - -std::array dmmv_k_str_values = { - "dequantize_mul_mat_vec_q2_K", "struct block_q2_K", "vec_dot_q2_K", - "dequantize_mul_mat_vec_q3_K", "struct block_q3_K", "vec_dot_q3_K", - "dequantize_mul_mat_vec_q4_K", "struct block_q4_K", "vec_dot_q4_K", - "dequantize_mul_mat_vec_q5_K", "struct block_q5_K", "vec_dot_q5_K", - "dequantize_mul_mat_vec_q6_K", "struct block_q6_K", "vec_dot_q6_K", -}; - std::string& replace(std::string& s, const std::string& from, const std::string& to) { size_t pos = 0; while ((pos = s.find(from, pos)) != std::string::npos) { @@ -673,6 +855,7 @@ std::string& replace(std::string& s, const std::string& from, const std::string& std::string generate_kernels() { std::stringstream src; src << program_source << '\n'; + src << k_quants_source << '\n'; for (size_t i = 0; i < dequant_str_values.size(); i += dequant_str_keys.size()) { std::string dequant_kernel = dequant_template; std::string dmmv_kernel = dequant_mul_mat_vec_template; @@ -690,13 +873,6 @@ std::string generate_kernels() { } src << mul_kernel << '\n'; } - for (size_t i = 0; i < dmmv_k_str_values.size(); i += dmmv_k_str_keys.size()) { - std::string dmmv_k_kernel = dequant_mul_mat_vec_k_template; - for (size_t j = 0; j < dmmv_k_str_keys.size(); j++) { - replace(dmmv_k_kernel, dmmv_k_str_keys[j], dmmv_k_str_values[i + j]); - } - src << dmmv_k_kernel << '\n'; - } return src.str(); } @@ -729,10 +905,11 @@ static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, co exit(1); } - const char* compile_opts = "-cl-mad-enable -cl-unsafe-math-optimizations -cl-finite-math-only -cl-fast-relaxed-math " - "-DQK4_0=32 -DQR4_0=2 -DQK4_1=32 -DQR4_1=2 -DQK5_0=32 -DQR5_0=2 -DQK5_1=32 -DQR5_1=2 -DQK8_0=32 -DQR8_0=1"; + std::string compile_opts = "-cl-mad-enable -cl-unsafe-math-optimizations -cl-finite-math-only -cl-fast-relaxed-math " + "-DQK4_0=32 -DQR4_0=2 -DQK4_1=32 -DQR4_1=2 -DQK5_0=32 -DQR5_0=2 -DQK5_1=32 -DQR5_1=2 -DQK8_0=32 -DQR8_0=1 " + "-DQK_K=256 -DK_QUANTS_PER_ITERATION=" + std::to_string(K_QUANTS_PER_ITERATION); - err = clBuildProgram(p, 0, NULL, compile_opts, NULL, NULL); + err = clBuildProgram(p, 0, NULL, compile_opts.c_str(), NULL, NULL); if(err < 0) { clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, 0, NULL, &log_size); diff --git a/src/ggml.c b/src/ggml.c index 92faf03f7..afeb72ff0 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -3846,6 +3846,41 @@ static_assert(GGML_OP_COUNT == 64, "GGML_OP_COUNT != 64"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN"); +// WARN: +// Mis-confguration can lead to problem that's hard to reason about: +// * At best it crash or talks nosense. +// * At worst it talks slightly difference but hard to perceive. +// +// An op has to enable INIT or FINALIZE when any of it's branch needs that pass. +// Take care about compile options (e.g., GGML_USE_xxx). +static bool GGML_OP_HAS_INIT [GGML_OP_COUNT] = { 0 }; +static bool GGML_OP_HAS_FINALIZE[GGML_OP_COUNT] = { 0 }; + +static void ggml_setup_op_has_task_pass(void) { + { // INIT + bool * p = GGML_OP_HAS_INIT; + + p[GGML_OP_ACC ] = true; + p[GGML_OP_MUL_MAT ] = true; + p[GGML_OP_OUT_PROD ] = true; + p[GGML_OP_SET ] = true; + p[GGML_OP_GET_ROWS_BACK ] = true; + p[GGML_OP_DIAG_MASK_INF ] = true; + p[GGML_OP_DIAG_MASK_ZERO ] = true; + p[GGML_OP_CONV_1D_S1_PH ] = true; + p[GGML_OP_CONV_1D_S2_PH ] = true; + p[GGML_OP_CONV_2D_SK_P0 ] = true; + p[GGML_OP_FLASH_ATTN_BACK ] = true; + p[GGML_OP_CROSS_ENTROPY_LOSS ] = true; + } + + { // FINALIZE + bool * p = GGML_OP_HAS_FINALIZE; + + p[GGML_OP_CROSS_ENTROPY_LOSS ] = true; + } +} + // // ggml context // @@ -4267,6 +4302,8 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { ggml_cl_init(); #endif + ggml_setup_op_has_task_pass(); + is_first_call = false; } @@ -16684,7 +16721,8 @@ typedef pthread_t ggml_thread_t; #endif -#ifdef __linux__ +// Android's libc implementation "bionic" does not support setting affinity +#if defined(__linux__) && !defined(__BIONIC__) void set_numa_thread_affinity(int thread_n, int n_threads) { if (!ggml_is_numa()) { return; @@ -16790,9 +16828,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { if (node_n != -1) { /* FINALIZE */ struct ggml_tensor * node = state->shared->cgraph->nodes[node_n]; - params.nth = node->n_tasks; - ggml_compute_forward(¶ms, node); - ggml_graph_compute_perf_stats_node(node, state->shared); + if (GGML_OP_HAS_FINALIZE[node->op]) { + params.nth = node->n_tasks; + ggml_compute_forward(¶ms, node); + ggml_graph_compute_perf_stats_node(node, state->shared); + } } // distribute new work or execute it direct if 1T @@ -16804,10 +16844,13 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { state->shared->perf_node_start_cycles = ggml_perf_cycles(); state->shared->perf_node_start_time_us = ggml_perf_time_us(); + params.nth = node->n_tasks; + /* INIT */ - params.type = GGML_TASK_INIT; - params.nth = node->n_tasks; - ggml_compute_forward(¶ms, node); + if (GGML_OP_HAS_INIT[node->op]) { + params.type = GGML_TASK_INIT; + ggml_compute_forward(¶ms, node); + } if (node->n_tasks == 1) { // TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1, @@ -16815,9 +16858,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { params.type = GGML_TASK_COMPUTE; ggml_compute_forward(¶ms, node); - params.type = GGML_TASK_FINALIZE; - ggml_compute_forward(¶ms, node); - ggml_graph_compute_perf_stats_node(node, state->shared); + if (GGML_OP_HAS_FINALIZE[node->op]) { + params.type = GGML_TASK_FINALIZE; + ggml_compute_forward(¶ms, node); + ggml_graph_compute_perf_stats_node(node, state->shared); + } } else { break; } From dfef9c61a0bfd62ad8710af7ea59f1995b53287b Mon Sep 17 00:00:00 2001 From: the-crypt-keeper <84680712+the-crypt-keeper@users.noreply.github.com> Date: Sun, 2 Jul 2023 10:52:52 -0400 Subject: [PATCH 06/15] starcoder : add repeat penalty (#311) * implement repeat penalty processing for starcoder * show effective parameters at starcoder startup --------- Co-authored-by: Mike Ravkine --- examples/common.cpp | 6 ++++++ examples/common.h | 2 ++ examples/starcoder/main.cpp | 23 +++++++++++++++++++++-- 3 files changed, 29 insertions(+), 2 deletions(-) diff --git a/examples/common.cpp b/examples/common.cpp index fe00278c2..7b01089b0 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -39,6 +39,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.top_p = std::stof(argv[++i]); } else if (arg == "--temp") { params.temp = std::stof(argv[++i]); + } else if (arg == "--repeat-last-n") { + params.repeat_last_n = std::stof(argv[++i]); + } else if (arg == "--repeat-penalty") { + params.repeat_penalty = std::stof(argv[++i]); } else if (arg == "-b" || arg == "--batch_size") { params.n_batch = std::stoi(argv[++i]); } else if (arg == "-m" || arg == "--model") { @@ -90,6 +94,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k); fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p); fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp); + fprintf(stderr, " --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled)\n", params.repeat_last_n); + fprintf(stderr, " --repeat-penalty N penalize repeat sequence of tokens (default: %.2f, 1.0 = disabled)\n", (double)params.repeat_penalty); fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch); fprintf(stderr, " -m FNAME, --model FNAME\n"); fprintf(stderr, " model path (default: %s)\n", params.model.c_str()); diff --git a/examples/common.h b/examples/common.h index 7e9b867d3..12b2b339d 100644 --- a/examples/common.h +++ b/examples/common.h @@ -23,6 +23,8 @@ struct gpt_params { int32_t top_k = 40; float top_p = 0.9f; float temp = 0.9f; + int32_t repeat_last_n = 64; + float repeat_penalty = 1.00f; int32_t n_batch = 8; // batch size for prompt processing diff --git a/examples/starcoder/main.cpp b/examples/starcoder/main.cpp index 2016f8974..5c6065980 100644 --- a/examples/starcoder/main.cpp +++ b/examples/starcoder/main.cpp @@ -782,6 +782,16 @@ int main(int argc, char ** argv) { test_gpt_tokenizer(vocab, params.token_test); } + if (params.repeat_last_n == -1) { + params.repeat_last_n = model.hparams.n_ctx; + } + printf("\n"); + printf("%s: temp = %.3f\n", __func__, params.temp); + printf("%s: top_k = %d\n", __func__, params.top_k); + printf("%s: top_p = %.3f\n", __func__, params.top_p); + printf("%s: repeat_last_n = %d\n", __func__, params.repeat_last_n); + printf("%s: repeat_penalty = %.3f\n", __func__, params.repeat_penalty); + int n_past = 0; int64_t t_sample_us = 0; @@ -789,6 +799,9 @@ int main(int argc, char ** argv) { std::vector logits; + std::vector last_n_tokens(model.hparams.n_ctx); + std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); + // tokenize the prompt std::vector embd_inp = ::gpt_tokenize(vocab, params.prompt); @@ -847,17 +860,23 @@ int main(int argc, char ** argv) { { const int64_t t_start_sample_us = ggml_time_us(); - id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng); - + id = gpt_sample_top_k_top_p_repeat(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens.data(), last_n_tokens.size(), top_k, top_p, temp, params.repeat_last_n, params.repeat_penalty, rng); t_sample_us += ggml_time_us() - t_start_sample_us; } // add it to the context embd.push_back(id); + + last_n_tokens.erase(last_n_tokens.begin()); + last_n_tokens.push_back(id); } else { // if here, it means we are still processing the input prompt for (int k = i; k < embd_inp.size(); k++) { embd.push_back(embd_inp[k]); + + last_n_tokens.erase(last_n_tokens.begin()); + last_n_tokens.push_back(embd_inp[k]); + if (embd.size() >= params.n_batch) { break; } From 66ea4034165120cfc254cd2b7fd38a35ccef49b4 Mon Sep 17 00:00:00 2001 From: goerch Date: Sun, 2 Jul 2023 17:13:23 +0200 Subject: [PATCH 07/15] ggml : add GGML_TENSOR_LOCALS helper macros (#309) * [WIP] ref #292 * Further code reduction * ggml : minor style fixes * ggml : hide op locals in source file --------- Co-authored-by: Georgi Gerganov --- include/ggml/ggml.h | 26 ++ src/ggml.c | 1010 +++++-------------------------------------- 2 files changed, 132 insertions(+), 904 deletions(-) diff --git a/include/ggml/ggml.h b/include/ggml/ggml.h index 11b51f8bd..fb3fec29c 100644 --- a/include/ggml/ggml.h +++ b/include/ggml/ggml.h @@ -201,6 +201,8 @@ #define GGML_MAX_NAME 48 #define GGML_DEFAULT_N_THREADS 4 +#define GGML_UNUSED(x) (void)(x) + #define GGML_ASSERT(x) \ do { \ if (!(x)) { \ @@ -209,6 +211,30 @@ } \ } while (0) +// used to copy the number of elements and stride in bytes of tensors into local variables. +// main purpose is to reduce code duplication and improve readability. +// +// example: +// +// GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne); +// GGML_TENSOR_LOCALS(size_t, nb1, src1, nb); +// +#define GGML_TENSOR_LOCALS_1(type, prefix, pointer, array) \ + const type prefix##0 = (pointer)->array[0]; \ + GGML_UNUSED(prefix##0); +#define GGML_TENSOR_LOCALS_2(type, prefix, pointer, array) \ + GGML_TENSOR_LOCALS_1 (type, prefix, pointer, array) \ + const type prefix##1 = (pointer)->array[1]; \ + GGML_UNUSED(prefix##1); +#define GGML_TENSOR_LOCALS_3(type, prefix, pointer, array) \ + GGML_TENSOR_LOCALS_2 (type, prefix, pointer, array) \ + const type prefix##2 = (pointer)->array[2]; \ + GGML_UNUSED(prefix##2); +#define GGML_TENSOR_LOCALS(type, prefix, pointer, array) \ + GGML_TENSOR_LOCALS_3 (type, prefix, pointer, array) \ + const type prefix##3 = (pointer)->array[3]; \ + GGML_UNUSED(prefix##3); + #ifdef __cplusplus extern "C" { #endif diff --git a/src/ggml.c b/src/ggml.c index afeb72ff0..7eaff45cc 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -220,9 +220,27 @@ inline static void* ggml_aligned_malloc(size_t size) { #define GGML_ALIGNED_FREE(ptr) free(ptr) #endif -#define UNUSED(x) (void)(x) +#define UNUSED GGML_UNUSED #define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0) +// +// tensor access macros +// + +#define GGML_TENSOR_UNARY_OP_LOCALS \ + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); \ + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb); \ + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); \ + GGML_TENSOR_LOCALS(size_t, nb, dst, nb); + +#define GGML_TENSOR_BINARY_OP_LOCALS \ + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); \ + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb); \ + GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne); \ + GGML_TENSOR_LOCALS(size_t, nb1, src1, nb); \ + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); \ + GGML_TENSOR_LOCALS(size_t, nb, dst, nb); + #if defined(GGML_USE_ACCELERATE) #include #if defined(GGML_USE_CLBLAST) // allow usage of CLBlast alongside Accelerate functions @@ -7603,25 +7621,7 @@ static void ggml_compute_forward_dup_f16( return; } - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - const int64_t ne03 = src0->ne[3]; - - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; - const int64_t ne2 = dst->ne[2]; - const int64_t ne3 = dst->ne[3]; - - const size_t nb00 = src0->nb[0]; - const size_t nb01 = src0->nb[1]; - const size_t nb02 = src0->nb[2]; - const size_t nb03 = src0->nb[3]; - - const size_t nb0 = dst->nb[0]; - const size_t nb1 = dst->nb[1]; - const size_t nb2 = dst->nb[2]; - const size_t nb3 = dst->nb[3]; + GGML_TENSOR_UNARY_OP_LOCALS; const int ith = params->ith; // thread index const int nth = params->nth; // number of threads @@ -7892,25 +7892,7 @@ static void ggml_compute_forward_dup_f32( return; } - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - const int64_t ne03 = src0->ne[3]; - - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; - const int64_t ne2 = dst->ne[2]; - const int64_t ne3 = dst->ne[3]; - - const size_t nb00 = src0->nb[0]; - const size_t nb01 = src0->nb[1]; - const size_t nb02 = src0->nb[2]; - const size_t nb03 = src0->nb[3]; - - const size_t nb0 = dst->nb[0]; - const size_t nb1 = dst->nb[1]; - const size_t nb2 = dst->nb[2]; - const size_t nb3 = dst->nb[3]; + GGML_TENSOR_UNARY_OP_LOCALS; const int ith = params->ith; // thread index const int nth = params->nth; // number of threads @@ -8208,24 +8190,8 @@ static void ggml_compute_forward_add_f32( const int nth = params->nth; const int nr = ggml_nrows(src0); - const int64_t ne0 = src0->ne[0]; - const int64_t ne1 = src0->ne[1]; - const int64_t ne2 = src0->ne[2]; - - const size_t nb00 = src0->nb[0]; - const size_t nb01 = src0->nb[1]; - const size_t nb02 = src0->nb[2]; - const size_t nb03 = src0->nb[3]; - const size_t nb10 = src1->nb[0]; - const size_t nb11 = src1->nb[1]; - const size_t nb12 = src1->nb[2]; - const size_t nb13 = src1->nb[3]; - - const size_t nb0 = dst->nb[0]; - const size_t nb1 = dst->nb[1]; - const size_t nb2 = dst->nb[2]; - const size_t nb3 = dst->nb[3]; + GGML_TENSOR_BINARY_OP_LOCALS; GGML_ASSERT( nb0 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float)); @@ -8294,28 +8260,12 @@ static void ggml_compute_forward_add_f16_f32( const int nth = params->nth; const int nr = ggml_nrows(src0); - const int64_t ne0 = src0->ne[0]; - const int64_t ne1 = src0->ne[1]; - const int64_t ne2 = src0->ne[2]; - const size_t nb00 = src0->nb[0]; - const size_t nb01 = src0->nb[1]; - const size_t nb02 = src0->nb[2]; - const size_t nb03 = src0->nb[3]; - - const size_t nb10 = src1->nb[0]; - const size_t nb11 = src1->nb[1]; - const size_t nb12 = src1->nb[2]; - const size_t nb13 = src1->nb[3]; - - const size_t nb0 = dst->nb[0]; - const size_t nb1 = dst->nb[1]; - const size_t nb2 = dst->nb[2]; - const size_t nb3 = dst->nb[3]; + GGML_TENSOR_BINARY_OP_LOCALS; GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F16); GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); @@ -8364,24 +8314,8 @@ static void ggml_compute_forward_add_f16_f16( const int nth = params->nth; const int nr = ggml_nrows(src0); - const int64_t ne0 = src0->ne[0]; - const int64_t ne1 = src0->ne[1]; - const int64_t ne2 = src0->ne[2]; - const size_t nb00 = src0->nb[0]; - const size_t nb01 = src0->nb[1]; - const size_t nb02 = src0->nb[2]; - const size_t nb03 = src0->nb[3]; - - const size_t nb10 = src1->nb[0]; - const size_t nb11 = src1->nb[1]; - const size_t nb12 = src1->nb[2]; - const size_t nb13 = src1->nb[3]; - - const size_t nb0 = dst->nb[0]; - const size_t nb1 = dst->nb[1]; - const size_t nb2 = dst->nb[2]; - const size_t nb3 = dst->nb[3]; + GGML_TENSOR_BINARY_OP_LOCALS; GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F16); @@ -8431,25 +8365,8 @@ static void ggml_compute_forward_add_q_f32( } const int nr = ggml_nrows(src0); - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - //const int64_t ne03 = src0->ne[3]; - - const size_t nb00 = src0->nb[0]; - const size_t nb01 = src0->nb[1]; - const size_t nb02 = src0->nb[2]; - const size_t nb03 = src0->nb[3]; - const size_t nb10 = src1->nb[0]; - const size_t nb11 = src1->nb[1]; - const size_t nb12 = src1->nb[2]; - const size_t nb13 = src1->nb[3]; - - const size_t nb0 = dst->nb[0]; - const size_t nb1 = dst->nb[1]; - const size_t nb2 = dst->nb[2]; - const size_t nb3 = dst->nb[3]; + GGML_TENSOR_BINARY_OP_LOCALS; const int ith = params->ith; const int nth = params->nth; @@ -8570,19 +8487,8 @@ static void ggml_compute_forward_add1_f32( const int nth = params->nth; const int nr = ggml_nrows(src0); - const int64_t ne0 = src0->ne[0]; - const int64_t ne1 = src0->ne[1]; - const int64_t ne2 = src0->ne[2]; - const size_t nb00 = src0->nb[0]; - const size_t nb01 = src0->nb[1]; - const size_t nb02 = src0->nb[2]; - const size_t nb03 = src0->nb[3]; - - const size_t nb0 = dst->nb[0]; - const size_t nb1 = dst->nb[1]; - const size_t nb2 = dst->nb[2]; - const size_t nb3 = dst->nb[3]; + GGML_TENSOR_UNARY_OP_LOCALS; GGML_ASSERT( nb0 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float)); @@ -8636,23 +8542,12 @@ static void ggml_compute_forward_add1_f16_f32( const int nth = params->nth; const int nr = ggml_nrows(src0); - const int64_t ne0 = src0->ne[0]; - const int64_t ne1 = src0->ne[1]; - const int64_t ne2 = src0->ne[2]; - const size_t nb00 = src0->nb[0]; - const size_t nb01 = src0->nb[1]; - const size_t nb02 = src0->nb[2]; - const size_t nb03 = src0->nb[3]; - - const size_t nb0 = dst->nb[0]; - const size_t nb1 = dst->nb[1]; - const size_t nb2 = dst->nb[2]; - const size_t nb3 = dst->nb[3]; + GGML_TENSOR_UNARY_OP_LOCALS; GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F16); GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); @@ -8697,23 +8592,12 @@ static void ggml_compute_forward_add1_f16_f16( const int nth = params->nth; const int nr = ggml_nrows(src0); - const int64_t ne0 = src0->ne[0]; - const int64_t ne1 = src0->ne[1]; - const int64_t ne2 = src0->ne[2]; - const size_t nb00 = src0->nb[0]; - const size_t nb01 = src0->nb[1]; - const size_t nb02 = src0->nb[2]; - const size_t nb03 = src0->nb[3]; - - const size_t nb0 = dst->nb[0]; - const size_t nb1 = dst->nb[1]; - const size_t nb2 = dst->nb[2]; - const size_t nb3 = dst->nb[3]; + GGML_TENSOR_UNARY_OP_LOCALS; GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F16); GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); @@ -8758,19 +8642,8 @@ static void ggml_compute_forward_add1_q_f32( const int nth = params->nth; const int nr = ggml_nrows(src0); - const int64_t ne0 = src0->ne[0]; - const int64_t ne1 = src0->ne[1]; - const int64_t ne2 = src0->ne[2]; - - const size_t nb00 = src0->nb[0]; - const size_t nb01 = src0->nb[1]; - const size_t nb02 = src0->nb[2]; - const size_t nb03 = src0->nb[3]; - const size_t nb0 = dst->nb[0]; - const size_t nb1 = dst->nb[1]; - const size_t nb2 = dst->nb[2]; - const size_t nb3 = dst->nb[3]; + GGML_TENSOR_UNARY_OP_LOCALS; const enum ggml_type type = src0->type; dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q; @@ -8902,15 +8775,8 @@ static void ggml_compute_forward_acc_f32( const int nr = ggml_nrows(src1); const int nc = src1->ne[0]; - const int64_t ne10 = src1->ne[0]; - const int64_t ne11 = src1->ne[1]; - const int64_t ne12 = src1->ne[2]; - const int64_t ne13 = src1->ne[3]; - - const size_t nb10 = src1->nb[0]; - const size_t nb11 = src1->nb[1]; - const size_t nb12 = src1->nb[2]; - const size_t nb13 = src1->nb[3]; + GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne); + GGML_TENSOR_LOCALS(size_t, nb1, src1, nb); // src0 and dst as viewed during acc const size_t nb0 = ggml_element_size(src0); @@ -8999,24 +8865,8 @@ static void ggml_compute_forward_sub_f32( } const int nr = ggml_nrows(src0); - const int64_t ne0 = src0->ne[0]; - const int64_t ne1 = src0->ne[1]; - const int64_t ne2 = src0->ne[2]; - - const size_t nb00 = src0->nb[0]; - const size_t nb01 = src0->nb[1]; - const size_t nb02 = src0->nb[2]; - const size_t nb03 = src0->nb[3]; - const size_t nb10 = src1->nb[0]; - const size_t nb11 = src1->nb[1]; - const size_t nb12 = src1->nb[2]; - const size_t nb13 = src1->nb[3]; - - const size_t nb0 = dst->nb[0]; - const size_t nb1 = dst->nb[1]; - const size_t nb2 = dst->nb[2]; - const size_t nb3 = dst->nb[3]; + GGML_TENSOR_BINARY_OP_LOCALS; GGML_ASSERT( nb0 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float)); @@ -9106,29 +8956,7 @@ static void ggml_compute_forward_mul_f32( const int64_t nr = ggml_nrows(src0); - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - - const int64_t ne10 = src1->ne[0]; - const int64_t ne11 = src1->ne[1]; - const int64_t ne12 = src1->ne[2]; - const int64_t ne13 = src1->ne[3]; - - const size_t nb00 = src0->nb[0]; - const size_t nb01 = src0->nb[1]; - const size_t nb02 = src0->nb[2]; - const size_t nb03 = src0->nb[3]; - - const size_t nb10 = src1->nb[0]; - const size_t nb11 = src1->nb[1]; - const size_t nb12 = src1->nb[2]; - const size_t nb13 = src1->nb[3]; - - const size_t nb0 = dst->nb[0]; - const size_t nb1 = dst->nb[1]; - const size_t nb2 = dst->nb[2]; - const size_t nb3 = dst->nb[3]; + GGML_TENSOR_BINARY_OP_LOCALS; GGML_ASSERT( nb0 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float)); @@ -9216,24 +9044,8 @@ static void ggml_compute_forward_div_f32( } const int nr = ggml_nrows(src0); - const int64_t ne0 = src0->ne[0]; - const int64_t ne1 = src0->ne[1]; - const int64_t ne2 = src0->ne[2]; - const size_t nb00 = src0->nb[0]; - const size_t nb01 = src0->nb[1]; - const size_t nb02 = src0->nb[2]; - const size_t nb03 = src0->nb[3]; - - const size_t nb10 = src1->nb[0]; - const size_t nb11 = src1->nb[1]; - const size_t nb12 = src1->nb[2]; - const size_t nb13 = src1->nb[3]; - - const size_t nb0 = dst->nb[0]; - const size_t nb1 = dst->nb[1]; - const size_t nb2 = dst->nb[2]; - const size_t nb3 = dst->nb[3]; + GGML_TENSOR_BINARY_OP_LOCALS; GGML_ASSERT( nb0 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float)); @@ -9440,14 +9252,8 @@ static void ggml_compute_forward_sum_f32( assert(ggml_is_scalar(dst)); assert(src0->nb[0] == sizeof(float)); - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - const int64_t ne03 = src0->ne[3]; - - const size_t nb01 = src0->nb[1]; - const size_t nb02 = src0->nb[2]; - const size_t nb03 = src0->nb[3]; + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb); ggml_float sum = 0; ggml_float row_sum = 0; @@ -9496,29 +9302,13 @@ static void ggml_compute_forward_sum_rows_f32( GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(dst->nb[0] == sizeof(float)); - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - const int64_t ne03 = src0->ne[3]; - - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; - const int64_t ne2 = dst->ne[2]; - const int64_t ne3 = dst->ne[3]; + GGML_TENSOR_UNARY_OP_LOCALS; GGML_ASSERT(ne0 == 1); GGML_ASSERT(ne1 == ne01); GGML_ASSERT(ne2 == ne02); GGML_ASSERT(ne3 == ne03); - const size_t nb01 = src0->nb[1]; - const size_t nb02 = src0->nb[2]; - const size_t nb03 = src0->nb[3]; - - const size_t nb1 = dst->nb[1]; - const size_t nb2 = dst->nb[2]; - const size_t nb3 = dst->nb[3]; - for (int64_t i3 = 0; i3 < ne03; i3++) { for (int64_t i2 = 0; i2 < ne02; i2++) { for (int64_t i1 = 0; i1 < ne01; i1++) { @@ -9562,19 +9352,7 @@ static void ggml_compute_forward_mean_f32( assert(src0->nb[0] == sizeof(float)); - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - const int64_t ne03 = src0->ne[3]; - - const size_t nb01 = src0->nb[1]; - const size_t nb02 = src0->nb[2]; - const size_t nb03 = src0->nb[3]; - - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; - const int64_t ne2 = dst->ne[2]; - const int64_t ne3 = dst->ne[3]; + GGML_TENSOR_UNARY_OP_LOCALS; assert(ne0 == 1); assert(ne1 == ne01); @@ -9586,10 +9364,6 @@ static void ggml_compute_forward_mean_f32( UNUSED(ne2); UNUSED(ne3); - const size_t nb1 = dst->nb[1]; - const size_t nb2 = dst->nb[2]; - const size_t nb3 = dst->nb[3]; - for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i01 = 0; i01 < ne01; i01++) { @@ -9632,25 +9406,7 @@ static void ggml_compute_forward_repeat_f32( return; } - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; - const int64_t ne2 = dst->ne[2]; - const int64_t ne3 = dst->ne[3]; - - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - const int64_t ne03 = src0->ne[3]; - - const size_t nb0 = dst->nb[0]; - const size_t nb1 = dst->nb[1]; - const size_t nb2 = dst->nb[2]; - const size_t nb3 = dst->nb[3]; - - const size_t nb00 = src0->nb[0]; - const size_t nb01 = src0->nb[1]; - const size_t nb02 = src0->nb[2]; - const size_t nb03 = src0->nb[3]; + GGML_TENSOR_UNARY_OP_LOCALS; // guaranteed to be an integer due to the check in ggml_can_repeat const int nr0 = (int)(ne0/ne00); @@ -9711,25 +9467,7 @@ static void ggml_compute_forward_repeat_back_f32( return; } - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; - const int64_t ne2 = dst->ne[2]; - const int64_t ne3 = dst->ne[3]; - - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - const int64_t ne03 = src0->ne[3]; - - const size_t nb0 = dst->nb[0]; - const size_t nb1 = dst->nb[1]; - const size_t nb2 = dst->nb[2]; - const size_t nb3 = dst->nb[3]; - - const size_t nb00 = src0->nb[0]; - const size_t nb01 = src0->nb[1]; - const size_t nb02 = src0->nb[2]; - const size_t nb03 = src0->nb[3]; + GGML_TENSOR_UNARY_OP_LOCALS; // guaranteed to be an integer due to the check in ggml_can_repeat const int nr0 = (int)(ne00/ne0); @@ -10260,18 +9998,7 @@ static void ggml_compute_forward_norm_f32( const int ith = params->ith; const int nth = params->nth; - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - const int64_t ne03 = src0->ne[3]; - - const size_t nb01 = src0->nb[1]; - const size_t nb02 = src0->nb[2]; - const size_t nb03 = src0->nb[3]; - - const size_t nb1 = dst->nb[1]; - const size_t nb2 = dst->nb[2]; - const size_t nb3 = dst->nb[3]; + GGML_TENSOR_UNARY_OP_LOCALS; const float eps = 1e-5f; // TODO: make this a parameter @@ -10337,18 +10064,7 @@ static void ggml_compute_forward_rms_norm_f32( const int ith = params->ith; const int nth = params->nth; - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - const int64_t ne03 = src0->ne[3]; - - const size_t nb01 = src0->nb[1]; - const size_t nb02 = src0->nb[2]; - const size_t nb03 = src0->nb[3]; - - const size_t nb1 = dst->nb[1]; - const size_t nb2 = dst->nb[2]; - const size_t nb3 = dst->nb[3]; + GGML_TENSOR_UNARY_OP_LOCALS; const float eps = 1e-6f; // TODO: make this a parameter @@ -10413,22 +10129,7 @@ static void ggml_compute_forward_rms_norm_back_f32( const int ith = params->ith; const int nth = params->nth; - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - const int64_t ne03 = src0->ne[3]; - - const size_t nb01 = src0->nb[1]; - const size_t nb02 = src0->nb[2]; - const size_t nb03 = src0->nb[3]; - - const size_t nb11 = src1->nb[1]; - const size_t nb12 = src1->nb[2]; - const size_t nb13 = src1->nb[3]; - - const size_t nb1 = dst->nb[1]; - const size_t nb2 = dst->nb[2]; - const size_t nb3 = dst->nb[3]; + GGML_TENSOR_BINARY_OP_LOCALS; const float eps = 1e-6f; // TODO: make this a parameter @@ -10624,41 +10325,7 @@ static void ggml_compute_forward_mul_mat_f32( int64_t t0 = ggml_perf_time_us(); UNUSED(t0); - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - const int64_t ne03 = src0->ne[3]; - -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) - const int64_t ne10 = src1->ne[0]; -#endif - const int64_t ne11 = src1->ne[1]; -#ifndef NDEBUG - const int64_t ne12 = src1->ne[2]; - const int64_t ne13 = src1->ne[3]; - - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; - const int64_t ne2 = dst->ne[2]; - const int64_t ne3 = dst->ne[3]; - - const int nb00 = src0->nb[0]; -#endif - const int nb01 = src0->nb[1]; - const int nb02 = src0->nb[2]; - const int nb03 = src0->nb[3]; - -#ifndef NDEBUG - const int nb10 = src1->nb[0]; -#endif - const int nb11 = src1->nb[1]; - const int nb12 = src1->nb[2]; - const int nb13 = src1->nb[3]; - - const int nb0 = dst->nb[0]; - const int nb1 = dst->nb[1]; - const int nb2 = dst->nb[2]; - const int nb3 = dst->nb[3]; + GGML_TENSOR_BINARY_OP_LOCALS; const int ith = params->ith; const int nth = params->nth; @@ -10795,37 +10462,10 @@ static void ggml_compute_forward_mul_mat_f16_f32( int64_t t0 = ggml_perf_time_us(); UNUSED(t0); - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - const int64_t ne03 = src0->ne[3]; + GGML_TENSOR_BINARY_OP_LOCALS; - const int64_t ne10 = src1->ne[0]; - const int64_t ne11 = src1->ne[1]; - const int64_t ne12 = src1->ne[2]; - const int64_t ne13 = src1->ne[3]; - - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; - const int64_t ne2 = dst->ne[2]; - const int64_t ne3 = dst->ne[3]; //const int64_t ne = ne0*ne1*ne2*ne3; - const int nb00 = src0->nb[0]; - const int nb01 = src0->nb[1]; - const int nb02 = src0->nb[2]; - const int nb03 = src0->nb[3]; - - const int nb10 = src1->nb[0]; - const int nb11 = src1->nb[1]; - const int nb12 = src1->nb[2]; - const int nb13 = src1->nb[3]; - - const int nb0 = dst->nb[0]; - const int nb1 = dst->nb[1]; - const int nb2 = dst->nb[2]; - const int nb3 = dst->nb[3]; - const int ith = params->ith; const int nth = params->nth; @@ -10995,35 +10635,7 @@ static void ggml_compute_forward_mul_mat_q_f32( int64_t t0 = ggml_perf_time_us(); UNUSED(t0); - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - const int64_t ne03 = src0->ne[3]; - - const int64_t ne10 = src1->ne[0]; - const int64_t ne11 = src1->ne[1]; - const int64_t ne12 = src1->ne[2]; - const int64_t ne13 = src1->ne[3]; - - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; - const int64_t ne2 = dst->ne[2]; - const int64_t ne3 = dst->ne[3]; - - const int nb00 = src0->nb[0]; - const int nb01 = src0->nb[1]; - const int nb02 = src0->nb[2]; - const int nb03 = src0->nb[3]; - - const int nb10 = src1->nb[0]; - const int nb11 = src1->nb[1]; - const int nb12 = src1->nb[2]; - const int nb13 = src1->nb[3]; - - const int nb0 = dst->nb[0]; - const int nb1 = dst->nb[1]; - const int nb2 = dst->nb[2]; - const int nb3 = dst->nb[3]; + GGML_TENSOR_BINARY_OP_LOCALS; const int ith = params->ith; const int nth = params->nth; @@ -11039,7 +10651,7 @@ static void ggml_compute_forward_mul_mat_q_f32( enum ggml_type const vec_dot_type = quantize_fns[type].vec_dot_type; // we don't support permuted src0 or src1 - GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]); + GGML_ASSERT(nb00 == GGML_TYPE_SIZE[type]); GGML_ASSERT(nb10 == sizeof(float)); // dst cannot be transposed or permuted @@ -11233,35 +10845,7 @@ static void ggml_compute_forward_out_prod_f32( int64_t t0 = ggml_perf_time_us(); UNUSED(t0); - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - const int64_t ne03 = src0->ne[3]; - - const int64_t ne10 = src1->ne[0]; - //const int64_t ne11 = src1->ne[1]; - const int64_t ne12 = src1->ne[2]; - const int64_t ne13 = src1->ne[3]; - - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; - const int64_t ne2 = dst->ne[2]; - const int64_t ne3 = dst->ne[3]; - - const int nb00 = src0->nb[0]; - const int nb01 = src0->nb[1]; - const int nb02 = src0->nb[2]; - const int nb03 = src0->nb[3]; - - const int nb10 = src1->nb[0]; - const int nb11 = src1->nb[1]; - const int nb12 = src1->nb[2]; - const int nb13 = src1->nb[3]; - - const int nb0 = dst->nb[0]; - const int nb1 = dst->nb[1]; - const int nb2 = dst->nb[2]; - const int nb3 = dst->nb[3]; + GGML_TENSOR_BINARY_OP_LOCALS; const int ith = params->ith; const int nth = params->nth; @@ -11496,15 +11080,8 @@ static void ggml_compute_forward_set_f32( const int nr = ggml_nrows(src1); const int nc = src1->ne[0]; - const int64_t ne10 = src1->ne[0]; - const int64_t ne11 = src1->ne[1]; - const int64_t ne12 = src1->ne[2]; - const int64_t ne13 = src1->ne[3]; - - const size_t nb10 = src1->nb[0]; - const size_t nb11 = src1->nb[1]; - const size_t nb12 = src1->nb[2]; - const size_t nb13 = src1->nb[3]; + GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne); + GGML_TENSOR_LOCALS(size_t, nb1, src1, nb); // src0 and dst as viewed during set const size_t nb0 = ggml_element_size(src0); @@ -11895,29 +11472,14 @@ static void ggml_compute_forward_diag_f32( // TODO: handle transposed/permuted matrices - const int ne00 = src0->ne[0]; - const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - const int ne03 = src0->ne[3]; - const int ne0 = dst->ne[0]; - const int ne1 = dst->ne[1]; - const int ne2 = dst->ne[2]; - const int ne3 = dst->ne[3]; + GGML_TENSOR_UNARY_OP_LOCALS; + GGML_ASSERT(ne00 == ne0); GGML_ASSERT(ne00 == ne1); GGML_ASSERT(ne01 == 1); GGML_ASSERT(ne02 == ne2); GGML_ASSERT(ne03 == ne3); - const int nb00 = src0->nb[0]; - //const int nb01 = src0->nb[1]; - const int nb02 = src0->nb[2]; - const int nb03 = src0->nb[3]; - const int nb0 = dst->nb[0]; - const int nb1 = dst->nb[1]; - const int nb2 = dst->nb[2]; - const int nb3 = dst->nb[3]; - GGML_ASSERT(nb00 == sizeof(float)); GGML_ASSERT(nb0 == sizeof(float)); @@ -12494,20 +12056,7 @@ static void ggml_compute_forward_rope_f32( assert(n_past >= 0); - const size_t nb00 = src0->nb[0]; - const size_t nb01 = src0->nb[1]; - const size_t nb02 = src0->nb[2]; - const size_t nb03 = src0->nb[3]; - - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; - const int64_t ne2 = dst->ne[2]; - const int64_t ne3 = dst->ne[3]; - - const size_t nb0 = dst->nb[0]; - const size_t nb1 = dst->nb[1]; - const size_t nb2 = dst->nb[2]; - const size_t nb3 = dst->nb[3]; + GGML_TENSOR_UNARY_OP_LOCALS; //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); //printf("n_past = %d, ne2 = %d\n", n_past, ne2); @@ -12634,20 +12183,7 @@ static void ggml_compute_forward_rope_f16( assert(n_past >= 0); - const size_t nb00 = src0->nb[0]; - const size_t nb01 = src0->nb[1]; - const size_t nb02 = src0->nb[2]; - const size_t nb03 = src0->nb[3]; - - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; - const int64_t ne2 = dst->ne[2]; - const int64_t ne3 = dst->ne[3]; - - const size_t nb0 = dst->nb[0]; - const size_t nb1 = dst->nb[1]; - const size_t nb2 = dst->nb[2]; - const size_t nb3 = dst->nb[3]; + GGML_TENSOR_UNARY_OP_LOCALS; //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); //printf("n_past = %d, ne2 = %d\n", n_past, ne2); @@ -12800,21 +12336,7 @@ static void ggml_compute_forward_rope_back_f32( assert(n_past >= 0); - const size_t nb00 = src0->nb[0]; - const size_t nb01 = src0->nb[1]; - const size_t nb02 = src0->nb[2]; - const size_t nb03 = src0->nb[3]; - - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; - const int64_t ne2 = dst->ne[2]; - const int64_t ne3 = dst->ne[3]; - - const size_t nb0 = dst->nb[0]; - const size_t nb1 = dst->nb[1]; - const size_t nb2 = dst->nb[2]; - const size_t nb3 = dst->nb[3]; - + GGML_TENSOR_UNARY_OP_LOCALS; //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); //printf("n_past = %d, ne2 = %d\n", n_past, ne2); @@ -12913,21 +12435,7 @@ static void ggml_compute_forward_rope_back_f16( assert(n_past >= 0); - const size_t nb00 = src0->nb[0]; - const size_t nb01 = src0->nb[1]; - const size_t nb02 = src0->nb[2]; - const size_t nb03 = src0->nb[3]; - - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; - const int64_t ne2 = dst->ne[2]; - const int64_t ne3 = dst->ne[3]; - - const size_t nb0 = dst->nb[0]; - const size_t nb1 = dst->nb[1]; - const size_t nb2 = dst->nb[2]; - const size_t nb3 = dst->nb[3]; - + GGML_TENSOR_UNARY_OP_LOCALS; //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); //printf("n_past = %d, ne2 = %d\n", n_past, ne2); @@ -13039,36 +12547,7 @@ static void ggml_compute_forward_conv_1d_s1_ph_f16_f32( int64_t t0 = ggml_perf_time_us(); UNUSED(t0); - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - //const int64_t ne03 = src0->ne[3]; - - const int64_t ne10 = src1->ne[0]; - const int64_t ne11 = src1->ne[1]; - //const int64_t ne12 = src1->ne[2]; - //const int64_t ne13 = src1->ne[3]; - - //const int64_t ne0 = dst->ne[0]; - //const int64_t ne1 = dst->ne[1]; - //const int64_t ne2 = dst->ne[2]; - //const int64_t ne3 = dst->ne[3]; - //const int64_t ne = ne0*ne1*ne2*ne3; - - const int nb00 = src0->nb[0]; - const int nb01 = src0->nb[1]; - const int nb02 = src0->nb[2]; - //const int nb03 = src0->nb[3]; - - const int nb10 = src1->nb[0]; - const int nb11 = src1->nb[1]; - //const int nb12 = src1->nb[2]; - //const int nb13 = src1->nb[3]; - - //const int nb0 = dst->nb[0]; - const int nb1 = dst->nb[1]; - //const int nb2 = dst->nb[2]; - //const int nb3 = dst->nb[3]; + GGML_TENSOR_BINARY_OP_LOCALS; const int ith = params->ith; const int nth = params->nth; @@ -13159,36 +12638,7 @@ static void ggml_compute_forward_conv_1d_s1_ph_f32( int64_t t0 = ggml_perf_time_us(); UNUSED(t0); - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - //const int64_t ne03 = src0->ne[3]; - - const int64_t ne10 = src1->ne[0]; - const int64_t ne11 = src1->ne[1]; - //const int64_t ne12 = src1->ne[2]; - //const int64_t ne13 = src1->ne[3]; - - //const int64_t ne0 = dst->ne[0]; - //const int64_t ne1 = dst->ne[1]; - //const int64_t ne2 = dst->ne[2]; - //const int64_t ne3 = dst->ne[3]; - //const int64_t ne = ne0*ne1*ne2*ne3; - - const int nb00 = src0->nb[0]; - const int nb01 = src0->nb[1]; - const int nb02 = src0->nb[2]; - //const int nb03 = src0->nb[3]; - - const int nb10 = src1->nb[0]; - const int nb11 = src1->nb[1]; - //const int nb12 = src1->nb[2]; - //const int nb13 = src1->nb[3]; - - //const int nb0 = dst->nb[0]; - const int nb1 = dst->nb[1]; - //const int nb2 = dst->nb[2]; - //const int nb3 = dst->nb[3]; + GGML_TENSOR_BINARY_OP_LOCALS; const int ith = params->ith; const int nth = params->nth; @@ -13302,36 +12752,7 @@ static void ggml_compute_forward_conv_1d_s2_ph_f16_f32( int64_t t0 = ggml_perf_time_us(); UNUSED(t0); - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - //const int64_t ne03 = src0->ne[3]; - - const int64_t ne10 = src1->ne[0]; - const int64_t ne11 = src1->ne[1]; - //const int64_t ne12 = src1->ne[2]; - //const int64_t ne13 = src1->ne[3]; - - //const int64_t ne0 = dst->ne[0]; - //const int64_t ne1 = dst->ne[1]; - //const int64_t ne2 = dst->ne[2]; - //const int64_t ne3 = dst->ne[3]; - //const int64_t ne = ne0*ne1*ne2*ne3; - - const int nb00 = src0->nb[0]; - const int nb01 = src0->nb[1]; - const int nb02 = src0->nb[2]; - //const int nb03 = src0->nb[3]; - - const int nb10 = src1->nb[0]; - const int nb11 = src1->nb[1]; - //const int nb12 = src1->nb[2]; - //const int nb13 = src1->nb[3]; - - //const int nb0 = dst->nb[0]; - const int nb1 = dst->nb[1]; - //const int nb2 = dst->nb[2]; - //const int nb3 = dst->nb[3]; + GGML_TENSOR_BINARY_OP_LOCALS; const int ith = params->ith; const int nth = params->nth; @@ -13422,36 +12843,7 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32( int64_t t0 = ggml_perf_time_us(); UNUSED(t0); - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - //const int64_t ne03 = src0->ne[3]; - - const int64_t ne10 = src1->ne[0]; - const int64_t ne11 = src1->ne[1]; - //const int64_t ne12 = src1->ne[2]; - //const int64_t ne13 = src1->ne[3]; - - //const int64_t ne0 = dst->ne[0]; - //const int64_t ne1 = dst->ne[1]; - //const int64_t ne2 = dst->ne[2]; - //const int64_t ne3 = dst->ne[3]; - //const int64_t ne = ne0*ne1*ne2*ne3; - - const int nb00 = src0->nb[0]; - const int nb01 = src0->nb[1]; - const int nb02 = src0->nb[2]; - //const int nb03 = src0->nb[3]; - - const int nb10 = src1->nb[0]; - const int nb11 = src1->nb[1]; - //const int nb12 = src1->nb[2]; - //const int nb13 = src1->nb[3]; - - //const int nb0 = dst->nb[0]; - const int nb1 = dst->nb[1]; - //const int nb2 = dst->nb[2]; - //const int nb3 = dst->nb[3]; + GGML_TENSOR_BINARY_OP_LOCALS; const int ith = params->ith; const int nth = params->nth; @@ -13565,36 +12957,7 @@ static void ggml_compute_forward_conv_2d_sk_p0_f16_f32( int64_t t0 = ggml_perf_time_us(); UNUSED(t0); - const int ne00 = src0->ne[0]; - const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - //const int ne03 = src0->ne[3]; - - const int ne10 = src1->ne[0]; - //const int ne11 = src1->ne[1]; - const int ne12 = src1->ne[2]; - //const int ne13 = src1->ne[3]; - - const int ne0 = dst->ne[0]; - const int ne1 = dst->ne[1]; - const int ne2 = dst->ne[2]; - //const int ne3 = dst->ne[3]; - //const int ne = ne0*ne1*ne2*ne3; - - const int nb00 = src0->nb[0]; - //const int nb01 = src0->nb[1]; - //const int nb02 = src0->nb[2]; - const int nb03 = src0->nb[3]; - - const int nb10 = src1->nb[0]; - //const int nb11 = src1->nb[1]; - const int nb12 = src1->nb[2]; - //const int nb13 = src1->nb[3]; - - //const int nb0 = dst->nb[0]; - //const int nb1 = dst->nb[1]; - const int nb2 = dst->nb[2]; - //const int nb3 = dst->nb[3]; + GGML_TENSOR_BINARY_OP_LOCALS; const int ith = params->ith; const int nth = params->nth; @@ -13699,45 +13062,14 @@ static void ggml_compute_forward_flash_attn_f32( int64_t t0 = ggml_perf_time_us(); UNUSED(t0); - const int64_t neq0 = q->ne[0]; - const int64_t neq1 = q->ne[1]; - const int64_t neq2 = q->ne[2]; - const int64_t neq3 = q->ne[3]; - - const int64_t nek0 = k->ne[0]; - const int64_t nek1 = k->ne[1]; - //const int64_t nek2 = k->ne[2]; - //const int64_t nek3 = k->ne[3]; - - //const int64_t nev0 = v->ne[0]; - const int64_t nev1 = v->ne[1]; - //const int64_t nev2 = v->ne[2]; - //const int64_t nev3 = v->ne[3]; - - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; - //const int64_t ne2 = dst->ne[2]; - //const int64_t ne3 = dst->ne[3]; - - const int nbk0 = k->nb[0]; - const int nbk1 = k->nb[1]; - const int nbk2 = k->nb[2]; - const int nbk3 = k->nb[3]; - - const int nbq0 = q->nb[0]; - const int nbq1 = q->nb[1]; - const int nbq2 = q->nb[2]; - const int nbq3 = q->nb[3]; - - const int nbv0 = v->nb[0]; - const int nbv1 = v->nb[1]; - const int nbv2 = v->nb[2]; - const int nbv3 = v->nb[3]; - - const int nb0 = dst->nb[0]; - const int nb1 = dst->nb[1]; - const int nb2 = dst->nb[2]; - const int nb3 = dst->nb[3]; + GGML_TENSOR_LOCALS(int64_t, neq, q, ne); + GGML_TENSOR_LOCALS(size_t, nbq, q, nb); + GGML_TENSOR_LOCALS(int64_t, nek, k, ne); + GGML_TENSOR_LOCALS(size_t, nbk, k, nb); + GGML_TENSOR_LOCALS(int64_t, nev, v, ne); + GGML_TENSOR_LOCALS(size_t, nbv, v, nb); + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); + GGML_TENSOR_LOCALS(size_t, nb, dst, nb); const int ith = params->ith; const int nth = params->nth; @@ -13908,45 +13240,14 @@ static void ggml_compute_forward_flash_attn_f16( int64_t t0 = ggml_perf_time_us(); UNUSED(t0); - const int64_t neq0 = q->ne[0]; - const int64_t neq1 = q->ne[1]; - const int64_t neq2 = q->ne[2]; - const int64_t neq3 = q->ne[3]; - - const int64_t nek0 = k->ne[0]; - const int64_t nek1 = k->ne[1]; - //const int64_t nek2 = k->ne[2]; - //const int64_t nek3 = k->ne[3]; - - //const int64_t nev0 = v->ne[0]; - const int64_t nev1 = v->ne[1]; - //const int64_t nev2 = v->ne[2]; - //const int64_t nev3 = v->ne[3]; - - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; - //const int64_t ne2 = dst->ne[2]; - //const int64_t ne3 = dst->ne[3]; - - const int nbk0 = k->nb[0]; - const int nbk1 = k->nb[1]; - const int nbk2 = k->nb[2]; - const int nbk3 = k->nb[3]; - - const int nbq0 = q->nb[0]; - const int nbq1 = q->nb[1]; - const int nbq2 = q->nb[2]; - const int nbq3 = q->nb[3]; - - const int nbv0 = v->nb[0]; - const int nbv1 = v->nb[1]; - const int nbv2 = v->nb[2]; - const int nbv3 = v->nb[3]; - - const int nb0 = dst->nb[0]; - const int nb1 = dst->nb[1]; - const int nb2 = dst->nb[2]; - const int nb3 = dst->nb[3]; + GGML_TENSOR_LOCALS(int64_t, neq, q, ne); + GGML_TENSOR_LOCALS(size_t, nbq, q, nb); + GGML_TENSOR_LOCALS(int64_t, nek, k, ne); + GGML_TENSOR_LOCALS(size_t, nbk, k, nb); + GGML_TENSOR_LOCALS(int64_t, nev, v, ne); + GGML_TENSOR_LOCALS(size_t, nbv, v, nb); + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); + GGML_TENSOR_LOCALS(size_t, nb, dst, nb); const int ith = params->ith; const int nth = params->nth; @@ -14180,65 +13481,18 @@ static void ggml_compute_forward_flash_ff_f16( int64_t t0 = ggml_perf_time_us(); UNUSED(t0); - const int64_t nea0 = a->ne[0]; - const int64_t nea1 = a->ne[1]; - const int64_t nea2 = a->ne[2]; - const int64_t nea3 = a->ne[3]; - - const int64_t neb00 = b0->ne[0]; - const int64_t neb01 = b0->ne[1]; - //const int64_t neb02 = b0->ne[2]; - //const int64_t neb03 = b0->ne[3]; - - const int64_t neb10 = b1->ne[0]; - const int64_t neb11 = b1->ne[1]; - //const int64_t neb12 = b1->ne[2]; - //const int64_t neb13 = b1->ne[3]; - - const int64_t nec00 = c0->ne[0]; - const int64_t nec01 = c0->ne[1]; - //const int64_t nec02 = c0->ne[2]; - //const int64_t nec03 = c0->ne[3]; - - const int64_t nec10 = c1->ne[0]; - const int64_t nec11 = c1->ne[1]; - //const int64_t nec12 = c1->ne[2]; - //const int64_t nec13 = c1->ne[3]; - - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; - const int64_t ne2 = dst->ne[2]; - //const int64_t ne3 = dst->ne[3]; - - const int nba0 = a->nb[0]; - const int nba1 = a->nb[1]; - const int nba2 = a->nb[2]; - const int nba3 = a->nb[3]; - - const int nbb00 = b0->nb[0]; - const int nbb01 = b0->nb[1]; - const int nbb02 = b0->nb[2]; - const int nbb03 = b0->nb[3]; - - const int nbb10 = b1->nb[0]; - //const int nbb11 = b1->nb[1]; - //const int nbb12 = b1->nb[2]; - //const int nbb13 = b1->nb[3]; - - const int nbc00 = c0->nb[0]; - const int nbc01 = c0->nb[1]; - const int nbc02 = c0->nb[2]; - const int nbc03 = c0->nb[3]; - - const int nbc10 = c1->nb[0]; - //const int nbc11 = c1->nb[1]; - //const int nbc12 = c1->nb[2]; - //const int nbc13 = c1->nb[3]; - - const int nb0 = dst->nb[0]; - const int nb1 = dst->nb[1]; - const int nb2 = dst->nb[2]; - const int nb3 = dst->nb[3]; + GGML_TENSOR_LOCALS(int64_t, nea, a, ne); + GGML_TENSOR_LOCALS(size_t, nba, a, nb); + GGML_TENSOR_LOCALS(int64_t, neb0, b0, ne); + GGML_TENSOR_LOCALS(size_t, nbb0, b0, nb); + GGML_TENSOR_LOCALS(int64_t, neb1, b1, ne); + GGML_TENSOR_LOCALS(size_t, nbb1, b1, nb); + GGML_TENSOR_LOCALS(int64_t, nec0, c0, ne); + GGML_TENSOR_LOCALS(size_t, nbc0, c0, nb); + GGML_TENSOR_LOCALS(int64_t, nec1, c1, ne); + GGML_TENSOR_LOCALS(size_t, nbc1, c1, nb); + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); + GGML_TENSOR_LOCALS(size_t, nb, dst, nb); const int ith = params->ith; const int nth = params->nth; @@ -14386,55 +13640,16 @@ static void ggml_compute_forward_flash_attn_back_f32( int64_t t0 = ggml_perf_time_us(); UNUSED(t0); - const int64_t neq0 = q->ne[0]; - const int64_t neq1 = q->ne[1]; - const int64_t neq2 = q->ne[2]; - const int64_t neq3 = q->ne[3]; - - const int64_t nek0 = k->ne[0]; - const int64_t nek1 = k->ne[1]; - //const int64_t nek2 = k->ne[2]; - //const int64_t nek3 = k->ne[3]; - - const int64_t nev0 = v->ne[0]; - const int64_t nev1 = v->ne[1]; - //const int64_t nev2 = v->ne[2]; - //const int64_t nev3 = v->ne[3]; - - const int64_t ned0 = d->ne[0]; - const int64_t ned1 = d->ne[1]; - //const int64_t ned2 = d->ne[2]; - //const int64_t ned3 = d->ne[3]; - - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; - const int64_t ne2 = dst->ne[2]; - const int64_t ne3 = dst->ne[3]; - - const int nbk0 = k->nb[0]; - const int nbk1 = k->nb[1]; - const int nbk2 = k->nb[2]; - const int nbk3 = k->nb[3]; - - const int nbq0 = q->nb[0]; - const int nbq1 = q->nb[1]; - const int nbq2 = q->nb[2]; - const int nbq3 = q->nb[3]; - - const int nbv0 = v->nb[0]; - const int nbv1 = v->nb[1]; - const int nbv2 = v->nb[2]; - const int nbv3 = v->nb[3]; - - const int nbd0 = d->nb[0]; - const int nbd1 = d->nb[1]; - const int nbd2 = d->nb[2]; - const int nbd3 = d->nb[3]; - - const int nb0 = dst->nb[0]; - const int nb1 = dst->nb[1]; - const int nb2 = dst->nb[2]; - const int nb3 = dst->nb[3]; + GGML_TENSOR_LOCALS(int64_t, neq, q, ne); + GGML_TENSOR_LOCALS(size_t, nbq, q, nb); + GGML_TENSOR_LOCALS(int64_t, nek, k, ne); + GGML_TENSOR_LOCALS(size_t, nbk, k, nb); + GGML_TENSOR_LOCALS(int64_t, nev, v, ne); + GGML_TENSOR_LOCALS(size_t, nbv, v, nb); + GGML_TENSOR_LOCALS(int64_t, ned, d, ne); + GGML_TENSOR_LOCALS(size_t, nbd, d, nb); + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); + GGML_TENSOR_LOCALS(size_t, nb, dst, nb); const int ith = params->ith; const int nth = params->nth; @@ -14792,15 +14007,8 @@ static void ggml_compute_forward_win_part_f32( return; } - const int64_t ne00 = src0->ne[0]; UNUSED(ne00); - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - const int64_t ne03 = src0->ne[3]; UNUSED(ne03); - - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; - const int64_t ne2 = dst->ne[2]; - const int64_t ne3 = dst->ne[3]; UNUSED(ne3); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); const int32_t nep0 = ((const int32_t *)(opt0->data))[0]; const int32_t nep1 = ((const int32_t *)(opt0->data))[1]; @@ -14863,14 +14071,8 @@ static void ggml_compute_forward_win_unpart_f32( return; } - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - //const int64_t ne03 = src0->ne[3]; - - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; - const int64_t ne2 = dst->ne[2]; + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); const int32_t w = ((const int32_t *)(opt0->data))[0]; From bba0dafe5ea79cc8c5ab60de038b72bbb461c3d9 Mon Sep 17 00:00:00 2001 From: PAB Date: Sun, 2 Jul 2023 17:25:37 +0200 Subject: [PATCH 08/15] ggml : add `ELU`, `TANH`, `ARGMAX` (#316) * add: `elu` activation * add: `tanh` activation * add: `argmax` * ggml : rearrange ops - put "tanh" after "step" --------- Co-authored-by: Georgi Gerganov --- include/ggml/ggml.h | 24 ++++ src/ggml.c | 284 ++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 298 insertions(+), 10 deletions(-) diff --git a/include/ggml/ggml.h b/include/ggml/ggml.h index fb3fec29c..05d7f0a19 100644 --- a/include/ggml/ggml.h +++ b/include/ggml/ggml.h @@ -321,17 +321,20 @@ extern "C" { GGML_OP_SUM, GGML_OP_SUM_ROWS, GGML_OP_MEAN, + GGML_OP_ARGMAX, GGML_OP_REPEAT, GGML_OP_REPEAT_BACK, GGML_OP_ABS, GGML_OP_SGN, GGML_OP_NEG, GGML_OP_STEP, + GGML_OP_ELU, GGML_OP_RELU, GGML_OP_GELU, GGML_OP_GELU_QUICK, GGML_OP_SILU, GGML_OP_SILU_BACK, + GGML_OP_TANH, GGML_OP_NORM, // normalize GGML_OP_RMS_NORM, GGML_OP_RMS_NORM_BACK, @@ -716,6 +719,11 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + // argmax along rows + GGML_API struct ggml_tensor * ggml_argmax( + struct ggml_context * ctx, + struct ggml_tensor * a); + // if a is the same shape as b, and a is not parameter, return a // otherwise, return a new tensor: repeat(a) to fit in b GGML_API struct ggml_tensor * ggml_repeat( @@ -760,6 +768,22 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_tanh( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_tanh_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_elu( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_elu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_relu( struct ggml_context * ctx, struct ggml_tensor * a); diff --git a/src/ggml.c b/src/ggml.c index 7eaff45cc..a1b2b1b5f 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -3465,6 +3465,8 @@ inline static void ggml_vec_log_f32 (const int n, float * y, const float * x) { inline static void ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); } inline static void ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); } inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; } +inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); } +inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; } inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; } static const float GELU_COEF_A = 0.044715f; @@ -3616,6 +3618,16 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x *s = 1.f/(*s); } +inline static void ggml_vec_argmax_f32(const int n, int * s, const float * x) { + float max = -INFINITY; + int idx = 0; + for (int i = 0; i < n; ++i) { + max = MAX(max, x[i]); + if (max == x[i]) { idx = i; } + } + *s = idx; +} + // // data types // @@ -3725,12 +3737,15 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "SUM", "SUM_ROWS", "MEAN", + "ARGMAX", "REPEAT", "REPEAT_BACK", "ABS", "SGN", "NEG", "STEP", + "TANH", + "ELU", "RELU", "GELU", "GELU_QUICK", @@ -3783,7 +3798,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 64, "GGML_OP_COUNT != 64"); +static_assert(GGML_OP_COUNT == 67, "GGML_OP_COUNT != 67"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -3801,12 +3816,15 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "Σx", "Σx_k", "Σx/n", + "argmax(x)", "repeat(x)", "repeat_back(x)", "abs(x)", "sgn(x)", "-x", "step(x)", + "tanh(x)", + "elu(x)", "relu(x)", "gelu(x)", "gelu_quick(x)", @@ -3859,7 +3877,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 64, "GGML_OP_COUNT != 64"); +static_assert(GGML_OP_COUNT == 67, "GGML_OP_COUNT != 67"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN"); @@ -5458,6 +5476,30 @@ struct ggml_tensor * ggml_mean( return result; } +// ggml_argmax + +struct ggml_tensor * ggml_argmax( + struct ggml_context * ctx, + struct ggml_tensor * a) { + GGML_ASSERT(ggml_is_matrix(a)); + bool is_node = false; + + if (a->grad) { + GGML_ASSERT(false); + is_node = true; + } + + int64_t ne[GGML_MAX_DIMS] = { a->ne[1], 1, 1, 1 }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, a->n_dims, ne); + + result->op = GGML_OP_ARGMAX; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + // ggml_repeat struct ggml_tensor * ggml_repeat( @@ -5651,6 +5693,74 @@ struct ggml_tensor * ggml_step_inplace( return ggml_step_impl(ctx, a, true); } +// ggml_tanh + +struct ggml_tensor * ggml_tanh_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_TANH; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct ggml_tensor * ggml_tanh( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_tanh_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_tanh_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_tanh_impl(ctx, a, true); +} + +// ggml_elu + +struct ggml_tensor * ggml_elu_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_ELU; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct ggml_tensor * ggml_elu( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_elu_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_elu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_elu_impl(ctx, a, true); +} + // ggml_relu struct ggml_tensor * ggml_relu_impl( @@ -9393,6 +9503,52 @@ static void ggml_compute_forward_mean( } } +// ggml_compute_forward_argmax + +static void ggml_compute_forward_argmax_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + assert(src0->nb[0] == sizeof(float)); + assert(dst->nb[0] == sizeof(float)); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + + const size_t nb01 = src0->nb[1]; + const size_t nb0 = dst->nb[0]; + + for (int64_t i1 = 0; i1 < ne01; i1++) { + float * src = (float *) ((char *) src0->data + i1*nb01); + int32_t * dst_ = (int32_t *) ((char *) dst->data + i1*nb0); + int v = 0; + ggml_vec_argmax_f32(ne00, &v, src); + dst_[0] = v; + } +} + +static void ggml_compute_forward_argmax( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_argmax_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_repeat static void ggml_compute_forward_repeat_f32( @@ -9697,6 +9853,90 @@ static void ggml_compute_forward_step( } } +// ggml_compute_forward_tanh + +static void ggml_compute_forward_tanh_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_tanh_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_tanh( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_tanh_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_elu + +static void ggml_compute_forward_elu_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_elu_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_elu( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_elu_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_relu static void ggml_compute_forward_relu_f32( @@ -14670,6 +14910,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_mean(params, tensor->src0, tensor); } break; + case GGML_OP_ARGMAX: + { + ggml_compute_forward_argmax(params, tensor->src0, tensor); + } break; case GGML_OP_REPEAT: { ggml_compute_forward_repeat(params, tensor->src0, tensor); @@ -14694,6 +14938,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_step(params, tensor->src0, tensor); } break; + case GGML_OP_TANH: + { + ggml_compute_forward_tanh(params, tensor->src0, tensor); + } break; + case GGML_OP_ELU: + { + ggml_compute_forward_elu(params, tensor->src0, tensor); + } break; case GGML_OP_RELU: { ggml_compute_forward_relu(params, tensor->src0, tensor); @@ -15069,6 +15321,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor } } break; case GGML_OP_MEAN: + case GGML_OP_ARGMAX: { GGML_ASSERT(false); // TODO: implement } break; @@ -15122,6 +15375,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor // noop } } break; + case GGML_OP_TANH: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_ELU: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_OP_RELU: { if (src0->grad) { @@ -15141,14 +15402,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ASSERT(false); // TODO: not implemented } break; - case GGML_OP_ALIBI: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_CLAMP: - { - GGML_ASSERT(false); // TODO: not implemented - } break; case GGML_OP_SILU: { // necessary for llama @@ -15505,6 +15758,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor // noop } } break; + case GGML_OP_ALIBI: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_CLAMP: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_OP_CONV_1D_S1_PH: { GGML_ASSERT(false); // TODO: not implemented @@ -16170,12 +16431,15 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) case GGML_OP_SUM: case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: + case GGML_OP_ARGMAX: case GGML_OP_REPEAT: case GGML_OP_REPEAT_BACK: case GGML_OP_ABS: case GGML_OP_SGN: case GGML_OP_NEG: case GGML_OP_STEP: + case GGML_OP_TANH: + case GGML_OP_ELU: case GGML_OP_RELU: { node->n_tasks = 1; From daf63b0d65df42a79a280395d0736933f7db594e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 2 Jul 2023 18:26:26 +0300 Subject: [PATCH 09/15] ggml : fix enum order for TANH (#316) --- include/ggml/ggml.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ggml/ggml.h b/include/ggml/ggml.h index 05d7f0a19..0b3c31817 100644 --- a/include/ggml/ggml.h +++ b/include/ggml/ggml.h @@ -328,13 +328,13 @@ extern "C" { GGML_OP_SGN, GGML_OP_NEG, GGML_OP_STEP, + GGML_OP_TANH, GGML_OP_ELU, GGML_OP_RELU, GGML_OP_GELU, GGML_OP_GELU_QUICK, GGML_OP_SILU, GGML_OP_SILU_BACK, - GGML_OP_TANH, GGML_OP_NORM, // normalize GGML_OP_RMS_NORM, GGML_OP_RMS_NORM_BACK, From 50462bbb72412ef25e0a34e878e178a321e54d76 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 2 Jul 2023 18:33:41 +0300 Subject: [PATCH 10/15] ggml : remove tensor ptr from export for now (close #267) Not used for now --- src/ggml.c | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/src/ggml.c b/src/ggml.c index a1b2b1b5f..60b195b57 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -16945,13 +16945,6 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) { fwrite(&nb, sizeof(uint64_t), 1, fout); } - // store the pointer address - { - const uint64_t ptr = (uint64_t) tensor->data; - - fwrite(&ptr, sizeof(uint64_t), 1, fout); - } - fwrite(tensor->name, sizeof(char), GGML_MAX_NAME, fout); // dump the data @@ -16985,13 +16978,6 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) { fwrite(&nb, sizeof(uint64_t), 1, fout); } - // store the pointer address - { - const uint64_t ptr = (uint64_t) tensor->data; - - fwrite(&ptr, sizeof(uint64_t), 1, fout); - } - fwrite(tensor->name, sizeof(char), GGML_MAX_NAME, fout); // output the op arguments @@ -17176,8 +17162,6 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context ** tensor->op = (enum ggml_op) op; - uint64_t ptr_cur = *(const uint64_t *) ptr; ptr += sizeof(ptr_cur); - memcpy(tensor->name, ptr, GGML_MAX_NAME); ptr += GGML_MAX_NAME; tensor->data = (void *) ptr; @@ -17223,8 +17207,6 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context ** nb[j] = nb_cur; } - uint64_t ptr_cur = *(const uint64_t *) ptr; ptr += sizeof(ptr_cur); // TODO: not yet used - const char * ptr_name = ptr; ptr += GGML_MAX_NAME; const int32_t * ptr_arg_idx = (const int32_t *) ptr; ptr += (2 + GGML_MAX_OPT)*sizeof(int32_t); From 97f72d8459923409679e61b5eb1cc976bdaed17c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 2 Jul 2023 18:53:42 +0300 Subject: [PATCH 11/15] ggml : disable ggml_rope_back for ChatGLM --- src/ggml.c | 4 +++- tests/test-grad0.c | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/ggml.c b/src/ggml.c index 60b195b57..aed3fcf50 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -7002,6 +7002,8 @@ struct ggml_tensor * ggml_rope_back( int n_dims, int mode) { GGML_ASSERT(n_past >= 0); + GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet"); + bool is_node = false; if (a->grad) { @@ -15718,7 +15720,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor // necessary for llama if (src0->grad) { assert(src1->type == GGML_TYPE_I32); - assert(ggml_nelements(src1) == 3); + assert(ggml_nelements(src1) == 4); const int n_past = ((int32_t *) src1->data)[0]; const int n_dims = ((int32_t *) src1->data)[1]; const int mode = ((int32_t *) src1->data)[2]; diff --git a/tests/test-grad0.c b/tests/test-grad0.c index ce5dc32bf..a3e25214b 100644 --- a/tests/test-grad0.c +++ b/tests/test-grad0.c @@ -1139,7 +1139,7 @@ int main(int argc, const char ** argv) { int n_rot = ne2[0]; for (int ndims = 3; ndims <= 4; ++ndims) { - for (int mode = 0; mode < 8; ++mode) { + for (int mode = 0; mode < 4; ++mode) { for (int n_past = 1; n_past < ne2[2]; ++n_past) { x[0] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f); From 6e7a1ca48a8e7d266a102d7316f15243a6f55b53 Mon Sep 17 00:00:00 2001 From: Borislav Stanimirov Date: Sun, 2 Jul 2023 18:54:16 +0300 Subject: [PATCH 12/15] ggml : generalize interface for 1d and 2d convolutions (#313) * conv_1d wip * conv_1d opt * conv_1d done * conv_1 improve alias func name * conv_2d wip * conv size to separate func * conv2d done --------- Co-authored-by: Georgi Gerganov --- examples/whisper/whisper.cpp | 4 +- include/ggml/ggml.h | 68 ++++-------- src/ggml.c | 202 ++++++++++++++++++++++------------- 3 files changed, 153 insertions(+), 121 deletions(-) diff --git a/examples/whisper/whisper.cpp b/examples/whisper/whisper.cpp index 74cfd7b24..813a0272d 100644 --- a/examples/whisper/whisper.cpp +++ b/examples/whisper/whisper.cpp @@ -1472,7 +1472,7 @@ static bool whisper_encode_internal( { wstate.use_buf(ctx0, 1); - cur = ggml_conv_1d_s1_ph(ctx0, model.e_conv_1_w, mel); + cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1); cur = ggml_add(ctx0, ggml_repeat(ctx0, model.e_conv_1_b, @@ -1483,7 +1483,7 @@ static bool whisper_encode_internal( wstate.use_buf(ctx0, 0); - cur = ggml_conv_1d_s2_ph(ctx0, model.e_conv_2_w, cur); + cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1); cur = ggml_add(ctx0, ggml_repeat(ctx0, model.e_conv_2_b, diff --git a/include/ggml/ggml.h b/include/ggml/ggml.h index 0b3c31817..0af96c76b 100644 --- a/include/ggml/ggml.h +++ b/include/ggml/ggml.h @@ -361,9 +361,8 @@ extern "C" { GGML_OP_ROPE_BACK, GGML_OP_ALIBI, GGML_OP_CLAMP, - GGML_OP_CONV_1D_S1_PH, - GGML_OP_CONV_1D_S2_PH, - GGML_OP_CONV_2D_SK_P0, + GGML_OP_CONV_1D, + GGML_OP_CONV_2D, GGML_OP_FLASH_ATTN, GGML_OP_FLASH_FF, @@ -1134,58 +1133,33 @@ extern "C" { float min, float max); - // TODO: implement general-purpose convolutions - // GGML_API struct ggml_tensor * ggml_conv_1d( - // struct ggml_context * ctx, - // struct ggml_tensor * a, - // struct ggml_tensor * b, - // int s0 - // int p0, - // int d0); - // - // GGML_API struct ggml_tensor * ggml_conv_2d( - // struct ggml_context * ctx, - // struct ggml_tensor * a, - // struct ggml_tensor * b, - // int s0, - // int s1, - // int p0, - // int p1, - // int d0, - // int d1); - - // padding = half - // TODO: we don't support extra parameters for now - // that's why we are hard-coding the stride, padding, and dilation - // not great .. - // example: - // a: 3 80 768 1 - // b: 3000 80 1 1 - // res: 3000 768 1 1 - // used in whisper - GGML_API struct ggml_tensor * ggml_conv_1d_s1_ph( + GGML_API struct ggml_tensor * ggml_conv_1d( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b); + struct ggml_tensor * b, + int s0, // stride + int p0, // padding + int d0); // dilation - // used in whisper - GGML_API struct ggml_tensor * ggml_conv_1d_s2_ph( + GGML_API struct ggml_tensor * ggml_conv_2d( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b); + struct ggml_tensor * b, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1); - // kernel size is a->ne[0] x a->ne[1] - // stride is equal to kernel size - // padding is zero - // example: - // a: 16 16 3 768 - // b: 1024 1024 3 1 - // res: 64 64 768 1 - // used in sam - GGML_API struct ggml_tensor * ggml_conv_2d_sk_p0( + // conv_1d with padding = half + // alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d) + GGML_API struct ggml_tensor* ggml_conv_1d_ph( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b); + struct ggml_tensor * b, + int s, + int d); GGML_API struct ggml_tensor * ggml_flash_attn( struct ggml_context * ctx, diff --git a/src/ggml.c b/src/ggml.c index aed3fcf50..88cbed7d5 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -3777,9 +3777,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "ROPE_BACK", "ALIBI", "CLAMP", - "CONV_1D_S1_PH", - "CONV_1D_S2_PH", - "CONV_2D_SK_P0", + "CONV_1D", + "CONV_2D", "FLASH_ATTN", "FLASH_FF", @@ -3798,7 +3797,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 67, "GGML_OP_COUNT != 67"); +static_assert(GGML_OP_COUNT == 66, "GGML_OP_COUNT != 66"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -3856,9 +3855,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "rope_back(x)", "alibi(x)", "clamp(x)", - "conv_1d_s1_ph(x)", - "conv_1d_s2_ph(x)", - "conv_2d_sk_p0(x)", + "conv_1d(x)", + "conv_2d(x)", "flash_attn(x)", "flash_ff(x)", @@ -3877,7 +3875,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 67, "GGML_OP_COUNT != 67"); +static_assert(GGML_OP_COUNT == 66, "GGML_OP_COUNT != 66"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN"); @@ -3903,9 +3901,8 @@ static void ggml_setup_op_has_task_pass(void) { p[GGML_OP_GET_ROWS_BACK ] = true; p[GGML_OP_DIAG_MASK_INF ] = true; p[GGML_OP_DIAG_MASK_ZERO ] = true; - p[GGML_OP_CONV_1D_S1_PH ] = true; - p[GGML_OP_CONV_1D_S2_PH ] = true; - p[GGML_OP_CONV_2D_SK_P0 ] = true; + p[GGML_OP_CONV_1D ] = true; + p[GGML_OP_CONV_2D ] = true; p[GGML_OP_FLASH_ATTN_BACK ] = true; p[GGML_OP_CROSS_ENTROPY_LOSS ] = true; } @@ -7104,15 +7101,21 @@ struct ggml_tensor * ggml_clamp( return result; } -// ggml_conv_1d_s1_ph +// ggml_conv_1d -struct ggml_tensor * ggml_conv_1d_s1_ph( +static int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) { + return (ins + 2 * p - d * (ks - 1) - 1) / s + 1; +} + +GGML_API struct ggml_tensor * ggml_conv_1d( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b) { + struct ggml_tensor * b, + int s0, + int p0, + int d0) { GGML_ASSERT(ggml_is_matrix(b)); GGML_ASSERT(a->ne[1] == b->ne[1]); - GGML_ASSERT(a->ne[3] == 1); bool is_node = false; if (a->grad || b->grad) { @@ -7120,26 +7123,43 @@ struct ggml_tensor * ggml_conv_1d_s1_ph( is_node = true; } - const int64_t ne[4] = { b->ne[0], a->ne[2], 1, 1, }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne); + const int64_t ne[4] = { + ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0), + a->ne[2], 1, 1, + }; + struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne); + + ggml_scratch_save(ctx); + struct ggml_tensor* c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3); + ((int32_t*)c->data)[0] = s0; + ((int32_t*)c->data)[1] = p0; + ((int32_t*)c->data)[2] = d0; + ggml_scratch_load(ctx); - result->op = GGML_OP_CONV_1D_S1_PH; + result->op = GGML_OP_CONV_1D; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = b; + result->opt[0] = c; return result; } -// ggml_conv_1d_s2_ph +// ggml_conv_2d -struct ggml_tensor * ggml_conv_1d_s2_ph( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - GGML_ASSERT(ggml_is_matrix(b)); - GGML_ASSERT(a->ne[1] == b->ne[1]); - GGML_ASSERT(a->ne[3] == 1); +struct ggml_tensor* ggml_conv_2d( + struct ggml_context* ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1) { + + GGML_ASSERT(b->ne[3] == 1); + GGML_ASSERT(a->ne[2] == b->ne[2]); bool is_node = false; if (a->grad || b->grad) { @@ -7147,43 +7167,42 @@ struct ggml_tensor * ggml_conv_1d_s2_ph( is_node = true; } - const int64_t ne[4] = { b->ne[0]/2, a->ne[2], 1, 1, }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne); + const int64_t ne[4] = { + ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0), + ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1), + a->ne[3], 1, + }; + struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + ggml_scratch_save(ctx); + struct ggml_tensor* c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 6); + ((int32_t*)c->data)[0] = s0; + ((int32_t*)c->data)[1] = s1; + ((int32_t*)c->data)[2] = p0; + ((int32_t*)c->data)[3] = p1; + ((int32_t*)c->data)[4] = d0; + ((int32_t*)c->data)[5] = d1; + ggml_scratch_load(ctx); - result->op = GGML_OP_CONV_1D_S2_PH; + result->op = GGML_OP_CONV_2D; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = b; + result->opt[0] = c; return result; + } -// ggml_conv_2d_sk_p0 +// ggml_conv_1d_ph -struct ggml_tensor * ggml_conv_2d_sk_p0( +struct ggml_tensor* ggml_conv_1d_ph( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b) { - GGML_ASSERT(b->ne[3] == 1); - GGML_ASSERT(a->ne[2] == b->ne[2]); - GGML_ASSERT(b->ne[0] % a->ne[0] == 0); - GGML_ASSERT(b->ne[1] % a->ne[1] == 0); - bool is_node = false; - - if (a->grad || b->grad) { - GGML_ASSERT(false); // TODO: implement backward - is_node = true; - } - - const int64_t ne[4] = { b->ne[0]/a->ne[0], b->ne[1]/a->ne[1], a->ne[3], 1, }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - - result->op = GGML_OP_CONV_2D_SK_P0; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = b; - - return result; + struct ggml_tensor * b, + int s, + int d) { + return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d); } // ggml_flash_attn @@ -12775,7 +12794,7 @@ static void ggml_compute_forward_rope_back( } } -// ggml_compute_forward_conv_1d_s1_ph +// ggml_compute_forward_conv_1d static void ggml_compute_forward_conv_1d_s1_ph_f16_f32( const struct ggml_compute_params * params, @@ -12980,8 +12999,6 @@ static void ggml_compute_forward_conv_1d_s1_ph( } } -// ggml_compute_forward_conv_1d_s2_ph - static void ggml_compute_forward_conv_1d_s2_ph_f16_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, @@ -13185,6 +13202,28 @@ static void ggml_compute_forward_conv_1d_s2_ph( } } +// ggml_compute_forward_conv_1d + +static void ggml_compute_forward_conv_1d( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * opt0, + struct ggml_tensor * dst) { + const int32_t s0 = ((const int32_t*)(opt0->data))[0]; + const int32_t p0 = ((const int32_t*)(opt0->data))[1]; + const int32_t d0 = ((const int32_t*)(opt0->data))[2]; + GGML_ASSERT(d0 == 1); // dilation not supported + GGML_ASSERT(p0 == src0->ne[0]/2); // only half padding supported + if (s0 == 1) { + ggml_compute_forward_conv_1d_s1_ph(params, src0, src1, dst); + } else if (s0 == 2) { + ggml_compute_forward_conv_1d_s2_ph(params, src0, src1, dst); + } else { + GGML_ASSERT(false); // only stride 1 and 2 supported + }; +} + // ggml_compute_forward_conv_2d_sk_p0 static void ggml_compute_forward_conv_2d_sk_p0_f16_f32( @@ -13292,6 +13331,34 @@ static void ggml_compute_forward_conv_2d_sk_p0( } } +// ggml_compute_forward_conv_2d + +static void ggml_compute_forward_conv_2d( + const struct ggml_compute_params* params, + const struct ggml_tensor* src0, + const struct ggml_tensor* src1, + const struct ggml_tensor* opt0, + struct ggml_tensor* dst) { + const int32_t s0 = ((const int32_t*)(opt0->data))[0]; + const int32_t s1 = ((const int32_t*)(opt0->data))[1]; + const int32_t p0 = ((const int32_t*)(opt0->data))[2]; + const int32_t p1 = ((const int32_t*)(opt0->data))[3]; + const int32_t d0 = ((const int32_t*)(opt0->data))[4]; + const int32_t d1 = ((const int32_t*)(opt0->data))[5]; + GGML_ASSERT(d0 == 1); // dilation not supported + GGML_ASSERT(d1 == 1); + GGML_ASSERT(p0 == 0); // padding not supported + GGML_ASSERT(p1 == 0); + + if (s0 == src0->ne[0] && s1 == src0->ne[1]) { + ggml_compute_forward_conv_2d_sk_p0(params, src0, src1, dst); + } + else { + GGML_ASSERT(false); // only stride equal to kernel size is supported + }; +} + + // ggml_compute_forward_flash_attn static void ggml_compute_forward_flash_attn_f32( @@ -15064,17 +15131,13 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_clamp(params, tensor->src0, tensor->src1, tensor); } break; - case GGML_OP_CONV_1D_S1_PH: - { - ggml_compute_forward_conv_1d_s1_ph(params, tensor->src0, tensor->src1, tensor); - } break; - case GGML_OP_CONV_1D_S2_PH: + case GGML_OP_CONV_1D: { - ggml_compute_forward_conv_1d_s2_ph(params, tensor->src0, tensor->src1, tensor); + ggml_compute_forward_conv_1d(params, tensor->src0, tensor->src1, tensor->opt[0], tensor); } break; - case GGML_OP_CONV_2D_SK_P0: + case GGML_OP_CONV_2D: { - ggml_compute_forward_conv_2d_sk_p0(params, tensor->src0, tensor->src1, tensor); + ggml_compute_forward_conv_2d(params, tensor->src0, tensor->src1, tensor->opt[0], tensor); } break; case GGML_OP_FLASH_ATTN: { @@ -15768,15 +15831,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ASSERT(false); // TODO: not implemented } break; - case GGML_OP_CONV_1D_S1_PH: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_CONV_1D_S2_PH: + case GGML_OP_CONV_1D: { GGML_ASSERT(false); // TODO: not implemented } break; - case GGML_OP_CONV_2D_SK_P0: + case GGML_OP_CONV_2D: { GGML_ASSERT(false); // TODO: not implemented } break; @@ -16555,8 +16614,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) { node->n_tasks = 1; //TODO } break; - case GGML_OP_CONV_1D_S1_PH: - case GGML_OP_CONV_1D_S2_PH: + case GGML_OP_CONV_1D: { node->n_tasks = n_threads; @@ -16585,7 +16643,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) work_size = MAX(work_size, cur); } break; - case GGML_OP_CONV_2D_SK_P0: + case GGML_OP_CONV_2D: { node->n_tasks = n_threads; From 21d2959ec945632553cc7ee430776fbf3f0baf44 Mon Sep 17 00:00:00 2001 From: Hugo Rosenkranz-Costa Date: Sun, 2 Jul 2023 18:05:24 +0200 Subject: [PATCH 13/15] mpt : convert model weights part by part to save memory (#314) * mpt : update conversion script to load model weights part by part * mpt : add usage README --- examples/mpt/README.md | 27 ++++++ examples/mpt/convert-h5-to-ggml.py | 137 ++++++++++++++++------------- 2 files changed, 101 insertions(+), 63 deletions(-) create mode 100644 examples/mpt/README.md mode change 100644 => 100755 examples/mpt/convert-h5-to-ggml.py diff --git a/examples/mpt/README.md b/examples/mpt/README.md new file mode 100644 index 000000000..39f46bae3 --- /dev/null +++ b/examples/mpt/README.md @@ -0,0 +1,27 @@ +# MPT + +Ref: https://github.com/mosaicml/llm-foundry#mpt + +## Usage + +```bash +# get the repo and build it +git clone https://github.com/ggerganov/ggml +cd ggml +mkdir build && cd build +cmake .. +make -j + +# get the model from HuggingFace +# be sure to have git-lfs installed +git clone https://huggingface.co/mosaicml/mpt-30b + +# convert model to FP16 +python3 ../examples/mpt/convert-h5-to-ggml.py ./mpt-30b 1 + +# run inference using FP16 precision +./bin/mpt -m ./mpt-30b/ggml-model-f16.bin -p "I believe the meaning of life is" -t 8 -n 64 + +# quantize the model to 5-bits using Q5_0 quantization +./bin/mpt-quantize ./mpt-30b/ggml-model-f16.bin ./mpt-30b/ggml-model-q5_0.bin q5_0 +``` diff --git a/examples/mpt/convert-h5-to-ggml.py b/examples/mpt/convert-h5-to-ggml.py old mode 100644 new mode 100755 index 0765011cc..ccd6459fe --- a/examples/mpt/convert-h5-to-ggml.py +++ b/examples/mpt/convert-h5-to-ggml.py @@ -1,13 +1,13 @@ -import sys +import os import struct -import json -import numpy as np -from transformers import AutoModelForCausalLM, AutoTokenizer -import sentencepiece.sentencepiece_model_pb2 as model +import sys + +import torch +from transformers import AutoConfig, AutoTokenizer + # ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py def bytes_to_unicode(): - """ Returns list of utf-8 byte and a corresponding list of unicode strings. The reversible bpe codes work on unicode strings. @@ -17,19 +17,36 @@ def bytes_to_unicode(): To avoid that, we want lookup tables between utf-8 bytes and unicode strings. And avoids mapping to whitespace/control characters the bpe code barfs on. """ - bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) cs = bs[:] n = 0 for b in range(2**8): if b not in bs: bs.append(b) - cs.append(2**8+n) + cs.append(2**8 + n) n += 1 cs = [chr(n) for n in cs] return dict(zip(bs, cs)) + +def count_model_parts(dir_model: str) -> int: + """Returns the number of model parts in the model directory.""" + num_parts = 0 + for filename in os.listdir(dir_model): + if filename.startswith("pytorch_model-"): + num_parts += 1 + + if num_parts > 0: + print(f"Found {num_parts} model parts in {dir_model}") + return num_parts + + if len(sys.argv) < 3: print("Usage: convert-h5-to-ggml.py dir-model [use-f32]\n") print(" ftype == 0 -> float32") @@ -39,11 +56,8 @@ def bytes_to_unicode(): # output in the same directory as the model dir_model = sys.argv[1] -fname_out = sys.argv[1] + "/ggml-model.bin" - - -with open(dir_model + "/config.json", "r", encoding="utf-8") as f: - hparams = json.load(f) +# get number of model parts +num_parts = count_model_parts(dir_model) # possible data types # ftype == 0 -> float32 @@ -58,25 +72,15 @@ def bytes_to_unicode(): if ftype < 0 or ftype > 1: print("Invalid ftype: " + str(ftype)) sys.exit(1) - fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin" + fname_out = dir_model + "/ggml-model-" + ftype_str[ftype] + ".bin" tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True) -model = AutoModelForCausalLM.from_pretrained( - dir_model, low_cpu_mem_usage=True, trust_remote_code=True -) -# print (model) - -# print(tokenizer.encode('I believe the meaning of life is')) - -list_vars = model.state_dict() -for name in list_vars.keys(): - print(name, list_vars[name].shape, list_vars[name].dtype) +config = AutoConfig.from_pretrained(dir_model, trust_remote_code=True) +hparams = config.to_dict() fout = open(fname_out, "wb") -print(hparams) - fout.write(struct.pack("i", 0x67676D6C)) # magic: ggml in hex fout.write(struct.pack("i", hparams["d_model"])) fout.write(struct.pack("i", hparams["max_seq_len"])) @@ -94,19 +98,19 @@ def bytes_to_unicode(): encoder.update(tokenizer.get_added_vocab()) byte_encoder = bytes_to_unicode() -byte_decoder = {v:k for k, v in byte_encoder.items()} +byte_decoder = {v: k for k, v in byte_encoder.items()} counter = 0 # sort by value for key in sorted(encoder, key=encoder.get): # workaround for key error when c not found - text="" + text = "" for c in key: if c not in byte_decoder: text += c else: - text += chr(byte_decoder[c] ) - text = bytearray( text, encoding="utf-8" ) + text += chr(byte_decoder[c]) + text = bytearray(text, encoding="utf-8") fout.write(struct.pack("i", len(text))) fout.write(text) counter += 1 @@ -117,40 +121,47 @@ def bytes_to_unicode(): fout.write(text) counter += 1 -# assert counter == config.vocab_size - -for name in list_vars.keys(): - data = list_vars[name].squeeze().numpy() - print("Processing variable: " + name + " with shape: ", data.shape) - - n_dims = len(data.shape) - - # ftype == 0 -> float32, ftype == 1 -> float16 - ftype_cur = 0 - if ftype != 0: - if name[-7:] == ".weight" and n_dims == 2: - print(" Converting to float16") - data = data.astype(np.float16) +if num_parts == 0: + part_names = ("pytorch_model.bin",) +else: + part_names = ( + f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1) + ) + +for part_name in part_names: + print(f"\n* Loading part: {part_name}") + model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu") + + for name in model_part.keys(): + data = model_part[name].squeeze() + n_dims = len(data.shape) + + # ftype == 0 -> float32, ftype == 1 -> float16 + # default type is fp32 + ftype_cur = 0 + if ftype == 1 and name[-7:] == ".weight" and n_dims > 1: ftype_cur = 1 - else: - print(" Converting to float32") - data = data.astype(np.float32) - ftype_cur = 0 - else: - if data.dtype != np.float32: - print(" Converting to float32") - data = data.astype(np.float32) - ftype_cur = 0 - - # header - str = name.encode("utf-8") - fout.write(struct.pack("iii", n_dims, len(str), ftype_cur)) - for i in range(n_dims): - fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) - fout.write(str) - - # data - data.tofile(fout) + data = data.to(dtype=torch.float16 if ftype_cur == 1 else torch.float32).numpy() + + print( + "Processing variable: " + name + " with shape: ", + data.shape, + "->", + data.dtype, + ) + + # header + str = name.encode("utf-8") + fout.write(struct.pack("iii", n_dims, len(str), ftype_cur)) + for i in range(n_dims): + fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) + fout.write(str) + + # data + data.tofile(fout) + + # release memory + del model_part fout.close() From 497d8e586cb0ab226075f4a32d9b2404694627a6 Mon Sep 17 00:00:00 2001 From: sjinzh Date: Mon, 3 Jul 2023 00:36:53 +0800 Subject: [PATCH 14/15] zig : add tests codes using zig (#315) * update build.zig * zig : add tests by zig * zig : add tests codes using zig * zig : add tests codes using zig --- build.zig | 44 +++++++------ tests/test0.zig | 2 - tests/test1.zig | 36 +++++------ tests/test2.zig | 165 ++++++++++++++++++++++++++++++++++++++++++++++++ tests/test3.zig | 102 ++++++++++++++++++++++++++++++ 5 files changed, 308 insertions(+), 41 deletions(-) create mode 100644 tests/test2.zig create mode 100644 tests/test3.zig diff --git a/build.zig b/build.zig index 87151a885..a4f9016e7 100644 --- a/build.zig +++ b/build.zig @@ -1,27 +1,29 @@ const std = @import("std"); -// Zig Version: 0.11.0-dev.3798+a5e15eced +// Zig Version: 0.11.0-dev.3886+0c1bfe271 // Zig Build Command: zig build -// Zig Run Command: zig build -h -// zig build run_dolly-v2 -// zig build run_gpt-2 -// zig build run_gpt-j -// zig build run_gpt-neox -// zig build run_mnist -// zig build run_mpt -// zig build run_replit -// zig build run_starcoder -// zig build run_test-grad0 -// zig build run_test-mul-mat0 -// zig build run_test-mul-mat2 -// zig build run_test-opt -// zig build run_test-vec1 -// zig build run_test0 -// zig build run_test1 -// zig build run_test2 +// Zig Run Command: zig build -h +// zig build run_dolly-v2 +// zig build run_gpt-2 +// zig build run_gpt-j +// zig build run_gpt-neox +// zig build run_mnist +// zig build run_mpt +// zig build run_replit +// zig build run_starcoder +// zig build run_test-grad0 +// zig build run_test-mul-mat0 +// zig build run_test-mul-mat2 +// zig build run_test-opt +// zig build run_test-vec1 +// zig build run_test0 +// zig build run_test1 +// zig build run_test2 // zig build run_test3 // zig build run_zig_test0 -// zig build run_zig_test1 +// zig build run_zig_test1 +// zig build run_zig_test2 +// zig build run_zig_test3 pub fn build(b: *std.build.Builder) void { const target = b.standardTargetOptions(.{}); const optimize = b.standardOptimizeOption(.{}); @@ -117,6 +119,8 @@ pub fn build(b: *std.build.Builder) void { const zig_tests = .{ "test0", "test1", + "test2", + "test3", }; inline for (zig_tests) |name| { const exe = b.addExecutable(.{ @@ -135,4 +139,4 @@ pub fn build(b: *std.build.Builder) void { const run_step = b.step("run_zig_" ++ name, "Run zig_tests"); run_step.dependOn(&run_cmd.step); } -} \ No newline at end of file +} diff --git a/tests/test0.zig b/tests/test0.zig index 8246a978b..e47bf3652 100644 --- a/tests/test0.zig +++ b/tests/test0.zig @@ -1,7 +1,5 @@ const std = @import("std"); const c = @cImport({ - @cInclude("stdio.h"); - @cInclude("stdlib.h"); @cInclude("ggml/ggml.h"); }); diff --git a/tests/test1.zig b/tests/test1.zig index e40c37e81..60fed53d1 100644 --- a/tests/test1.zig +++ b/tests/test1.zig @@ -1,7 +1,5 @@ const std = @import("std"); const c = @cImport({ - @cInclude("stdio.h"); - @cInclude("stdlib.h"); @cInclude("ggml/ggml.h"); }); @@ -114,7 +112,7 @@ pub fn main() !void { c.ggml_graph_dump_dot(&gf, null, "test1-2-forward.dot"); c.ggml_graph_dump_dot(&gb, &gf, "test1-2-backward.dot"); } - + /////////////////////////////////////////////////////////////// { @@ -182,7 +180,7 @@ pub fn main() !void { try std.testing.expect(c.ggml_get_f32_1d(y, 0) == 12.0); try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 24.0); try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 12.0); - try std.testing.expect(c.ggml_get_f32_1d(x3.*.grad, 0) == 4.0); + try std.testing.expect(c.ggml_get_f32_1d(x3.*.grad, 0) == 4.0); const g1 = x1.*.grad; const g2 = x2.*.grad; @@ -197,16 +195,16 @@ pub fn main() !void { c.ggml_graph_compute(ctx0, @constCast(&gbb)); - std.debug.print("H * [1, 1, 1] = [ {d:.6} {d:.6} {d:.6}]\n", - .{ - c.ggml_get_f32_1d(x1.*.grad, 0), + std.debug.print("H * [1, 1, 1] = [ {d:.6} {d:.6} {d:.6}]\n", + .{ + c.ggml_get_f32_1d(x1.*.grad, 0), c.ggml_get_f32_1d(x2.*.grad, 0), c.ggml_get_f32_1d(x3.*.grad, 0), }); try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 56.0); try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 34.0); - try std.testing.expect(c.ggml_get_f32_1d(x3.*.grad, 0) == 12.0); + try std.testing.expect(c.ggml_get_f32_1d(x3.*.grad, 0) == 12.0); c.ggml_graph_dump_dot(&gf, null, "test1-4-forward.dot"); c.ggml_graph_dump_dot(&gb, &gf, "test1-4-backward.dot"); @@ -235,13 +233,13 @@ pub fn main() !void { c.ggml_graph_compute(ctx0, @constCast(&gb)); std.debug.print("y = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)}); - std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n", + std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n", .{ c.ggml_get_f32_1d(x1.*.grad, 0), c.ggml_get_f32_1d(x1.*.grad, 1), c.ggml_get_f32_1d(x1.*.grad, 2), }); - std.debug.print("df/dx2 = {d:.6} {d:.6} {d:.6}\n", + std.debug.print("df/dx2 = {d:.6} {d:.6} {d:.6}\n", .{ c.ggml_get_f32_1d(x2.*.grad, 0), c.ggml_get_f32_1d(x2.*.grad, 1), @@ -292,13 +290,13 @@ pub fn main() !void { c.ggml_graph_compute(ctx0, @constCast(&gb)); std.debug.print("y = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)}); - std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n", + std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n", .{ c.ggml_get_f32_1d(x1.*.grad, 0), c.ggml_get_f32_1d(x1.*.grad, 1), c.ggml_get_f32_1d(x1.*.grad, 2), }); - std.debug.print("df/dx2 = {d:.6} {d:.6} {d:.6}\n", + std.debug.print("df/dx2 = {d:.6} {d:.6} {d:.6}\n", .{ c.ggml_get_f32_1d(x2.*.grad, 0), c.ggml_get_f32_1d(x2.*.grad, 1), @@ -349,13 +347,13 @@ pub fn main() !void { c.ggml_graph_compute(ctx0, @constCast(&gb)); std.debug.print("y = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)}); - std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n", + std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n", .{ c.ggml_get_f32_1d(x1.*.grad, 0), c.ggml_get_f32_1d(x1.*.grad, 1), c.ggml_get_f32_1d(x1.*.grad, 2), }); - std.debug.print("df/dx2 = {d:.6} {d:.6} {d:.6}\n", + std.debug.print("df/dx2 = {d:.6} {d:.6} {d:.6}\n", .{ c.ggml_get_f32_1d(x2.*.grad, 0), c.ggml_get_f32_1d(x2.*.grad, 1), @@ -400,13 +398,13 @@ pub fn main() !void { c.ggml_graph_compute(ctx0, @constCast(&gb)); std.debug.print("y = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)}); - std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n", + std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n", .{ c.ggml_get_f32_1d(x1.*.grad, 0), c.ggml_get_f32_1d(x1.*.grad, 1), c.ggml_get_f32_1d(x1.*.grad, 2), }); - std.debug.print("df/dx2 = {d:.6} {d:.6} {d:.6}\n", + std.debug.print("df/dx2 = {d:.6} {d:.6} {d:.6}\n", .{ c.ggml_get_f32_1d(x2.*.grad, 0), c.ggml_get_f32_1d(x2.*.grad, 1), @@ -430,13 +428,13 @@ pub fn main() !void { c.ggml_graph_compute(ctx0, @constCast(&gb)); std.debug.print("y = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)}); - std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n", + std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n", .{ c.ggml_get_f32_1d(x1.*.grad, 0), c.ggml_get_f32_1d(x1.*.grad, 1), c.ggml_get_f32_1d(x1.*.grad, 2), }); - std.debug.print("df/dx2 = {d:.6} {d:.6} {d:.6}\n", + std.debug.print("df/dx2 = {d:.6} {d:.6} {d:.6}\n", .{ c.ggml_get_f32_1d(x2.*.grad, 0), c.ggml_get_f32_1d(x2.*.grad, 1), @@ -456,4 +454,4 @@ pub fn main() !void { } _ = try std.io.getStdIn().reader().readByte(); -} \ No newline at end of file +} diff --git a/tests/test2.zig b/tests/test2.zig new file mode 100644 index 000000000..667de967f --- /dev/null +++ b/tests/test2.zig @@ -0,0 +1,165 @@ +const std = @import("std"); +const Thread = std.Thread; +const c = @cImport({ + @cInclude("ggml/ggml.h"); +}); + +fn is_close(a: f32, b: f32, epsilon: f32) bool { + return std.math.fabs(a - b) < epsilon; +} + +pub fn main() !void { + const params = .{ + .mem_size = 128*1024*1024, + .mem_buffer = null, + .no_alloc = false, + }; + + var opt_params = c.ggml_opt_default_params(c.GGML_OPT_LBFGS); + + const nthreads = try Thread.getCpuCount(); + opt_params.n_threads = @intCast(nthreads); + std.debug.print("test2: n_threads:{}\n", .{opt_params.n_threads}); + + const xi = [_]f32{ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 }; + const yi = [_]f32{ 15.0, 25.0, 35.0, 45.0, 55.0, 65.0, 75.0, 85.0, 95.0, 105.0 }; + + const n = xi.len; + + const ctx0 = c.ggml_init(params); + defer c.ggml_free(ctx0); + + const x = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, n); + const y = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, n); + + for (0..n) |i| { + const x_data_pointer: [*]f32 = @ptrCast(@alignCast(x.*.data)); + x_data_pointer[i] = xi[i]; + const y_data_pointer: [*]f32 = @ptrCast(@alignCast(y.*.data)); + y_data_pointer[i] = yi[i]; + } + + { + const t0 = c.ggml_new_f32(ctx0, 0.0); + const t1 = c.ggml_new_f32(ctx0, 0.0); + + // initialize auto-diff parameters: + _ = c.ggml_set_param(ctx0, t0); + _ = c.ggml_set_param(ctx0, t1); + + // f = sum_i[(t0 + t1*x_i - y_i)^2]/(2n) + const f = + c.ggml_div(ctx0, + c.ggml_sum(ctx0, + c.ggml_sqr(ctx0, + c.ggml_sub(ctx0, + c.ggml_add(ctx0, + c.ggml_mul(ctx0, x, c.ggml_repeat(ctx0, t1, x)), + c.ggml_repeat(ctx0, t0, x)), + y) + ) + ), + c.ggml_new_f32(ctx0, @as(f32, 2.0)*n)); + + const res = c.ggml_opt(null, opt_params, f); + + std.debug.print("t0 = {d:.6}\n", .{c.ggml_get_f32_1d(t0, 0)}); + std.debug.print("t1 = {d:.6}\n", .{c.ggml_get_f32_1d(t1, 0)}); + + try std.testing.expect(res == c.GGML_OPT_OK); + try std.testing.expect(is_close(c.ggml_get_f32_1d(t0, 0), 5.0, 1e-3)); + try std.testing.expect(is_close(c.ggml_get_f32_1d(t1, 0), 10.0, 1e-3)); + } + + { + const t0 = c.ggml_new_f32(ctx0, -1.0); + const t1 = c.ggml_new_f32(ctx0, 9.0); + + _ = c.ggml_set_param(ctx0, t0); + _ = c.ggml_set_param(ctx0, t1); + + // f = 0.5*sum_i[abs(t0 + t1*x_i - y_i)]/n + const f = + c.ggml_mul(ctx0, + c.ggml_new_f32(ctx0, @as(f32, 1.0)/(2*n)), + c.ggml_sum(ctx0, + c.ggml_abs(ctx0, + c.ggml_sub(ctx0, + c.ggml_add(ctx0, + c.ggml_mul(ctx0, x, c.ggml_repeat(ctx0, t1, x)), + c.ggml_repeat(ctx0, t0, x)), + y) + ) + ) + ); + + + const res = c.ggml_opt(null, opt_params, f); + + try std.testing.expect(res == c.GGML_OPT_OK); + try std.testing.expect(is_close(c.ggml_get_f32_1d(t0, 0), 5.0, 1e-2)); + try std.testing.expect(is_close(c.ggml_get_f32_1d(t1, 0), 10.0, 1e-2)); + } + + { + const t0 = c.ggml_new_f32(ctx0, 5.0); + const t1 = c.ggml_new_f32(ctx0, -4.0); + + _ = c.ggml_set_param(ctx0, t0); + _ = c.ggml_set_param(ctx0, t1); + + // f = t0^2 + t1^2 + const f = + c.ggml_add(ctx0, + c.ggml_sqr(ctx0, t0), + c.ggml_sqr(ctx0, t1) + ); + + const res = c.ggml_opt(null, opt_params, f); + + try std.testing.expect(res == c.GGML_OPT_OK); + try std.testing.expect(is_close(c.ggml_get_f32_1d(f, 0), 0.0, 1e-3)); + try std.testing.expect(is_close(c.ggml_get_f32_1d(t0, 0), 0.0, 1e-3)); + try std.testing.expect(is_close(c.ggml_get_f32_1d(t1, 0), 0.0, 1e-3)); + } + + ///////////////////////////////////////// + + { + const t0 = c.ggml_new_f32(ctx0, -7.0); + const t1 = c.ggml_new_f32(ctx0, 8.0); + + _ = c.ggml_set_param(ctx0, t0); + _ = c.ggml_set_param(ctx0, t1); + + // f = (t0 + 2*t1 - 7)^2 + (2*t0 + t1 - 5)^2 + const f = + c.ggml_add(ctx0, + c.ggml_sqr(ctx0, + c.ggml_sub(ctx0, + c.ggml_add(ctx0, + t0, + c.ggml_mul(ctx0, t1, c.ggml_new_f32(ctx0, 2.0))), + c.ggml_new_f32(ctx0, 7.0) + ) + ), + c.ggml_sqr(ctx0, + c.ggml_sub(ctx0, + c.ggml_add(ctx0, + c.ggml_mul(ctx0, t0, c.ggml_new_f32(ctx0, 2.0)), + t1), + c.ggml_new_f32(ctx0, 5.0) + ) + ) + ); + + const res = c.ggml_opt(null, opt_params, f); + + try std.testing.expect(res == c.GGML_OPT_OK); + try std.testing.expect(is_close(c.ggml_get_f32_1d(f, 0), 0.0, 1e-3)); + try std.testing.expect(is_close(c.ggml_get_f32_1d(t0, 0), 1.0, 1e-3)); + try std.testing.expect(is_close(c.ggml_get_f32_1d(t1, 0), 3.0, 1e-3)); + } + + _ = try std.io.getStdIn().reader().readByte(); +} diff --git a/tests/test3.zig b/tests/test3.zig new file mode 100644 index 000000000..d676961fc --- /dev/null +++ b/tests/test3.zig @@ -0,0 +1,102 @@ +const std = @import("std"); +const Thread = std.Thread; +const c = @cImport({ + @cInclude("stdlib.h"); + @cInclude("ggml/ggml.h"); +}); + +fn is_close(a: f32, b: f32, epsilon: f32) bool { + return std.math.fabs(a - b) < epsilon; +} + +pub fn main() !void { + const params = .{ + .mem_size = 128*1024*1024, + .mem_buffer = null, + .no_alloc = false, + }; + + var opt_params = c.ggml_opt_default_params(c.GGML_OPT_LBFGS); + + const nthreads = try Thread.getCpuCount(); + opt_params.n_threads = @intCast(nthreads); + + const NP = 1 << 12; + const NF = 1 << 8; + + const ctx0 = c.ggml_init(params); + defer c.ggml_free(ctx0); + + const F = c.ggml_new_tensor_2d(ctx0, c.GGML_TYPE_F32, NF, NP); + const l = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, NP); + + // regularization weight + const lambda = c.ggml_new_f32(ctx0, 1e-5); + + c.srand(0); + + const l_data_pointer: [*]f32 = @ptrCast(@alignCast(l.*.data)); + const f_data_pointer: [*]f32 = @ptrCast(@alignCast(F.*.data)); + for (0..NP) |j| { + const ll = if (j < NP/2) @as(f32, 1.0) else @as(f32, -1.0); + l_data_pointer[j] = ll; + + for (0..NF) |i| { + const c_rand: f32 = @floatFromInt(c.rand()); + f_data_pointer[j*NF + i] = + ((if (ll > 0 and i < NF/2) @as(f32, 1.0) else + if (ll < 0 and i >= NF/2) @as(f32, 1.0) else @as(f32, 0.0)) + + (c_rand/c.RAND_MAX - 0.5) * 0.1) / (0.5 * NF); + } + } + + { + // initial guess + const x = c.ggml_set_f32(c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, NF), 0.0); + + c.ggml_set_param(ctx0, x); + + // f = sum[(fj*x - l)^2]/n + lambda*|x^2| + const f = + c.ggml_add(ctx0, + c.ggml_div(ctx0, + c.ggml_sum(ctx0, + c.ggml_sqr(ctx0, + c.ggml_sub(ctx0, + c.ggml_mul_mat(ctx0, F, x), + l) + ) + ), + c.ggml_new_f32(ctx0, @as(f32, NP)) + ), + c.ggml_mul(ctx0, + c.ggml_sum(ctx0, c.ggml_sqr(ctx0, x)), + lambda) + ); + + const res = c.ggml_opt(null, opt_params, f); + + try std.testing.expect(res == c.GGML_OPT_OK); + + const x_data_pointer: [*]f32 = @ptrCast(@alignCast(x.*.data)); + // print results + for (0..16) |i| { + std.debug.print("x[{d:3}] = {d:.6}\n", .{i, x_data_pointer[i]}); + } + std.debug.print("...\n", .{}); + for (NF - 16..NF) |i| { + std.debug.print("x[{d:3}] = {d:.6}\n", .{i, x_data_pointer[i]}); + } + std.debug.print("\n", .{}); + + for (0..NF) |i| { + if (i < NF/2) { + try std.testing.expect(is_close(x_data_pointer[i], 1.0, 1e-2)); + } else { + try std.testing.expect(is_close(x_data_pointer[i], -1.0, 1e-2)); + } + } + } + + _ = try std.io.getStdIn().reader().readByte(); +} From b22fb03254acbe7b490b21e0d50ae4d96da90e22 Mon Sep 17 00:00:00 2001 From: Hirochika Matsumoto Date: Mon, 3 Jul 2023 01:47:47 +0900 Subject: [PATCH 15/15] examples : use GGML_FILE_MAGIC where possible (#323) --- examples/dolly-v2/main.cpp | 2 +- examples/dolly-v2/quantize.cpp | 2 +- examples/gpt-2/main.cpp | 2 +- examples/gpt-2/quantize.cpp | 2 +- examples/gpt-j/main.cpp | 2 +- examples/gpt-j/quantize.cpp | 2 +- examples/gpt-neox/main.cpp | 2 +- examples/gpt-neox/quantize.cpp | 2 +- examples/mnist/main.cpp | 2 +- examples/mpt/main.cpp | 2 +- examples/mpt/quantize.cpp | 2 +- examples/replit/main.cpp | 2 +- examples/replit/quantize.cpp | 2 +- examples/starcoder/main.cpp | 2 +- examples/starcoder/quantize.cpp | 2 +- examples/whisper/quantize.cpp | 2 +- examples/whisper/whisper.cpp | 2 +- 17 files changed, 17 insertions(+), 17 deletions(-) diff --git a/examples/dolly-v2/main.cpp b/examples/dolly-v2/main.cpp index ce7d3d4fa..0d511b4e6 100644 --- a/examples/dolly-v2/main.cpp +++ b/examples/dolly-v2/main.cpp @@ -100,7 +100,7 @@ bool dollyv2_model_load(const std::string & fname, dollyv2_model & model, gpt_vo { uint32_t magic; fin.read((char *) &magic, sizeof(magic)); - if (magic != 0x67676d6c) { + if (magic != GGML_FILE_MAGIC) { fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str()); return false; } diff --git a/examples/dolly-v2/quantize.cpp b/examples/dolly-v2/quantize.cpp index 83f757274..0c0d24ccf 100644 --- a/examples/dolly-v2/quantize.cpp +++ b/examples/dolly-v2/quantize.cpp @@ -47,7 +47,7 @@ bool dollyv2_model_quantize(const std::string & fname_inp, const std::string & f { uint32_t magic; finp.read((char *) &magic, sizeof(magic)); - if (magic != 0x67676d6c) { + if (magic != GGML_FILE_MAGIC) { fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str()); return false; } diff --git a/examples/gpt-2/main.cpp b/examples/gpt-2/main.cpp index 103bd388a..c6f2c7b12 100644 --- a/examples/gpt-2/main.cpp +++ b/examples/gpt-2/main.cpp @@ -85,7 +85,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & { uint32_t magic; fin.read((char *) &magic, sizeof(magic)); - if (magic != 0x67676d6c) { + if (magic != GGML_FILE_MAGIC) { fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str()); return false; } diff --git a/examples/gpt-2/quantize.cpp b/examples/gpt-2/quantize.cpp index d80218952..9d8d53a67 100644 --- a/examples/gpt-2/quantize.cpp +++ b/examples/gpt-2/quantize.cpp @@ -45,7 +45,7 @@ bool gpt2_model_quantize(const std::string & fname_inp, const std::string & fnam { uint32_t magic; finp.read((char *) &magic, sizeof(magic)); - if (magic != 0x67676d6c) { + if (magic != GGML_FILE_MAGIC) { fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str()); return false; } diff --git a/examples/gpt-j/main.cpp b/examples/gpt-j/main.cpp index 516e0998a..ad99c80bc 100644 --- a/examples/gpt-j/main.cpp +++ b/examples/gpt-j/main.cpp @@ -85,7 +85,7 @@ bool gptj_model_load(const std::string & fname, gptj_model & model, gpt_vocab & { uint32_t magic; fin.read((char *) &magic, sizeof(magic)); - if (magic != 0x67676d6c) { + if (magic != GGML_FILE_MAGIC) { fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str()); return false; } diff --git a/examples/gpt-j/quantize.cpp b/examples/gpt-j/quantize.cpp index 0c1f795f4..437053b7d 100644 --- a/examples/gpt-j/quantize.cpp +++ b/examples/gpt-j/quantize.cpp @@ -46,7 +46,7 @@ bool gptj_model_quantize(const std::string & fname_inp, const std::string & fnam { uint32_t magic; finp.read((char *) &magic, sizeof(magic)); - if (magic != 0x67676d6c) { + if (magic != GGML_FILE_MAGIC) { fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str()); return false; } diff --git a/examples/gpt-neox/main.cpp b/examples/gpt-neox/main.cpp index 649b51c0e..065af4dae 100644 --- a/examples/gpt-neox/main.cpp +++ b/examples/gpt-neox/main.cpp @@ -90,7 +90,7 @@ bool gpt_neox_model_load(const std::string & fname, gpt_neox_model & model, gpt_ { uint32_t magic; fin.read((char *) &magic, sizeof(magic)); - if (magic != 0x67676d6c) { + if (magic != GGML_FILE_MAGIC) { fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str()); return false; } diff --git a/examples/gpt-neox/quantize.cpp b/examples/gpt-neox/quantize.cpp index ac7d681cb..96208c1e8 100644 --- a/examples/gpt-neox/quantize.cpp +++ b/examples/gpt-neox/quantize.cpp @@ -47,7 +47,7 @@ bool gpt_neox_model_quantize(const std::string & fname_inp, const std::string & { uint32_t magic; finp.read((char *) &magic, sizeof(magic)); - if (magic != 0x67676d6c) { + if (magic != GGML_FILE_MAGIC) { fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str()); return false; } diff --git a/examples/mnist/main.cpp b/examples/mnist/main.cpp index 823616fc8..fac641e0e 100644 --- a/examples/mnist/main.cpp +++ b/examples/mnist/main.cpp @@ -48,7 +48,7 @@ bool mnist_model_load(const std::string & fname, mnist_model & model) { { uint32_t magic; fin.read((char *) &magic, sizeof(magic)); - if (magic != 0x67676d6c) { + if (magic != GGML_FILE_MAGIC) { fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str()); return false; } diff --git a/examples/mpt/main.cpp b/examples/mpt/main.cpp index e5903c3c3..063d85669 100644 --- a/examples/mpt/main.cpp +++ b/examples/mpt/main.cpp @@ -188,7 +188,7 @@ bool mpt_model_load(const std::string & fname, mpt_model & model, gpt_vocab & vo { uint32_t magic; fin.read((char *)&magic, sizeof(magic)); - if (magic != 0x67676d6c) { + if (magic != GGML_FILE_MAGIC) { fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str()); return false; } diff --git a/examples/mpt/quantize.cpp b/examples/mpt/quantize.cpp index 95b83c36f..d0c9dda82 100644 --- a/examples/mpt/quantize.cpp +++ b/examples/mpt/quantize.cpp @@ -48,7 +48,7 @@ bool mpt_model_quantize(const std::string & fname_inp, { uint32_t magic; finp.read((char *)&magic, sizeof(magic)); - if (magic != 0x67676d6c) { + if (magic != GGML_FILE_MAGIC) { fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str()); return false; diff --git a/examples/replit/main.cpp b/examples/replit/main.cpp index 77a38be0c..4b96ddfb7 100644 --- a/examples/replit/main.cpp +++ b/examples/replit/main.cpp @@ -200,7 +200,7 @@ bool replit_model_load(const std::string & fname, replit_model & model, replit_t { uint32_t magic; fin.read((char *)&magic, sizeof(magic)); - if (magic != 0x67676d6c) { + if (magic != GGML_FILE_MAGIC) { fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str()); return false; } diff --git a/examples/replit/quantize.cpp b/examples/replit/quantize.cpp index 9a4ec4339..f274074bb 100644 --- a/examples/replit/quantize.cpp +++ b/examples/replit/quantize.cpp @@ -46,7 +46,7 @@ bool mpt_model_quantize(const std::string & fname_inp, { uint32_t magic; finp.read((char *)&magic, sizeof(magic)); - if (magic != 0x67676d6c) { + if (magic != GGML_FILE_MAGIC) { fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str()); return false; diff --git a/examples/starcoder/main.cpp b/examples/starcoder/main.cpp index 5c6065980..8a694aa85 100644 --- a/examples/starcoder/main.cpp +++ b/examples/starcoder/main.cpp @@ -86,7 +86,7 @@ bool starcoder_model_load(const std::string & fname, starcoder_model & model, gp { uint32_t magic; fin.read((char *) &magic, sizeof(magic)); - if (magic != 0x67676d6c) { + if (magic != GGML_FILE_MAGIC) { fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str()); return false; } diff --git a/examples/starcoder/quantize.cpp b/examples/starcoder/quantize.cpp index 101af5091..d3aee3f26 100644 --- a/examples/starcoder/quantize.cpp +++ b/examples/starcoder/quantize.cpp @@ -45,7 +45,7 @@ bool starcoder_model_quantize(const std::string & fname_inp, const std::string & { uint32_t magic; finp.read((char *) &magic, sizeof(magic)); - if (magic != 0x67676d6c) { + if (magic != GGML_FILE_MAGIC) { fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str()); return false; } diff --git a/examples/whisper/quantize.cpp b/examples/whisper/quantize.cpp index 3df7b1c71..64e8f35c3 100644 --- a/examples/whisper/quantize.cpp +++ b/examples/whisper/quantize.cpp @@ -57,7 +57,7 @@ bool whisper_model_quantize(const std::string & fname_inp, const std::string & f { uint32_t magic; finp.read((char *) &magic, sizeof(magic)); - if (magic != 0x67676d6c) { + if (magic != GGML_FILE_MAGIC) { fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str()); return false; } diff --git a/examples/whisper/whisper.cpp b/examples/whisper/whisper.cpp index 813a0272d..610180ab7 100644 --- a/examples/whisper/whisper.cpp +++ b/examples/whisper/whisper.cpp @@ -812,7 +812,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con { uint32_t magic; read_safe(loader, magic); - if (magic != 0x67676d6c) { + if (magic != GGML_FILE_MAGIC) { fprintf(stderr, "%s: invalid model data (bad magic)\n", __func__); return false; }