mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
faster
This commit is contained in:
parent
24579c45de
commit
ead19abb0e
@ -3,7 +3,6 @@ import torch.distributed
|
||||
|
||||
from accelerate import init_empty_weights
|
||||
from dataclasses import dataclass
|
||||
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
|
||||
from opentelemetry import trace
|
||||
from safetensors import safe_open
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase, AutoConfig
|
||||
@ -13,6 +12,8 @@ from text_generation_server.models import Model
|
||||
from text_generation_server.models.flash_neox_modeling import (
|
||||
FlashGPTNeoXForCausalLM,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear
|
||||
)
|
||||
from text_generation_server.models.types import (
|
||||
Batch,
|
||||
@ -42,7 +43,7 @@ class FlashNeoXBatch(Batch):
|
||||
position_ids: torch.Tensor
|
||||
# cumulative sequence lengths
|
||||
cu_seqlens: torch.Tensor
|
||||
max_seqlen: torch.Tensor
|
||||
max_seqlen: int
|
||||
past_key_values: Optional[torch.Tensor]
|
||||
|
||||
# All tokens
|
||||
@ -95,7 +96,6 @@ class FlashNeoXBatch(Batch):
|
||||
input_ids = torch.concat(input_ids).unsqueeze(1)
|
||||
position_ids = torch.concat(position_ids)
|
||||
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
|
||||
max_seqlen = torch.tensor(max_seqlen, dtype=torch.int32, device=device)
|
||||
|
||||
return cls(
|
||||
batch_id=pb.id,
|
||||
@ -168,7 +168,7 @@ class FlashNeoX(Model):
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
max_s: torch.Tensor,
|
||||
max_s: int,
|
||||
past_key_values: Optional = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Model Forward
|
||||
@ -184,10 +184,6 @@ class FlashNeoX(Model):
|
||||
def generate_token(
|
||||
self, batch: FlashNeoXBatch
|
||||
) -> Tuple[List[Generation], Optional[FlashNeoXBatch]]:
|
||||
print("pos", batch.position_ids)
|
||||
print("cu", batch.cu_seqlens)
|
||||
print("max", batch.max_seqlen)
|
||||
|
||||
out, present = self.forward(
|
||||
batch.input_ids.squeeze(1),
|
||||
batch.position_ids,
|
||||
@ -228,7 +224,7 @@ class FlashNeoX(Model):
|
||||
# Indexing metadata
|
||||
start_index = batch.cu_seqlens[i]
|
||||
end_index = batch.cu_seqlens[i + 1]
|
||||
seq_length = end_index - start_index
|
||||
seq_length = (end_index - start_index).item()
|
||||
|
||||
if batch.past_key_values is None:
|
||||
# Prefill mode
|
||||
@ -263,7 +259,7 @@ class FlashNeoX(Model):
|
||||
|
||||
if stop:
|
||||
# Decode generated tokens
|
||||
output_text = self.decode(all_input_ids)
|
||||
output_text = self.decode(all_input_ids[-stopping_criteria.current_tokens :])
|
||||
# Get seed
|
||||
if isinstance(next_token_chooser.choice, Sampling):
|
||||
seed = next_token_chooser.choice.seed
|
||||
@ -282,7 +278,7 @@ class FlashNeoX(Model):
|
||||
generated_text = None
|
||||
next_batch_keep_indices.append(i)
|
||||
next_batch_input_ids.append(next_token_id)
|
||||
next_batch_position_ids.append(new_input_length)
|
||||
next_batch_position_ids.append(seq_length)
|
||||
next_batch_cu_seqlens.append(
|
||||
next_batch_cu_seqlens[i] + new_input_length
|
||||
)
|
||||
@ -435,13 +431,13 @@ class FlashNeoXSharded(FlashNeoX):
|
||||
|
||||
slice_ = f.get_slice(name)
|
||||
|
||||
if isinstance(module, ColumnParallelLinear):
|
||||
if isinstance(module, TensorParallelColumnLinear):
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
elif isinstance(module, RowParallelLinear):
|
||||
elif isinstance(module, TensorParallelRowLinear):
|
||||
if param_name == "weight":
|
||||
size = slice_.get_shape()[1]
|
||||
block_size = size // world_size
|
||||
@ -491,7 +487,7 @@ class FlashNeoXSharded(FlashNeoX):
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
max_s: torch.Tensor,
|
||||
max_s: int,
|
||||
past_key_values: Optional = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.model.gpt_neox.tp_embeddings:
|
||||
|
@ -1,22 +1,90 @@
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch import nn
|
||||
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.models.gpt_neox import GPTNeoXConfig
|
||||
from einops import rearrange
|
||||
|
||||
import rotary_emb
|
||||
import flash_attn_cuda
|
||||
|
||||
from flash_attn.flash_attn_interface import (
|
||||
flash_attn_unpadded_qkvpacked_func,
|
||||
flash_attn_unpadded_kvpacked_func,
|
||||
)
|
||||
from flash_attn.ops.fused_dense import (
|
||||
FusedDense,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
fused_mlp_func,
|
||||
)
|
||||
# from flash_attn.ops.fused_dense import (
|
||||
# FusedDense,
|
||||
# ColumnParallelLinear,
|
||||
# RowParallelLinear,
|
||||
# fused_mlp_func,
|
||||
# )
|
||||
from flash_attn.layers.rotary import RotaryEmbedding, apply_rotary_emb_qkv_
|
||||
from flash_attn.ops.layer_norm import dropout_add_layer_norm
|
||||
|
||||
|
||||
# from flash_attn.ops.layer_norm import dropout_add_layer_norm
|
||||
|
||||
|
||||
class TensorParallelColumnLinear(nn.Linear):
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
out_features,
|
||||
process_group: torch.distributed.ProcessGroup,
|
||||
bias=True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
self.process_group = process_group
|
||||
self.tp_world_size = process_group.size()
|
||||
assert out_features % self.tp_world_size == 0
|
||||
out_features = out_features // self.tp_world_size
|
||||
|
||||
super().__init__(in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
device=device,
|
||||
dtype=dtype)
|
||||
|
||||
@staticmethod
|
||||
def linear(input, weight, bias):
|
||||
return F.linear(input, weight, bias)
|
||||
|
||||
def forward(self, input):
|
||||
return self.linear(input, self.weight, self.bias)
|
||||
|
||||
|
||||
class TensorParallelRowLinear(nn.Linear):
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
out_features,
|
||||
process_group: torch.distributed.ProcessGroup,
|
||||
bias=True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
self.process_group = process_group
|
||||
self.tp_world_size = process_group.size()
|
||||
assert in_features % self.tp_world_size == 0
|
||||
in_features = in_features // self.tp_world_size
|
||||
|
||||
super().__init__(in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
device=device,
|
||||
dtype=dtype)
|
||||
|
||||
@staticmethod
|
||||
def linear(input, weight, bias):
|
||||
return F.linear(input, weight, bias)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
out = self.linear(input, self.weight, self.bias)
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class TensorParallelEmbedding(nn.Embedding):
|
||||
@ -32,7 +100,7 @@ class TensorParallelEmbedding(nn.Embedding):
|
||||
sparse=False,
|
||||
_weight=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
dtype=None
|
||||
):
|
||||
self.process_group = process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
@ -40,33 +108,22 @@ class TensorParallelEmbedding(nn.Embedding):
|
||||
|
||||
self.original_num_embeddings = num_embeddings
|
||||
|
||||
# TODO @thomasw21 fix and remove that constraint
|
||||
assert num_embeddings % self.tp_world_size == 0
|
||||
block_size = num_embeddings // self.tp_world_size
|
||||
# inputs in `[min_id, max_id[` are handled by `self` to get embeddings
|
||||
self.min_id = self.tp_rank * block_size
|
||||
self.max_id = (self.tp_rank + 1) * block_size
|
||||
|
||||
super().__init__(
|
||||
block_size,
|
||||
embedding_dim,
|
||||
padding_idx=padding_idx,
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type,
|
||||
scale_grad_by_freq=scale_grad_by_freq,
|
||||
sparse=sparse,
|
||||
_weight=_weight,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
super().__init__(block_size, embedding_dim, padding_idx=padding_idx, max_norm=max_norm, norm_type=norm_type,
|
||||
scale_grad_by_freq=scale_grad_by_freq, sparse=sparse, _weight=_weight, device=device,
|
||||
dtype=dtype)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
# Sanity check
|
||||
if torch.any(
|
||||
torch.logical_or(0 > input, input >= self.original_num_embeddings)
|
||||
):
|
||||
if torch.any(torch.logical_or(0 > input, input >= self.original_num_embeddings)):
|
||||
raise IndexError(
|
||||
f"Input is required to be in [0, {self.original_num_embeddings}[, got min: {torch.min(input)} and max: {torch.max(input)}"
|
||||
)
|
||||
f"Input is required to be in [0, {self.original_num_embeddings}[, got min: {torch.min(input)} and max: {torch.max(input)}")
|
||||
|
||||
# `0` if input is in the correct interval, else `1`
|
||||
input_mask = torch.logical_or(self.min_id > input, input >= self.max_id)
|
||||
@ -81,15 +138,50 @@ class TensorParallelEmbedding(nn.Embedding):
|
||||
|
||||
|
||||
class PositionRotaryEmbedding(RotaryEmbedding):
|
||||
def forward(self, qkv: torch.Tensor, position_ids: torch.Tensor):
|
||||
assert self.scale is None
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
# Reset the tables if the sequence length has changed,
|
||||
# or if we're on a new device (possibly due to tracing for instance)
|
||||
if (seqlen > self._seq_len_cached or self._cos_cached.device != device
|
||||
or self._cos_cached.dtype != dtype):
|
||||
self._seq_len_cached = seqlen
|
||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||
# Don't do einsum, it converts fp32 to fp16
|
||||
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||
if self.scale is None:
|
||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
else:
|
||||
power = ((torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
|
||||
- seqlen // 2) / self.scale_base)
|
||||
scale = self.scale.to(device=power.device) ** power.unsqueeze(1)
|
||||
# We want the multiplication by scale to happen in fp32
|
||||
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
||||
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
||||
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
||||
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
||||
|
||||
self._update_cos_sin_cache(qkv, position_ids.max() + 1)
|
||||
def forward(self, qkv: torch.Tensor, position_ids: torch.Tensor, max_s: int):
|
||||
self._update_cos_sin_cache(qkv.dtype, qkv.device, max_s)
|
||||
|
||||
cos = self._cos_cached[position_ids]
|
||||
sin = self._sin_cached[position_ids]
|
||||
q1, q2, k1, k2, cos, sin = _prepare_rotary(qkv, self._cos_cached, self._sin_cached, position_ids)
|
||||
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
|
||||
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
|
||||
return qkv
|
||||
|
||||
return apply_rotary_emb_qkv_(qkv, cos, sin, None, None)
|
||||
|
||||
@torch.jit.script
|
||||
def _prepare_rotary(qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor):
|
||||
cos = torch.index_select(cos, 0, position_ids)
|
||||
sin = torch.index_select(sin, 0, position_ids)
|
||||
|
||||
rotary_dim = cos.shape[-1]
|
||||
q1 = qkv[:, 0, :, :rotary_dim]
|
||||
q2 = qkv[:, 0, :, rotary_dim:2*rotary_dim]
|
||||
k1 = qkv[:, 1, :, :rotary_dim]
|
||||
k2 = qkv[:, 1, :, rotary_dim: 2*rotary_dim]
|
||||
|
||||
return q1, q2, k1, k2, cos.unsqueeze(1), sin.unsqueeze(1)
|
||||
|
||||
|
||||
class FlashNeoxAttention(torch.nn.Module):
|
||||
@ -106,115 +198,100 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
||||
if process_group is None:
|
||||
self.query_key_value = FusedDense(hidden_size, 3 * hidden_size)
|
||||
self.dense = FusedDense(hidden_size, hidden_size)
|
||||
self.query_key_value = nn.Linear(hidden_size, 3 * hidden_size)
|
||||
self.dense = nn.Linear(hidden_size, hidden_size)
|
||||
else:
|
||||
self.num_heads = self.num_heads // process_group.size()
|
||||
self.query_key_value = ColumnParallelLinear(
|
||||
self.query_key_value = TensorParallelColumnLinear(
|
||||
hidden_size,
|
||||
3 * hidden_size,
|
||||
process_group=process_group,
|
||||
sequence_parallel=False,
|
||||
)
|
||||
self.dense = RowParallelLinear(
|
||||
self.dense = TensorParallelRowLinear(
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
process_group=process_group,
|
||||
sequence_parallel=False,
|
||||
)
|
||||
self.swap_dims = False
|
||||
|
||||
def _swap_dims(self):
|
||||
self.query_key_value.weight = torch.nn.Parameter(
|
||||
self.query_key_value.weight.view(self.num_heads, 3, self.head_size, self.hidden_size)
|
||||
.permute(1, 0, 2, 3).reshape(-1, self.hidden_size)
|
||||
)
|
||||
self.query_key_value.bias = torch.nn.Parameter(
|
||||
self.query_key_value.bias.view(self.num_heads, 3, self.head_size)
|
||||
.permute(1, 0, 2).reshape(-1)
|
||||
)
|
||||
self.swap_dims = True
|
||||
|
||||
def forward(
|
||||
self, hidden_states, position_ids, cu_seqlens, max_s, layer_past, prefill
|
||||
):
|
||||
if not self.swap_dims:
|
||||
self._swap_dims()
|
||||
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
qkv = rearrange(
|
||||
qkv, "... (h three d) -> ... h three d", three=3, d=self.head_size
|
||||
).permute(0, 2, 1, 3)
|
||||
qkv_rot = self.rotary_emb(qkv.unsqueeze(0), position_ids).squeeze(0)
|
||||
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
||||
qkv_rot = self.rotary_emb(qkv, position_ids, max_s)
|
||||
|
||||
if prefill:
|
||||
layer_past[...] = qkv_rot[:, 1:]
|
||||
|
||||
# test flash_attn_unpadded_qkvpacked_split_func
|
||||
attn_output = flash_attn_unpadded_qkvpacked_func(
|
||||
qkv_rot, cu_seqlens, max_s, 0.0, self.softmax_scale, causal=True
|
||||
attn_output = torch.empty_like(qkv[:, 0])
|
||||
flash_attn_cuda.fwd(
|
||||
qkv[:, 0], qkv[:, 1], qkv[:, 2], attn_output, cu_seqlens, cu_seqlens, max_s, max_s, 0.0,
|
||||
self.softmax_scale,
|
||||
False, True, False, 0, None
|
||||
)
|
||||
else:
|
||||
query = qkv_rot[:, 0]
|
||||
layer_past[cu_seqlens[1:] - 1] = qkv_rot[:, 1:]
|
||||
|
||||
attn_output = flash_attn_unpadded_kvpacked_func(
|
||||
query,
|
||||
layer_past,
|
||||
cu_seqlens_q=torch.arange(len(cu_seqlens), dtype=torch.int32).to(
|
||||
attn_output = torch.empty_like(query)
|
||||
flash_attn_cuda.fwd(
|
||||
query, layer_past[:, 0], layer_past[:, 1], attn_output,
|
||||
torch.arange(len(cu_seqlens), dtype=torch.int32).to(
|
||||
query.device
|
||||
),
|
||||
max_seqlen_q=torch.tensor(1, dtype=torch.int32).to(query.device),
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_k=max_s,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=self.softmax_scale,
|
||||
causal=False,
|
||||
), cu_seqlens, torch.tensor(1, dtype=torch.int32).to(query.device), max_s, 0.0,
|
||||
self.softmax_scale,
|
||||
False, False, False, 0, None
|
||||
)
|
||||
|
||||
return self.dense(rearrange(attn_output, "... h d -> ... (h d)"))
|
||||
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
||||
|
||||
class FlashMLP(nn.Module):
|
||||
def __init__(self, act, hidden_size, intermediate_size, process_group=None):
|
||||
super().__init__()
|
||||
if "gelu" in act:
|
||||
act = "gelu_approx"
|
||||
assert act in ["gelu_approx", "relu"]
|
||||
self.act = act
|
||||
assert "gelu" in act
|
||||
# if "gelu" in act:
|
||||
# act = "gelu_approx"
|
||||
# assert act in ["gelu_approx", "relu"]
|
||||
self.act = lambda x: F.gelu(x, approximate="tanh")
|
||||
|
||||
if process_group is None:
|
||||
self.dense_h_to_4h = FusedDense(hidden_size, intermediate_size)
|
||||
self.dense_4h_to_h = FusedDense(intermediate_size, hidden_size)
|
||||
self.dense_h_to_4h = nn.Linear(hidden_size, intermediate_size)
|
||||
self.dense_4h_to_h = nn.Linear(intermediate_size, hidden_size)
|
||||
else:
|
||||
self.dense_h_to_4h = ColumnParallelLinear(
|
||||
self.dense_h_to_4h = TensorParallelColumnLinear(
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
process_group=process_group,
|
||||
sequence_parallel=False,
|
||||
)
|
||||
self.dense_4h_to_h = RowParallelLinear(
|
||||
self.dense_4h_to_h = TensorParallelRowLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
process_group=process_group,
|
||||
sequence_parallel=False,
|
||||
)
|
||||
self.heuristic = "auto"
|
||||
self.process_group = process_group
|
||||
|
||||
def forward(self, x):
|
||||
if self.heuristic == "auto":
|
||||
if self.act == "gelu_approx":
|
||||
cuda_ver = tuple(map(int, torch.version.cuda.split(".")))
|
||||
self.heuristic = (
|
||||
0
|
||||
if cuda_ver >= (11, 8)
|
||||
else (1 if x.dtype == torch.float16 else -1)
|
||||
)
|
||||
else:
|
||||
self.heuristic = 0
|
||||
|
||||
out = fused_mlp_func(
|
||||
x,
|
||||
self.dense_h_to_4h.weight,
|
||||
self.dense_4h_to_h.weight,
|
||||
self.dense_h_to_4h.bias,
|
||||
self.dense_4h_to_h.bias,
|
||||
activation=self.act,
|
||||
save_pre_act=self.training,
|
||||
checkpoint_lvl=0,
|
||||
heuristic=self.heuristic,
|
||||
process_group=self.process_group,
|
||||
sequence_parallel=False,
|
||||
)
|
||||
if self.process_group is not None:
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
return out
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.dense_h_to_4h(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.dense_4h_to_h(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlashNeoXLayer(nn.Module):
|
||||
@ -250,36 +327,15 @@ class FlashNeoXLayer(nn.Module):
|
||||
prefill,
|
||||
):
|
||||
if self.use_parallel_residual:
|
||||
ln1_hidden_states = dropout_add_layer_norm(
|
||||
hidden_states,
|
||||
residual,
|
||||
self.input_layernorm.weight,
|
||||
self.input_layernorm.bias,
|
||||
0.0,
|
||||
self.input_layernorm.eps,
|
||||
rowscale=None,
|
||||
prenorm=False,
|
||||
residual_in_fp32=False,
|
||||
)
|
||||
attn_output = self.attention(
|
||||
ln1_hidden_states, position_ids, cu_seqlens, max_s, layer_past, prefill
|
||||
self.input_layernorm(hidden_states), position_ids, cu_seqlens, max_s, layer_past, prefill
|
||||
)
|
||||
|
||||
ln2_hidden_states = dropout_add_layer_norm(
|
||||
hidden_states,
|
||||
residual,
|
||||
self.post_attention_layernorm.weight,
|
||||
self.post_attention_layernorm.bias,
|
||||
0.0,
|
||||
self.post_attention_layernorm.eps,
|
||||
rowscale=None,
|
||||
prenorm=False,
|
||||
residual_in_fp32=False,
|
||||
)
|
||||
mlp_output = self.mlp(ln2_hidden_states)
|
||||
mlp_output = self.mlp(self.post_attention_layernorm(hidden_states))
|
||||
return mlp_output + attn_output + hidden_states, None
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
hidden_states, residual = dropout_add_layer_norm(
|
||||
hidden_states,
|
||||
residual,
|
||||
@ -399,17 +455,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
||||
prefill,
|
||||
)
|
||||
|
||||
hidden_states = dropout_add_layer_norm(
|
||||
hidden_states,
|
||||
residual,
|
||||
self.final_layer_norm.weight,
|
||||
self.final_layer_norm.bias,
|
||||
0.0,
|
||||
self.final_layer_norm.eps,
|
||||
rowscale=None,
|
||||
prenorm=False,
|
||||
residual_in_fp32=False,
|
||||
)
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
|
||||
return hidden_states, past_key_values
|
||||
|
||||
@ -426,13 +472,13 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
||||
self.gpt_neox = FlashGPTNeoXModel(config, process_group)
|
||||
|
||||
if self.gpt_neox.tp_embeddings:
|
||||
self.embed_out = FusedDense(
|
||||
self.embed_out = nn.Linear(
|
||||
config.hidden_size,
|
||||
config.vocab_size // process_group.size(),
|
||||
bias=False,
|
||||
)
|
||||
else:
|
||||
self.embed_out = FusedDense(
|
||||
self.embed_out = nn.Linear(
|
||||
config.hidden_size, config.vocab_size, bias=False
|
||||
)
|
||||
|
||||
@ -448,142 +494,3 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
||||
input_ids, position_ids, cu_seqlens, max_s, past_key_values
|
||||
)
|
||||
return self.embed_out(hidden_states), present
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from transformers import AutoTokenizer
|
||||
from flash_attn.bert_padding import unpad_input
|
||||
|
||||
model = (
|
||||
FlashGPTNeoXForCausalLM.from_pretrained("EleutherAI/pythia-160m")
|
||||
.cuda()
|
||||
.to(torch.half)
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"EleutherAI/pythia-160m", padding_side="left"
|
||||
)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
tokenized_inputs = tokenizer(
|
||||
["What is this?\n\nA:\n\nThe answer to the problem?", "hello!"],
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
).to("cuda")
|
||||
|
||||
input_ids, indices, cu_seqlens, max_seqlen = unpad_input(
|
||||
tokenized_inputs["input_ids"].unsqueeze(-1), tokenized_inputs["attention_mask"]
|
||||
)
|
||||
|
||||
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 0)
|
||||
|
||||
unpad_position_ids = torch.gather(position_ids.view(-1).cuda(), 0, indices)
|
||||
|
||||
gen_input_ids = input_ids.squeeze(1).cuda().clone()
|
||||
gen_position_ids = unpad_position_ids.clone()
|
||||
gen_indices = indices.clone()
|
||||
gen_cu_seqlens = cu_seqlens.clone()
|
||||
gen_max_seqlen = max_seqlen
|
||||
|
||||
past_key_values = None
|
||||
|
||||
results = []
|
||||
with torch.no_grad():
|
||||
out, present, _ = model(
|
||||
gen_input_ids,
|
||||
gen_position_ids,
|
||||
gen_cu_seqlens,
|
||||
gen_max_seqlen,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
futures = []
|
||||
new_gen_cu_seqlens = [0]
|
||||
new_position_ids = []
|
||||
next_token_ids = []
|
||||
|
||||
for i in range(len(gen_cu_seqlens) - 1):
|
||||
start_index = gen_cu_seqlens[i]
|
||||
end_index = gen_cu_seqlens[i + 1]
|
||||
|
||||
seq_logits = out[start_index:end_index]
|
||||
next_token_id = torch.argmax(seq_logits[-1:], dim=1)
|
||||
next_token_ids.append(next_token_id)
|
||||
|
||||
sequence_length = end_index - start_index
|
||||
new_gen_cu_seqlens.append(new_gen_cu_seqlens[i] + sequence_length + 1)
|
||||
|
||||
seq_position_ids = gen_position_ids[start_index:end_index]
|
||||
new_position_ids.append(
|
||||
torch.concat([seq_position_ids, seq_position_ids[-1:] + 1])
|
||||
)
|
||||
|
||||
seq_present = present[:, start_index:end_index]
|
||||
future = torch.nn.functional.pad(seq_present, (0, 0, 0, 0, 0, 0, 0, 1))
|
||||
|
||||
futures.append(future)
|
||||
|
||||
past_key_values = torch.concat(futures, dim=1)
|
||||
new_position_ids = torch.concat(new_position_ids)
|
||||
new_gen_cu_seqlens = torch.tensor(
|
||||
new_gen_cu_seqlens, device=past_key_values.device, dtype=torch.int32
|
||||
)
|
||||
next_token_ids = torch.concat(next_token_ids)
|
||||
|
||||
gen_max_seqlen += 1
|
||||
|
||||
gen_input_ids = next_token_ids
|
||||
gen_position_ids = new_position_ids
|
||||
gen_cu_seqlens = new_gen_cu_seqlens
|
||||
|
||||
print(tokenizer.batch_decode(gen_input_ids))
|
||||
|
||||
for _ in range(40):
|
||||
out, present, _ = model(
|
||||
gen_input_ids,
|
||||
gen_position_ids,
|
||||
gen_cu_seqlens,
|
||||
gen_max_seqlen,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
futures = []
|
||||
new_gen_cu_seqlens = [0]
|
||||
new_position_ids = []
|
||||
next_token_ids = []
|
||||
for i in range(len(gen_cu_seqlens) - 1):
|
||||
start_index = gen_cu_seqlens[i]
|
||||
end_index = gen_cu_seqlens[i + 1]
|
||||
|
||||
seq_logits = out[i]
|
||||
next_token_id = torch.argmax(seq_logits.view(1, -1)[-1:], dim=1)
|
||||
next_token_ids.append(next_token_id)
|
||||
|
||||
sequence_length = end_index - start_index
|
||||
new_gen_cu_seqlens.append(new_gen_cu_seqlens[i] + sequence_length + 1)
|
||||
|
||||
seq_position_ids = gen_position_ids[start_index:end_index]
|
||||
new_position_ids.append(
|
||||
torch.concat([seq_position_ids, seq_position_ids[-1:] + 1])
|
||||
)
|
||||
|
||||
seq_present = present[:, start_index:end_index]
|
||||
future = torch.nn.functional.pad(seq_present, (0, 0, 0, 0, 0, 0, 0, 1))
|
||||
|
||||
futures.append(future)
|
||||
|
||||
past_key_values = torch.concat(futures, dim=1)
|
||||
new_position_ids = torch.concat(new_position_ids)
|
||||
new_gen_cu_seqlens = torch.tensor(
|
||||
new_gen_cu_seqlens, device=past_key_values.device, dtype=torch.int32
|
||||
)
|
||||
next_token_ids = torch.concat(next_token_ids)
|
||||
|
||||
gen_max_seqlen += 1
|
||||
|
||||
gen_input_ids = next_token_ids
|
||||
gen_position_ids = new_position_ids
|
||||
gen_cu_seqlens = new_gen_cu_seqlens
|
||||
|
||||
print(tokenizer.batch_decode(gen_input_ids))
|
||||
|
Loading…
Reference in New Issue
Block a user