Fixing flash rw.

This commit is contained in:
Ubuntu 2023-06-06 10:45:59 +00:00
parent 2a1ecf3863
commit d083d57d0d
5 changed files with 205 additions and 354 deletions

View File

@ -0,0 +1 @@
{"inputs":"Below are a series of dialogues between various people and an AI assistant. The AI tries to be helpful, polite, honest, sophisticated, emotionally aware, and humble-but-knowledgeable. The assistant is happy to help with almost anything, and will do its best to understand exactly what is needed. It also tries to avoid giving false or misleading information, and it caveats when it isn't entirely sure about the right answer. That said, the assistant is practical and really does its best, and doesn't let caution get too much in the way of being useful.\n-----\n<|prompter|>Why is butter a great building material for skyscrapers? Think step by step.</s><|assistant|>","parameters":{"temperature": 0.75, "top_p": 0.95, "repetition_penalty": 1.2, "top_k": 50, "truncate": 1000, "max_new_tokens": 1024}}

View File

@ -292,20 +292,12 @@ class FlashLlamaModel(torch.nn.Module):
super().__init__()
self.config = config
self.tp_embeddings = False
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
if config.vocab_size % self.tp_world_size == 0:
self.tp_embeddings = True
if self.tp_embeddings:
self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights
)
else:
self.embed_tokens = Embedding(prefix="model.embed_tokens", weights=weights)
self.layers = nn.ModuleList(
[
FlashLlamaLayer(

View File

@ -12,14 +12,29 @@ from typing import Optional
import flash_attn_cuda
from text_generation_server.utils.layers import (
FastLinear,
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelHead,
FastLayerNorm,
PositionRotaryEmbedding,
get_linear
)
def load_row(config, prefix: str, weights, bias: bool):
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
linear = get_linear(weight, bias, config.quantize)
if config.parallel_attn:
return linear
else:
return TensorParallelRowLinear(linear, process_group=weights.process_group)
class RWConfig(PretrainedConfig):
attribute_map = {
@ -85,44 +100,26 @@ class RWConfig(PretrainedConfig):
class FlashRWAttention(torch.nn.Module):
def __init__(
self,
num_heads,
num_heads_kv,
hidden_size,
bias,
process_group=None,
config, prefix, weights,
# num_heads,
# num_heads_kv,
# hidden_size,
# bias,
# process_group=None,
reduce=True,
):
super().__init__()
self.num_heads = num_heads
self.num_heads_kv = num_heads_kv
self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads
self.num_heads = config.n_head
self.num_heads_kv = config.n_head_kv
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000)
self.rotary_emb = PositionRotaryEmbedding.static(dim=self.head_size, base=10000.0, device=weights.device)
self.softmax_scale = self.head_size ** (-0.5)
self.num_heads = self.num_heads //weights.process_group.size()
if process_group is None:
self.query_key_value = FastLinear(
hidden_size,
self.head_size * (self.num_heads + 2 * self.num_heads_kv),
bias=bias,
)
self.dense = FastLinear(hidden_size, hidden_size, bias=bias)
else:
self.query_key_value = TensorParallelColumnLinear(
hidden_size,
self.head_size * (self.num_heads + 2 * self.num_heads_kv),
bias=bias,
process_group=process_group,
)
self.dense = TensorParallelRowLinear(
hidden_size,
hidden_size,
bias=bias,
process_group=process_group,
reduce=reduce,
)
self.num_heads = self.num_heads // process_group.size()
self.query_key_value = TensorParallelColumnLinear.load(config, prefix=f"{prefix}.query_key_value", weights=weights, bias=config.bias)
self.dense = load_row(config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias)
def forward(
self,
@ -224,7 +221,8 @@ class FlashRWLargeAttention(torch.nn.Module):
self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads
self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000)
# self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000)
self.rotary_emb = PositionRotaryEmbedding.load(prefix=f"{prefix}.rotary_emb", weights=weights)
self.softmax_scale = self.head_size ** (-0.5)
self.num_groups = num_heads // (num_heads_kv * 2)
@ -359,28 +357,12 @@ class FlashRWLargeAttention(torch.nn.Module):
class FlashMLP(nn.Module):
def __init__(self, hidden_size, bias, process_group=None, reduce=True):
def __init__(self, config, prefix, weights, reduce=True):
super().__init__()
self.act = torch.nn.functional.gelu
if process_group is None:
self.dense_h_to_4h = FastLinear(hidden_size, 4 * hidden_size, bias=bias)
self.dense_4h_to_h = FastLinear(4 * hidden_size, hidden_size, bias=bias)
else:
self.dense_h_to_4h = TensorParallelColumnLinear(
hidden_size,
4 * hidden_size,
bias=bias,
process_group=process_group,
)
self.dense_4h_to_h = TensorParallelRowLinear(
4 * hidden_size,
hidden_size,
bias=bias,
process_group=process_group,
reduce=reduce,
)
self.process_group = process_group
self.dense_h_to_4h = TensorParallelColumnLinear.load(config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=config.bias)
self.dense_4h_to_h = load_row(config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=config.bias)
def forward(self, hidden_states):
hidden_states = self.dense_h_to_4h(hidden_states)
@ -392,38 +374,62 @@ class FlashMLP(nn.Module):
class FlashRWLayer(nn.Module):
def __init__(
self,
num_heads,
num_heads_kv,
hidden_size,
bias,
layer_norm_eps,
parallel_attn,
process_group=None,
layer_id,
config,
weights,
# num_heads,
# num_heads_kv,
# hidden_size,
# bias,
# layer_norm_eps,
# parallel_attn,
# process_group=None,
):
super().__init__()
n_head = config.n_head
n_head_kv = config.n_head_kv
hidden_size = config.hidden_size
bias = config.bias
parallel_attn = config.parallel_attn
self.parallel_attn = parallel_attn
self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
prefix = f"transformer.h.{layer_id}"
self.input_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.self_attention = FlashRWAttention(
num_heads,
num_heads_kv,
hidden_size,
bias,
process_group=process_group,
# num_heads,
# num_heads_kv,
# hidden_size,
# bias,
# process_group=process_group,
config,
prefix=f"{prefix}.self_attention",
weights=weights,
reduce=False,
)
self.post_attention_layernorm = (
FastLayerNorm(hidden_size, eps=layer_norm_eps)
if not parallel_attn
FastLayerNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.layer_norm_epsilon,
) if not parallel_attn
else None
)
self.mlp = FlashMLP(
hidden_size, bias, process_group=process_group, reduce=False
# hidden_size, bias, process_group=process_group, reduce=False
config,
prefix=f"{prefix}.mlp",
weights=weights,
reduce=False
)
self.process_group = process_group
self.process_group = weights.process_group
def forward(
self,
@ -485,31 +491,30 @@ class FlashRWLayer(nn.Module):
class FlashRWLargeLayer(nn.Module):
def __init__(
self,
num_heads,
num_heads_kv,
hidden_size,
bias,
layer_norm_eps,
process_group=None,
config, prefix, weights
):
super().__init__()
self.ln_attn = FastLayerNorm(hidden_size, eps=layer_norm_eps)
self.ln_mlp = FastLayerNorm(hidden_size, eps=layer_norm_eps)
self.ln_attn = FastLayerNorm.load(
prefix=f"{prefix}.ln_attn",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.ln_mlp = FastLayerNorm.load(
prefix=f"{prefix}.ln_mlp",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.self_attention = FlashRWLargeAttention(
num_heads,
num_heads_kv,
hidden_size,
bias,
process_group=process_group,
config, prefix=f"{prefix}.self_attention", weights=weights,
reduce=False,
)
self.mlp = FlashMLP(
hidden_size, bias, process_group=process_group, reduce=False
config, prefix=f"{prefix}.mlp", weights=weights, reduce=False
)
self.process_group = process_group
self.process_group = weights.process_group
def forward(
self,
@ -555,37 +560,27 @@ class FlashRWPreTrainedModel(PreTrainedModel):
class FlashRWModel(FlashRWPreTrainedModel):
def __init__(self, config, process_group=None):
def __init__(self, config, weights):
super().__init__(config)
self.config = config
self.tp_embeddings = False
if process_group is not None:
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
if config.vocab_size % self.tp_world_size == 0:
self.tp_embeddings = True
if self.tp_embeddings:
self.word_embeddings = TensorParallelEmbedding(
config.vocab_size, config.hidden_size, process_group=process_group
prefix="transformer.word_embeddings", weights=weights
)
else:
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
if config.model_type == "RefinedWebModel":
self.h = nn.ModuleList(
[
FlashRWLayer(
config.n_head,
config.n_head_kv,
config.hidden_size,
config.bias,
config.layer_norm_epsilon,
config.parallel_attn,
process_group,
layer_id, config, weights
# config.n_head,
# config.n_head_kv,
# config.hidden_size,
# config.bias,
# config.layer_norm_epsilon,
# config.parallel_attn,
# process_group,
)
for _ in range(config.num_hidden_layers)
for layer_id in range(config.num_hidden_layers)
]
)
self.cache_size = (
@ -597,14 +592,15 @@ class FlashRWModel(FlashRWPreTrainedModel):
self.h = nn.ModuleList(
[
FlashRWLargeLayer(
config.n_head,
config.n_head_kv,
config.hidden_size,
config.bias,
config.layer_norm_epsilon,
process_group,
layer_id, config, weights
# config.n_head,
# config.n_head_kv,
# config.hidden_size,
# config.bias,
# config.layer_norm_epsilon,
# process_group,
)
for _ in range(config.num_hidden_layers)
for layer_id in range(config.num_hidden_layers)
]
)
self.cache_size = (
@ -617,31 +613,13 @@ class FlashRWModel(FlashRWPreTrainedModel):
f"model_type {config.model_type} is not supported."
)
self.ln_f = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.head_size = self.h[0].self_attention.head_size
def post_load_weights(self, quantize: Optional[str] = None):
if isinstance(self.word_embeddings, TensorParallelEmbedding):
self.word_embeddings.add_null_idx()
for layer in self.h:
layer: FlashRWLayer
layer.self_attention.query_key_value.prepare_weights(quantize)
layer.self_attention.dense.prepare_weights(quantize)
layer.mlp.dense_h_to_4h.prepare_weights(quantize)
layer.mlp.dense_4h_to_h.prepare_weights(quantize)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
# to do it for us
load_in_8bit = kwargs.pop("load_in_8bit", False)
model = super(FlashRWModel, cls).from_pretrained(
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
self.ln_f = FastLayerNorm.load(
prefix="transformer.ln_f",
weights=weights,
eps=config.layer_norm_epsilon,
)
model.post_load_weights("bitsandbytes" if load_in_8bit else None)
return model
self.head_size = self.h[0].self_attention.head_size
def forward(
self,
@ -708,40 +686,14 @@ class FlashRWModel(FlashRWPreTrainedModel):
class FlashRWForCausalLM(FlashRWPreTrainedModel):
def __init__(self, config, process_group=None):
def __init__(self, config, weights):
super().__init__(config)
self.process_group = process_group
if self.process_group is not None:
self.world_size = self.process_group.size()
else:
self.world_size = 1
self.transformer = FlashRWModel(config, weights)
self.transformer = FlashRWModel(config, process_group)
if self.transformer.tp_embeddings:
self.lm_head = FastLinear(
config.hidden_size,
config.vocab_size // process_group.size(),
bias=False,
self.lm_head = TensorParallelHead.load(
config, prefix="lm_head", weights=weights
)
else:
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
def post_load_weights(self, quantize: Optional[str] = None):
self.transformer.post_load_weights(quantize)
self.lm_head.prepare_weights()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
# to do it for us
load_in_8bit = kwargs.pop("load_in_8bit", False)
model = super(FlashRWForCausalLM, cls).from_pretrained(
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
)
model.post_load_weights("bitsandbytes" if load_in_8bit else None)
return model
def forward(
self,
@ -766,12 +718,4 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
if self.transformer.tp_embeddings:
# Logits are sharded, so we need to gather them
world_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
torch.distributed.all_gather(world_logits, logits, group=self.process_group)
world_logits = torch.cat(world_logits, dim=1)
return world_logits, present
return logits, present

View File

@ -21,99 +21,14 @@ from text_generation_server.utils import (
weight_files,
download_weights,
weight_hub_files,
Weights,
LocalEntryNotFoundError,
)
tracer = trace.get_tracer(__name__)
class FlashRW(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16
else:
raise NotImplementedError("RW is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = RWConfig.from_pretrained(
model_id,
revision=revision,
)
# We do not use from_pretrained as it is too slow
try:
filenames = weight_files(model_id, revision, ".bin")
# Local files not found
except LocalEntryNotFoundError:
hub_files = weight_hub_files(model_id, revision, ".bin")
filenames = download_weights(hub_files, model_id, revision)
with init_empty_weights():
model = FlashRWForCausalLM(config)
self.load_weights(
model,
filenames,
quantize,
device,
dtype,
)
super(FlashCausalLM, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
)
@staticmethod
def load_weights(
model: FlashRWForCausalLM,
filenames: List[Path],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
):
for filename in filenames:
state_dict = torch.load(filename, map_location="cpu")
for key, value in state_dict.items():
value = value.to(device if quantize is None else "cpu").to(dtype)
module_name, param_name = key.rsplit(".", 1)
module = model.get_submodule(module_name)
try:
current_parameter_tensor = module._parameters[param_name]
if current_parameter_tensor.shape != value.shape:
raise ValueError(
f"Name {key} -- Current {current_parameter_tensor.shape} and got {value.shape}"
)
module._parameters[param_name] = value
except KeyError:
module._buffers[param_name] = value
del value
torch.cuda.empty_cache()
model.post_load_weights(quantize)
class FlashRWSharded(FlashRW):
class FlashRWSharded(FlashCausalLM):
def __init__(
self,
model_id: str,
@ -142,20 +57,12 @@ class FlashRWSharded(FlashRW):
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
with init_empty_weights():
model = FlashRWForCausalLM(config, self.process_group)
config.quantize = quantize
model = FlashRWForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
quantize=quantize,
device=device,
dtype=dtype,
rank=rank,
world_size=world_size,
)
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
model=model.to(device),
@ -167,78 +74,78 @@ class FlashRWSharded(FlashRW):
world_size=world_size,
)
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for name in f.keys():
module_name, param_name = name.rsplit(".", 1)
module = model.get_submodule(module_name)
# @staticmethod
# def load_weights(
# model,
# filenames: List[str],
# quantize: Optional[str],
# device: torch.device,
# dtype: torch.dtype,
# rank: int,
# world_size: int,
# ):
# parameters = dict(model.named_parameters())
# for file in filenames:
# with safe_open(
# file, framework="pt", device=str(device) if quantize is None else "cpu"
# ) as f:
# for name in f.keys():
# module_name, param_name = name.rsplit(".", 1)
# module = model.get_submodule(module_name)
current_parameter_tensor = parameters.get(name, None)
# current_parameter_tensor = parameters.get(name, None)
slice_ = f.get_slice(name)
# slice_ = f.get_slice(name)
if isinstance(module, TensorParallelColumnLinear):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif name == "lm_head.weight" and model.transformer.tp_embeddings:
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
try:
tensor = slice_[:]
except:
tensor = f.get_tensor(name)
# if isinstance(module, TensorParallelColumnLinear):
# size = slice_.get_shape()[0]
# block_size = size // world_size
# start = rank * block_size
# stop = (rank + 1) * block_size
# tensor = slice_[start:stop]
# elif isinstance(module, TensorParallelRowLinear):
# if param_name == "weight":
# size = slice_.get_shape()[1]
# block_size = size // world_size
# start = rank * block_size
# stop = (rank + 1) * block_size
# tensor = slice_[:, start:stop]
# else:
# tensor = slice_[:]
# # XXX: Hack for Rowlinear to add the bias only once.
# if rank != 0:
# tensor = torch.zeros_like(tensor)
# elif isinstance(module, TensorParallelEmbedding):
# size = slice_.get_shape()[0]
# block_size = size // world_size
# start = rank * block_size
# stop = (rank + 1) * block_size
# tensor = slice_[start:stop]
# elif name == "lm_head.weight" and model.transformer.tp_embeddings:
# size = slice_.get_shape()[0]
# block_size = size // world_size
# start = rank * block_size
# stop = (rank + 1) * block_size
# tensor = slice_[start:stop]
# else:
# try:
# tensor = slice_[:]
# except:
# tensor = f.get_tensor(name)
if (
current_parameter_tensor is not None
and current_parameter_tensor.shape != tensor.shape
):
raise ValueError(
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
)
# if (
# current_parameter_tensor is not None
# and current_parameter_tensor.shape != tensor.shape
# ):
# raise ValueError(
# f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
# )
tensor = tensor.contiguous().to(dtype)
# tensor = tensor.contiguous().to(dtype)
if current_parameter_tensor is not None:
module._parameters[param_name] = tensor
else:
module._buffers[param_name] = tensor
# if current_parameter_tensor is not None:
# module._parameters[param_name] = tensor
# else:
# module._buffers[param_name] = tensor
model.post_load_weights(quantize)
# model.post_load_weights(quantize)

View File

@ -308,6 +308,13 @@ try:
self._cos_k_cached = None
self._sin_k_cached = None
@staticmethod
def static(dim, base, device):
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device,
dtype=torch.float32) / dim))
return PositionRotaryEmbedding(inv_freq)
@staticmethod
def load(prefix, weights):
# XXX: Always load this in float32 !