feat(server): flash santacoder

This commit is contained in:
OlivierDehaene 2023-04-03 15:25:49 +02:00
parent 5dfc9c7613
commit 05aee8b503
4 changed files with 294 additions and 381 deletions

View File

@ -18,6 +18,7 @@ from text_generation_server.models.t5 import T5Sharded
try:
from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
from text_generation_server.models.flash_santacoder import FlashSantacoder
FLASH_ATTENTION = (
torch.cuda.is_available() and int(os.environ.get("FLASH_ATTENTION", 0)) == 1
@ -67,7 +68,11 @@ def get_model(
return Galactica(model_id, revision, quantize=quantize)
if "santacoder" in model_id:
return SantaCoder(model_id, revision, quantize)
if sharded:
raise NotImplementedError("sharded is not supported for Santacoder")
else:
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
return santacoder_cls(model_id, revision, quantize)
config = AutoConfig.from_pretrained(model_id, revision=revision)
model_type = config.model_type

View File

@ -1,12 +1,8 @@
import torch
import torch.distributed
from torch.nn import functional as F
from torch import nn
from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel
from transformers.models.gpt_neox import GPTNeoXConfig
# Flash attention imports
import flash_attn_cuda
@ -51,12 +47,12 @@ class FastLayerNorm(nn.LayerNorm):
class FastLinear(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
@ -69,132 +65,12 @@ class FastLinear(nn.Linear):
return torch.matmul(input, self.weight)
class TensorParallelColumnLinear(FastLinear):
class FlashMQAttention(torch.nn.Module):
def __init__(
self,
in_features,
out_features,
process_group: torch.distributed.ProcessGroup,
bias=True,
device=None,
dtype=None,
):
self.process_group = process_group
self.tp_world_size = process_group.size()
assert out_features % self.tp_world_size == 0
out_features = out_features // self.tp_world_size
super().__init__(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
dtype=dtype,
)
class TensorParallelRowLinear(FastLinear):
def __init__(
self,
in_features,
out_features,
process_group: torch.distributed.ProcessGroup,
reduce=True,
bias=True,
device=None,
dtype=None,
):
self.process_group = process_group
self.tp_world_size = process_group.size()
self.reduce = reduce
assert in_features % self.tp_world_size == 0
in_features = in_features // self.tp_world_size
super().__init__(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
dtype=dtype,
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
out = super(TensorParallelRowLinear, self).forward(input)
if self.reduce:
torch.distributed.all_reduce(out, group=self.process_group)
return out
class TensorParallelEmbedding(nn.Embedding):
def __init__(
self,
num_embeddings,
embedding_dim,
process_group: torch.distributed.ProcessGroup,
padding_idx=None,
max_norm=None,
norm_type=2.0,
scale_grad_by_freq=False,
sparse=False,
_weight=None,
device=None,
dtype=None,
):
self.process_group = process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.original_num_embeddings = num_embeddings
assert num_embeddings % self.tp_world_size == 0
block_size = num_embeddings // self.tp_world_size
# inputs in `[min_id, max_id[` are handled by `self` to get embeddings
self.min_id = self.tp_rank * block_size
self.max_id = (self.tp_rank + 1) * block_size
# Additional entry that will map to zero
# Used for masking
self.null_idx = block_size
super().__init__(
block_size,
embedding_dim,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse,
_weight=_weight,
device=device,
dtype=dtype,
)
def add_null_idx(self):
"""Additional 0 entry used for masking"""
self.weight = nn.Parameter(F.pad(self.weight, (0, 0, 0, 1)))
def forward(self, input: torch.Tensor) -> torch.Tensor:
# default all out of bounds values to `self.null_idx` that will then be mapped to 0
# translate for [0, self.max_id - self.min_id[
input = torch.where(
(self.min_id > input) | (input >= self.max_id),
self.null_idx,
input - self.min_id,
)
out = super().forward(input)
torch.distributed.all_reduce(out, group=self.process_group)
return out
class FlashNeoxAttention(torch.nn.Module):
def __init__(
self,
num_heads,
hidden_size,
process_group=None,
reduce=True,
self,
num_heads,
hidden_size,
process_group=None,
):
super().__init__()
self.num_heads = num_heads
@ -204,61 +80,43 @@ class FlashNeoxAttention(torch.nn.Module):
self.softmax_scale = self.head_size ** (-0.5)
if process_group is None:
self.query_key_value = FastLinear(hidden_size, 3 * hidden_size)
self.attn = FastLinear(hidden_size, hidden_size + 2 * self.head_size)
self.c_proj = FastLinear(hidden_size, hidden_size)
else:
self.num_heads = self.num_heads // process_group.size()
self.query_key_value = TensorParallelColumnLinear(
hidden_size,
3 * hidden_size,
process_group=process_group,
)
self.c_proj = TensorParallelRowLinear(
hidden_size, hidden_size, process_group=process_group, reduce=reduce
)
def shuffle_qkv_dims(self):
"""Swap dims to avoid an additional permute"""
self.query_key_value.weight = torch.nn.Parameter(
self.query_key_value.weight.view(
self.num_heads, 3, self.head_size, self.hidden_size
)
.permute(1, 0, 2, 3)
.reshape(-1, self.hidden_size)
)
self.query_key_value.bias = torch.nn.Parameter(
self.query_key_value.bias.view(self.num_heads, 3, self.head_size)
.permute(1, 0, 2)
.reshape(-1)
)
raise NotImplementedError
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlens,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
self,
hidden_states,
cu_seqlens,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
):
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
qkv_rot = self.rotary_emb(qkv, cos, sin)
qkv = self.attn(hidden_states)
# Split query from key_value
query, key_value = qkv.split([self.hidden_size, 2 * self.head_size], dim=1)
# Prepare query and key_value for indexing
query = query.view(-1, self.num_heads, self.head_size)
key_value = key_value.view(-1, 2, 1, self.head_size)
# Prefill
if layer_past_present_indices is None:
# Copy to layer past
layer_past[...] = qkv_rot[:, 1:]
layer_past[...] = key_value
# Expand from 1 to num_heads
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)
# output
attn_output = torch.empty_like(qkv[:, 0])
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
query,
key_value[:, 0],
key_value[:, 1],
attn_output,
cu_seqlens,
cu_seqlens,
@ -274,17 +132,18 @@ class FlashNeoxAttention(torch.nn.Module):
)
# Decode
else:
query = qkv_rot[:, 0]
# Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = qkv_rot[:, 1:]
layer_past[layer_past_present_indices] = key_value
# Expand from 1 to num_heads
key_value = layer_past.expand(-1, 2, self.num_heads, self.head_size)
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
layer_past[:, 0],
layer_past[:, 1],
key_value[:, 0],
key_value[:, 1],
attn_output,
cu_seqlens_q,
cu_seqlens,
@ -299,226 +158,147 @@ class FlashNeoxAttention(torch.nn.Module):
None,
)
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size))
class FlashMLP(nn.Module):
class MLP(nn.Module):
def __init__(
self, act, hidden_size, intermediate_size, process_group=None, reduce=True
self, act, hidden_size, intermediate_size, process_group=None
):
super().__init__()
self.act = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(x, approximate="tanh")
else lambda x: torch.nn.functional.gelu(x, approximate="tanh" if act in ["gelu_fast",
"gelu_pytorch_tanh"] else None)
)
if process_group is None:
self.dense_h_to_4h = FastLinear(hidden_size, intermediate_size)
self.dense_4h_to_h = FastLinear(intermediate_size, hidden_size)
self.c_fc = FastLinear(hidden_size, intermediate_size)
self.c_proj = FastLinear(intermediate_size, hidden_size)
else:
self.dense_h_to_4h = TensorParallelColumnLinear(
hidden_size,
intermediate_size,
process_group=process_group,
)
self.dense_4h_to_h = TensorParallelRowLinear(
intermediate_size,
hidden_size,
process_group=process_group,
reduce=reduce,
)
self.process_group = process_group
raise NotImplementedError
def forward(self, hidden_states):
hidden_states = self.dense_h_to_4h(hidden_states)
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.dense_4h_to_h(hidden_states)
hidden_states = self.c_proj(hidden_states)
return hidden_states
class FlashNeoXLayer(nn.Module):
class Block(nn.Module):
def __init__(
self,
num_heads,
act,
hidden_size,
intermediate_size,
rotary_pct,
rotary_emb_base,
layer_norm_eps,
use_parallel_residual,
process_group=None,
self,
num_heads,
act,
hidden_size,
intermediate_size,
layer_norm_eps,
process_group=None,
):
super().__init__()
self.use_parallel_residual = use_parallel_residual
self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
self.post_attention_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
self.attention = FlashNeoxAttention(
self.ln_1 = FastLayerNorm(hidden_size, eps=layer_norm_eps)
self.ln_2 = FastLayerNorm(hidden_size, eps=layer_norm_eps)
self.attn = FlashMQAttention(
num_heads,
hidden_size,
rotary_pct,
rotary_emb_base,
process_group,
reduce=not use_parallel_residual,
)
self.mlp = FlashMLP(
self.mlp = MLP(
act,
hidden_size,
intermediate_size,
process_group,
reduce=not use_parallel_residual,
)
self.process_group = process_group
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlens,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
self,
hidden_states,
residual,
cu_seqlens,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
):
if self.use_parallel_residual:
ln1_hidden_states, _ = self.input_layernorm(hidden_states)
hidden_states, residual = self.ln_1(hidden_states, residual)
attn_output = self.attention(
ln1_hidden_states,
cos,
sin,
cu_seqlens,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
)
hidden_states = self.attn(
hidden_states,
cu_seqlens,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
)
ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)
hidden_states, residual = self.ln_2(
hidden_states, residual
)
mlp_output = self.mlp(ln2_hidden_states)
intermediate = mlp_output + attn_output
mlp_output = self.mlp(hidden_states)
# Only reduce once and after the addition instead of once per layer
if self.process_group is not None:
torch.distributed.all_reduce(intermediate, group=self.process_group)
return intermediate + hidden_states, None
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.attention(
hidden_states,
cos,
sin,
cu_seqlens,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
)
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
mlp_output = self.mlp(hidden_states)
return mlp_output, residual
return mlp_output, residual
class FlashGPTNeoXPreTrainedModel(PreTrainedModel):
config_class = GPTNeoXConfig
base_model_prefix = "gpt_neox"
supports_gradient_checkpointing = False
_no_split_modules = None
class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
class FlashSantacoderModel(nn.Module):
def __init__(self, config, process_group=None):
super().__init__(config)
super().__init__()
self.config = config
self.tp_embeddings = False
if process_group is not None:
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
if config.vocab_size % self.tp_world_size == 0:
self.tp_embeddings = True
raise NotImplementedError
if self.tp_embeddings:
self.embed_in = TensorParallelEmbedding(
config.vocab_size, config.hidden_size, process_group=process_group
)
else:
self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
self.wte = nn.Embedding(config.vocab_size, config.hidden_size)
self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.layers = nn.ModuleList(
self.h = nn.ModuleList(
[
FlashNeoXLayer(
Block(
config.num_attention_heads,
config.hidden_act,
config.activation_function,
config.hidden_size,
config.intermediate_size,
config.rotary_pct,
config.rotary_emb_base,
config.layer_norm_eps,
config.use_parallel_residual,
config.n_inner if config.n_inner is not None else 4 * config.hidden_size,
config.layer_norm_epsilon,
process_group,
)
for _ in range(config.num_hidden_layers)
]
)
self.final_layer_norm = FastLayerNorm(
config.hidden_size, eps=config.layer_norm_eps
self.ln_f = FastLayerNorm(
config.hidden_size, eps=config.layer_norm_epsilon
)
self.gradient_checkpointing = False
self.head_size = self.layers[0].attention.head_size
self.num_heads = self.layers[0].attention.num_heads
self.head_size = self.h[0].attn.head_size
self.num_heads = self.h[0].attn.num_heads
def post_load_weights(self):
if isinstance(self.embed_in, TensorParallelEmbedding):
self.embed_in.add_null_idx()
for layer in self.layers:
layer: FlashNeoXLayer
layer.attention.shuffle_qkv_dims()
layer.attention.query_key_value.transpose_weight()
layer.attention.dense.transpose_weight()
layer.mlp.dense_h_to_4h.transpose_weight()
layer.mlp.dense_4h_to_h.transpose_weight()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
model = super(FlashGPTNeoXModel, cls).from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
model.post_load_weights()
return model
for layer in self.h:
layer: Block
layer.attn.attn.transpose_weight()
layer.attn.c_proj.transpose_weight()
layer.mlp.c_fc.transpose_weight()
layer.mlp.c_proj.transpose_weight()
def forward(
self,
input_ids,
position_ids,
cu_seqlens,
max_s,
past_key_values=None,
self,
input_ids,
position_ids,
cu_seqlens,
max_s,
past_key_values=None,
):
hidden_states = self.embed_in(input_ids)
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
# Prefill
if past_key_values is None:
# Create past tensor
past_key_values = hidden_states.new_empty(
(
len(self.layers),
len(self.h),
len(hidden_states),
2,
self.num_heads,
1,
self.head_size,
)
)
@ -532,19 +312,11 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device
)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None
for i, layer in enumerate(self.layers):
for i, layer in enumerate(self.h):
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlens,
max_s,
past_key_values[i],
@ -552,54 +324,34 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
cu_seqlens_q,
)
hidden_states, _ = self.final_layer_norm(hidden_states, residual)
hidden_states, _ = self.ln_f(hidden_states, residual)
return hidden_states, past_key_values
class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
def __init__(self, config):
super().__init__(config)
class FlashSantacoderForCausalLM(nn.Module):
def __init__(self, config, process_group=None):
super().__init__()
if config.tp_parallel:
process_group = torch.distributed.distributed_c10d._get_default_group()
else:
process_group = None
self.transformer = FlashSantacoderModel(config, process_group)
self.gpt_neox = FlashGPTNeoXModel(config, process_group)
if self.gpt_neox.tp_embeddings:
self.embed_out = FastLinear(
config.hidden_size,
config.vocab_size // process_group.size(),
bias=False,
)
else:
self.embed_out = FastLinear(
config.hidden_size, config.vocab_size, bias=False
)
self.lm_head = FastLinear(
config.hidden_size, config.vocab_size, bias=False
)
def post_load_weights(self):
self.gpt_neox.post_load_weights()
self.embed_out.transpose_weight()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
model = super(FlashGPTNeoXForCausalLM, cls).from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
model.post_load_weights()
return model
self.transformer.post_load_weights()
self.lm_head.transpose_weight()
def forward(
self,
input_ids,
position_ids,
cu_seqlens,
max_s,
past_key_values=None,
self,
input_ids,
position_ids,
cu_seqlens,
max_s,
past_key_values=None,
):
hidden_states, present = self.gpt_neox(
hidden_states, present = self.transformer(
input_ids, position_ids, cu_seqlens, max_s, past_key_values
)
return self.embed_out(hidden_states), present
return self.lm_head(hidden_states), present

View File

@ -0,0 +1,138 @@
import torch
import torch.distributed
from accelerate import init_empty_weights
from opentelemetry import trace
from pathlib import Path
from transformers import AutoTokenizer, AutoConfig
from typing import Optional, List
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
FlashSantacoderForCausalLM
)
from text_generation_server.utils import (
weight_files,
download_weights,
weight_hub_files,
LocalEntryNotFoundError,
)
tracer = trace.get_tracer(__name__)
class FlashSantacoder(FlashCausalLM):
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
else:
raise NotImplementedError("FlashSantacoder is only available on GPU")
if quantize:
raise NotImplementedError("FlashSantacoder does not support quantization")
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left"
)
config = AutoConfig.from_pretrained(
model_id, revision=revision,
trust_remote_code=True # Needed as the config is not part of Transformers
)
# We do not use from_pretrained as we modified the model internal module layout
try:
filenames = weight_files(model_id, revision, ".bin")
# Local files not found
except LocalEntryNotFoundError:
hub_files = weight_hub_files(model_id, revision, ".bin")
filenames = download_weights(hub_files, model_id, revision)
with init_empty_weights():
model = FlashSantacoderForCausalLM(config)
self.load_weights(
model,
filenames,
)
self.model = model.eval().to(device).to(dtype)
super(FlashCausalLM, self).__init__(
tokenizer=tokenizer,
device=device,
)
@staticmethod
def load_weights(
model: FlashSantacoderForCausalLM,
filenames: List[Path],
):
for filename in filenames:
state_dict = torch.load(filename, map_location="cpu")
for key, value in state_dict.items():
layer_name = ".".join(key.split(".")[:4])
# Fused qkv
if "q_attn.weight" in key or "kv_attn.weight" in key:
final_key = layer_name + ".attn.weight"
elif "q_attn.bias" in key or "kv_attn.bias" in key:
final_key = layer_name + ".attn.bias"
else:
final_key = key
module_name, param_name = final_key.rsplit(".", 1)
module = model.get_submodule(module_name)
try:
current_parameter_tensor = module._parameters[param_name]
except KeyError:
current_parameter_tensor = None
if current_parameter_tensor is not None:
if "c_fc.weight" in key or "c_proj.weight" in key or "q_attn.weight" in key or "kv_attn.weight" in key:
# Tranpose as we use nn.Linear instead of Conv1D
value = value.T
if current_parameter_tensor.device == torch.device("meta"):
# Init qkv
if "attn.weight" in final_key:
module._parameters[param_name] = value.new_empty(
(model.transformer.head_size * (model.transformer.num_heads + 2), value.shape[1])
)
elif "attn.bias" in final_key:
module._parameters[param_name] = value.new_empty(
(model.transformer.head_size * (model.transformer.num_heads + 2))
)
# Copy to correct slice
if "q_attn.weight" in key:
module._parameters[param_name][: value.shape[0]] = value
elif "q_attn.bias" in key:
module._parameters[param_name][: value.shape[0]] = value
elif "kv_attn.weight" in key:
module._parameters[param_name][
model.transformer.head_size * model.transformer.num_heads:
] = value
elif "kv_attn.bias" in key:
module._parameters[param_name][
model.transformer.head_size * model.transformer.num_heads:
] = value
else:
if current_parameter_tensor.shape != value.shape:
raise ValueError(
f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}"
)
module._parameters[param_name] = value
else:
module._buffers[param_name] = value
torch.cuda.empty_cache()
model.post_load_weights()
def decode(self, generated_ids: List[int]) -> str:
# Do not skip special tokens as they are used for custom parsing rules of the generated text
return self.tokenizer.decode(
generated_ids, skip_special_tokens=False, cleanup_tokenization_spaces=False
)

View File

@ -6,6 +6,12 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
from text_generation_server.models import CausalLM
FIM_PREFIX = "<fim-prefix>"
FIM_MIDDLE = "<fim-middle>"
FIM_SUFFIX = "<fim-suffix>"
FIM_PAD = "<fim-pad>"
EOD = "<|endoftext|>"
class SantaCoder(CausalLM):
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
@ -22,6 +28,18 @@ class SantaCoder(CausalLM):
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left"
)
tokenizer.add_special_tokens(
{
"additional_special_tokens": [
EOD,
FIM_PREFIX,
FIM_MIDDLE,
FIM_SUFFIX,
FIM_PAD,
],
"pad_token": EOD,
}
)
self.model = (
AutoModelForCausalLM.from_pretrained(