fix cohere

This commit is contained in:
OlivierDehaene 2024-04-05 18:42:33 +02:00
parent 91d76a65f5
commit 946bf44242
3 changed files with 68 additions and 9 deletions

View File

@ -211,7 +211,7 @@ COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /op
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
# Install flash-attention dependencies
RUN pip install einops --no-cache-dir
RUN pip install einops prometheus_client --no-cache-dir
# Install server
COPY proto proto

View File

@ -27,6 +27,7 @@ from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.import_utils import IS_ROCM_SYSTEM, IS_CUDA_SYSTEM
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -34,9 +35,65 @@ from text_generation_server.utils.layers import (
PositionRotaryEmbedding,
SpeculativeHead,
get_linear,
FastRMSNorm,
FastLayerNorm,
)
if IS_CUDA_SYSTEM:
import dropout_layer_norm
else:
dropout_layer_norm = None
class CohereLayerNorm(nn.Module):
def __init__(self, prefix, weights, eps):
super().__init__()
weight = weights.get_tensor(f"{prefix}.weight")
self.weight = nn.Parameter(weight)
# Fake weights
self.ones = weight.new_ones(weight.shape[1])
self.eps = eps
def forward(self, hidden_states):
if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
mean = hidden_states.mean(-1, keepdim=True)
hidden_states_minus_mean = hidden_states - mean
variance = hidden_states_minus_mean.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states_minus_mean * torch.rsqrt(variance + self.eps)
hidden_states = self.weight.to(torch.float32) * hidden_states
return hidden_states.to(input_dtype)
(
hidden_states,
*rest,
) = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
None,
self.ones,
None,
None,
None,
None,
None,
0.0,
self.eps,
1.0,
0,
None,
False,
False,
)
# Required to apply one weight matrix per head
hidden_states = hidden_states.view(
-1, self.weight.shape[0], self.weight.shape[1]
)
hidden_states = self.weight * hidden_states
hidden_states = hidden_states.view(-1, self.weight.shape[1])
return hidden_states
class CohereConfig(PretrainedConfig):
def __init__(
@ -180,7 +237,7 @@ class FlashCohereAttention(torch.nn.Module):
self.use_qk_norm = config.use_qk_norm
if self.use_qk_norm:
rank = weights.process_group.rank()
self.q_norm = FastRMSNorm.load(
self.q_norm = CohereLayerNorm(
prefix=f"{prefix}.q_norm",
weights=weights,
eps=config.layer_norm_eps,
@ -188,7 +245,7 @@ class FlashCohereAttention(torch.nn.Module):
self.q_norm.weight.data = self.q_norm.weight[
self.num_heads * rank : self.num_heads * (rank + 1)
]
self.k_norm = FastRMSNorm.load(
self.k_norm = CohereLayerNorm(
prefix=f"{prefix}.k_norm",
weights=weights,
eps=config.layer_norm_eps,
@ -230,8 +287,10 @@ class FlashCohereAttention(torch.nn.Module):
)
if self.use_qk_norm:
query, _ = self.q_norm(query.contiguous())
key, _ = self.k_norm(key.contiguous())
query = query.reshape(-1, self.head_size)
key = key.reshape(-1, self.head_size)
query = self.q_norm(query)
key = self.k_norm(key)
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_key_value_heads, self.head_size)
@ -324,7 +383,7 @@ class FlashCohereLayer(nn.Module):
)
self.mlp = CohereMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = FastRMSNorm.load(
self.input_layernorm = FastLayerNorm.load_no_bias(
prefix=f"{prefix}.input_layernorm",
weights=weights,
eps=config.layer_norm_eps,
@ -388,7 +447,7 @@ class FlashCohereModel(torch.nn.Module):
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = FastRMSNorm.load(
self.norm = FastLayerNorm.load_no_bias(
prefix="model.norm", weights=weights, eps=config.layer_norm_eps
)

View File

@ -32,7 +32,7 @@ class FlashCohere(FlashCausalLM):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if dtype is None else dtype
dtype = torch.float16 if dtype is None else dtype
else:
raise NotImplementedError("FlashCohere is only available on GPU")