diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index abe74be9..ed9306e0 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -173,6 +173,10 @@ class MistralAttention(torch.nn.Module): weights=weights, bias=False, ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) def forward( self, @@ -232,7 +236,7 @@ class MistralAttention(torch.nn.Module): query, kv_cache[0], kv_cache[1], - self.num_key_value_heads, + self.kv_head_mapping, self.softmax_scale, block_tables, input_lengths, diff --git a/server/text_generation_server/models/flash_cohere.py b/server/text_generation_server/models/flash_cohere.py index ebdf3793..f85c7722 100644 --- a/server/text_generation_server/models/flash_cohere.py +++ b/server/text_generation_server/models/flash_cohere.py @@ -3,12 +3,12 @@ import torch.distributed from opentelemetry import trace from typing import Optional -from transformers.models.cohere import AutoTokenizer, CohereConfig +from transformers import AutoTokenizer, AutoConfig from text_generation_server.models import FlashCausalLM from text_generation_server.models.custom_modeling.flash_cohere_modeling import ( FlashCohereForCausalLM, - CohereConfig, +) from text_generation_server.utils import ( initialize_torch_distributed, weight_files, @@ -45,7 +45,7 @@ class FlashCohere(FlashCausalLM): from_slow=False, ) - config = CohereConfig.from_pretrained( + config = AutoConfig.from_pretrained( model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize