From 24579c45de0ca2e5c4efddb795b6366beb0c0d8c Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 22 Mar 2023 11:46:09 +0100 Subject: [PATCH] wip --- .../text_generation_server/models/__init__.py | 5 +- .../models/causal_lm.py | 1 - .../models/flash_neox.py | 516 +++++++++++++++ .../models/flash_neox_modeling.py | 589 ++++++++++++++++++ 4 files changed, 1108 insertions(+), 3 deletions(-) create mode 100644 server/text_generation_server/models/flash_neox.py create mode 100644 server/text_generation_server/models/flash_neox_modeling.py diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 3e2f5c66..29d67c7b 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -11,6 +11,7 @@ from text_generation_server.models.galactica import Galactica, GalacticaSharded from text_generation_server.models.santacoder import SantaCoder from text_generation_server.models.gpt_neox import GPTNeoxSharded from text_generation_server.models.t5 import T5Sharded +from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded __all__ = [ "Model", @@ -59,9 +60,9 @@ def get_model( if config.model_type == "gpt_neox": if sharded: - return GPTNeoxSharded(model_id, revision, quantize=quantize) + return FlashNeoXSharded(model_id, revision, quantize=quantize) else: - return CausalLM(model_id, revision, quantize=quantize) + return FlashNeoX(model_id, revision, quantize=quantize) if config.model_type == "t5": if sharded: diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 88ea6c75..c2ad0587 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -64,7 +64,6 @@ class CausalLMBatch(Batch): inputs = [] next_token_choosers = [] stopping_criterias = [] - input_lengths = [] # Parse batch padding_right_offset = 0 diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py new file mode 100644 index 00000000..3adf8062 --- /dev/null +++ b/server/text_generation_server/models/flash_neox.py @@ -0,0 +1,516 @@ +import torch +import torch.distributed + +from accelerate import init_empty_weights +from dataclasses import dataclass +from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear +from opentelemetry import trace +from safetensors import safe_open +from transformers import AutoTokenizer, PreTrainedTokenizerBase, AutoConfig +from typing import Optional, Tuple, List, Type, Union + +from text_generation_server.models import Model +from text_generation_server.models.flash_neox_modeling import ( + FlashGPTNeoXForCausalLM, + TensorParallelEmbedding, +) +from text_generation_server.models.types import ( + Batch, + PrefillTokens, + Generation, + GeneratedText, +) +from text_generation_server.pb import generate_pb2 +from text_generation_server.utils import ( + NextTokenChooser, + StoppingCriteria, + Sampling, + initialize_torch_distributed, + weight_files, +) + +tracer = trace.get_tracer(__name__) + + +@dataclass +class FlashNeoXBatch(Batch): + batch_id: int + requests: List[generate_pb2.Request] + + # Decoder values + input_ids: torch.Tensor + position_ids: torch.Tensor + # cumulative sequence lengths + cu_seqlens: torch.Tensor + max_seqlen: torch.Tensor + past_key_values: Optional[torch.Tensor] + + # All tokens + all_input_ids: List[torch.Tensor] + + # Generation helpers + next_token_choosers: List[NextTokenChooser] + stopping_criterias: List[StoppingCriteria] + + def to_pb(self) -> generate_pb2.Batch: + return generate_pb2.Batch( + id=self.batch_id, requests=self.requests, size=len(self) + ) + + @classmethod + def from_pb( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + device: torch.device, + ) -> "CausalLMBatch": + input_ids = [] + position_ids = [] + cu_seqlens = [0] + max_seqlen = 0 + + next_token_choosers = [] + stopping_criterias = [] + + # Parse batch + for r in pb.requests: + tokenized_input = ( + tokenizer(r.inputs, return_tensors="pt")["input_ids"] + .to(device) + .squeeze(0) + ) + input_ids.append(tokenized_input) + position_ids.append( + torch.arange(0, len(tokenized_input), dtype=torch.int32, device=device) + ) + cu_seqlens.append(len(tokenized_input)) + max_seqlen = max(max_seqlen, len(tokenized_input)) + + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) + stopping_criterias.append( + StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) + ) + + all_input_ids = input_ids + input_ids = torch.concat(input_ids).unsqueeze(1) + position_ids = torch.concat(position_ids) + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + max_seqlen = torch.tensor(max_seqlen, dtype=torch.int32, device=device) + + return cls( + batch_id=pb.id, + requests=pb.requests, + input_ids=input_ids, + position_ids=position_ids, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + past_key_values=None, + all_input_ids=all_input_ids, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + ) + + @classmethod + @tracer.start_as_current_span("concatenate") + def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": + raise NotImplementedError + + + def __len__(self): + return len(self.requests) + + +class FlashNeoX(Model): + def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): + if torch.cuda.is_available(): + device = torch.device("cuda") + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + else: + raise NotImplementedError("FlashNeoX is only available on GPU") + + if quantize: + raise NotImplementedError("FlashNeoX does not support quantization") + + tokenizer = AutoTokenizer.from_pretrained( + model_id, revision=revision, padding_side="left" + ) + self.model = ( + FlashGPTNeoXForCausalLM.from_pretrained( + model_id, + revision=revision, + torch_dtype=dtype, + ) + .eval() + .cuda() + ) + tokenizer.pad_token_id = ( + self.model.config.pad_token_id + if self.model.config.pad_token_id is not None + else self.model.config.eos_token_id + ) + + super(FlashNeoX, self).__init__( + tokenizer=tokenizer, + device=device, + ) + + @property + def batch_type(self) -> Type[FlashNeoXBatch]: + return FlashNeoXBatch + + def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str: + return self.tokenizer.decode( + generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlens: torch.Tensor, + max_s: torch.Tensor, + past_key_values: Optional = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Model Forward + return self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlens=cu_seqlens, + max_s=max_s, + past_key_values=past_key_values, + ) + + @tracer.start_as_current_span("generate_token") + def generate_token( + self, batch: FlashNeoXBatch + ) -> Tuple[List[Generation], Optional[FlashNeoXBatch]]: + print("pos", batch.position_ids) + print("cu", batch.cu_seqlens) + print("max", batch.max_seqlen) + + out, present = self.forward( + batch.input_ids.squeeze(1), + batch.position_ids, + batch.cu_seqlens, + batch.max_seqlen, + batch.past_key_values, + ) + + # List of indices to cache + next_batch_keep_indices = [] + + # New values for next forward + next_batch_input_ids = [] + next_batch_position_ids = [] + next_batch_cu_seqlens = [0] + next_batch_max_seqlen = 0 + next_batch_past_key_values = [] + next_batch_all_input_ids = [] + + # Results + generations: List[Generation] = [] + + # Zipped iterator + iterator = zip( + batch.requests, + batch.next_token_choosers, + batch.stopping_criterias, + batch.all_input_ids, + ) + + # For each member of the batch + for i, ( + request, + next_token_chooser, + stopping_criteria, + all_input_ids, + ) in enumerate(iterator): + # Indexing metadata + start_index = batch.cu_seqlens[i] + end_index = batch.cu_seqlens[i + 1] + seq_length = end_index - start_index + + if batch.past_key_values is None: + # Prefill mode + # out is of shape [cumulative_sequence_lengths, vocab_size] + logits = out[start_index:end_index] + else: + # Decode mode + # out is of shape [batch_size, vocab_size] + logits = out[i].unsqueeze(0) + + # Select next token + next_token_id, logprobs = next_token_chooser( + all_input_ids.view(1, -1), logits + ) + + # Append next token to all tokens + all_input_ids = torch.cat([all_input_ids, next_token_id.squeeze(1)]) + new_input_length = seq_length + 1 + + # Generated token + next_token_logprob = logprobs[-1, next_token_id] + next_token_id_squeezed = next_token_id.squeeze() + next_token_text = self.decode_token( + next_token_id_squeezed, + ) + + # Evaluate stopping criteria + stop, reason = stopping_criteria( + next_token_id_squeezed, + next_token_text, + ) + + if stop: + # Decode generated tokens + output_text = self.decode(all_input_ids) + # Get seed + if isinstance(next_token_chooser.choice, Sampling): + seed = next_token_chooser.choice.seed + else: + seed = None + + generated_text = GeneratedText( + output_text, stopping_criteria.current_tokens, reason, seed + ) + else: + # Keep request in the batch + seq_present = present[:, start_index:end_index] + past = torch.nn.functional.pad(seq_present, (0, 0, 0, 0, 0, 0, 0, 1)) + next_batch_past_key_values.append(past) + + generated_text = None + next_batch_keep_indices.append(i) + next_batch_input_ids.append(next_token_id) + next_batch_position_ids.append(new_input_length) + next_batch_cu_seqlens.append( + next_batch_cu_seqlens[i] + new_input_length + ) + next_batch_all_input_ids.append(all_input_ids) + next_batch_max_seqlen = max(next_batch_max_seqlen, new_input_length) + + # Prefill + if stopping_criteria.current_tokens == 1: + # Remove generated token to only have prefill and add nan for first prompt token + prefill_logprobs = [float("nan")] + logprobs.gather( + 1, all_input_ids[1:].unsqueeze(1) + ).squeeze(1)[:-1].tolist() + prefill_token_ids = all_input_ids[:-1] + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + prefill_tokens = PrefillTokens( + prefill_token_ids, prefill_logprobs, prefill_texts + ) + else: + prefill_tokens = None + + generation = Generation( + request.id, + prefill_tokens, + next_token_id_squeezed, + next_token_logprob, + next_token_text, + next_token_id_squeezed.item() in self.all_special_ids, + generated_text, + ) + + generations.append(generation) + + # We finished all generations in the batch; there is no next batch + if not next_batch_keep_indices: + return generations, None + + # If we finished at least one generation, we need to evict the indices of the generations that finished + # from the values of the next batch + if len(next_batch_keep_indices) != len(batch): + # Apply indices to requests, token_choosers and stopping_criterias that need to be cached + next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices] + next_batch_next_token_choosers = [ + batch.next_token_choosers[i] for i in next_batch_keep_indices + ] + next_batch_stopping_criterias = [ + batch.stopping_criterias[i] for i in next_batch_keep_indices + ] + else: + next_batch_requests = batch.requests + next_batch_next_token_choosers = batch.next_token_choosers + next_batch_stopping_criterias = batch.stopping_criterias + + # Create final next batch tensors + device = out.device + next_batch_input_ids = torch.concat(next_batch_input_ids, dim=0) + next_batch_position_ids = torch.tensor( + next_batch_position_ids, dtype=torch.int32, device=device + ) + next_batch_cu_seqlens = torch.tensor( + next_batch_cu_seqlens, dtype=torch.int32, device=device + ) + if len(next_batch_keep_indices) > 1: + next_batch_past_key_values = torch.concat(next_batch_past_key_values) + else: + next_batch_past_key_values = next_batch_past_key_values[0] + + next_batch = FlashNeoXBatch( + batch_id=batch.batch_id, + requests=next_batch_requests, + input_ids=next_batch_input_ids, + position_ids=next_batch_position_ids, + cu_seqlens=next_batch_cu_seqlens, + max_seqlen=next_batch_max_seqlen, + past_key_values=next_batch_past_key_values, + all_input_ids=next_batch_all_input_ids, + next_token_choosers=next_batch_next_token_choosers, + stopping_criterias=next_batch_stopping_criterias, + ) + return generations, next_batch + + +class FlashNeoXSharded(FlashNeoX): + def __init__( + self, model_id: str, revision: Optional[str] = None, quantize: bool = False + ): + self.process_group, self.rank, self.world_size = initialize_torch_distributed() + self.master = self.rank == 0 + if torch.cuda.is_available(): + device = torch.device(f"cuda:{self.rank}") + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + else: + raise NotImplementedError("FlashNeoX is only available on GPU") + + if quantize: + raise NotImplementedError("FlashNeoX does not support quantization") + + tokenizer = AutoTokenizer.from_pretrained( + model_id, revision=revision, padding_side="left" + ) + + config = AutoConfig.from_pretrained( + model_id, revision=revision, tp_parallel=True + ) + + torch.distributed.barrier(group=self.process_group) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + + with init_empty_weights(): + model = FlashGPTNeoXForCausalLM(config) + + torch.distributed.barrier(group=self.process_group) + self.load_weights( + model, + filenames, + quantize=quantize, + device=device, + rank=self.rank, + world_size=self.world_size, + ) + self.model = model.eval().to(dtype) + torch.distributed.barrier(group=self.process_group) + super(FlashNeoX, self).__init__( + tokenizer=tokenizer, + device=device, + ) + + @staticmethod + def load_weights( + model, + filenames: List[str], + quantize: bool, + device: torch.device, + rank: int, + world_size: int, + ): + parameters = dict(model.named_parameters()) + for file in filenames: + with safe_open( + file, framework="pt", device=str(device) if not quantize else "cpu" + ) as f: + for name in f.keys(): + module_name, param_name = name.rsplit(".", 1) + module = model.get_submodule(module_name) + + current_parameter_tensor = parameters.get(name, None) + + slice_ = f.get_slice(name) + + if isinstance(module, ColumnParallelLinear): + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + elif isinstance(module, RowParallelLinear): + if param_name == "weight": + size = slice_.get_shape()[1] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[:, start:stop] + else: + tensor = slice_[:] + # XXX: Hack for Rowlinear to add the bias only once. + if rank != 0: + tensor = torch.zeros_like(tensor) + elif isinstance(module, TensorParallelEmbedding): + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings: + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + else: + try: + tensor = slice_[:] + except: + tensor = f.get_tensor(name) + + if ( + current_parameter_tensor is not None + and current_parameter_tensor.shape != tensor.shape + ): + raise ValueError( + f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" + ) + + tensor = tensor.contiguous() + + if current_parameter_tensor is not None: + module._parameters[param_name] = tensor + else: + module._buffers[param_name] = tensor + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlens: torch.Tensor, + max_s: torch.Tensor, + past_key_values: Optional = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.model.gpt_neox.tp_embeddings: + logits, present = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlens=cu_seqlens, + max_s=max_s, + past_key_values=past_key_values, + ) + + # Logits are sharded, so we need to gather them + world_logits = [torch.empty_like(logits) for _ in range(self.world_size)] + torch.distributed.all_gather(world_logits, logits, group=self.process_group) + world_logits = torch.cat(world_logits, dim=1) + + return world_logits, present + # While the model itself is sharded, the embeddings might not as they might not be dividable by num-shard + else: + return super(FlashNeoXSharded, self).forward( + input_ids, position_ids, cu_seqlens, max_s, past_key_values + ) diff --git a/server/text_generation_server/models/flash_neox_modeling.py b/server/text_generation_server/models/flash_neox_modeling.py new file mode 100644 index 00000000..06fab145 --- /dev/null +++ b/server/text_generation_server/models/flash_neox_modeling.py @@ -0,0 +1,589 @@ +import torch + +from torch import nn + +from transformers.modeling_utils import PreTrainedModel +from transformers.models.gpt_neox import GPTNeoXConfig +from einops import rearrange +from flash_attn.flash_attn_interface import ( + flash_attn_unpadded_qkvpacked_func, + flash_attn_unpadded_kvpacked_func, +) +from flash_attn.ops.fused_dense import ( + FusedDense, + ColumnParallelLinear, + RowParallelLinear, + fused_mlp_func, +) +from flash_attn.layers.rotary import RotaryEmbedding, apply_rotary_emb_qkv_ +from flash_attn.ops.layer_norm import dropout_add_layer_norm + + +class TensorParallelEmbedding(nn.Embedding): + def __init__( + self, + num_embeddings, + embedding_dim, + process_group: torch.distributed.ProcessGroup, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + _weight=None, + device=None, + dtype=None, + ): + self.process_group = process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + + self.original_num_embeddings = num_embeddings + + assert num_embeddings % self.tp_world_size == 0 + block_size = num_embeddings // self.tp_world_size + # inputs in `[min_id, max_id[` are handled by `self` to get embeddings + self.min_id = self.tp_rank * block_size + self.max_id = (self.tp_rank + 1) * block_size + + super().__init__( + block_size, + embedding_dim, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + _weight=_weight, + device=device, + dtype=dtype, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + # Sanity check + if torch.any( + torch.logical_or(0 > input, input >= self.original_num_embeddings) + ): + raise IndexError( + f"Input is required to be in [0, {self.original_num_embeddings}[, got min: {torch.min(input)} and max: {torch.max(input)}" + ) + + # `0` if input is in the correct interval, else `1` + input_mask = torch.logical_or(self.min_id > input, input >= self.max_id) + # translate for [0, self.max_id - self.min_id[ + input = input - self.min_id + # default all out of bounds values to `0` + input[input_mask] = 0 + out = super().forward(input) + out[input_mask] = 0.0 + torch.distributed.all_reduce(out, group=self.process_group) + return out + + +class PositionRotaryEmbedding(RotaryEmbedding): + def forward(self, qkv: torch.Tensor, position_ids: torch.Tensor): + assert self.scale is None + + self._update_cos_sin_cache(qkv, position_ids.max() + 1) + + cos = self._cos_cached[position_ids] + sin = self._sin_cached[position_ids] + + return apply_rotary_emb_qkv_(qkv, cos, sin, None, None) + + +class FlashNeoxAttention(torch.nn.Module): + def __init__( + self, num_heads, hidden_size, rotary_pct, rotary_emb_base, process_group=None + ): + super().__init__() + self.num_heads = num_heads + self.hidden_size = hidden_size + self.head_size = hidden_size // num_heads + + rotary_ndims = int(self.head_size * rotary_pct) + self.rotary_emb = PositionRotaryEmbedding(rotary_ndims, base=rotary_emb_base) + self.softmax_scale = self.head_size ** (-0.5) + + if process_group is None: + self.query_key_value = FusedDense(hidden_size, 3 * hidden_size) + self.dense = FusedDense(hidden_size, hidden_size) + else: + self.num_heads = self.num_heads // process_group.size() + self.query_key_value = ColumnParallelLinear( + hidden_size, + 3 * hidden_size, + process_group=process_group, + sequence_parallel=False, + ) + self.dense = RowParallelLinear( + hidden_size, + hidden_size, + process_group=process_group, + sequence_parallel=False, + ) + + def forward( + self, hidden_states, position_ids, cu_seqlens, max_s, layer_past, prefill + ): + qkv = self.query_key_value(hidden_states) + qkv = rearrange( + qkv, "... (h three d) -> ... h three d", three=3, d=self.head_size + ).permute(0, 2, 1, 3) + qkv_rot = self.rotary_emb(qkv.unsqueeze(0), position_ids).squeeze(0) + + if prefill: + layer_past[...] = qkv_rot[:, 1:] + + # test flash_attn_unpadded_qkvpacked_split_func + attn_output = flash_attn_unpadded_qkvpacked_func( + qkv_rot, cu_seqlens, max_s, 0.0, self.softmax_scale, causal=True + ) + else: + query = qkv_rot[:, 0] + layer_past[cu_seqlens[1:] - 1] = qkv_rot[:, 1:] + + attn_output = flash_attn_unpadded_kvpacked_func( + query, + layer_past, + cu_seqlens_q=torch.arange(len(cu_seqlens), dtype=torch.int32).to( + query.device + ), + max_seqlen_q=torch.tensor(1, dtype=torch.int32).to(query.device), + cu_seqlens_k=cu_seqlens, + max_seqlen_k=max_s, + dropout_p=0.0, + softmax_scale=self.softmax_scale, + causal=False, + ) + + return self.dense(rearrange(attn_output, "... h d -> ... (h d)")) + + +class FlashMLP(nn.Module): + def __init__(self, act, hidden_size, intermediate_size, process_group=None): + super().__init__() + if "gelu" in act: + act = "gelu_approx" + assert act in ["gelu_approx", "relu"] + self.act = act + + if process_group is None: + self.dense_h_to_4h = FusedDense(hidden_size, intermediate_size) + self.dense_4h_to_h = FusedDense(intermediate_size, hidden_size) + else: + self.dense_h_to_4h = ColumnParallelLinear( + hidden_size, + intermediate_size, + process_group=process_group, + sequence_parallel=False, + ) + self.dense_4h_to_h = RowParallelLinear( + intermediate_size, + hidden_size, + process_group=process_group, + sequence_parallel=False, + ) + self.heuristic = "auto" + self.process_group = process_group + + def forward(self, x): + if self.heuristic == "auto": + if self.act == "gelu_approx": + cuda_ver = tuple(map(int, torch.version.cuda.split("."))) + self.heuristic = ( + 0 + if cuda_ver >= (11, 8) + else (1 if x.dtype == torch.float16 else -1) + ) + else: + self.heuristic = 0 + + out = fused_mlp_func( + x, + self.dense_h_to_4h.weight, + self.dense_4h_to_h.weight, + self.dense_h_to_4h.bias, + self.dense_4h_to_h.bias, + activation=self.act, + save_pre_act=self.training, + checkpoint_lvl=0, + heuristic=self.heuristic, + process_group=self.process_group, + sequence_parallel=False, + ) + if self.process_group is not None: + torch.distributed.all_reduce(out, group=self.process_group) + return out + + +class FlashNeoXLayer(nn.Module): + def __init__( + self, + num_heads, + act, + hidden_size, + intermediate_size, + rotary_pct, + rotary_emb_base, + layer_norm_eps, + use_parallel_residual, + process_group=None, + ): + super().__init__() + self.use_parallel_residual = use_parallel_residual + self.input_layernorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) + self.post_attention_layernorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) + self.attention = FlashNeoxAttention( + num_heads, hidden_size, rotary_pct, rotary_emb_base, process_group + ) + self.mlp = FlashMLP(act, hidden_size, intermediate_size, process_group) + + def forward( + self, + hidden_states, + residual, + position_ids, + cu_seqlens, + max_s, + layer_past, + prefill, + ): + if self.use_parallel_residual: + ln1_hidden_states = dropout_add_layer_norm( + hidden_states, + residual, + self.input_layernorm.weight, + self.input_layernorm.bias, + 0.0, + self.input_layernorm.eps, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + ) + attn_output = self.attention( + ln1_hidden_states, position_ids, cu_seqlens, max_s, layer_past, prefill + ) + + ln2_hidden_states = dropout_add_layer_norm( + hidden_states, + residual, + self.post_attention_layernorm.weight, + self.post_attention_layernorm.bias, + 0.0, + self.post_attention_layernorm.eps, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + ) + mlp_output = self.mlp(ln2_hidden_states) + return mlp_output + attn_output + hidden_states, None + + else: + hidden_states, residual = dropout_add_layer_norm( + hidden_states, + residual, + self.input_layernorm.weight, + self.input_layernorm.bias, + 0.0, + self.input_layernorm.eps, + rowscale=None, + prenorm=True, + residual_in_fp32=True, + ) + + hidden_states = self.attention( + hidden_states, position_ids, cu_seqlens, max_s, layer_past, prefill + ) + + hidden_states, residual = dropout_add_layer_norm( + hidden_states, + residual, + self.post_attention_layernorm.weight, + self.post_attention_layernorm.bias, + 0.0, + self.post_attention_layernorm.eps, + rowscale=None, + prenorm=True, + residual_in_fp32=True, + ) + + mlp_output = self.mlp(hidden_states) + + return mlp_output, residual + + +class FlashGPTNeoXPreTrainedModel(PreTrainedModel): + config_class = GPTNeoXConfig + base_model_prefix = "gpt_neox" + supports_gradient_checkpointing = False + _no_split_modules = None + + +class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): + def __init__(self, config, process_group=None): + super().__init__(config) + self.config = config + + self.tp_embeddings = False + if process_group is not None: + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + if config.vocab_size % self.tp_world_size == 0: + self.tp_embeddings = True + + if self.tp_embeddings: + self.embed_in = TensorParallelEmbedding( + config.vocab_size, config.hidden_size, process_group=process_group + ) + else: + self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size) + + self.layers = nn.ModuleList( + [ + FlashNeoXLayer( + config.num_attention_heads, + config.hidden_act, + config.hidden_size, + config.intermediate_size, + config.rotary_pct, + config.rotary_emb_base, + config.layer_norm_eps, + config.use_parallel_residual, + process_group, + ) + for _ in range(config.num_hidden_layers) + ] + ) + self.final_layer_norm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) + + self.gradient_checkpointing = False + + self.head_size = self.layers[0].attention.head_size + self.num_heads = self.layers[0].attention.num_heads + + def forward( + self, + input_ids, + position_ids, + cu_seqlens, + max_s, + past_key_values=None, + ): + hidden_states = self.embed_in(input_ids) + + prefill = False + if past_key_values is None: + past_key_values = hidden_states.new_empty( + ( + len(self.layers), + len(hidden_states), + 2, + self.num_heads, + self.head_size, + ) + ) + prefill = True + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + position_ids, + cu_seqlens, + max_s, + past_key_values[i], + prefill, + ) + + hidden_states = dropout_add_layer_norm( + hidden_states, + residual, + self.final_layer_norm.weight, + self.final_layer_norm.bias, + 0.0, + self.final_layer_norm.eps, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + ) + + return hidden_states, past_key_values + + +class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + if config.tp_parallel: + process_group = torch.distributed.distributed_c10d._get_default_group() + else: + process_group = None + + self.gpt_neox = FlashGPTNeoXModel(config, process_group) + + if self.gpt_neox.tp_embeddings: + self.embed_out = FusedDense( + config.hidden_size, + config.vocab_size // process_group.size(), + bias=False, + ) + else: + self.embed_out = FusedDense( + config.hidden_size, config.vocab_size, bias=False + ) + + def forward( + self, + input_ids, + position_ids, + cu_seqlens, + max_s, + past_key_values=None, + ): + hidden_states, present = self.gpt_neox( + input_ids, position_ids, cu_seqlens, max_s, past_key_values + ) + return self.embed_out(hidden_states), present + + +if __name__ == "__main__": + from transformers import AutoTokenizer + from flash_attn.bert_padding import unpad_input + + model = ( + FlashGPTNeoXForCausalLM.from_pretrained("EleutherAI/pythia-160m") + .cuda() + .to(torch.half) + ) + + tokenizer = AutoTokenizer.from_pretrained( + "EleutherAI/pythia-160m", padding_side="left" + ) + tokenizer.pad_token = tokenizer.eos_token + + tokenized_inputs = tokenizer( + ["What is this?\n\nA:\n\nThe answer to the problem?", "hello!"], + padding=True, + return_tensors="pt", + ).to("cuda") + + input_ids, indices, cu_seqlens, max_seqlen = unpad_input( + tokenized_inputs["input_ids"].unsqueeze(-1), tokenized_inputs["attention_mask"] + ) + + position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 + position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 0) + + unpad_position_ids = torch.gather(position_ids.view(-1).cuda(), 0, indices) + + gen_input_ids = input_ids.squeeze(1).cuda().clone() + gen_position_ids = unpad_position_ids.clone() + gen_indices = indices.clone() + gen_cu_seqlens = cu_seqlens.clone() + gen_max_seqlen = max_seqlen + + past_key_values = None + + results = [] + with torch.no_grad(): + out, present, _ = model( + gen_input_ids, + gen_position_ids, + gen_cu_seqlens, + gen_max_seqlen, + past_key_values=past_key_values, + ) + + futures = [] + new_gen_cu_seqlens = [0] + new_position_ids = [] + next_token_ids = [] + + for i in range(len(gen_cu_seqlens) - 1): + start_index = gen_cu_seqlens[i] + end_index = gen_cu_seqlens[i + 1] + + seq_logits = out[start_index:end_index] + next_token_id = torch.argmax(seq_logits[-1:], dim=1) + next_token_ids.append(next_token_id) + + sequence_length = end_index - start_index + new_gen_cu_seqlens.append(new_gen_cu_seqlens[i] + sequence_length + 1) + + seq_position_ids = gen_position_ids[start_index:end_index] + new_position_ids.append( + torch.concat([seq_position_ids, seq_position_ids[-1:] + 1]) + ) + + seq_present = present[:, start_index:end_index] + future = torch.nn.functional.pad(seq_present, (0, 0, 0, 0, 0, 0, 0, 1)) + + futures.append(future) + + past_key_values = torch.concat(futures, dim=1) + new_position_ids = torch.concat(new_position_ids) + new_gen_cu_seqlens = torch.tensor( + new_gen_cu_seqlens, device=past_key_values.device, dtype=torch.int32 + ) + next_token_ids = torch.concat(next_token_ids) + + gen_max_seqlen += 1 + + gen_input_ids = next_token_ids + gen_position_ids = new_position_ids + gen_cu_seqlens = new_gen_cu_seqlens + + print(tokenizer.batch_decode(gen_input_ids)) + + for _ in range(40): + out, present, _ = model( + gen_input_ids, + gen_position_ids, + gen_cu_seqlens, + gen_max_seqlen, + past_key_values=past_key_values, + ) + + futures = [] + new_gen_cu_seqlens = [0] + new_position_ids = [] + next_token_ids = [] + for i in range(len(gen_cu_seqlens) - 1): + start_index = gen_cu_seqlens[i] + end_index = gen_cu_seqlens[i + 1] + + seq_logits = out[i] + next_token_id = torch.argmax(seq_logits.view(1, -1)[-1:], dim=1) + next_token_ids.append(next_token_id) + + sequence_length = end_index - start_index + new_gen_cu_seqlens.append(new_gen_cu_seqlens[i] + sequence_length + 1) + + seq_position_ids = gen_position_ids[start_index:end_index] + new_position_ids.append( + torch.concat([seq_position_ids, seq_position_ids[-1:] + 1]) + ) + + seq_present = present[:, start_index:end_index] + future = torch.nn.functional.pad(seq_present, (0, 0, 0, 0, 0, 0, 0, 1)) + + futures.append(future) + + past_key_values = torch.concat(futures, dim=1) + new_position_ids = torch.concat(new_position_ids) + new_gen_cu_seqlens = torch.tensor( + new_gen_cu_seqlens, device=past_key_values.device, dtype=torch.int32 + ) + next_token_ids = torch.concat(next_token_ids) + + gen_max_seqlen += 1 + + gen_input_ids = next_token_ids + gen_position_ids = new_position_ids + gen_cu_seqlens = new_gen_cu_seqlens + + print(tokenizer.batch_decode(gen_input_ids))