From d199c71a32674f0d0b7d0bf0c2eab8898381a0f3 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 23 Mar 2023 17:47:15 +0100 Subject: [PATCH] make neox go brrr --- Dockerfile | 9 +- server/Makefile | 17 ++- .../text_generation_server/models/__init__.py | 19 ++- .../models/flash_neox.py | 117 ++++++++++++++---- .../models/flash_neox_modeling.py | 96 ++++++++------ server/text_generation_server/utils/tokens.py | 2 +- .../text_generation_server/utils/watermark.py | 29 +++-- 7 files changed, 203 insertions(+), 86 deletions(-) diff --git a/Dockerfile b/Dockerfile index 5fbf8985..592f1f72 100644 --- a/Dockerfile +++ b/Dockerfile @@ -43,7 +43,7 @@ ENV LANG=C.UTF-8 \ CONDA_DEFAULT_ENV=text-generation \ PATH=$PATH:/opt/miniconda/envs/text-generation/bin:/opt/miniconda/bin:/usr/local/cuda/bin -RUN apt-get update && apt-get install -y unzip curl libssl-dev && rm -rf /var/lib/apt/lists/* +RUN apt-get update && apt-get install -y git curl libssl-dev && rm -rf /var/lib/apt/lists/* RUN cd ~ && \ curl -L -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ @@ -53,10 +53,13 @@ RUN cd ~ && \ WORKDIR /usr/src +# Install torch +RUN pip install torch --extra-index-url https://download.pytorch.org/whl/cu118 --no-cache-dir + COPY server/Makefile server/Makefile -# Install specific version of torch -RUN cd server && make install-torch +# Install specific version of flash attention +RUN cd server && make install-flash-attention # Install specific version of transformers RUN cd server && BUILD_EXTENSIONS="True" make install-transformers diff --git a/server/Makefile b/server/Makefile index e8b0364e..69ef9bc5 100644 --- a/server/Makefile +++ b/server/Makefile @@ -1,4 +1,5 @@ transformers_commit := 2b57aa18da658e7d2f42ef6bd5b56751af582fef +flash_att_commit := 4d87e4d875077ad9efd25030efa4ab0ba92c19e1 gen-server: # Compile protos @@ -12,13 +13,19 @@ install-transformers: # Install specific version of transformers with custom cuda kernels pip uninstall transformers -y || true rm -rf transformers || true - rm -rf transformers-$(transformers_commit) || true - curl -L -O https://github.com/OlivierDehaene/transformers/archive/$(transformers_commit).zip - unzip $(transformers_commit).zip - rm $(transformers_commit).zip - mv transformers-$(transformers_commit) transformers + git clone https://github.com/OlivierDehaene/transformers.git + cd transformers && git checkout $(transformers_commit) cd transformers && python setup.py install +install-flash-attention: + # Install specific version of flash attention + pip install packaging + pip uninstall flash_attn rotary_emb dropout_layer_norm -y || true + rm -rf flash-attention || true + git clone https://github.com/HazyResearch/flash-attention.git + cd flash-attention && git checkout $(flash_att_commit) + cd flash-attention && python setup.py install && cd csrc/layer_norm && python setup.py install && cd ../rotary && python setup.py install + install-torch: # Install specific version of torch pip install torch --extra-index-url https://download.pytorch.org/whl/cu118 --no-cache-dir diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 29d67c7b..d4edb8f7 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -11,7 +11,12 @@ from text_generation_server.models.galactica import Galactica, GalacticaSharded from text_generation_server.models.santacoder import SantaCoder from text_generation_server.models.gpt_neox import GPTNeoxSharded from text_generation_server.models.t5 import T5Sharded -from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded + +try: + from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded + FLASH_NEOX = torch.cuda.is_available() +except ImportError: + FLASH_NEOX = False __all__ = [ "Model", @@ -27,6 +32,10 @@ __all__ = [ "get_model", ] +if FLASH_NEOX: + __all__.append(FlashNeoX) + __all__.append(FlashNeoXSharded) + # The flag below controls whether to allow TF32 on matmul. This flag defaults to False # in PyTorch 1.12 and later. torch.backends.cuda.matmul.allow_tf32 = True @@ -39,7 +48,7 @@ torch.set_grad_enabled(False) def get_model( - model_id: str, revision: Optional[str], sharded: bool, quantize: bool + model_id: str, revision: Optional[str], sharded: bool, quantize: bool ) -> Model: if "facebook/galactica" in model_id: if sharded: @@ -60,9 +69,11 @@ def get_model( if config.model_type == "gpt_neox": if sharded: - return FlashNeoXSharded(model_id, revision, quantize=quantize) + neox_cls = FlashNeoXSharded if FLASH_NEOX else GPTNeoxSharded + return neox_cls(model_id, revision, quantize=quantize) else: - return FlashNeoX(model_id, revision, quantize=quantize) + neox_cls = FlashNeoX if FLASH_NEOX else CausalLM + return neox_cls(model_id, revision, quantize=quantize) if config.model_type == "t5": if sharded: diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 97b9f0b5..2d3c6d8e 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -79,30 +79,41 @@ class FlashNeoXBatch(Batch): next_token_choosers = [] stopping_criterias = [] + # Cumulative length + cumulative_length = 0 + # Parse batch for r in pb.requests: tokenized_input = ( tokenizer(r.inputs, return_tensors="pt")["input_ids"] - .to(device) .squeeze(0) ) input_ids.append(tokenized_input) all_input_ids.append(tokenized_input.tolist()) + + input_length = len(tokenized_input) + max_seqlen = max(max_seqlen, input_length) + input_lengths.append(input_length) + + # Position ids position_ids.append( - torch.arange(0, len(tokenized_input), dtype=torch.int32, device=device) + torch.arange(0, input_length, dtype=torch.int32) ) - input_lengths.append(len(tokenized_input)) - cu_seqlens.append(len(tokenized_input)) - max_seqlen = max(max_seqlen, len(tokenized_input)) + + # Add cumulative lengths of all previous inputs + cu_seqlens.append(cumulative_length + input_length) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) stopping_criterias.append( StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) ) + # Update + cumulative_length += input_length + input_ids = torch.concat(input_ids).unsqueeze(1) position_ids = torch.concat(position_ids) - cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32) return cls( batch_id=pb.id, @@ -121,7 +132,62 @@ class FlashNeoXBatch(Batch): @classmethod @tracer.start_as_current_span("concatenate") def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": - raise NotImplementedError + # Batch attributes + requests = [] + input_lengths = [] + all_input_ids = [] + next_token_choosers = [] + stopping_criterias = [] + + # Batch tensors + input_ids = [] + position_ids = [] + cu_seqlens = [torch.tensor([0], dtype=torch.int32)] + max_seqlen = 0 + past_key_values = [] + + # Cumulative length + cumulative_length = torch.tensor(0) + + for i, batch in enumerate(batches): + requests.extend(batch.requests) + input_lengths.extend(batch.input_lengths) + all_input_ids.extend(batch.all_input_ids) + next_token_choosers.extend(batch.next_token_choosers) + stopping_criterias.extend(batch.stopping_criterias) + + # Add cumulative lengths of all previous inputs + cu_seqlens.append(batch.cu_seqlens[1:] + cumulative_length) + + input_ids.append(batch.input_ids) + position_ids.append(batch.position_ids) + past_key_values.append(batch.past_key_values) + + max_seqlen = max(max_seqlen, batch.max_seqlen) + + # Update + cumulative_length += batch.cu_seqlens[-1] + + input_ids = torch.concat(input_ids) + position_ids = torch.concat(position_ids) + # Concat on dim=1 as first dim represents the model layers + past_key_values = torch.concat(past_key_values, dim=1) + cu_seqlens = torch.concat(cu_seqlens) + + return FlashNeoXBatch( + batch_id=batches[0].batch_id, + requests=requests, + input_ids=input_ids, + position_ids=position_ids, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + past_key_values=past_key_values, + input_lengths=input_lengths, + all_input_ids=all_input_ids, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + ) + def __len__(self): return len(self.requests) @@ -191,16 +257,19 @@ class FlashNeoX(Model): def generate_token( self, batch: FlashNeoXBatch ) -> Tuple[List[Generation], Optional[FlashNeoXBatch]]: + # Better to send to device here to avoid device issues in concatenate + position_ids = batch.position_ids.to(self.device, non_blocking=True) + cu_seqlens = batch.cu_seqlens.to(self.device, non_blocking=True) + input_ids = batch.input_ids.squeeze(1).to(self.device) + out, present = self.forward( - batch.input_ids.squeeze(1), - batch.position_ids, - batch.cu_seqlens, + input_ids, + position_ids, + cu_seqlens, batch.max_seqlen, batch.past_key_values, ) - device = out.device - # List of indices to cache next_batch_keep_indices = [] @@ -253,7 +322,8 @@ class FlashNeoX(Model): next_token_id, logprobs = next_token_chooser( all_input_ids, logits ) - next_token_id = next_token_id.to("cpu") + # Copy to cpu to avoid other copies when indexing and calling .item() + next_token_id = next_token_id.to("cpu", non_blocking=True) logprobs = logprobs.to("cpu") next_token_id_squeezed = next_token_id.squeeze() @@ -261,7 +331,6 @@ class FlashNeoX(Model): # Append next token to all tokens all_input_ids.append(next_token_id_item) - # all_input_ids = torch.cat([all_input_ids, next_token_id.squeeze(1)]) new_input_length = input_length + 1 # Generated token @@ -292,16 +361,20 @@ class FlashNeoX(Model): ) else: # Keep request in the batch + next_batch_keep_indices.append(i) + generated_text = None + + # Get sequence present seq_present = present[:, start_index:end_index] + # Pad it for next iter attention past = torch.nn.functional.pad(seq_present, (0, 0, 0, 0, 0, 0, 0, 1)) next_batch_past_key_values.append(past) - generated_text = None - next_batch_keep_indices.append(i) next_batch_input_ids.append(next_token_id) next_batch_position_ids.append(input_length) + # Cumulative sum next_batch_cu_seqlens.append( - next_batch_cu_seqlens[i] + new_input_length + next_batch_cu_seqlens[-1] + new_input_length ) next_batch_input_lengths.append(new_input_length) next_batch_all_input_ids.append(all_input_ids) @@ -360,16 +433,16 @@ class FlashNeoX(Model): # Create final next batch tensors next_batch_position_ids = torch.tensor( - next_batch_position_ids, dtype=torch.int32, device=device + next_batch_position_ids, dtype=torch.int32 ) next_batch_cu_seqlens = torch.tensor( - next_batch_cu_seqlens, dtype=torch.int32, device=device + next_batch_cu_seqlens, dtype=torch.int32 ) if len(next_batch_keep_indices) > 1: - next_batch_input_ids = torch.concat(next_batch_input_ids, dim=0) - next_batch_past_key_values = torch.concat(next_batch_past_key_values) + next_batch_input_ids = torch.concat(next_batch_input_ids) + next_batch_past_key_values = torch.concat(next_batch_past_key_values, dim=1) else: - next_batch_input_ids = next_batch_input_ids[0].to(device) + next_batch_input_ids = next_batch_input_ids[0] next_batch_past_key_values = next_batch_past_key_values[0] next_batch = FlashNeoXBatch( diff --git a/server/text_generation_server/models/flash_neox_modeling.py b/server/text_generation_server/models/flash_neox_modeling.py index 1f8c6c4d..d67ca3c0 100644 --- a/server/text_generation_server/models/flash_neox_modeling.py +++ b/server/text_generation_server/models/flash_neox_modeling.py @@ -4,16 +4,16 @@ import torch.distributed import torch.nn.functional as F from torch import nn +from transformers.activations import ACT2FN from transformers.modeling_utils import PreTrainedModel from transformers.models.gpt_neox import GPTNeoXConfig +# Flash attention imports import rotary_emb import flash_attn_cuda import dropout_layer_norm -import fused_dense_lib as fused_dense_cuda - -from flash_attn.layers.rotary import RotaryEmbedding, apply_rotary_emb_qkv_ +from flash_attn.layers.rotary import RotaryEmbedding class TensorParallelColumnLinear(nn.Linear): @@ -102,7 +102,6 @@ class TensorParallelEmbedding(nn.Embedding): self.original_num_embeddings = num_embeddings - # TODO @thomasw21 fix and remove that constraint assert num_embeddings % self.tp_world_size == 0 block_size = num_embeddings // self.tp_world_size # inputs in `[min_id, max_id[` are handled by `self` to get embeddings @@ -157,24 +156,14 @@ class PositionRotaryEmbedding(RotaryEmbedding): # Don't do einsum, it converts fp32 to fp16 # freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.outer(t, self.inv_freq.to(device=t.device)) - if self.scale is None: - self._cos_cached = torch.cos(freqs).to(dtype) - self._sin_cached = torch.sin(freqs).to(dtype) - else: - power = ( - torch.arange( - seqlen, dtype=self.scale.dtype, device=self.scale.device - ) - - seqlen // 2 - ) / self.scale_base - scale = self.scale.to(device=power.device) ** power.unsqueeze(1) - # We want the multiplication by scale to happen in fp32 - self._cos_cached = (torch.cos(freqs) * scale).to(dtype) - self._sin_cached = (torch.sin(freqs) * scale).to(dtype) - self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) - self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype): + """ + Return cos and sin for the asked position ids + """ + self._update_cos_sin_cache(dtype, position_ids.device, max_s) cos = torch.index_select(self._cos_cached, 0, position_ids) @@ -223,7 +212,9 @@ class FlashNeoxAttention(torch.nn.Module): ) self.swap_dims = True + # TODO: remove and swap dims when loading weights def _swap_dims(self): + """Swap dims for the first inference to avoid an additional permute""" self.query_key_value.weight = torch.nn.Parameter( self.query_key_value.weight.view( self.num_heads, 3, self.head_size, self.hidden_size @@ -256,10 +247,14 @@ class FlashNeoxAttention(torch.nn.Module): qkv = qkv.view(-1, 3, self.num_heads, self.head_size) qkv_rot = self.rotary_emb(qkv, cos, sin) + # Prefill if layer_past_present_indices is None: + # Copy to layer past layer_past[...] = qkv_rot[:, 1:] + # output attn_output = torch.empty_like(qkv[:, 0]) + # flash attention flash_attn_cuda.fwd( qkv[:, 0], qkv[:, 1], @@ -277,11 +272,15 @@ class FlashNeoxAttention(torch.nn.Module): 0, None, ) + # Decode else: query = qkv_rot[:, 0] + # Add present to the layer_past tensor at the correct indices layer_past[layer_past_present_indices] = qkv_rot[:, 1:] + # output attn_output = torch.empty_like(query) + # flash attention flash_attn_cuda.fwd( query, layer_past[:, 0], @@ -306,11 +305,11 @@ class FlashNeoxAttention(torch.nn.Module): class FlashMLP(nn.Module): def __init__(self, act, hidden_size, intermediate_size, process_group=None): super().__init__() - if "gelu" in act: - act = "gelu_approx" - assert act in ["gelu_approx", "relu"] - self.is_gelu = act == "gelu_approx" - # self.act = lambda x: F.gelu(x, approximate="tanh") + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu(x, approximate="tanh") + ) if process_group is None: self.dense_h_to_4h = nn.Linear(hidden_size, intermediate_size) @@ -330,20 +329,10 @@ class FlashMLP(nn.Module): self.process_group = process_group def forward(self, hidden_states): - hidden_states, *rest = fused_dense_cuda.linear_act_forward( - hidden_states, - self.dense_h_to_4h.weight, - self.dense_h_to_4h.bias, - self.is_gelu, - False, - 0, - ) - return self.dense_4h_to_h(hidden_states) - # - # hidden_states = self.dense_h_to_4h(hidden_states) - # hidden_states = self.act(hidden_states) - # hidden_states = self.dense_4h_to_h(hidden_states) - # return hidden_states + hidden_states = self.dense_h_to_4h(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dense_4h_to_h(hidden_states) + return hidden_states class FlashNeoXLayer(nn.Module): @@ -381,6 +370,7 @@ class FlashNeoXLayer(nn.Module): cu_seqlens_q, ): if self.use_parallel_residual: + # faster input layer norm ln1_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd( hidden_states, None, @@ -410,6 +400,7 @@ class FlashNeoXLayer(nn.Module): cu_seqlens_q, ) + # faster post attention layer norm ln2_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd( hidden_states, None, @@ -431,6 +422,7 @@ class FlashNeoXLayer(nn.Module): mlp_output = self.mlp(ln2_hidden_states) return mlp_output + attn_output + hidden_states, None else: + # faster input layer norm hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd( hidden_states, residual, @@ -460,6 +452,7 @@ class FlashNeoXLayer(nn.Module): cu_seqlens_q, ) + # faster post attention layer norm hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd( hidden_states, residual, @@ -544,7 +537,9 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): ): hidden_states = self.embed_in(input_ids) + # Prefill if past_key_values is None: + # Create past tensor past_key_values = hidden_states.new_empty( ( len(self.layers), @@ -556,12 +551,16 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): ) layer_past_present_indices = None cu_seqlens_q = None + # Decode else: + # Create indices from cumulative sequence lengths layer_past_present_indices = cu_seqlens[1:] - 1 cu_seqlens_q = torch.arange( len(cu_seqlens), dtype=torch.int32, device=hidden_states.device ) + # Get rotary cos and sin for this forward + # Avoid to index in each layer cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin( position_ids, max_s, hidden_states.dtype ) @@ -580,7 +579,24 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): cu_seqlens_q, ) - hidden_states = self.final_layer_norm(hidden_states) + # Faster final layer norm + hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.final_layer_norm.weight, + self.final_layer_norm.bias, + None, + None, + None, + None, + 0.0, + self.final_layer_norm.eps, + 1.0, + 0, + None, + False, + False, + ) return hidden_states, past_key_values diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index c7594644..597fbe7c 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -24,7 +24,7 @@ class Sampling: self.seed = seed def __call__(self, logits): - probs = torch.nn.functional.softmax(logits) + probs = torch.nn.functional.softmax(logits, -1) next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator) return next_tokens diff --git a/server/text_generation_server/utils/watermark.py b/server/text_generation_server/utils/watermark.py index 8e90a59c..86c777a5 100644 --- a/server/text_generation_server/utils/watermark.py +++ b/server/text_generation_server/utils/watermark.py @@ -17,6 +17,7 @@ import os import torch from transformers import LogitsProcessor +from typing import List, Union GAMMA = os.getenv("WATERMARK_GAMMA", 0.5) DELTA = os.getenv("WATERMARK_DELTA", 2.0) @@ -36,22 +37,29 @@ class WatermarkLogitsProcessor(LogitsProcessor): self.rng = torch.Generator(device=device) self.hash_key = hash_key - def _seed_rng(self, input_ids: torch.LongTensor) -> None: - assert ( - input_ids.shape[-1] >= 1 - ), "requires at least a 1 token prefix sequence to seed rng" - prev_token = input_ids[-1].item() + def _seed_rng(self, input_ids: Union[List[int], torch.LongTensor]): + if isinstance(input_ids, list): + assert (len(input_ids) >= 1 + ), "requires at least a 1 token prefix sequence to seed rng" + prev_token = input_ids[-1] + else: + input_ids = input_ids[0] + assert len(input_ids) == 1 + assert ( + input_ids.shape[-1] >= 1 + ), "requires at least a 1 token prefix sequence to seed rng" + prev_token = input_ids[-1].item() self.rng.manual_seed(self.hash_key * prev_token) def _get_greenlist_ids( - self, input_ids: torch.LongTensor, max_value: int - ) -> list[int]: + self, input_ids: Union[List[int], torch.LongTensor], max_value: int, device: torch.device + ) -> List[int]: # seed the rng using the previous tokens/prefix self._seed_rng(input_ids) greenlist_size = int(max_value * self.gamma) vocab_permutation = torch.randperm( - max_value, device=input_ids.device, generator=self.rng + max_value, device=device, generator=self.rng ) greenlist_ids = vocab_permutation[:greenlist_size] return greenlist_ids @@ -73,10 +81,9 @@ class WatermarkLogitsProcessor(LogitsProcessor): return scores def __call__( - self, input_ids: torch.LongTensor, scores: torch.FloatTensor + self, input_ids: Union[List[int], torch.LongTensor], scores: torch.FloatTensor ) -> torch.FloatTensor: - assert len(input_ids) == 1 - greenlist_ids = self._get_greenlist_ids(input_ids[0], scores.shape[-1]) + greenlist_ids = self._get_greenlist_ids(input_ids, scores.shape[-1], scores.device) green_tokens_mask = self._calc_greenlist_mask( scores=scores, greenlist_token_ids=greenlist_ids )