This commit is contained in:
OlivierDehaene 2023-05-30 10:41:10 +02:00
parent 12ab24ae64
commit cbffddcc06
7 changed files with 509 additions and 255 deletions

View File

@ -286,7 +286,9 @@ def test_batch_concatenate(
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
)
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
next_batch = next_batch.filter(
[next_batch.requests[0].id, next_batch.requests[1].id]
)
for _ in range(
default_bloom_batch.stopping_criterias[0].max_new_tokens

View File

@ -285,7 +285,9 @@ def test_batch_concatenate(
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
)
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
next_batch = next_batch.filter(
[next_batch.requests[0].id, next_batch.requests[1].id]
)
for _ in range(
default_causal_lm_batch.stopping_criterias[0].max_new_tokens

View File

@ -323,7 +323,9 @@ def test_batch_concatenate(
)
assert generations[2].generated_text.generated_tokens == 5
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
next_batch = next_batch.filter(
[next_batch.requests[0].id, next_batch.requests[1].id]
)
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None

View File

@ -31,7 +31,7 @@ try:
)
from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
from text_generation_server.models.flash_rw import FlashRW
from text_generation_server.models.flash_rw import FlashRW, FlashRWSharded
from text_generation_server.models.flash_llama import (
FlashLlama,
FlashLlamaSharded,
@ -71,6 +71,7 @@ if FLASH_ATTENTION:
__all__.append(FlashNeoX)
__all__.append(FlashNeoXSharded)
__all__.append(FlashRW)
__all__.append(FlashRWSharded)
__all__.append(FlashSantacoder)
__all__.append(FlashSantacoderSharded)
__all__.append(FlashLlama)
@ -202,13 +203,15 @@ def get_model(
if FLASH_ATTENTION:
if config.alibi:
raise NotImplementedError("sharded is not supported for this model")
# return FlashRWSharded(
# model_id,
# revision,
# quantize=quantize,
# trust_remote_code=trust_remote_code,
# )
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb"))
return FlashRWSharded(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb")
)
else:
if FLASH_ATTENTION and not config.alibi:
return FlashRW(

View File

@ -48,7 +48,9 @@ class RWConfig(PretrainedConfig):
**kwargs,
):
if alibi:
raise NotImplementedError("alibi is not supported by this version of the model")
raise NotImplementedError(
"alibi is not supported by this version of the model"
)
self.model_type = model_type
self.alibi = False
@ -99,15 +101,25 @@ class FlashRWAttention(torch.nn.Module):
self.softmax_scale = self.head_size ** (-0.5)
if process_group is None:
self.query_key_value = FastLinear(hidden_size, self.head_size * (self.num_heads + 2 * self.num_heads_kv),
bias=bias)
self.query_key_value = FastLinear(
hidden_size,
self.head_size * (self.num_heads + 2 * self.num_heads_kv),
bias=bias,
)
self.dense = FastLinear(hidden_size, hidden_size, bias=bias)
else:
self.num_heads = self.num_heads // process_group.size()
self.query_key_value = FastLinear(hidden_size, self.head_size * (self.num_heads + 2 * self.num_heads_kv),
bias=bias)
self.query_key_value = FastLinear(
hidden_size,
self.head_size * (self.num_heads + 2 * self.num_heads_kv),
bias=bias,
)
self.dense = TensorParallelRowLinear(
hidden_size, hidden_size, bias=bias, process_group=process_group, reduce=reduce
hidden_size,
hidden_size,
bias=bias,
process_group=process_group,
reduce=reduce,
)
def forward(
@ -125,7 +137,8 @@ class FlashRWAttention(torch.nn.Module):
# Split query from key_value
query, kv = qkv.split(
[self.head_size * self.num_heads, 2 * self.head_size * self.num_heads_kv], dim=1
[self.head_size * self.num_heads, 2 * self.head_size * self.num_heads_kv],
dim=1,
)
# Prepare query and key_value for indexing
@ -194,11 +207,149 @@ class FlashRWAttention(torch.nn.Module):
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
class FlashMLP(nn.Module):
class FlashRWLargeAttention(torch.nn.Module):
def __init__(
self, hidden_size, bias, process_group=None, reduce=True
self,
num_heads,
num_heads_kv,
hidden_size,
bias,
process_group=None,
reduce=True,
):
super().__init__()
self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads
self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000)
self.softmax_scale = self.head_size ** (-0.5)
self.num_groups = num_heads // (num_heads_kv * 2)
self.num_heads = num_heads // self.num_groups
self.num_heads_kv = num_heads_kv // self.num_groups
if process_group is None:
self.query_key_value = FastLinear(
hidden_size,
self.num_groups *
self.head_size
* (self.num_heads + 2 * self.num_heads_kv),
bias=bias,
)
self.dense = FastLinear(hidden_size, hidden_size, bias=bias)
else:
self.query_key_value = TensorParallelColumnLinear(
hidden_size,
self.num_groups
* self.head_size
* (self.num_heads + 2 * self.num_heads_kv),
bias=bias,
process_group=process_group,
)
self.dense = TensorParallelRowLinear(
hidden_size,
hidden_size,
bias=bias,
process_group=process_group,
reduce=reduce,
)
self.num_groups = self.num_groups // process_group.size()
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlens,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
):
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
# Split query from key_value
query, kv = qkv.split(
[self.num_heads, 2],
dim=2,
)
# Prepare query and key_value for indexing
query = query.reshape(-1, self.num_groups * self.num_heads, self.head_size)
kv = kv.transpose(1, 2)
# Inplace rotary
self.rotary_emb(query, cos, sin)
self.rotary_emb(kv[:, 0], cos, sin)
# Prefill
if layer_past_present_indices is None:
# Copy to layer past
layer_past[...] = kv
k, v = kv.split(1, dim=1)
# Expand to query shape
k = k.expand(-1, self.num_heads, self.num_groups, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
v = v.expand(-1, self.num_heads, self.num_groups, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
k,
v,
attn_output,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
self.softmax_scale,
False,
True,
False,
0,
None,
)
# Decode
else:
# Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = kv
k, v = layer_past.split(1, dim=1)
# Expand to query shape
k = k.expand(-1, self.num_heads, self.num_groups, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
v = v.expand(-1, self.num_heads, self.num_groups, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
k,
v,
attn_output,
cu_seqlens_q,
cu_seqlens,
1,
max_s,
0.0,
self.softmax_scale,
False,
False,
False,
0,
None,
)
return self.dense(attn_output.view(-1, self.num_heads * self.num_groups * self.head_size))
class FlashMLP(nn.Module):
def __init__(self, hidden_size, bias, process_group=None, reduce=True):
super().__init__()
self.act = torch.nn.functional.gelu
if process_group is None:
@ -207,12 +358,14 @@ class FlashMLP(nn.Module):
else:
self.dense_h_to_4h = TensorParallelColumnLinear(
hidden_size,
4 * hidden_size, bias=bias,
4 * hidden_size,
bias=bias,
process_group=process_group,
)
self.dense_4h_to_h = TensorParallelRowLinear(
4 * hidden_size,
hidden_size, bias=bias,
hidden_size,
bias=bias,
process_group=process_group,
reduce=reduce,
)
@ -231,6 +384,7 @@ class FlashRWLayer(nn.Module):
num_heads,
num_heads_kv,
hidden_size,
bias,
layer_norm_eps,
parallel_attn,
process_group=None,
@ -240,10 +394,16 @@ class FlashRWLayer(nn.Module):
self.parallel_attn = parallel_attn
self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
self.self_attention = FlashRWAttention(num_heads, num_heads_kv, hidden_size, process_group, reduce=False)
self.post_attention_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps) if not parallel_attn else None
self.self_attention = FlashRWAttention(
num_heads, num_heads_kv, hidden_size, bias, process_group=process_group, reduce=False
)
self.post_attention_layernorm = (
FastLayerNorm(hidden_size, eps=layer_norm_eps)
if not parallel_attn
else None
)
self.mlp = FlashMLP(hidden_size, process_group, reduce=False)
self.mlp = FlashMLP(hidden_size, bias, process_group=process_group, reduce=False)
self.process_group = process_group
@ -303,6 +463,68 @@ class FlashRWLayer(nn.Module):
return mlp_output, residual
class FlashRWLargeLayer(nn.Module):
def __init__(
self,
num_heads,
num_heads_kv,
hidden_size,
bias,
layer_norm_eps,
process_group=None,
):
super().__init__()
self.ln_attn = FastLayerNorm(hidden_size, eps=layer_norm_eps)
self.ln_mlp = FastLayerNorm(hidden_size, eps=layer_norm_eps)
self.self_attention = FlashRWLargeAttention(
num_heads, num_heads_kv, hidden_size, bias, process_group=process_group, reduce=False
)
self.mlp = FlashMLP(hidden_size, bias, process_group=process_group, reduce=False)
self.process_group = process_group
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlens,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
):
ln_attn, residual = self.ln_attn(hidden_states, residual)
ln_mlp, _ = self.ln_mlp(hidden_states, residual)
# Self attention.
attn_output = self.self_attention(
ln_attn,
cos,
sin,
cu_seqlens,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
)
# MLP.
mlp_output = self.mlp(ln_mlp)
intermediate = attn_output + mlp_output
# Only reduce once and after the addition instead of once per layer
if self.process_group is not None:
torch.distributed.all_reduce(intermediate, group=self.process_group)
return intermediate, residual
class FlashRWPreTrainedModel(PreTrainedModel):
config_class = RWConfig
supports_gradient_checkpointing = False
@ -328,12 +550,14 @@ class FlashRWModel(FlashRWPreTrainedModel):
else:
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
if config.model_type == "RefinedWebModel":
self.h = nn.ModuleList(
[
FlashRWLayer(
config.n_head,
config.n_head_kv,
config.hidden_size,
config.bias,
config.layer_norm_epsilon,
config.parallel_attn,
process_group,
@ -341,14 +565,32 @@ class FlashRWModel(FlashRWPreTrainedModel):
for _ in range(config.num_hidden_layers)
]
)
self.ln_f = FastLayerNorm(
config.hidden_size, eps=config.layer_norm_epsilon
self.kv_size = self.h[0].self_attention.num_heads_kv
elif config.model_type == "RefinedWeb":
self.h = nn.ModuleList(
[
FlashRWLargeLayer(
config.n_head,
config.n_head_kv,
config.hidden_size,
config.bias,
config.layer_norm_epsilon,
process_group,
)
for _ in range(config.num_hidden_layers)
]
)
self.kv_size = self.h[0].self_attention.num_groups
else:
raise NotImplementedError(
f"model_type {config.model_type} is not supported."
)
self.ln_f = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.gradient_checkpointing = False
self.head_size = self.h[0].self_attention.head_size
self.num_heads_kv = self.h[0].self_attention.num_heads_kv
def post_load_weights(self, quantize: Optional[str] = None):
if isinstance(self.word_embeddings, TensorParallelEmbedding):
@ -394,7 +636,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
if pre_allocate_past_size is None
else pre_allocate_past_size,
2,
self.num_heads_kv,
self.kv_size,
self.head_size,
)
)
@ -457,9 +699,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
bias=False,
)
else:
self.lm_head = FastLinear(
config.hidden_size, config.vocab_size, bias=False
)
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
def post_load_weights(self, quantize: Optional[str] = None):
self.transformer.post_load_weights(quantize)

View File

@ -53,8 +53,6 @@ class FlashRW(FlashCausalLM):
model_id,
revision=revision,
)
from loguru import logger
logger.error(config.model_type)
# We do not use from_pretrained as we modified the model internal module layout
try:
@ -114,133 +112,134 @@ class FlashRW(FlashCausalLM):
torch.cuda.empty_cache()
model.post_load_weights(quantize)
#
# class FlashNeoXSharded(FlashNeoX):
# def __init__(
# self,
# model_id: str,
# revision: Optional[str] = None,
# quantize: Optional[str] = None,
# trust_remote_code: bool = False,
# ):
# self.process_group, rank, world_size = initialize_torch_distributed()
# if torch.cuda.is_available():
# device = torch.device(f"cuda:{rank}")
# dtype = torch.float16
# else:
# raise NotImplementedError("FlashNeoX is only available on GPU")
#
# tokenizer = AutoTokenizer.from_pretrained(
# model_id,
# revision=revision,
# padding_side="left",
# truncation_side="left",
# trust_remote_code=trust_remote_code,
# )
#
# config = AutoConfig.from_pretrained(
# model_id, revision=revision, trust_remote_code=trust_remote_code
# )
#
# torch.distributed.barrier(group=self.process_group)
# filenames = weight_files(model_id, revision=revision, extension=".safetensors")
#
# with init_empty_weights():
# model = FlashGPTNeoXForCausalLM(config, self.process_group)
#
# torch.distributed.barrier(group=self.process_group)
# self.load_weights(
# model,
# filenames,
# quantize=quantize,
# device=device,
# dtype=dtype,
# rank=rank,
# world_size=world_size,
# )
# torch.distributed.barrier(group=self.process_group)
# super(FlashCausalLM, self).__init__(
# model=model.to(device),
# tokenizer=tokenizer,
# requires_padding=False,
# dtype=dtype,
# device=device,
# rank=rank,
# world_size=world_size,
# )
#
# @staticmethod
# def load_weights(
# model,
# filenames: List[str],
# quantize: Optional[str],
# device: torch.device,
# dtype: torch.dtype,
# rank: int,
# world_size: int,
# ):
# parameters = dict(model.named_parameters())
# for file in filenames:
# with safe_open(
# file, framework="pt", device=str(device) if quantize is None else "cpu"
# ) as f:
# for name in f.keys():
# module_name, param_name = name.rsplit(".", 1)
# module = model.get_submodule(module_name)
#
# current_parameter_tensor = parameters.get(name, None)
#
# slice_ = f.get_slice(name)
#
# if isinstance(module, TensorParallelColumnLinear):
# size = slice_.get_shape()[0]
# block_size = size // world_size
# start = rank * block_size
# stop = (rank + 1) * block_size
# tensor = slice_[start:stop]
# elif isinstance(module, TensorParallelRowLinear):
# if param_name == "weight":
# size = slice_.get_shape()[1]
# block_size = size // world_size
# start = rank * block_size
# stop = (rank + 1) * block_size
# tensor = slice_[:, start:stop]
# else:
# tensor = slice_[:]
# # XXX: Hack for Rowlinear to add the bias only once.
# if rank != 0:
# tensor = torch.zeros_like(tensor)
# elif isinstance(module, TensorParallelEmbedding):
# size = slice_.get_shape()[0]
# block_size = size // world_size
# start = rank * block_size
# stop = (rank + 1) * block_size
# tensor = slice_[start:stop]
# elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings:
# size = slice_.get_shape()[0]
# block_size = size // world_size
# start = rank * block_size
# stop = (rank + 1) * block_size
# tensor = slice_[start:stop]
# else:
# try:
# tensor = slice_[:]
# except:
# tensor = f.get_tensor(name)
#
# if (
# current_parameter_tensor is not None
# and current_parameter_tensor.shape != tensor.shape
# ):
# raise ValueError(
# f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
# )
#
# tensor = tensor.contiguous().to(dtype)
#
# if current_parameter_tensor is not None:
# module._parameters[param_name] = tensor
# else:
# module._buffers[param_name] = tensor
#
# model.post_load_weights(quantize)
class FlashRWSharded(FlashRW):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16
else:
raise NotImplementedError("FlashRW is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = RWConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
with init_empty_weights():
model = FlashRWForCausalLM(config, self.process_group)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
quantize=quantize,
device=device,
dtype=dtype,
rank=rank,
world_size=world_size,
)
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for name in f.keys():
module_name, param_name = name.rsplit(".", 1)
module = model.get_submodule(module_name)
current_parameter_tensor = parameters.get(name, None)
slice_ = f.get_slice(name)
if isinstance(module, TensorParallelColumnLinear):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif name == "lm_head.weight" and model.transformer.tp_embeddings:
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
try:
tensor = slice_[:]
except:
tensor = f.get_tensor(name)
if (
current_parameter_tensor is not None
and current_parameter_tensor.shape != tensor.shape
):
raise ValueError(
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous().to(dtype)
if current_parameter_tensor is not None:
module._parameters[param_name] = tensor
else:
module._buffers[param_name] = tensor
model.post_load_weights(quantize)

View File

@ -71,10 +71,16 @@ class RW(CausalLM):
for layer in past_key_values:
past_keys, past_values = layer
reshaped_past_key_values.append(
(past_keys.view(-1, *past_keys.shape[-2:]), past_values.view(-1, *past_values.shape[-2:]))
(
past_keys.view(-1, *past_keys.shape[-2:]),
past_values.view(-1, *past_values.shape[-2:]),
)
)
past_key_values = reshaped_past_key_values
outputs = self.model.forward(input_ids=input_ids, attention_mask=attention_mask,
past_key_values=past_key_values)
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
)
return outputs.logits, outputs.past_key_values