diff --git a/ggml_extend.hpp b/ggml_extend.hpp index acbee880..8452a0b6 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -710,11 +710,10 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* float scale = (1.0f / sqrt((float)d_head)); - if (flash_attn) { - // TODO: remove before merge - LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N); - } - // is there anything oddly shaped?? + //if (flash_attn) { + // LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N); + //} + // is there anything oddly shaped?? ping Green-Sky if you can trip this assert GGML_ASSERT(((L_k % 256 == 0) && L_q == L_k) || !(L_k % 256 == 0)); bool can_use_flash_attn = true; @@ -725,17 +724,17 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* can_use_flash_attn = can_use_flash_attn && d_head <= 256; // double check if (mask != nullptr) { - // TODO: figure out if we can bend t5 to work too + // TODO(Green-Sky): figure out if we can bend t5 to work too can_use_flash_attn = can_use_flash_attn && mask->ne[2] == 1; can_use_flash_attn = can_use_flash_attn && mask->ne[3] == 1; } - // TODO: more pad or disable for funny tensor shapes + // TODO(Green-Sky): more pad or disable for funny tensor shapes ggml_tensor* kqv = nullptr; //GGML_ASSERT((flash_attn && can_use_flash_attn) || !flash_attn); if (can_use_flash_attn && flash_attn) { - LOG_DEBUG("using flash attention"); + //LOG_DEBUG("using flash attention"); k = ggml_cast(ctx, k, GGML_TYPE_F16); v = ggml_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3)); // [N, n_head, L_k, d_head]