mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Fix cohere.
This commit is contained in:
parent
daddd2e90b
commit
a76e650283
@ -25,7 +25,11 @@ from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.layers.attention import paged_attention, attention
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
)
|
||||
from text_generation_server.models.globals import FLASH_DECODING
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers import (
|
||||
@ -283,7 +287,7 @@ class FlashCohereAttention(torch.nn.Module):
|
||||
|
||||
self.rotary_emb(query, key, cos, sin)
|
||||
|
||||
paged_attention.reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
||||
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
||||
|
||||
# output tensor
|
||||
attn_output = torch.empty_like(query)
|
||||
|
Loading…
Reference in New Issue
Block a user