fix mistral

This commit is contained in:
OlivierDehaene 2024-04-09 19:31:16 +02:00
parent d4da0d4d97
commit 26da6bfb2d
2 changed files with 8 additions and 4 deletions

View File

@ -173,6 +173,10 @@ class MistralAttention(torch.nn.Module):
weights=weights,
bias=False,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward(
self,
@ -232,7 +236,7 @@ class MistralAttention(torch.nn.Module):
query,
kv_cache[0],
kv_cache[1],
self.num_key_value_heads,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,

View File

@ -3,12 +3,12 @@ import torch.distributed
from opentelemetry import trace
from typing import Optional
from transformers.models.cohere import AutoTokenizer, CohereConfig
from transformers import AutoTokenizer, AutoConfig
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
FlashCohereForCausalLM,
CohereConfig,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
@ -45,7 +45,7 @@ class FlashCohere(FlashCausalLM):
from_slow=False,
)
config = CohereConfig.from_pretrained(
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize