Fix cohere.

This commit is contained in:
Nicolas Patry 2024-05-29 17:41:15 +00:00
parent daddd2e90b
commit a76e650283

View File

@ -25,7 +25,11 @@ from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.layers.attention import paged_attention, attention
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
)
from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers import (
@ -283,7 +287,7 @@ class FlashCohereAttention(torch.nn.Module):
self.rotary_emb(query, key, cos, sin)
paged_attention.reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)