From a76e650283828d4aa5fd2144d046f878b142b405 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 29 May 2024 17:41:15 +0000 Subject: [PATCH] Fix cohere. --- .../models/custom_modeling/flash_cohere_modeling.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index ec73031c..088e3062 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -25,7 +25,11 @@ from torch import nn from transformers.activations import ACT2FN from typing import Optional, List, Tuple -from text_generation_server.layers.attention import paged_attention, attention +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, +) from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers import ( @@ -283,7 +287,7 @@ class FlashCohereAttention(torch.nn.Module): self.rotary_emb(query, key, cos, sin) - paged_attention.reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) + reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) # output tensor attn_output = torch.empty_like(query)