mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix cohere
This commit is contained in:
parent
91d76a65f5
commit
946bf44242
@ -211,7 +211,7 @@ COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /op
|
|||||||
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
||||||
|
|
||||||
# Install flash-attention dependencies
|
# Install flash-attention dependencies
|
||||||
RUN pip install einops --no-cache-dir
|
RUN pip install einops prometheus_client --no-cache-dir
|
||||||
|
|
||||||
# Install server
|
# Install server
|
||||||
COPY proto proto
|
COPY proto proto
|
||||||
|
@ -27,6 +27,7 @@ from transformers.configuration_utils import PretrainedConfig
|
|||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
|
from text_generation_server.utils.import_utils import IS_ROCM_SYSTEM, IS_CUDA_SYSTEM
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
@ -34,9 +35,65 @@ from text_generation_server.utils.layers import (
|
|||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
FastRMSNorm,
|
FastLayerNorm,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if IS_CUDA_SYSTEM:
|
||||||
|
import dropout_layer_norm
|
||||||
|
else:
|
||||||
|
dropout_layer_norm = None
|
||||||
|
|
||||||
|
|
||||||
|
class CohereLayerNorm(nn.Module):
|
||||||
|
def __init__(self, prefix, weights, eps):
|
||||||
|
super().__init__()
|
||||||
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
self.weight = nn.Parameter(weight)
|
||||||
|
# Fake weights
|
||||||
|
self.ones = weight.new_ones(weight.shape[1])
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
mean = hidden_states.mean(-1, keepdim=True)
|
||||||
|
hidden_states_minus_mean = hidden_states - mean
|
||||||
|
variance = hidden_states_minus_mean.pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states_minus_mean * torch.rsqrt(variance + self.eps)
|
||||||
|
hidden_states = self.weight.to(torch.float32) * hidden_states
|
||||||
|
return hidden_states.to(input_dtype)
|
||||||
|
|
||||||
|
(
|
||||||
|
hidden_states,
|
||||||
|
*rest,
|
||||||
|
) = dropout_layer_norm.dropout_add_ln_fwd(
|
||||||
|
hidden_states,
|
||||||
|
None,
|
||||||
|
self.ones,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
0.0,
|
||||||
|
self.eps,
|
||||||
|
1.0,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Required to apply one weight matrix per head
|
||||||
|
hidden_states = hidden_states.view(
|
||||||
|
-1, self.weight.shape[0], self.weight.shape[1]
|
||||||
|
)
|
||||||
|
hidden_states = self.weight * hidden_states
|
||||||
|
hidden_states = hidden_states.view(-1, self.weight.shape[1])
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class CohereConfig(PretrainedConfig):
|
class CohereConfig(PretrainedConfig):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -180,7 +237,7 @@ class FlashCohereAttention(torch.nn.Module):
|
|||||||
self.use_qk_norm = config.use_qk_norm
|
self.use_qk_norm = config.use_qk_norm
|
||||||
if self.use_qk_norm:
|
if self.use_qk_norm:
|
||||||
rank = weights.process_group.rank()
|
rank = weights.process_group.rank()
|
||||||
self.q_norm = FastRMSNorm.load(
|
self.q_norm = CohereLayerNorm(
|
||||||
prefix=f"{prefix}.q_norm",
|
prefix=f"{prefix}.q_norm",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.layer_norm_eps,
|
eps=config.layer_norm_eps,
|
||||||
@ -188,7 +245,7 @@ class FlashCohereAttention(torch.nn.Module):
|
|||||||
self.q_norm.weight.data = self.q_norm.weight[
|
self.q_norm.weight.data = self.q_norm.weight[
|
||||||
self.num_heads * rank : self.num_heads * (rank + 1)
|
self.num_heads * rank : self.num_heads * (rank + 1)
|
||||||
]
|
]
|
||||||
self.k_norm = FastRMSNorm.load(
|
self.k_norm = CohereLayerNorm(
|
||||||
prefix=f"{prefix}.k_norm",
|
prefix=f"{prefix}.k_norm",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.layer_norm_eps,
|
eps=config.layer_norm_eps,
|
||||||
@ -230,8 +287,10 @@ class FlashCohereAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.use_qk_norm:
|
if self.use_qk_norm:
|
||||||
query, _ = self.q_norm(query.contiguous())
|
query = query.reshape(-1, self.head_size)
|
||||||
key, _ = self.k_norm(key.contiguous())
|
key = key.reshape(-1, self.head_size)
|
||||||
|
query = self.q_norm(query)
|
||||||
|
key = self.k_norm(key)
|
||||||
|
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
key = key.view(-1, self.num_key_value_heads, self.head_size)
|
key = key.view(-1, self.num_key_value_heads, self.head_size)
|
||||||
@ -324,7 +383,7 @@ class FlashCohereLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
self.mlp = CohereMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
self.mlp = CohereMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||||
|
|
||||||
self.input_layernorm = FastRMSNorm.load(
|
self.input_layernorm = FastLayerNorm.load_no_bias(
|
||||||
prefix=f"{prefix}.input_layernorm",
|
prefix=f"{prefix}.input_layernorm",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.layer_norm_eps,
|
eps=config.layer_norm_eps,
|
||||||
@ -388,7 +447,7 @@ class FlashCohereModel(torch.nn.Module):
|
|||||||
for layer_id in range(config.num_hidden_layers)
|
for layer_id in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.norm = FastRMSNorm.load(
|
self.norm = FastLayerNorm.load_no_bias(
|
||||||
prefix="model.norm", weights=weights, eps=config.layer_norm_eps
|
prefix="model.norm", weights=weights, eps=config.layer_norm_eps
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ class FlashCohere(FlashCausalLM):
|
|||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashCohere is only available on GPU")
|
raise NotImplementedError("FlashCohere is only available on GPU")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user