This commit is contained in:
OlivierDehaene 2023-03-22 11:46:09 +01:00
parent b49dbf2d88
commit 24579c45de
4 changed files with 1108 additions and 3 deletions

View File

@ -11,6 +11,7 @@ from text_generation_server.models.galactica import Galactica, GalacticaSharded
from text_generation_server.models.santacoder import SantaCoder from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.gpt_neox import GPTNeoxSharded from text_generation_server.models.gpt_neox import GPTNeoxSharded
from text_generation_server.models.t5 import T5Sharded from text_generation_server.models.t5 import T5Sharded
from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
__all__ = [ __all__ = [
"Model", "Model",
@ -59,9 +60,9 @@ def get_model(
if config.model_type == "gpt_neox": if config.model_type == "gpt_neox":
if sharded: if sharded:
return GPTNeoxSharded(model_id, revision, quantize=quantize) return FlashNeoXSharded(model_id, revision, quantize=quantize)
else: else:
return CausalLM(model_id, revision, quantize=quantize) return FlashNeoX(model_id, revision, quantize=quantize)
if config.model_type == "t5": if config.model_type == "t5":
if sharded: if sharded:

View File

@ -64,7 +64,6 @@ class CausalLMBatch(Batch):
inputs = [] inputs = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
input_lengths = []
# Parse batch # Parse batch
padding_right_offset = 0 padding_right_offset = 0

View File

@ -0,0 +1,516 @@
import torch
import torch.distributed
from accelerate import init_empty_weights
from dataclasses import dataclass
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
from opentelemetry import trace
from safetensors import safe_open
from transformers import AutoTokenizer, PreTrainedTokenizerBase, AutoConfig
from typing import Optional, Tuple, List, Type, Union
from text_generation_server.models import Model
from text_generation_server.models.flash_neox_modeling import (
FlashGPTNeoXForCausalLM,
TensorParallelEmbedding,
)
from text_generation_server.models.types import (
Batch,
PrefillTokens,
Generation,
GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import (
NextTokenChooser,
StoppingCriteria,
Sampling,
initialize_torch_distributed,
weight_files,
)
tracer = trace.get_tracer(__name__)
@dataclass
class FlashNeoXBatch(Batch):
batch_id: int
requests: List[generate_pb2.Request]
# Decoder values
input_ids: torch.Tensor
position_ids: torch.Tensor
# cumulative sequence lengths
cu_seqlens: torch.Tensor
max_seqlen: torch.Tensor
past_key_values: Optional[torch.Tensor]
# All tokens
all_input_ids: List[torch.Tensor]
# Generation helpers
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
def to_pb(self) -> generate_pb2.Batch:
return generate_pb2.Batch(
id=self.batch_id, requests=self.requests, size=len(self)
)
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
device: torch.device,
) -> "CausalLMBatch":
input_ids = []
position_ids = []
cu_seqlens = [0]
max_seqlen = 0
next_token_choosers = []
stopping_criterias = []
# Parse batch
for r in pb.requests:
tokenized_input = (
tokenizer(r.inputs, return_tensors="pt")["input_ids"]
.to(device)
.squeeze(0)
)
input_ids.append(tokenized_input)
position_ids.append(
torch.arange(0, len(tokenized_input), dtype=torch.int32, device=device)
)
cu_seqlens.append(len(tokenized_input))
max_seqlen = max(max_seqlen, len(tokenized_input))
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criterias.append(
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
)
all_input_ids = input_ids
input_ids = torch.concat(input_ids).unsqueeze(1)
position_ids = torch.concat(position_ids)
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
max_seqlen = torch.tensor(max_seqlen, dtype=torch.int32, device=device)
return cls(
batch_id=pb.id,
requests=pb.requests,
input_ids=input_ids,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
past_key_values=None,
all_input_ids=all_input_ids,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
)
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
raise NotImplementedError
def __len__(self):
return len(self.requests)
class FlashNeoX(Model):
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
else:
raise NotImplementedError("FlashNeoX is only available on GPU")
if quantize:
raise NotImplementedError("FlashNeoX does not support quantization")
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left"
)
self.model = (
FlashGPTNeoXForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
)
.eval()
.cuda()
)
tokenizer.pad_token_id = (
self.model.config.pad_token_id
if self.model.config.pad_token_id is not None
else self.model.config.eos_token_id
)
super(FlashNeoX, self).__init__(
tokenizer=tokenizer,
device=device,
)
@property
def batch_type(self) -> Type[FlashNeoXBatch]:
return FlashNeoXBatch
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
return self.tokenizer.decode(
generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlens: torch.Tensor,
max_s: torch.Tensor,
past_key_values: Optional = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Model Forward
return self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
max_s=max_s,
past_key_values=past_key_values,
)
@tracer.start_as_current_span("generate_token")
def generate_token(
self, batch: FlashNeoXBatch
) -> Tuple[List[Generation], Optional[FlashNeoXBatch]]:
print("pos", batch.position_ids)
print("cu", batch.cu_seqlens)
print("max", batch.max_seqlen)
out, present = self.forward(
batch.input_ids.squeeze(1),
batch.position_ids,
batch.cu_seqlens,
batch.max_seqlen,
batch.past_key_values,
)
# List of indices to cache
next_batch_keep_indices = []
# New values for next forward
next_batch_input_ids = []
next_batch_position_ids = []
next_batch_cu_seqlens = [0]
next_batch_max_seqlen = 0
next_batch_past_key_values = []
next_batch_all_input_ids = []
# Results
generations: List[Generation] = []
# Zipped iterator
iterator = zip(
batch.requests,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
)
# For each member of the batch
for i, (
request,
next_token_chooser,
stopping_criteria,
all_input_ids,
) in enumerate(iterator):
# Indexing metadata
start_index = batch.cu_seqlens[i]
end_index = batch.cu_seqlens[i + 1]
seq_length = end_index - start_index
if batch.past_key_values is None:
# Prefill mode
# out is of shape [cumulative_sequence_lengths, vocab_size]
logits = out[start_index:end_index]
else:
# Decode mode
# out is of shape [batch_size, vocab_size]
logits = out[i].unsqueeze(0)
# Select next token
next_token_id, logprobs = next_token_chooser(
all_input_ids.view(1, -1), logits
)
# Append next token to all tokens
all_input_ids = torch.cat([all_input_ids, next_token_id.squeeze(1)])
new_input_length = seq_length + 1
# Generated token
next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze()
next_token_text = self.decode_token(
next_token_id_squeezed,
)
# Evaluate stopping criteria
stop, reason = stopping_criteria(
next_token_id_squeezed,
next_token_text,
)
if stop:
# Decode generated tokens
output_text = self.decode(all_input_ids)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
)
else:
# Keep request in the batch
seq_present = present[:, start_index:end_index]
past = torch.nn.functional.pad(seq_present, (0, 0, 0, 0, 0, 0, 0, 1))
next_batch_past_key_values.append(past)
generated_text = None
next_batch_keep_indices.append(i)
next_batch_input_ids.append(next_token_id)
next_batch_position_ids.append(new_input_length)
next_batch_cu_seqlens.append(
next_batch_cu_seqlens[i] + new_input_length
)
next_batch_all_input_ids.append(all_input_ids)
next_batch_max_seqlen = max(next_batch_max_seqlen, new_input_length)
# Prefill
if stopping_criteria.current_tokens == 1:
# Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + logprobs.gather(
1, all_input_ids[1:].unsqueeze(1)
).squeeze(1)[:-1].tolist()
prefill_token_ids = all_input_ids[:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens(
prefill_token_ids, prefill_logprobs, prefill_texts
)
else:
prefill_tokens = None
generation = Generation(
request.id,
prefill_tokens,
next_token_id_squeezed,
next_token_logprob,
next_token_text,
next_token_id_squeezed.item() in self.all_special_ids,
generated_text,
)
generations.append(generation)
# We finished all generations in the batch; there is no next batch
if not next_batch_keep_indices:
return generations, None
# If we finished at least one generation, we need to evict the indices of the generations that finished
# from the values of the next batch
if len(next_batch_keep_indices) != len(batch):
# Apply indices to requests, token_choosers and stopping_criterias that need to be cached
next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
next_batch_next_token_choosers = [
batch.next_token_choosers[i] for i in next_batch_keep_indices
]
next_batch_stopping_criterias = [
batch.stopping_criterias[i] for i in next_batch_keep_indices
]
else:
next_batch_requests = batch.requests
next_batch_next_token_choosers = batch.next_token_choosers
next_batch_stopping_criterias = batch.stopping_criterias
# Create final next batch tensors
device = out.device
next_batch_input_ids = torch.concat(next_batch_input_ids, dim=0)
next_batch_position_ids = torch.tensor(
next_batch_position_ids, dtype=torch.int32, device=device
)
next_batch_cu_seqlens = torch.tensor(
next_batch_cu_seqlens, dtype=torch.int32, device=device
)
if len(next_batch_keep_indices) > 1:
next_batch_past_key_values = torch.concat(next_batch_past_key_values)
else:
next_batch_past_key_values = next_batch_past_key_values[0]
next_batch = FlashNeoXBatch(
batch_id=batch.batch_id,
requests=next_batch_requests,
input_ids=next_batch_input_ids,
position_ids=next_batch_position_ids,
cu_seqlens=next_batch_cu_seqlens,
max_seqlen=next_batch_max_seqlen,
past_key_values=next_batch_past_key_values,
all_input_ids=next_batch_all_input_ids,
next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias,
)
return generations, next_batch
class FlashNeoXSharded(FlashNeoX):
def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
):
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
else:
raise NotImplementedError("FlashNeoX is only available on GPU")
if quantize:
raise NotImplementedError("FlashNeoX does not support quantization")
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left"
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, tp_parallel=True
)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
with init_empty_weights():
model = FlashGPTNeoXForCausalLM(config)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
quantize=quantize,
device=device,
rank=self.rank,
world_size=self.world_size,
)
self.model = model.eval().to(dtype)
torch.distributed.barrier(group=self.process_group)
super(FlashNeoX, self).__init__(
tokenizer=tokenizer,
device=device,
)
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: bool,
device: torch.device,
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if not quantize else "cpu"
) as f:
for name in f.keys():
module_name, param_name = name.rsplit(".", 1)
module = model.get_submodule(module_name)
current_parameter_tensor = parameters.get(name, None)
slice_ = f.get_slice(name)
if isinstance(module, ColumnParallelLinear):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, RowParallelLinear):
if param_name == "weight":
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings:
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
try:
tensor = slice_[:]
except:
tensor = f.get_tensor(name)
if (
current_parameter_tensor is not None
and current_parameter_tensor.shape != tensor.shape
):
raise ValueError(
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous()
if current_parameter_tensor is not None:
module._parameters[param_name] = tensor
else:
module._buffers[param_name] = tensor
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlens: torch.Tensor,
max_s: torch.Tensor,
past_key_values: Optional = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.model.gpt_neox.tp_embeddings:
logits, present = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
max_s=max_s,
past_key_values=past_key_values,
)
# Logits are sharded, so we need to gather them
world_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
torch.distributed.all_gather(world_logits, logits, group=self.process_group)
world_logits = torch.cat(world_logits, dim=1)
return world_logits, present
# While the model itself is sharded, the embeddings might not as they might not be dividable by num-shard
else:
return super(FlashNeoXSharded, self).forward(
input_ids, position_ids, cu_seqlens, max_s, past_key_values
)

View File

@ -0,0 +1,589 @@
import torch
from torch import nn
from transformers.modeling_utils import PreTrainedModel
from transformers.models.gpt_neox import GPTNeoXConfig
from einops import rearrange
from flash_attn.flash_attn_interface import (
flash_attn_unpadded_qkvpacked_func,
flash_attn_unpadded_kvpacked_func,
)
from flash_attn.ops.fused_dense import (
FusedDense,
ColumnParallelLinear,
RowParallelLinear,
fused_mlp_func,
)
from flash_attn.layers.rotary import RotaryEmbedding, apply_rotary_emb_qkv_
from flash_attn.ops.layer_norm import dropout_add_layer_norm
class TensorParallelEmbedding(nn.Embedding):
def __init__(
self,
num_embeddings,
embedding_dim,
process_group: torch.distributed.ProcessGroup,
padding_idx=None,
max_norm=None,
norm_type=2.0,
scale_grad_by_freq=False,
sparse=False,
_weight=None,
device=None,
dtype=None,
):
self.process_group = process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.original_num_embeddings = num_embeddings
assert num_embeddings % self.tp_world_size == 0
block_size = num_embeddings // self.tp_world_size
# inputs in `[min_id, max_id[` are handled by `self` to get embeddings
self.min_id = self.tp_rank * block_size
self.max_id = (self.tp_rank + 1) * block_size
super().__init__(
block_size,
embedding_dim,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse,
_weight=_weight,
device=device,
dtype=dtype,
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
# Sanity check
if torch.any(
torch.logical_or(0 > input, input >= self.original_num_embeddings)
):
raise IndexError(
f"Input is required to be in [0, {self.original_num_embeddings}[, got min: {torch.min(input)} and max: {torch.max(input)}"
)
# `0` if input is in the correct interval, else `1`
input_mask = torch.logical_or(self.min_id > input, input >= self.max_id)
# translate for [0, self.max_id - self.min_id[
input = input - self.min_id
# default all out of bounds values to `0`
input[input_mask] = 0
out = super().forward(input)
out[input_mask] = 0.0
torch.distributed.all_reduce(out, group=self.process_group)
return out
class PositionRotaryEmbedding(RotaryEmbedding):
def forward(self, qkv: torch.Tensor, position_ids: torch.Tensor):
assert self.scale is None
self._update_cos_sin_cache(qkv, position_ids.max() + 1)
cos = self._cos_cached[position_ids]
sin = self._sin_cached[position_ids]
return apply_rotary_emb_qkv_(qkv, cos, sin, None, None)
class FlashNeoxAttention(torch.nn.Module):
def __init__(
self, num_heads, hidden_size, rotary_pct, rotary_emb_base, process_group=None
):
super().__init__()
self.num_heads = num_heads
self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads
rotary_ndims = int(self.head_size * rotary_pct)
self.rotary_emb = PositionRotaryEmbedding(rotary_ndims, base=rotary_emb_base)
self.softmax_scale = self.head_size ** (-0.5)
if process_group is None:
self.query_key_value = FusedDense(hidden_size, 3 * hidden_size)
self.dense = FusedDense(hidden_size, hidden_size)
else:
self.num_heads = self.num_heads // process_group.size()
self.query_key_value = ColumnParallelLinear(
hidden_size,
3 * hidden_size,
process_group=process_group,
sequence_parallel=False,
)
self.dense = RowParallelLinear(
hidden_size,
hidden_size,
process_group=process_group,
sequence_parallel=False,
)
def forward(
self, hidden_states, position_ids, cu_seqlens, max_s, layer_past, prefill
):
qkv = self.query_key_value(hidden_states)
qkv = rearrange(
qkv, "... (h three d) -> ... h three d", three=3, d=self.head_size
).permute(0, 2, 1, 3)
qkv_rot = self.rotary_emb(qkv.unsqueeze(0), position_ids).squeeze(0)
if prefill:
layer_past[...] = qkv_rot[:, 1:]
# test flash_attn_unpadded_qkvpacked_split_func
attn_output = flash_attn_unpadded_qkvpacked_func(
qkv_rot, cu_seqlens, max_s, 0.0, self.softmax_scale, causal=True
)
else:
query = qkv_rot[:, 0]
layer_past[cu_seqlens[1:] - 1] = qkv_rot[:, 1:]
attn_output = flash_attn_unpadded_kvpacked_func(
query,
layer_past,
cu_seqlens_q=torch.arange(len(cu_seqlens), dtype=torch.int32).to(
query.device
),
max_seqlen_q=torch.tensor(1, dtype=torch.int32).to(query.device),
cu_seqlens_k=cu_seqlens,
max_seqlen_k=max_s,
dropout_p=0.0,
softmax_scale=self.softmax_scale,
causal=False,
)
return self.dense(rearrange(attn_output, "... h d -> ... (h d)"))
class FlashMLP(nn.Module):
def __init__(self, act, hidden_size, intermediate_size, process_group=None):
super().__init__()
if "gelu" in act:
act = "gelu_approx"
assert act in ["gelu_approx", "relu"]
self.act = act
if process_group is None:
self.dense_h_to_4h = FusedDense(hidden_size, intermediate_size)
self.dense_4h_to_h = FusedDense(intermediate_size, hidden_size)
else:
self.dense_h_to_4h = ColumnParallelLinear(
hidden_size,
intermediate_size,
process_group=process_group,
sequence_parallel=False,
)
self.dense_4h_to_h = RowParallelLinear(
intermediate_size,
hidden_size,
process_group=process_group,
sequence_parallel=False,
)
self.heuristic = "auto"
self.process_group = process_group
def forward(self, x):
if self.heuristic == "auto":
if self.act == "gelu_approx":
cuda_ver = tuple(map(int, torch.version.cuda.split(".")))
self.heuristic = (
0
if cuda_ver >= (11, 8)
else (1 if x.dtype == torch.float16 else -1)
)
else:
self.heuristic = 0
out = fused_mlp_func(
x,
self.dense_h_to_4h.weight,
self.dense_4h_to_h.weight,
self.dense_h_to_4h.bias,
self.dense_4h_to_h.bias,
activation=self.act,
save_pre_act=self.training,
checkpoint_lvl=0,
heuristic=self.heuristic,
process_group=self.process_group,
sequence_parallel=False,
)
if self.process_group is not None:
torch.distributed.all_reduce(out, group=self.process_group)
return out
class FlashNeoXLayer(nn.Module):
def __init__(
self,
num_heads,
act,
hidden_size,
intermediate_size,
rotary_pct,
rotary_emb_base,
layer_norm_eps,
use_parallel_residual,
process_group=None,
):
super().__init__()
self.use_parallel_residual = use_parallel_residual
self.input_layernorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
self.attention = FlashNeoxAttention(
num_heads, hidden_size, rotary_pct, rotary_emb_base, process_group
)
self.mlp = FlashMLP(act, hidden_size, intermediate_size, process_group)
def forward(
self,
hidden_states,
residual,
position_ids,
cu_seqlens,
max_s,
layer_past,
prefill,
):
if self.use_parallel_residual:
ln1_hidden_states = dropout_add_layer_norm(
hidden_states,
residual,
self.input_layernorm.weight,
self.input_layernorm.bias,
0.0,
self.input_layernorm.eps,
rowscale=None,
prenorm=False,
residual_in_fp32=False,
)
attn_output = self.attention(
ln1_hidden_states, position_ids, cu_seqlens, max_s, layer_past, prefill
)
ln2_hidden_states = dropout_add_layer_norm(
hidden_states,
residual,
self.post_attention_layernorm.weight,
self.post_attention_layernorm.bias,
0.0,
self.post_attention_layernorm.eps,
rowscale=None,
prenorm=False,
residual_in_fp32=False,
)
mlp_output = self.mlp(ln2_hidden_states)
return mlp_output + attn_output + hidden_states, None
else:
hidden_states, residual = dropout_add_layer_norm(
hidden_states,
residual,
self.input_layernorm.weight,
self.input_layernorm.bias,
0.0,
self.input_layernorm.eps,
rowscale=None,
prenorm=True,
residual_in_fp32=True,
)
hidden_states = self.attention(
hidden_states, position_ids, cu_seqlens, max_s, layer_past, prefill
)
hidden_states, residual = dropout_add_layer_norm(
hidden_states,
residual,
self.post_attention_layernorm.weight,
self.post_attention_layernorm.bias,
0.0,
self.post_attention_layernorm.eps,
rowscale=None,
prenorm=True,
residual_in_fp32=True,
)
mlp_output = self.mlp(hidden_states)
return mlp_output, residual
class FlashGPTNeoXPreTrainedModel(PreTrainedModel):
config_class = GPTNeoXConfig
base_model_prefix = "gpt_neox"
supports_gradient_checkpointing = False
_no_split_modules = None
class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
def __init__(self, config, process_group=None):
super().__init__(config)
self.config = config
self.tp_embeddings = False
if process_group is not None:
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
if config.vocab_size % self.tp_world_size == 0:
self.tp_embeddings = True
if self.tp_embeddings:
self.embed_in = TensorParallelEmbedding(
config.vocab_size, config.hidden_size, process_group=process_group
)
else:
self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList(
[
FlashNeoXLayer(
config.num_attention_heads,
config.hidden_act,
config.hidden_size,
config.intermediate_size,
config.rotary_pct,
config.rotary_emb_base,
config.layer_norm_eps,
config.use_parallel_residual,
process_group,
)
for _ in range(config.num_hidden_layers)
]
)
self.final_layer_norm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps
)
self.gradient_checkpointing = False
self.head_size = self.layers[0].attention.head_size
self.num_heads = self.layers[0].attention.num_heads
def forward(
self,
input_ids,
position_ids,
cu_seqlens,
max_s,
past_key_values=None,
):
hidden_states = self.embed_in(input_ids)
prefill = False
if past_key_values is None:
past_key_values = hidden_states.new_empty(
(
len(self.layers),
len(hidden_states),
2,
self.num_heads,
self.head_size,
)
)
prefill = True
residual = None
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
position_ids,
cu_seqlens,
max_s,
past_key_values[i],
prefill,
)
hidden_states = dropout_add_layer_norm(
hidden_states,
residual,
self.final_layer_norm.weight,
self.final_layer_norm.bias,
0.0,
self.final_layer_norm.eps,
rowscale=None,
prenorm=False,
residual_in_fp32=False,
)
return hidden_states, past_key_values
class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
def __init__(self, config):
super().__init__(config)
if config.tp_parallel:
process_group = torch.distributed.distributed_c10d._get_default_group()
else:
process_group = None
self.gpt_neox = FlashGPTNeoXModel(config, process_group)
if self.gpt_neox.tp_embeddings:
self.embed_out = FusedDense(
config.hidden_size,
config.vocab_size // process_group.size(),
bias=False,
)
else:
self.embed_out = FusedDense(
config.hidden_size, config.vocab_size, bias=False
)
def forward(
self,
input_ids,
position_ids,
cu_seqlens,
max_s,
past_key_values=None,
):
hidden_states, present = self.gpt_neox(
input_ids, position_ids, cu_seqlens, max_s, past_key_values
)
return self.embed_out(hidden_states), present
if __name__ == "__main__":
from transformers import AutoTokenizer
from flash_attn.bert_padding import unpad_input
model = (
FlashGPTNeoXForCausalLM.from_pretrained("EleutherAI/pythia-160m")
.cuda()
.to(torch.half)
)
tokenizer = AutoTokenizer.from_pretrained(
"EleutherAI/pythia-160m", padding_side="left"
)
tokenizer.pad_token = tokenizer.eos_token
tokenized_inputs = tokenizer(
["What is this?\n\nA:\n\nThe answer to the problem?", "hello!"],
padding=True,
return_tensors="pt",
).to("cuda")
input_ids, indices, cu_seqlens, max_seqlen = unpad_input(
tokenized_inputs["input_ids"].unsqueeze(-1), tokenized_inputs["attention_mask"]
)
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 0)
unpad_position_ids = torch.gather(position_ids.view(-1).cuda(), 0, indices)
gen_input_ids = input_ids.squeeze(1).cuda().clone()
gen_position_ids = unpad_position_ids.clone()
gen_indices = indices.clone()
gen_cu_seqlens = cu_seqlens.clone()
gen_max_seqlen = max_seqlen
past_key_values = None
results = []
with torch.no_grad():
out, present, _ = model(
gen_input_ids,
gen_position_ids,
gen_cu_seqlens,
gen_max_seqlen,
past_key_values=past_key_values,
)
futures = []
new_gen_cu_seqlens = [0]
new_position_ids = []
next_token_ids = []
for i in range(len(gen_cu_seqlens) - 1):
start_index = gen_cu_seqlens[i]
end_index = gen_cu_seqlens[i + 1]
seq_logits = out[start_index:end_index]
next_token_id = torch.argmax(seq_logits[-1:], dim=1)
next_token_ids.append(next_token_id)
sequence_length = end_index - start_index
new_gen_cu_seqlens.append(new_gen_cu_seqlens[i] + sequence_length + 1)
seq_position_ids = gen_position_ids[start_index:end_index]
new_position_ids.append(
torch.concat([seq_position_ids, seq_position_ids[-1:] + 1])
)
seq_present = present[:, start_index:end_index]
future = torch.nn.functional.pad(seq_present, (0, 0, 0, 0, 0, 0, 0, 1))
futures.append(future)
past_key_values = torch.concat(futures, dim=1)
new_position_ids = torch.concat(new_position_ids)
new_gen_cu_seqlens = torch.tensor(
new_gen_cu_seqlens, device=past_key_values.device, dtype=torch.int32
)
next_token_ids = torch.concat(next_token_ids)
gen_max_seqlen += 1
gen_input_ids = next_token_ids
gen_position_ids = new_position_ids
gen_cu_seqlens = new_gen_cu_seqlens
print(tokenizer.batch_decode(gen_input_ids))
for _ in range(40):
out, present, _ = model(
gen_input_ids,
gen_position_ids,
gen_cu_seqlens,
gen_max_seqlen,
past_key_values=past_key_values,
)
futures = []
new_gen_cu_seqlens = [0]
new_position_ids = []
next_token_ids = []
for i in range(len(gen_cu_seqlens) - 1):
start_index = gen_cu_seqlens[i]
end_index = gen_cu_seqlens[i + 1]
seq_logits = out[i]
next_token_id = torch.argmax(seq_logits.view(1, -1)[-1:], dim=1)
next_token_ids.append(next_token_id)
sequence_length = end_index - start_index
new_gen_cu_seqlens.append(new_gen_cu_seqlens[i] + sequence_length + 1)
seq_position_ids = gen_position_ids[start_index:end_index]
new_position_ids.append(
torch.concat([seq_position_ids, seq_position_ids[-1:] + 1])
)
seq_present = present[:, start_index:end_index]
future = torch.nn.functional.pad(seq_present, (0, 0, 0, 0, 0, 0, 0, 1))
futures.append(future)
past_key_values = torch.concat(futures, dim=1)
new_position_ids = torch.concat(new_position_ids)
new_gen_cu_seqlens = torch.tensor(
new_gen_cu_seqlens, device=past_key_values.device, dtype=torch.int32
)
next_token_ids = torch.concat(next_token_ids)
gen_max_seqlen += 1
gen_input_ids = next_token_ids
gen_position_ids = new_position_ids
gen_cu_seqlens = new_gen_cu_seqlens
print(tokenizer.batch_decode(gen_input_ids))