From d083d57d0d3096e0c41e30bd1d91f34fd2213002 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 6 Jun 2023 10:45:59 +0000 Subject: [PATCH] Fixing flash rw. --- server/text_generation_server/input.json | 1 + .../custom_modeling/flash_llama_modeling.py | 14 +- .../custom_modeling/flash_rw_modeling.py | 296 +++++++----------- .../text_generation_server/models/flash_rw.py | 241 +++++--------- server/text_generation_server/utils/layers.py | 7 + 5 files changed, 205 insertions(+), 354 deletions(-) create mode 100644 server/text_generation_server/input.json diff --git a/server/text_generation_server/input.json b/server/text_generation_server/input.json new file mode 100644 index 00000000..274a4d9b --- /dev/null +++ b/server/text_generation_server/input.json @@ -0,0 +1 @@ +{"inputs":"Below are a series of dialogues between various people and an AI assistant. The AI tries to be helpful, polite, honest, sophisticated, emotionally aware, and humble-but-knowledgeable. The assistant is happy to help with almost anything, and will do its best to understand exactly what is needed. It also tries to avoid giving false or misleading information, and it caveats when it isn't entirely sure about the right answer. That said, the assistant is practical and really does its best, and doesn't let caution get too much in the way of being useful.\n-----\n<|prompter|>Why is butter a great building material for skyscrapers? Think step by step.<|assistant|>","parameters":{"temperature": 0.75, "top_p": 0.95, "repetition_penalty": 1.2, "top_k": 50, "truncate": 1000, "max_new_tokens": 1024}} diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index a959cf20..f27bd0d5 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -292,20 +292,12 @@ class FlashLlamaModel(torch.nn.Module): super().__init__() self.config = config - self.tp_embeddings = False process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() - if config.vocab_size % self.tp_world_size == 0: - self.tp_embeddings = True - - if self.tp_embeddings: - self.embed_tokens = TensorParallelEmbedding( - prefix="model.embed_tokens", weights=weights - ) - else: - self.embed_tokens = Embedding(prefix="model.embed_tokens", weights=weights) - + self.embed_tokens = TensorParallelEmbedding( + prefix="model.embed_tokens", weights=weights + ) self.layers = nn.ModuleList( [ FlashLlamaLayer( diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 03487703..9b175cf9 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -12,14 +12,29 @@ from typing import Optional import flash_attn_cuda from text_generation_server.utils.layers import ( - FastLinear, TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, + TensorParallelHead, FastLayerNorm, PositionRotaryEmbedding, + get_linear ) +def load_row(config, prefix: str, weights, bias: bool): + weight = weights.get_sharded(f"{prefix}.weight", dim=1) + if bias and weights.process_group.rank() == 0: + # Rank is only on the first rank process + bias = weights.get_tensor(f"{prefix}.bias") + else: + bias = None + + linear = get_linear(weight, bias, config.quantize) + if config.parallel_attn: + return linear + else: + return TensorParallelRowLinear(linear, process_group=weights.process_group) + class RWConfig(PretrainedConfig): attribute_map = { @@ -85,44 +100,26 @@ class RWConfig(PretrainedConfig): class FlashRWAttention(torch.nn.Module): def __init__( self, - num_heads, - num_heads_kv, - hidden_size, - bias, - process_group=None, + config, prefix, weights, + # num_heads, + # num_heads_kv, + # hidden_size, + # bias, + # process_group=None, reduce=True, ): super().__init__() - self.num_heads = num_heads - self.num_heads_kv = num_heads_kv - self.hidden_size = hidden_size - self.head_size = hidden_size // num_heads + self.num_heads = config.n_head + self.num_heads_kv = config.n_head_kv + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_heads - self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000) + self.rotary_emb = PositionRotaryEmbedding.static(dim=self.head_size, base=10000.0, device=weights.device) self.softmax_scale = self.head_size ** (-0.5) + self.num_heads = self.num_heads //weights.process_group.size() - if process_group is None: - self.query_key_value = FastLinear( - hidden_size, - self.head_size * (self.num_heads + 2 * self.num_heads_kv), - bias=bias, - ) - self.dense = FastLinear(hidden_size, hidden_size, bias=bias) - else: - self.query_key_value = TensorParallelColumnLinear( - hidden_size, - self.head_size * (self.num_heads + 2 * self.num_heads_kv), - bias=bias, - process_group=process_group, - ) - self.dense = TensorParallelRowLinear( - hidden_size, - hidden_size, - bias=bias, - process_group=process_group, - reduce=reduce, - ) - self.num_heads = self.num_heads // process_group.size() + self.query_key_value = TensorParallelColumnLinear.load(config, prefix=f"{prefix}.query_key_value", weights=weights, bias=config.bias) + self.dense = load_row(config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias) def forward( self, @@ -224,7 +221,8 @@ class FlashRWLargeAttention(torch.nn.Module): self.hidden_size = hidden_size self.head_size = hidden_size // num_heads - self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000) + # self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000) + self.rotary_emb = PositionRotaryEmbedding.load(prefix=f"{prefix}.rotary_emb", weights=weights) self.softmax_scale = self.head_size ** (-0.5) self.num_groups = num_heads // (num_heads_kv * 2) @@ -359,28 +357,12 @@ class FlashRWLargeAttention(torch.nn.Module): class FlashMLP(nn.Module): - def __init__(self, hidden_size, bias, process_group=None, reduce=True): + def __init__(self, config, prefix, weights, reduce=True): super().__init__() self.act = torch.nn.functional.gelu - if process_group is None: - self.dense_h_to_4h = FastLinear(hidden_size, 4 * hidden_size, bias=bias) - self.dense_4h_to_h = FastLinear(4 * hidden_size, hidden_size, bias=bias) - else: - self.dense_h_to_4h = TensorParallelColumnLinear( - hidden_size, - 4 * hidden_size, - bias=bias, - process_group=process_group, - ) - self.dense_4h_to_h = TensorParallelRowLinear( - 4 * hidden_size, - hidden_size, - bias=bias, - process_group=process_group, - reduce=reduce, - ) - self.process_group = process_group + self.dense_h_to_4h = TensorParallelColumnLinear.load(config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=config.bias) + self.dense_4h_to_h = load_row(config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=config.bias) def forward(self, hidden_states): hidden_states = self.dense_h_to_4h(hidden_states) @@ -392,38 +374,62 @@ class FlashMLP(nn.Module): class FlashRWLayer(nn.Module): def __init__( self, - num_heads, - num_heads_kv, - hidden_size, - bias, - layer_norm_eps, - parallel_attn, - process_group=None, + layer_id, + config, + weights, + # num_heads, + # num_heads_kv, + # hidden_size, + # bias, + # layer_norm_eps, + # parallel_attn, + # process_group=None, ): super().__init__() + n_head = config.n_head + n_head_kv = config.n_head_kv + hidden_size = config.hidden_size + bias = config.bias + parallel_attn = config.parallel_attn self.parallel_attn = parallel_attn - self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps) + prefix = f"transformer.h.{layer_id}" + + self.input_layernorm = FastLayerNorm.load( + prefix=f"{prefix}.input_layernorm", + weights=weights, + eps=config.layer_norm_epsilon, + ) self.self_attention = FlashRWAttention( - num_heads, - num_heads_kv, - hidden_size, - bias, - process_group=process_group, + # num_heads, + # num_heads_kv, + # hidden_size, + # bias, + # process_group=process_group, + config, + prefix=f"{prefix}.self_attention", + weights=weights, reduce=False, ) self.post_attention_layernorm = ( - FastLayerNorm(hidden_size, eps=layer_norm_eps) - if not parallel_attn + FastLayerNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.layer_norm_epsilon, + ) if not parallel_attn else None ) self.mlp = FlashMLP( - hidden_size, bias, process_group=process_group, reduce=False + # hidden_size, bias, process_group=process_group, reduce=False + config, + prefix=f"{prefix}.mlp", + weights=weights, + reduce=False ) - self.process_group = process_group + self.process_group = weights.process_group def forward( self, @@ -485,31 +491,30 @@ class FlashRWLayer(nn.Module): class FlashRWLargeLayer(nn.Module): def __init__( self, - num_heads, - num_heads_kv, - hidden_size, - bias, - layer_norm_eps, - process_group=None, + config, prefix, weights ): super().__init__() - self.ln_attn = FastLayerNorm(hidden_size, eps=layer_norm_eps) - self.ln_mlp = FastLayerNorm(hidden_size, eps=layer_norm_eps) + self.ln_attn = FastLayerNorm.load( + prefix=f"{prefix}.ln_attn", + weights=weights, + eps=config.layer_norm_epsilon, + ) + self.ln_mlp = FastLayerNorm.load( + prefix=f"{prefix}.ln_mlp", + weights=weights, + eps=config.layer_norm_epsilon, + ) self.self_attention = FlashRWLargeAttention( - num_heads, - num_heads_kv, - hidden_size, - bias, - process_group=process_group, + config, prefix=f"{prefix}.self_attention", weights=weights, reduce=False, ) self.mlp = FlashMLP( - hidden_size, bias, process_group=process_group, reduce=False + config, prefix=f"{prefix}.mlp", weights=weights, reduce=False ) - self.process_group = process_group + self.process_group = weights.process_group def forward( self, @@ -555,37 +560,27 @@ class FlashRWPreTrainedModel(PreTrainedModel): class FlashRWModel(FlashRWPreTrainedModel): - def __init__(self, config, process_group=None): + def __init__(self, config, weights): super().__init__(config) self.config = config - self.tp_embeddings = False - if process_group is not None: - self.tp_rank = process_group.rank() - self.tp_world_size = process_group.size() - if config.vocab_size % self.tp_world_size == 0: - self.tp_embeddings = True - - if self.tp_embeddings: - self.word_embeddings = TensorParallelEmbedding( - config.vocab_size, config.hidden_size, process_group=process_group - ) - else: - self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) - + self.word_embeddings = TensorParallelEmbedding( + prefix="transformer.word_embeddings", weights=weights + ) if config.model_type == "RefinedWebModel": self.h = nn.ModuleList( [ FlashRWLayer( - config.n_head, - config.n_head_kv, - config.hidden_size, - config.bias, - config.layer_norm_epsilon, - config.parallel_attn, - process_group, + layer_id, config, weights + # config.n_head, + # config.n_head_kv, + # config.hidden_size, + # config.bias, + # config.layer_norm_epsilon, + # config.parallel_attn, + # process_group, ) - for _ in range(config.num_hidden_layers) + for layer_id in range(config.num_hidden_layers) ] ) self.cache_size = ( @@ -597,14 +592,15 @@ class FlashRWModel(FlashRWPreTrainedModel): self.h = nn.ModuleList( [ FlashRWLargeLayer( - config.n_head, - config.n_head_kv, - config.hidden_size, - config.bias, - config.layer_norm_epsilon, - process_group, + layer_id, config, weights + # config.n_head, + # config.n_head_kv, + # config.hidden_size, + # config.bias, + # config.layer_norm_epsilon, + # process_group, ) - for _ in range(config.num_hidden_layers) + for layer_id in range(config.num_hidden_layers) ] ) self.cache_size = ( @@ -617,31 +613,13 @@ class FlashRWModel(FlashRWPreTrainedModel): f"model_type {config.model_type} is not supported." ) - self.ln_f = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) - - self.head_size = self.h[0].self_attention.head_size - - def post_load_weights(self, quantize: Optional[str] = None): - if isinstance(self.word_embeddings, TensorParallelEmbedding): - self.word_embeddings.add_null_idx() - for layer in self.h: - layer: FlashRWLayer - layer.self_attention.query_key_value.prepare_weights(quantize) - layer.self_attention.dense.prepare_weights(quantize) - layer.mlp.dense_h_to_4h.prepare_weights(quantize) - layer.mlp.dense_4h_to_h.prepare_weights(quantize) - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - # Pop here as we will replace the layer in our own logic and don't want from_pretrained - # to do it for us - load_in_8bit = kwargs.pop("load_in_8bit", False) - model = super(FlashRWModel, cls).from_pretrained( - pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs + self.ln_f = FastLayerNorm.load( + prefix="transformer.ln_f", + weights=weights, + eps=config.layer_norm_epsilon, ) - model.post_load_weights("bitsandbytes" if load_in_8bit else None) - return model + self.head_size = self.h[0].self_attention.head_size def forward( self, @@ -708,40 +686,14 @@ class FlashRWModel(FlashRWPreTrainedModel): class FlashRWForCausalLM(FlashRWPreTrainedModel): - def __init__(self, config, process_group=None): + def __init__(self, config, weights): super().__init__(config) - self.process_group = process_group - if self.process_group is not None: - self.world_size = self.process_group.size() - else: - self.world_size = 1 + self.transformer = FlashRWModel(config, weights) - self.transformer = FlashRWModel(config, process_group) - - if self.transformer.tp_embeddings: - self.lm_head = FastLinear( - config.hidden_size, - config.vocab_size // process_group.size(), - bias=False, - ) - else: - self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False) - - def post_load_weights(self, quantize: Optional[str] = None): - self.transformer.post_load_weights(quantize) - self.lm_head.prepare_weights() - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - # Pop here as we will replace the layer in our own logic and don't want from_pretrained - # to do it for us - load_in_8bit = kwargs.pop("load_in_8bit", False) - model = super(FlashRWForCausalLM, cls).from_pretrained( - pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs + self.lm_head = TensorParallelHead.load( + config, prefix="lm_head", weights=weights ) - model.post_load_weights("bitsandbytes" if load_in_8bit else None) - return model def forward( self, @@ -766,12 +718,4 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits = self.lm_head(hidden_states) - - if self.transformer.tp_embeddings: - # Logits are sharded, so we need to gather them - world_logits = [torch.empty_like(logits) for _ in range(self.world_size)] - torch.distributed.all_gather(world_logits, logits, group=self.process_group) - world_logits = torch.cat(world_logits, dim=1) - - return world_logits, present return logits, present diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index 4fc4c389..846b9051 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -21,99 +21,14 @@ from text_generation_server.utils import ( weight_files, download_weights, weight_hub_files, + Weights, LocalEntryNotFoundError, ) tracer = trace.get_tracer(__name__) -class FlashRW(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - trust_remote_code: bool = False, - ): - if torch.cuda.is_available(): - device = torch.device("cuda") - dtype = torch.float16 - else: - raise NotImplementedError("RW is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = RWConfig.from_pretrained( - model_id, - revision=revision, - ) - - # We do not use from_pretrained as it is too slow - try: - filenames = weight_files(model_id, revision, ".bin") - # Local files not found - except LocalEntryNotFoundError: - hub_files = weight_hub_files(model_id, revision, ".bin") - filenames = download_weights(hub_files, model_id, revision) - - with init_empty_weights(): - model = FlashRWForCausalLM(config) - - self.load_weights( - model, - filenames, - quantize, - device, - dtype, - ) - - super(FlashCausalLM, self).__init__( - model=model.to(device), - tokenizer=tokenizer, - requires_padding=False, - dtype=dtype, - device=device, - ) - - @staticmethod - def load_weights( - model: FlashRWForCausalLM, - filenames: List[Path], - quantize: Optional[str], - device: torch.device, - dtype: torch.dtype, - ): - for filename in filenames: - state_dict = torch.load(filename, map_location="cpu") - for key, value in state_dict.items(): - value = value.to(device if quantize is None else "cpu").to(dtype) - - module_name, param_name = key.rsplit(".", 1) - module = model.get_submodule(module_name) - - try: - current_parameter_tensor = module._parameters[param_name] - if current_parameter_tensor.shape != value.shape: - raise ValueError( - f"Name {key} -- Current {current_parameter_tensor.shape} and got {value.shape}" - ) - module._parameters[param_name] = value - except KeyError: - module._buffers[param_name] = value - - del value - - torch.cuda.empty_cache() - model.post_load_weights(quantize) - - -class FlashRWSharded(FlashRW): +class FlashRWSharded(FlashCausalLM): def __init__( self, model_id: str, @@ -142,20 +57,12 @@ class FlashRWSharded(FlashRW): torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights(filenames, device, dtype, process_group=self.process_group) - with init_empty_weights(): - model = FlashRWForCausalLM(config, self.process_group) + config.quantize = quantize + + model = FlashRWForCausalLM(config, weights) - torch.distributed.barrier(group=self.process_group) - self.load_weights( - model, - filenames, - quantize=quantize, - device=device, - dtype=dtype, - rank=rank, - world_size=world_size, - ) torch.distributed.barrier(group=self.process_group) super(FlashCausalLM, self).__init__( model=model.to(device), @@ -167,78 +74,78 @@ class FlashRWSharded(FlashRW): world_size=world_size, ) - @staticmethod - def load_weights( - model, - filenames: List[str], - quantize: Optional[str], - device: torch.device, - dtype: torch.dtype, - rank: int, - world_size: int, - ): - parameters = dict(model.named_parameters()) - for file in filenames: - with safe_open( - file, framework="pt", device=str(device) if quantize is None else "cpu" - ) as f: - for name in f.keys(): - module_name, param_name = name.rsplit(".", 1) - module = model.get_submodule(module_name) + # @staticmethod + # def load_weights( + # model, + # filenames: List[str], + # quantize: Optional[str], + # device: torch.device, + # dtype: torch.dtype, + # rank: int, + # world_size: int, + # ): + # parameters = dict(model.named_parameters()) + # for file in filenames: + # with safe_open( + # file, framework="pt", device=str(device) if quantize is None else "cpu" + # ) as f: + # for name in f.keys(): + # module_name, param_name = name.rsplit(".", 1) + # module = model.get_submodule(module_name) - current_parameter_tensor = parameters.get(name, None) + # current_parameter_tensor = parameters.get(name, None) - slice_ = f.get_slice(name) + # slice_ = f.get_slice(name) - if isinstance(module, TensorParallelColumnLinear): - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - elif isinstance(module, TensorParallelRowLinear): - if param_name == "weight": - size = slice_.get_shape()[1] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[:, start:stop] - else: - tensor = slice_[:] - # XXX: Hack for Rowlinear to add the bias only once. - if rank != 0: - tensor = torch.zeros_like(tensor) - elif isinstance(module, TensorParallelEmbedding): - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - elif name == "lm_head.weight" and model.transformer.tp_embeddings: - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - else: - try: - tensor = slice_[:] - except: - tensor = f.get_tensor(name) + # if isinstance(module, TensorParallelColumnLinear): + # size = slice_.get_shape()[0] + # block_size = size // world_size + # start = rank * block_size + # stop = (rank + 1) * block_size + # tensor = slice_[start:stop] + # elif isinstance(module, TensorParallelRowLinear): + # if param_name == "weight": + # size = slice_.get_shape()[1] + # block_size = size // world_size + # start = rank * block_size + # stop = (rank + 1) * block_size + # tensor = slice_[:, start:stop] + # else: + # tensor = slice_[:] + # # XXX: Hack for Rowlinear to add the bias only once. + # if rank != 0: + # tensor = torch.zeros_like(tensor) + # elif isinstance(module, TensorParallelEmbedding): + # size = slice_.get_shape()[0] + # block_size = size // world_size + # start = rank * block_size + # stop = (rank + 1) * block_size + # tensor = slice_[start:stop] + # elif name == "lm_head.weight" and model.transformer.tp_embeddings: + # size = slice_.get_shape()[0] + # block_size = size // world_size + # start = rank * block_size + # stop = (rank + 1) * block_size + # tensor = slice_[start:stop] + # else: + # try: + # tensor = slice_[:] + # except: + # tensor = f.get_tensor(name) - if ( - current_parameter_tensor is not None - and current_parameter_tensor.shape != tensor.shape - ): - raise ValueError( - f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" - ) + # if ( + # current_parameter_tensor is not None + # and current_parameter_tensor.shape != tensor.shape + # ): + # raise ValueError( + # f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" + # ) - tensor = tensor.contiguous().to(dtype) + # tensor = tensor.contiguous().to(dtype) - if current_parameter_tensor is not None: - module._parameters[param_name] = tensor - else: - module._buffers[param_name] = tensor + # if current_parameter_tensor is not None: + # module._parameters[param_name] = tensor + # else: + # module._buffers[param_name] = tensor - model.post_load_weights(quantize) + # model.post_load_weights(quantize) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 1699622d..9fd31c76 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -308,6 +308,13 @@ try: self._cos_k_cached = None self._sin_k_cached = None + @staticmethod + def static(dim, base, device): + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, + dtype=torch.float32) / dim)) + return PositionRotaryEmbedding(inv_freq) + + @staticmethod def load(prefix, weights): # XXX: Always load this in float32 !