make neox go brrr

This commit is contained in:
OlivierDehaene 2023-03-23 17:47:15 +01:00
parent a4df5bc64a
commit d199c71a32
7 changed files with 203 additions and 86 deletions

View File

@ -43,7 +43,7 @@ ENV LANG=C.UTF-8 \
CONDA_DEFAULT_ENV=text-generation \
PATH=$PATH:/opt/miniconda/envs/text-generation/bin:/opt/miniconda/bin:/usr/local/cuda/bin
RUN apt-get update && apt-get install -y unzip curl libssl-dev && rm -rf /var/lib/apt/lists/*
RUN apt-get update && apt-get install -y git curl libssl-dev && rm -rf /var/lib/apt/lists/*
RUN cd ~ && \
curl -L -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
@ -53,10 +53,13 @@ RUN cd ~ && \
WORKDIR /usr/src
# Install torch
RUN pip install torch --extra-index-url https://download.pytorch.org/whl/cu118 --no-cache-dir
COPY server/Makefile server/Makefile
# Install specific version of torch
RUN cd server && make install-torch
# Install specific version of flash attention
RUN cd server && make install-flash-attention
# Install specific version of transformers
RUN cd server && BUILD_EXTENSIONS="True" make install-transformers

View File

@ -1,4 +1,5 @@
transformers_commit := 2b57aa18da658e7d2f42ef6bd5b56751af582fef
flash_att_commit := 4d87e4d875077ad9efd25030efa4ab0ba92c19e1
gen-server:
# Compile protos
@ -12,13 +13,19 @@ install-transformers:
# Install specific version of transformers with custom cuda kernels
pip uninstall transformers -y || true
rm -rf transformers || true
rm -rf transformers-$(transformers_commit) || true
curl -L -O https://github.com/OlivierDehaene/transformers/archive/$(transformers_commit).zip
unzip $(transformers_commit).zip
rm $(transformers_commit).zip
mv transformers-$(transformers_commit) transformers
git clone https://github.com/OlivierDehaene/transformers.git
cd transformers && git checkout $(transformers_commit)
cd transformers && python setup.py install
install-flash-attention:
# Install specific version of flash attention
pip install packaging
pip uninstall flash_attn rotary_emb dropout_layer_norm -y || true
rm -rf flash-attention || true
git clone https://github.com/HazyResearch/flash-attention.git
cd flash-attention && git checkout $(flash_att_commit)
cd flash-attention && python setup.py install && cd csrc/layer_norm && python setup.py install && cd ../rotary && python setup.py install
install-torch:
# Install specific version of torch
pip install torch --extra-index-url https://download.pytorch.org/whl/cu118 --no-cache-dir

View File

@ -11,7 +11,12 @@ from text_generation_server.models.galactica import Galactica, GalacticaSharded
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.gpt_neox import GPTNeoxSharded
from text_generation_server.models.t5 import T5Sharded
from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
try:
from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
FLASH_NEOX = torch.cuda.is_available()
except ImportError:
FLASH_NEOX = False
__all__ = [
"Model",
@ -27,6 +32,10 @@ __all__ = [
"get_model",
]
if FLASH_NEOX:
__all__.append(FlashNeoX)
__all__.append(FlashNeoXSharded)
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True
@ -39,7 +48,7 @@ torch.set_grad_enabled(False)
def get_model(
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
) -> Model:
if "facebook/galactica" in model_id:
if sharded:
@ -60,9 +69,11 @@ def get_model(
if config.model_type == "gpt_neox":
if sharded:
return FlashNeoXSharded(model_id, revision, quantize=quantize)
neox_cls = FlashNeoXSharded if FLASH_NEOX else GPTNeoxSharded
return neox_cls(model_id, revision, quantize=quantize)
else:
return FlashNeoX(model_id, revision, quantize=quantize)
neox_cls = FlashNeoX if FLASH_NEOX else CausalLM
return neox_cls(model_id, revision, quantize=quantize)
if config.model_type == "t5":
if sharded:

View File

@ -79,30 +79,41 @@ class FlashNeoXBatch(Batch):
next_token_choosers = []
stopping_criterias = []
# Cumulative length
cumulative_length = 0
# Parse batch
for r in pb.requests:
tokenized_input = (
tokenizer(r.inputs, return_tensors="pt")["input_ids"]
.to(device)
.squeeze(0)
)
input_ids.append(tokenized_input)
all_input_ids.append(tokenized_input.tolist())
input_length = len(tokenized_input)
max_seqlen = max(max_seqlen, input_length)
input_lengths.append(input_length)
# Position ids
position_ids.append(
torch.arange(0, len(tokenized_input), dtype=torch.int32, device=device)
torch.arange(0, input_length, dtype=torch.int32)
)
input_lengths.append(len(tokenized_input))
cu_seqlens.append(len(tokenized_input))
max_seqlen = max(max_seqlen, len(tokenized_input))
# Add cumulative lengths of all previous inputs
cu_seqlens.append(cumulative_length + input_length)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criterias.append(
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
)
# Update
cumulative_length += input_length
input_ids = torch.concat(input_ids).unsqueeze(1)
position_ids = torch.concat(position_ids)
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32)
return cls(
batch_id=pb.id,
@ -121,7 +132,62 @@ class FlashNeoXBatch(Batch):
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
raise NotImplementedError
# Batch attributes
requests = []
input_lengths = []
all_input_ids = []
next_token_choosers = []
stopping_criterias = []
# Batch tensors
input_ids = []
position_ids = []
cu_seqlens = [torch.tensor([0], dtype=torch.int32)]
max_seqlen = 0
past_key_values = []
# Cumulative length
cumulative_length = torch.tensor(0)
for i, batch in enumerate(batches):
requests.extend(batch.requests)
input_lengths.extend(batch.input_lengths)
all_input_ids.extend(batch.all_input_ids)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
# Add cumulative lengths of all previous inputs
cu_seqlens.append(batch.cu_seqlens[1:] + cumulative_length)
input_ids.append(batch.input_ids)
position_ids.append(batch.position_ids)
past_key_values.append(batch.past_key_values)
max_seqlen = max(max_seqlen, batch.max_seqlen)
# Update
cumulative_length += batch.cu_seqlens[-1]
input_ids = torch.concat(input_ids)
position_ids = torch.concat(position_ids)
# Concat on dim=1 as first dim represents the model layers
past_key_values = torch.concat(past_key_values, dim=1)
cu_seqlens = torch.concat(cu_seqlens)
return FlashNeoXBatch(
batch_id=batches[0].batch_id,
requests=requests,
input_ids=input_ids,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
past_key_values=past_key_values,
input_lengths=input_lengths,
all_input_ids=all_input_ids,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
)
def __len__(self):
return len(self.requests)
@ -191,16 +257,19 @@ class FlashNeoX(Model):
def generate_token(
self, batch: FlashNeoXBatch
) -> Tuple[List[Generation], Optional[FlashNeoXBatch]]:
# Better to send to device here to avoid device issues in concatenate
position_ids = batch.position_ids.to(self.device, non_blocking=True)
cu_seqlens = batch.cu_seqlens.to(self.device, non_blocking=True)
input_ids = batch.input_ids.squeeze(1).to(self.device)
out, present = self.forward(
batch.input_ids.squeeze(1),
batch.position_ids,
batch.cu_seqlens,
input_ids,
position_ids,
cu_seqlens,
batch.max_seqlen,
batch.past_key_values,
)
device = out.device
# List of indices to cache
next_batch_keep_indices = []
@ -253,7 +322,8 @@ class FlashNeoX(Model):
next_token_id, logprobs = next_token_chooser(
all_input_ids, logits
)
next_token_id = next_token_id.to("cpu")
# Copy to cpu to avoid other copies when indexing and calling .item()
next_token_id = next_token_id.to("cpu", non_blocking=True)
logprobs = logprobs.to("cpu")
next_token_id_squeezed = next_token_id.squeeze()
@ -261,7 +331,6 @@ class FlashNeoX(Model):
# Append next token to all tokens
all_input_ids.append(next_token_id_item)
# all_input_ids = torch.cat([all_input_ids, next_token_id.squeeze(1)])
new_input_length = input_length + 1
# Generated token
@ -292,16 +361,20 @@ class FlashNeoX(Model):
)
else:
# Keep request in the batch
next_batch_keep_indices.append(i)
generated_text = None
# Get sequence present
seq_present = present[:, start_index:end_index]
# Pad it for next iter attention
past = torch.nn.functional.pad(seq_present, (0, 0, 0, 0, 0, 0, 0, 1))
next_batch_past_key_values.append(past)
generated_text = None
next_batch_keep_indices.append(i)
next_batch_input_ids.append(next_token_id)
next_batch_position_ids.append(input_length)
# Cumulative sum
next_batch_cu_seqlens.append(
next_batch_cu_seqlens[i] + new_input_length
next_batch_cu_seqlens[-1] + new_input_length
)
next_batch_input_lengths.append(new_input_length)
next_batch_all_input_ids.append(all_input_ids)
@ -360,16 +433,16 @@ class FlashNeoX(Model):
# Create final next batch tensors
next_batch_position_ids = torch.tensor(
next_batch_position_ids, dtype=torch.int32, device=device
next_batch_position_ids, dtype=torch.int32
)
next_batch_cu_seqlens = torch.tensor(
next_batch_cu_seqlens, dtype=torch.int32, device=device
next_batch_cu_seqlens, dtype=torch.int32
)
if len(next_batch_keep_indices) > 1:
next_batch_input_ids = torch.concat(next_batch_input_ids, dim=0)
next_batch_past_key_values = torch.concat(next_batch_past_key_values)
next_batch_input_ids = torch.concat(next_batch_input_ids)
next_batch_past_key_values = torch.concat(next_batch_past_key_values, dim=1)
else:
next_batch_input_ids = next_batch_input_ids[0].to(device)
next_batch_input_ids = next_batch_input_ids[0]
next_batch_past_key_values = next_batch_past_key_values[0]
next_batch = FlashNeoXBatch(

View File

@ -4,16 +4,16 @@ import torch.distributed
import torch.nn.functional as F
from torch import nn
from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel
from transformers.models.gpt_neox import GPTNeoXConfig
# Flash attention imports
import rotary_emb
import flash_attn_cuda
import dropout_layer_norm
import fused_dense_lib as fused_dense_cuda
from flash_attn.layers.rotary import RotaryEmbedding, apply_rotary_emb_qkv_
from flash_attn.layers.rotary import RotaryEmbedding
class TensorParallelColumnLinear(nn.Linear):
@ -102,7 +102,6 @@ class TensorParallelEmbedding(nn.Embedding):
self.original_num_embeddings = num_embeddings
# TODO @thomasw21 fix and remove that constraint
assert num_embeddings % self.tp_world_size == 0
block_size = num_embeddings // self.tp_world_size
# inputs in `[min_id, max_id[` are handled by `self` to get embeddings
@ -157,24 +156,14 @@ class PositionRotaryEmbedding(RotaryEmbedding):
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
if self.scale is None:
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
else:
power = (
torch.arange(
seqlen, dtype=self.scale.dtype, device=self.scale.device
)
- seqlen // 2
) / self.scale_base
scale = self.scale.to(device=power.device) ** power.unsqueeze(1)
# We want the multiplication by scale to happen in fp32
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype):
"""
Return cos and sin for the asked position ids
"""
self._update_cos_sin_cache(dtype, position_ids.device, max_s)
cos = torch.index_select(self._cos_cached, 0, position_ids)
@ -223,7 +212,9 @@ class FlashNeoxAttention(torch.nn.Module):
)
self.swap_dims = True
# TODO: remove and swap dims when loading weights
def _swap_dims(self):
"""Swap dims for the first inference to avoid an additional permute"""
self.query_key_value.weight = torch.nn.Parameter(
self.query_key_value.weight.view(
self.num_heads, 3, self.head_size, self.hidden_size
@ -256,10 +247,14 @@ class FlashNeoxAttention(torch.nn.Module):
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
qkv_rot = self.rotary_emb(qkv, cos, sin)
# Prefill
if layer_past_present_indices is None:
# Copy to layer past
layer_past[...] = qkv_rot[:, 1:]
# output
attn_output = torch.empty_like(qkv[:, 0])
# flash attention
flash_attn_cuda.fwd(
qkv[:, 0],
qkv[:, 1],
@ -277,11 +272,15 @@ class FlashNeoxAttention(torch.nn.Module):
0,
None,
)
# Decode
else:
query = qkv_rot[:, 0]
# Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = qkv_rot[:, 1:]
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
layer_past[:, 0],
@ -306,11 +305,11 @@ class FlashNeoxAttention(torch.nn.Module):
class FlashMLP(nn.Module):
def __init__(self, act, hidden_size, intermediate_size, process_group=None):
super().__init__()
if "gelu" in act:
act = "gelu_approx"
assert act in ["gelu_approx", "relu"]
self.is_gelu = act == "gelu_approx"
# self.act = lambda x: F.gelu(x, approximate="tanh")
self.act = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(x, approximate="tanh")
)
if process_group is None:
self.dense_h_to_4h = nn.Linear(hidden_size, intermediate_size)
@ -330,20 +329,10 @@ class FlashMLP(nn.Module):
self.process_group = process_group
def forward(self, hidden_states):
hidden_states, *rest = fused_dense_cuda.linear_act_forward(
hidden_states,
self.dense_h_to_4h.weight,
self.dense_h_to_4h.bias,
self.is_gelu,
False,
0,
)
return self.dense_4h_to_h(hidden_states)
#
# hidden_states = self.dense_h_to_4h(hidden_states)
# hidden_states = self.act(hidden_states)
# hidden_states = self.dense_4h_to_h(hidden_states)
# return hidden_states
hidden_states = self.dense_h_to_4h(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.dense_4h_to_h(hidden_states)
return hidden_states
class FlashNeoXLayer(nn.Module):
@ -381,6 +370,7 @@ class FlashNeoXLayer(nn.Module):
cu_seqlens_q,
):
if self.use_parallel_residual:
# faster input layer norm
ln1_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
None,
@ -410,6 +400,7 @@ class FlashNeoXLayer(nn.Module):
cu_seqlens_q,
)
# faster post attention layer norm
ln2_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
None,
@ -431,6 +422,7 @@ class FlashNeoXLayer(nn.Module):
mlp_output = self.mlp(ln2_hidden_states)
return mlp_output + attn_output + hidden_states, None
else:
# faster input layer norm
hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
@ -460,6 +452,7 @@ class FlashNeoXLayer(nn.Module):
cu_seqlens_q,
)
# faster post attention layer norm
hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
@ -544,7 +537,9 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
):
hidden_states = self.embed_in(input_ids)
# Prefill
if past_key_values is None:
# Create past tensor
past_key_values = hidden_states.new_empty(
(
len(self.layers),
@ -556,12 +551,16 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
)
layer_past_present_indices = None
cu_seqlens_q = None
# Decode
else:
# Create indices from cumulative sequence lengths
layer_past_present_indices = cu_seqlens[1:] - 1
cu_seqlens_q = torch.arange(
len(cu_seqlens), dtype=torch.int32, device=hidden_states.device
)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
@ -580,7 +579,24 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
cu_seqlens_q,
)
hidden_states = self.final_layer_norm(hidden_states)
# Faster final layer norm
hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.final_layer_norm.weight,
self.final_layer_norm.bias,
None,
None,
None,
None,
0.0,
self.final_layer_norm.eps,
1.0,
0,
None,
False,
False,
)
return hidden_states, past_key_values

View File

@ -24,7 +24,7 @@ class Sampling:
self.seed = seed
def __call__(self, logits):
probs = torch.nn.functional.softmax(logits)
probs = torch.nn.functional.softmax(logits, -1)
next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator)
return next_tokens

View File

@ -17,6 +17,7 @@ import os
import torch
from transformers import LogitsProcessor
from typing import List, Union
GAMMA = os.getenv("WATERMARK_GAMMA", 0.5)
DELTA = os.getenv("WATERMARK_DELTA", 2.0)
@ -36,22 +37,29 @@ class WatermarkLogitsProcessor(LogitsProcessor):
self.rng = torch.Generator(device=device)
self.hash_key = hash_key
def _seed_rng(self, input_ids: torch.LongTensor) -> None:
assert (
input_ids.shape[-1] >= 1
), "requires at least a 1 token prefix sequence to seed rng"
prev_token = input_ids[-1].item()
def _seed_rng(self, input_ids: Union[List[int], torch.LongTensor]):
if isinstance(input_ids, list):
assert (len(input_ids) >= 1
), "requires at least a 1 token prefix sequence to seed rng"
prev_token = input_ids[-1]
else:
input_ids = input_ids[0]
assert len(input_ids) == 1
assert (
input_ids.shape[-1] >= 1
), "requires at least a 1 token prefix sequence to seed rng"
prev_token = input_ids[-1].item()
self.rng.manual_seed(self.hash_key * prev_token)
def _get_greenlist_ids(
self, input_ids: torch.LongTensor, max_value: int
) -> list[int]:
self, input_ids: Union[List[int], torch.LongTensor], max_value: int, device: torch.device
) -> List[int]:
# seed the rng using the previous tokens/prefix
self._seed_rng(input_ids)
greenlist_size = int(max_value * self.gamma)
vocab_permutation = torch.randperm(
max_value, device=input_ids.device, generator=self.rng
max_value, device=device, generator=self.rng
)
greenlist_ids = vocab_permutation[:greenlist_size]
return greenlist_ids
@ -73,10 +81,9 @@ class WatermarkLogitsProcessor(LogitsProcessor):
return scores
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
self, input_ids: Union[List[int], torch.LongTensor], scores: torch.FloatTensor
) -> torch.FloatTensor:
assert len(input_ids) == 1
greenlist_ids = self._get_greenlist_ids(input_ids[0], scores.shape[-1])
greenlist_ids = self._get_greenlist_ids(input_ids, scores.shape[-1], scores.device)
green_tokens_mask = self._calc_greenlist_mask(
scores=scores, greenlist_token_ids=greenlist_ids
)