Fixing flash rw.

This commit is contained in:
Ubuntu 2023-06-06 10:45:59 +00:00
parent 2a1ecf3863
commit d083d57d0d
5 changed files with 205 additions and 354 deletions

View File

@ -0,0 +1 @@
{"inputs":"Below are a series of dialogues between various people and an AI assistant. The AI tries to be helpful, polite, honest, sophisticated, emotionally aware, and humble-but-knowledgeable. The assistant is happy to help with almost anything, and will do its best to understand exactly what is needed. It also tries to avoid giving false or misleading information, and it caveats when it isn't entirely sure about the right answer. That said, the assistant is practical and really does its best, and doesn't let caution get too much in the way of being useful.\n-----\n<|prompter|>Why is butter a great building material for skyscrapers? Think step by step.</s><|assistant|>","parameters":{"temperature": 0.75, "top_p": 0.95, "repetition_penalty": 1.2, "top_k": 50, "truncate": 1000, "max_new_tokens": 1024}}

View File

@ -292,20 +292,12 @@ class FlashLlamaModel(torch.nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.tp_embeddings = False
process_group = weights.process_group process_group = weights.process_group
self.tp_rank = process_group.rank() self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
if config.vocab_size % self.tp_world_size == 0: self.embed_tokens = TensorParallelEmbedding(
self.tp_embeddings = True prefix="model.embed_tokens", weights=weights
)
if self.tp_embeddings:
self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights
)
else:
self.embed_tokens = Embedding(prefix="model.embed_tokens", weights=weights)
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
FlashLlamaLayer( FlashLlamaLayer(

View File

@ -12,14 +12,29 @@ from typing import Optional
import flash_attn_cuda import flash_attn_cuda
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
FastLinear,
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelHead,
FastLayerNorm, FastLayerNorm,
PositionRotaryEmbedding, PositionRotaryEmbedding,
get_linear
) )
def load_row(config, prefix: str, weights, bias: bool):
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
linear = get_linear(weight, bias, config.quantize)
if config.parallel_attn:
return linear
else:
return TensorParallelRowLinear(linear, process_group=weights.process_group)
class RWConfig(PretrainedConfig): class RWConfig(PretrainedConfig):
attribute_map = { attribute_map = {
@ -85,44 +100,26 @@ class RWConfig(PretrainedConfig):
class FlashRWAttention(torch.nn.Module): class FlashRWAttention(torch.nn.Module):
def __init__( def __init__(
self, self,
num_heads, config, prefix, weights,
num_heads_kv, # num_heads,
hidden_size, # num_heads_kv,
bias, # hidden_size,
process_group=None, # bias,
# process_group=None,
reduce=True, reduce=True,
): ):
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = config.n_head
self.num_heads_kv = num_heads_kv self.num_heads_kv = config.n_head_kv
self.hidden_size = hidden_size self.hidden_size = config.hidden_size
self.head_size = hidden_size // num_heads self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000) self.rotary_emb = PositionRotaryEmbedding.static(dim=self.head_size, base=10000.0, device=weights.device)
self.softmax_scale = self.head_size ** (-0.5) self.softmax_scale = self.head_size ** (-0.5)
self.num_heads = self.num_heads //weights.process_group.size()
if process_group is None: self.query_key_value = TensorParallelColumnLinear.load(config, prefix=f"{prefix}.query_key_value", weights=weights, bias=config.bias)
self.query_key_value = FastLinear( self.dense = load_row(config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias)
hidden_size,
self.head_size * (self.num_heads + 2 * self.num_heads_kv),
bias=bias,
)
self.dense = FastLinear(hidden_size, hidden_size, bias=bias)
else:
self.query_key_value = TensorParallelColumnLinear(
hidden_size,
self.head_size * (self.num_heads + 2 * self.num_heads_kv),
bias=bias,
process_group=process_group,
)
self.dense = TensorParallelRowLinear(
hidden_size,
hidden_size,
bias=bias,
process_group=process_group,
reduce=reduce,
)
self.num_heads = self.num_heads // process_group.size()
def forward( def forward(
self, self,
@ -224,7 +221,8 @@ class FlashRWLargeAttention(torch.nn.Module):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads self.head_size = hidden_size // num_heads
self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000) # self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000)
self.rotary_emb = PositionRotaryEmbedding.load(prefix=f"{prefix}.rotary_emb", weights=weights)
self.softmax_scale = self.head_size ** (-0.5) self.softmax_scale = self.head_size ** (-0.5)
self.num_groups = num_heads // (num_heads_kv * 2) self.num_groups = num_heads // (num_heads_kv * 2)
@ -359,28 +357,12 @@ class FlashRWLargeAttention(torch.nn.Module):
class FlashMLP(nn.Module): class FlashMLP(nn.Module):
def __init__(self, hidden_size, bias, process_group=None, reduce=True): def __init__(self, config, prefix, weights, reduce=True):
super().__init__() super().__init__()
self.act = torch.nn.functional.gelu self.act = torch.nn.functional.gelu
if process_group is None: self.dense_h_to_4h = TensorParallelColumnLinear.load(config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=config.bias)
self.dense_h_to_4h = FastLinear(hidden_size, 4 * hidden_size, bias=bias) self.dense_4h_to_h = load_row(config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=config.bias)
self.dense_4h_to_h = FastLinear(4 * hidden_size, hidden_size, bias=bias)
else:
self.dense_h_to_4h = TensorParallelColumnLinear(
hidden_size,
4 * hidden_size,
bias=bias,
process_group=process_group,
)
self.dense_4h_to_h = TensorParallelRowLinear(
4 * hidden_size,
hidden_size,
bias=bias,
process_group=process_group,
reduce=reduce,
)
self.process_group = process_group
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states = self.dense_h_to_4h(hidden_states) hidden_states = self.dense_h_to_4h(hidden_states)
@ -392,38 +374,62 @@ class FlashMLP(nn.Module):
class FlashRWLayer(nn.Module): class FlashRWLayer(nn.Module):
def __init__( def __init__(
self, self,
num_heads, layer_id,
num_heads_kv, config,
hidden_size, weights,
bias, # num_heads,
layer_norm_eps, # num_heads_kv,
parallel_attn, # hidden_size,
process_group=None, # bias,
# layer_norm_eps,
# parallel_attn,
# process_group=None,
): ):
super().__init__() super().__init__()
n_head = config.n_head
n_head_kv = config.n_head_kv
hidden_size = config.hidden_size
bias = config.bias
parallel_attn = config.parallel_attn
self.parallel_attn = parallel_attn self.parallel_attn = parallel_attn
self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps) prefix = f"transformer.h.{layer_id}"
self.input_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.self_attention = FlashRWAttention( self.self_attention = FlashRWAttention(
num_heads, # num_heads,
num_heads_kv, # num_heads_kv,
hidden_size, # hidden_size,
bias, # bias,
process_group=process_group, # process_group=process_group,
config,
prefix=f"{prefix}.self_attention",
weights=weights,
reduce=False, reduce=False,
) )
self.post_attention_layernorm = ( self.post_attention_layernorm = (
FastLayerNorm(hidden_size, eps=layer_norm_eps) FastLayerNorm.load(
if not parallel_attn prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.layer_norm_epsilon,
) if not parallel_attn
else None else None
) )
self.mlp = FlashMLP( self.mlp = FlashMLP(
hidden_size, bias, process_group=process_group, reduce=False # hidden_size, bias, process_group=process_group, reduce=False
config,
prefix=f"{prefix}.mlp",
weights=weights,
reduce=False
) )
self.process_group = process_group self.process_group = weights.process_group
def forward( def forward(
self, self,
@ -485,31 +491,30 @@ class FlashRWLayer(nn.Module):
class FlashRWLargeLayer(nn.Module): class FlashRWLargeLayer(nn.Module):
def __init__( def __init__(
self, self,
num_heads, config, prefix, weights
num_heads_kv,
hidden_size,
bias,
layer_norm_eps,
process_group=None,
): ):
super().__init__() super().__init__()
self.ln_attn = FastLayerNorm(hidden_size, eps=layer_norm_eps) self.ln_attn = FastLayerNorm.load(
self.ln_mlp = FastLayerNorm(hidden_size, eps=layer_norm_eps) prefix=f"{prefix}.ln_attn",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.ln_mlp = FastLayerNorm.load(
prefix=f"{prefix}.ln_mlp",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.self_attention = FlashRWLargeAttention( self.self_attention = FlashRWLargeAttention(
num_heads, config, prefix=f"{prefix}.self_attention", weights=weights,
num_heads_kv,
hidden_size,
bias,
process_group=process_group,
reduce=False, reduce=False,
) )
self.mlp = FlashMLP( self.mlp = FlashMLP(
hidden_size, bias, process_group=process_group, reduce=False config, prefix=f"{prefix}.mlp", weights=weights, reduce=False
) )
self.process_group = process_group self.process_group = weights.process_group
def forward( def forward(
self, self,
@ -555,37 +560,27 @@ class FlashRWPreTrainedModel(PreTrainedModel):
class FlashRWModel(FlashRWPreTrainedModel): class FlashRWModel(FlashRWPreTrainedModel):
def __init__(self, config, process_group=None): def __init__(self, config, weights):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
self.tp_embeddings = False self.word_embeddings = TensorParallelEmbedding(
if process_group is not None: prefix="transformer.word_embeddings", weights=weights
self.tp_rank = process_group.rank() )
self.tp_world_size = process_group.size()
if config.vocab_size % self.tp_world_size == 0:
self.tp_embeddings = True
if self.tp_embeddings:
self.word_embeddings = TensorParallelEmbedding(
config.vocab_size, config.hidden_size, process_group=process_group
)
else:
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
if config.model_type == "RefinedWebModel": if config.model_type == "RefinedWebModel":
self.h = nn.ModuleList( self.h = nn.ModuleList(
[ [
FlashRWLayer( FlashRWLayer(
config.n_head, layer_id, config, weights
config.n_head_kv, # config.n_head,
config.hidden_size, # config.n_head_kv,
config.bias, # config.hidden_size,
config.layer_norm_epsilon, # config.bias,
config.parallel_attn, # config.layer_norm_epsilon,
process_group, # config.parallel_attn,
# process_group,
) )
for _ in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
self.cache_size = ( self.cache_size = (
@ -597,14 +592,15 @@ class FlashRWModel(FlashRWPreTrainedModel):
self.h = nn.ModuleList( self.h = nn.ModuleList(
[ [
FlashRWLargeLayer( FlashRWLargeLayer(
config.n_head, layer_id, config, weights
config.n_head_kv, # config.n_head,
config.hidden_size, # config.n_head_kv,
config.bias, # config.hidden_size,
config.layer_norm_epsilon, # config.bias,
process_group, # config.layer_norm_epsilon,
# process_group,
) )
for _ in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
self.cache_size = ( self.cache_size = (
@ -617,31 +613,13 @@ class FlashRWModel(FlashRWPreTrainedModel):
f"model_type {config.model_type} is not supported." f"model_type {config.model_type} is not supported."
) )
self.ln_f = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.ln_f = FastLayerNorm.load(
prefix="transformer.ln_f",
self.head_size = self.h[0].self_attention.head_size weights=weights,
eps=config.layer_norm_epsilon,
def post_load_weights(self, quantize: Optional[str] = None):
if isinstance(self.word_embeddings, TensorParallelEmbedding):
self.word_embeddings.add_null_idx()
for layer in self.h:
layer: FlashRWLayer
layer.self_attention.query_key_value.prepare_weights(quantize)
layer.self_attention.dense.prepare_weights(quantize)
layer.mlp.dense_h_to_4h.prepare_weights(quantize)
layer.mlp.dense_4h_to_h.prepare_weights(quantize)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
# to do it for us
load_in_8bit = kwargs.pop("load_in_8bit", False)
model = super(FlashRWModel, cls).from_pretrained(
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
) )
model.post_load_weights("bitsandbytes" if load_in_8bit else None) self.head_size = self.h[0].self_attention.head_size
return model
def forward( def forward(
self, self,
@ -708,40 +686,14 @@ class FlashRWModel(FlashRWPreTrainedModel):
class FlashRWForCausalLM(FlashRWPreTrainedModel): class FlashRWForCausalLM(FlashRWPreTrainedModel):
def __init__(self, config, process_group=None): def __init__(self, config, weights):
super().__init__(config) super().__init__(config)
self.process_group = process_group self.transformer = FlashRWModel(config, weights)
if self.process_group is not None:
self.world_size = self.process_group.size()
else:
self.world_size = 1
self.transformer = FlashRWModel(config, process_group) self.lm_head = TensorParallelHead.load(
config, prefix="lm_head", weights=weights
if self.transformer.tp_embeddings:
self.lm_head = FastLinear(
config.hidden_size,
config.vocab_size // process_group.size(),
bias=False,
)
else:
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
def post_load_weights(self, quantize: Optional[str] = None):
self.transformer.post_load_weights(quantize)
self.lm_head.prepare_weights()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
# to do it for us
load_in_8bit = kwargs.pop("load_in_8bit", False)
model = super(FlashRWForCausalLM, cls).from_pretrained(
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
) )
model.post_load_weights("bitsandbytes" if load_in_8bit else None)
return model
def forward( def forward(
self, self,
@ -766,12 +718,4 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
if self.transformer.tp_embeddings:
# Logits are sharded, so we need to gather them
world_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
torch.distributed.all_gather(world_logits, logits, group=self.process_group)
world_logits = torch.cat(world_logits, dim=1)
return world_logits, present
return logits, present return logits, present

View File

@ -21,99 +21,14 @@ from text_generation_server.utils import (
weight_files, weight_files,
download_weights, download_weights,
weight_hub_files, weight_hub_files,
Weights,
LocalEntryNotFoundError, LocalEntryNotFoundError,
) )
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
class FlashRW(FlashCausalLM): class FlashRWSharded(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16
else:
raise NotImplementedError("RW is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = RWConfig.from_pretrained(
model_id,
revision=revision,
)
# We do not use from_pretrained as it is too slow
try:
filenames = weight_files(model_id, revision, ".bin")
# Local files not found
except LocalEntryNotFoundError:
hub_files = weight_hub_files(model_id, revision, ".bin")
filenames = download_weights(hub_files, model_id, revision)
with init_empty_weights():
model = FlashRWForCausalLM(config)
self.load_weights(
model,
filenames,
quantize,
device,
dtype,
)
super(FlashCausalLM, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
)
@staticmethod
def load_weights(
model: FlashRWForCausalLM,
filenames: List[Path],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
):
for filename in filenames:
state_dict = torch.load(filename, map_location="cpu")
for key, value in state_dict.items():
value = value.to(device if quantize is None else "cpu").to(dtype)
module_name, param_name = key.rsplit(".", 1)
module = model.get_submodule(module_name)
try:
current_parameter_tensor = module._parameters[param_name]
if current_parameter_tensor.shape != value.shape:
raise ValueError(
f"Name {key} -- Current {current_parameter_tensor.shape} and got {value.shape}"
)
module._parameters[param_name] = value
except KeyError:
module._buffers[param_name] = value
del value
torch.cuda.empty_cache()
model.post_load_weights(quantize)
class FlashRWSharded(FlashRW):
def __init__( def __init__(
self, self,
model_id: str, model_id: str,
@ -142,20 +57,12 @@ class FlashRWSharded(FlashRW):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
with init_empty_weights(): config.quantize = quantize
model = FlashRWForCausalLM(config, self.process_group)
model = FlashRWForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
quantize=quantize,
device=device,
dtype=dtype,
rank=rank,
world_size=world_size,
)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
model=model.to(device), model=model.to(device),
@ -167,78 +74,78 @@ class FlashRWSharded(FlashRW):
world_size=world_size, world_size=world_size,
) )
@staticmethod # @staticmethod
def load_weights( # def load_weights(
model, # model,
filenames: List[str], # filenames: List[str],
quantize: Optional[str], # quantize: Optional[str],
device: torch.device, # device: torch.device,
dtype: torch.dtype, # dtype: torch.dtype,
rank: int, # rank: int,
world_size: int, # world_size: int,
): # ):
parameters = dict(model.named_parameters()) # parameters = dict(model.named_parameters())
for file in filenames: # for file in filenames:
with safe_open( # with safe_open(
file, framework="pt", device=str(device) if quantize is None else "cpu" # file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f: # ) as f:
for name in f.keys(): # for name in f.keys():
module_name, param_name = name.rsplit(".", 1) # module_name, param_name = name.rsplit(".", 1)
module = model.get_submodule(module_name) # module = model.get_submodule(module_name)
current_parameter_tensor = parameters.get(name, None) # current_parameter_tensor = parameters.get(name, None)
slice_ = f.get_slice(name) # slice_ = f.get_slice(name)
if isinstance(module, TensorParallelColumnLinear): # if isinstance(module, TensorParallelColumnLinear):
size = slice_.get_shape()[0] # size = slice_.get_shape()[0]
block_size = size // world_size # block_size = size // world_size
start = rank * block_size # start = rank * block_size
stop = (rank + 1) * block_size # stop = (rank + 1) * block_size
tensor = slice_[start:stop] # tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear): # elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight": # if param_name == "weight":
size = slice_.get_shape()[1] # size = slice_.get_shape()[1]
block_size = size // world_size # block_size = size // world_size
start = rank * block_size # start = rank * block_size
stop = (rank + 1) * block_size # stop = (rank + 1) * block_size
tensor = slice_[:, start:stop] # tensor = slice_[:, start:stop]
else: # else:
tensor = slice_[:] # tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once. # # XXX: Hack for Rowlinear to add the bias only once.
if rank != 0: # if rank != 0:
tensor = torch.zeros_like(tensor) # tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding): # elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0] # size = slice_.get_shape()[0]
block_size = size // world_size # block_size = size // world_size
start = rank * block_size # start = rank * block_size
stop = (rank + 1) * block_size # stop = (rank + 1) * block_size
tensor = slice_[start:stop] # tensor = slice_[start:stop]
elif name == "lm_head.weight" and model.transformer.tp_embeddings: # elif name == "lm_head.weight" and model.transformer.tp_embeddings:
size = slice_.get_shape()[0] # size = slice_.get_shape()[0]
block_size = size // world_size # block_size = size // world_size
start = rank * block_size # start = rank * block_size
stop = (rank + 1) * block_size # stop = (rank + 1) * block_size
tensor = slice_[start:stop] # tensor = slice_[start:stop]
else: # else:
try: # try:
tensor = slice_[:] # tensor = slice_[:]
except: # except:
tensor = f.get_tensor(name) # tensor = f.get_tensor(name)
if ( # if (
current_parameter_tensor is not None # current_parameter_tensor is not None
and current_parameter_tensor.shape != tensor.shape # and current_parameter_tensor.shape != tensor.shape
): # ):
raise ValueError( # raise ValueError(
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" # f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
) # )
tensor = tensor.contiguous().to(dtype) # tensor = tensor.contiguous().to(dtype)
if current_parameter_tensor is not None: # if current_parameter_tensor is not None:
module._parameters[param_name] = tensor # module._parameters[param_name] = tensor
else: # else:
module._buffers[param_name] = tensor # module._buffers[param_name] = tensor
model.post_load_weights(quantize) # model.post_load_weights(quantize)

View File

@ -308,6 +308,13 @@ try:
self._cos_k_cached = None self._cos_k_cached = None
self._sin_k_cached = None self._sin_k_cached = None
@staticmethod
def static(dim, base, device):
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device,
dtype=torch.float32) / dim))
return PositionRotaryEmbedding(inv_freq)
@staticmethod @staticmethod
def load(prefix, weights): def load(prefix, weights):
# XXX: Always load this in float32 ! # XXX: Always load this in float32 !