mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix mistral
This commit is contained in:
parent
d4da0d4d97
commit
26da6bfb2d
@ -173,6 +173,10 @@ class MistralAttention(torch.nn.Module):
|
|||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
self.kv_head_mapping = torch.arange(
|
||||||
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
|
).repeat_interleave(self.num_groups)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -232,7 +236,7 @@ class MistralAttention(torch.nn.Module):
|
|||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
kv_cache[1],
|
kv_cache[1],
|
||||||
self.num_key_value_heads,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
|
@ -3,12 +3,12 @@ import torch.distributed
|
|||||||
|
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from transformers.models.cohere import AutoTokenizer, CohereConfig
|
from transformers import AutoTokenizer, AutoConfig
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
from text_generation_server.models import FlashCausalLM
|
||||||
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
|
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
|
||||||
FlashCohereForCausalLM,
|
FlashCohereForCausalLM,
|
||||||
CohereConfig,
|
)
|
||||||
from text_generation_server.utils import (
|
from text_generation_server.utils import (
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
@ -45,7 +45,7 @@ class FlashCohere(FlashCausalLM):
|
|||||||
from_slow=False,
|
from_slow=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
config = CohereConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
Loading…
Reference in New Issue
Block a user