mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
faster
This commit is contained in:
parent
ead19abb0e
commit
cdc70f4c23
@ -13,7 +13,7 @@ from text_generation_server.models.flash_neox_modeling import (
|
|||||||
FlashGPTNeoXForCausalLM,
|
FlashGPTNeoXForCausalLM,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear
|
TensorParallelColumnLinear,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.types import (
|
from text_generation_server.models.types import (
|
||||||
Batch,
|
Batch,
|
||||||
@ -115,7 +115,6 @@ class FlashNeoXBatch(Batch):
|
|||||||
def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
|
def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.requests)
|
return len(self.requests)
|
||||||
|
|
||||||
@ -259,7 +258,9 @@ class FlashNeoX(Model):
|
|||||||
|
|
||||||
if stop:
|
if stop:
|
||||||
# Decode generated tokens
|
# Decode generated tokens
|
||||||
output_text = self.decode(all_input_ids[-stopping_criteria.current_tokens :])
|
output_text = self.decode(
|
||||||
|
all_input_ids[-stopping_criteria.current_tokens :]
|
||||||
|
)
|
||||||
# Get seed
|
# Get seed
|
||||||
if isinstance(next_token_chooser.choice, Sampling):
|
if isinstance(next_token_chooser.choice, Sampling):
|
||||||
seed = next_token_chooser.choice.seed
|
seed = next_token_chooser.choice.seed
|
||||||
|
@ -9,43 +9,35 @@ from transformers.models.gpt_neox import GPTNeoXConfig
|
|||||||
|
|
||||||
import rotary_emb
|
import rotary_emb
|
||||||
import flash_attn_cuda
|
import flash_attn_cuda
|
||||||
|
import dropout_layer_norm
|
||||||
|
|
||||||
|
import fused_dense_lib as fused_dense_cuda
|
||||||
|
|
||||||
from flash_attn.flash_attn_interface import (
|
|
||||||
flash_attn_unpadded_qkvpacked_func,
|
|
||||||
flash_attn_unpadded_kvpacked_func,
|
|
||||||
)
|
|
||||||
# from flash_attn.ops.fused_dense import (
|
|
||||||
# FusedDense,
|
|
||||||
# ColumnParallelLinear,
|
|
||||||
# RowParallelLinear,
|
|
||||||
# fused_mlp_func,
|
|
||||||
# )
|
|
||||||
from flash_attn.layers.rotary import RotaryEmbedding, apply_rotary_emb_qkv_
|
from flash_attn.layers.rotary import RotaryEmbedding, apply_rotary_emb_qkv_
|
||||||
|
|
||||||
|
|
||||||
# from flash_attn.ops.layer_norm import dropout_add_layer_norm
|
|
||||||
|
|
||||||
|
|
||||||
class TensorParallelColumnLinear(nn.Linear):
|
class TensorParallelColumnLinear(nn.Linear):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_features,
|
in_features,
|
||||||
out_features,
|
out_features,
|
||||||
process_group: torch.distributed.ProcessGroup,
|
process_group: torch.distributed.ProcessGroup,
|
||||||
bias=True,
|
bias=True,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
):
|
):
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
self.tp_world_size = process_group.size()
|
self.tp_world_size = process_group.size()
|
||||||
assert out_features % self.tp_world_size == 0
|
assert out_features % self.tp_world_size == 0
|
||||||
out_features = out_features // self.tp_world_size
|
out_features = out_features // self.tp_world_size
|
||||||
|
|
||||||
super().__init__(in_features=in_features,
|
super().__init__(
|
||||||
out_features=out_features,
|
in_features=in_features,
|
||||||
bias=bias,
|
out_features=out_features,
|
||||||
device=device,
|
bias=bias,
|
||||||
dtype=dtype)
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def linear(input, weight, bias):
|
def linear(input, weight, bias):
|
||||||
@ -57,24 +49,26 @@ class TensorParallelColumnLinear(nn.Linear):
|
|||||||
|
|
||||||
class TensorParallelRowLinear(nn.Linear):
|
class TensorParallelRowLinear(nn.Linear):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_features,
|
in_features,
|
||||||
out_features,
|
out_features,
|
||||||
process_group: torch.distributed.ProcessGroup,
|
process_group: torch.distributed.ProcessGroup,
|
||||||
bias=True,
|
bias=True,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
):
|
):
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
self.tp_world_size = process_group.size()
|
self.tp_world_size = process_group.size()
|
||||||
assert in_features % self.tp_world_size == 0
|
assert in_features % self.tp_world_size == 0
|
||||||
in_features = in_features // self.tp_world_size
|
in_features = in_features // self.tp_world_size
|
||||||
|
|
||||||
super().__init__(in_features=in_features,
|
super().__init__(
|
||||||
out_features=out_features,
|
in_features=in_features,
|
||||||
bias=bias,
|
out_features=out_features,
|
||||||
device=device,
|
bias=bias,
|
||||||
dtype=dtype)
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def linear(input, weight, bias):
|
def linear(input, weight, bias):
|
||||||
@ -89,18 +83,18 @@ class TensorParallelRowLinear(nn.Linear):
|
|||||||
|
|
||||||
class TensorParallelEmbedding(nn.Embedding):
|
class TensorParallelEmbedding(nn.Embedding):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_embeddings,
|
num_embeddings,
|
||||||
embedding_dim,
|
embedding_dim,
|
||||||
process_group: torch.distributed.ProcessGroup,
|
process_group: torch.distributed.ProcessGroup,
|
||||||
padding_idx=None,
|
padding_idx=None,
|
||||||
max_norm=None,
|
max_norm=None,
|
||||||
norm_type=2.0,
|
norm_type=2.0,
|
||||||
scale_grad_by_freq=False,
|
scale_grad_by_freq=False,
|
||||||
sparse=False,
|
sparse=False,
|
||||||
_weight=None,
|
_weight=None,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None
|
dtype=None,
|
||||||
):
|
):
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
self.tp_rank = process_group.rank()
|
self.tp_rank = process_group.rank()
|
||||||
@ -115,15 +109,27 @@ class TensorParallelEmbedding(nn.Embedding):
|
|||||||
self.min_id = self.tp_rank * block_size
|
self.min_id = self.tp_rank * block_size
|
||||||
self.max_id = (self.tp_rank + 1) * block_size
|
self.max_id = (self.tp_rank + 1) * block_size
|
||||||
|
|
||||||
super().__init__(block_size, embedding_dim, padding_idx=padding_idx, max_norm=max_norm, norm_type=norm_type,
|
super().__init__(
|
||||||
scale_grad_by_freq=scale_grad_by_freq, sparse=sparse, _weight=_weight, device=device,
|
block_size,
|
||||||
dtype=dtype)
|
embedding_dim,
|
||||||
|
padding_idx=padding_idx,
|
||||||
|
max_norm=max_norm,
|
||||||
|
norm_type=norm_type,
|
||||||
|
scale_grad_by_freq=scale_grad_by_freq,
|
||||||
|
sparse=sparse,
|
||||||
|
_weight=_weight,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
# Sanity check
|
# Sanity check
|
||||||
if torch.any(torch.logical_or(0 > input, input >= self.original_num_embeddings)):
|
if torch.any(
|
||||||
|
torch.logical_or(0 > input, input >= self.original_num_embeddings)
|
||||||
|
):
|
||||||
raise IndexError(
|
raise IndexError(
|
||||||
f"Input is required to be in [0, {self.original_num_embeddings}[, got min: {torch.min(input)} and max: {torch.max(input)}")
|
f"Input is required to be in [0, {self.original_num_embeddings}[, got min: {torch.min(input)} and max: {torch.max(input)}"
|
||||||
|
)
|
||||||
|
|
||||||
# `0` if input is in the correct interval, else `1`
|
# `0` if input is in the correct interval, else `1`
|
||||||
input_mask = torch.logical_or(self.min_id > input, input >= self.max_id)
|
input_mask = torch.logical_or(self.min_id > input, input >= self.max_id)
|
||||||
@ -141,8 +147,11 @@ class PositionRotaryEmbedding(RotaryEmbedding):
|
|||||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||||
# Reset the tables if the sequence length has changed,
|
# Reset the tables if the sequence length has changed,
|
||||||
# or if we're on a new device (possibly due to tracing for instance)
|
# or if we're on a new device (possibly due to tracing for instance)
|
||||||
if (seqlen > self._seq_len_cached or self._cos_cached.device != device
|
if (
|
||||||
or self._cos_cached.dtype != dtype):
|
seqlen > self._seq_len_cached
|
||||||
|
or self._cos_cached.device != device
|
||||||
|
or self._cos_cached.dtype != dtype
|
||||||
|
):
|
||||||
self._seq_len_cached = seqlen
|
self._seq_len_cached = seqlen
|
||||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||||
# Don't do einsum, it converts fp32 to fp16
|
# Don't do einsum, it converts fp32 to fp16
|
||||||
@ -152,8 +161,12 @@ class PositionRotaryEmbedding(RotaryEmbedding):
|
|||||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||||
else:
|
else:
|
||||||
power = ((torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
|
power = (
|
||||||
- seqlen // 2) / self.scale_base)
|
torch.arange(
|
||||||
|
seqlen, dtype=self.scale.dtype, device=self.scale.device
|
||||||
|
)
|
||||||
|
- seqlen // 2
|
||||||
|
) / self.scale_base
|
||||||
scale = self.scale.to(device=power.device) ** power.unsqueeze(1)
|
scale = self.scale.to(device=power.device) ** power.unsqueeze(1)
|
||||||
# We want the multiplication by scale to happen in fp32
|
# We want the multiplication by scale to happen in fp32
|
||||||
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
||||||
@ -164,29 +177,33 @@ class PositionRotaryEmbedding(RotaryEmbedding):
|
|||||||
def forward(self, qkv: torch.Tensor, position_ids: torch.Tensor, max_s: int):
|
def forward(self, qkv: torch.Tensor, position_ids: torch.Tensor, max_s: int):
|
||||||
self._update_cos_sin_cache(qkv.dtype, qkv.device, max_s)
|
self._update_cos_sin_cache(qkv.dtype, qkv.device, max_s)
|
||||||
|
|
||||||
q1, q2, k1, k2, cos, sin = _prepare_rotary(qkv, self._cos_cached, self._sin_cached, position_ids)
|
q1, q2, k1, k2, cos, sin = _prepare_rotary(
|
||||||
|
qkv, self._cos_cached, self._sin_cached, position_ids
|
||||||
|
)
|
||||||
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
|
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
|
||||||
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
|
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
|
||||||
return qkv
|
return qkv
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def _prepare_rotary(qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor):
|
def _prepare_rotary(
|
||||||
|
qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor
|
||||||
|
):
|
||||||
cos = torch.index_select(cos, 0, position_ids)
|
cos = torch.index_select(cos, 0, position_ids)
|
||||||
sin = torch.index_select(sin, 0, position_ids)
|
sin = torch.index_select(sin, 0, position_ids)
|
||||||
|
|
||||||
rotary_dim = cos.shape[-1]
|
rotary_dim = cos.shape[-1]
|
||||||
q1 = qkv[:, 0, :, :rotary_dim]
|
q1 = qkv[:, 0, :, :rotary_dim]
|
||||||
q2 = qkv[:, 0, :, rotary_dim:2*rotary_dim]
|
q2 = qkv[:, 0, :, rotary_dim : 2 * rotary_dim]
|
||||||
k1 = qkv[:, 1, :, :rotary_dim]
|
k1 = qkv[:, 1, :, :rotary_dim]
|
||||||
k2 = qkv[:, 1, :, rotary_dim: 2*rotary_dim]
|
k2 = qkv[:, 1, :, rotary_dim : 2 * rotary_dim]
|
||||||
|
|
||||||
return q1, q2, k1, k2, cos.unsqueeze(1), sin.unsqueeze(1)
|
return q1, q2, k1, k2, cos.unsqueeze(1), sin.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
class FlashNeoxAttention(torch.nn.Module):
|
class FlashNeoxAttention(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, num_heads, hidden_size, rotary_pct, rotary_emb_base, process_group=None
|
self, num_heads, hidden_size, rotary_pct, rotary_emb_base, process_group=None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
@ -216,17 +233,21 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
|
|
||||||
def _swap_dims(self):
|
def _swap_dims(self):
|
||||||
self.query_key_value.weight = torch.nn.Parameter(
|
self.query_key_value.weight = torch.nn.Parameter(
|
||||||
self.query_key_value.weight.view(self.num_heads, 3, self.head_size, self.hidden_size)
|
self.query_key_value.weight.view(
|
||||||
.permute(1, 0, 2, 3).reshape(-1, self.hidden_size)
|
self.num_heads, 3, self.head_size, self.hidden_size
|
||||||
|
)
|
||||||
|
.permute(1, 0, 2, 3)
|
||||||
|
.reshape(-1, self.hidden_size)
|
||||||
)
|
)
|
||||||
self.query_key_value.bias = torch.nn.Parameter(
|
self.query_key_value.bias = torch.nn.Parameter(
|
||||||
self.query_key_value.bias.view(self.num_heads, 3, self.head_size)
|
self.query_key_value.bias.view(self.num_heads, 3, self.head_size)
|
||||||
.permute(1, 0, 2).reshape(-1)
|
.permute(1, 0, 2)
|
||||||
|
.reshape(-1)
|
||||||
)
|
)
|
||||||
self.swap_dims = True
|
self.swap_dims = True
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, hidden_states, position_ids, cu_seqlens, max_s, layer_past, prefill
|
self, hidden_states, position_ids, cu_seqlens, max_s, layer_past, prefill
|
||||||
):
|
):
|
||||||
if not self.swap_dims:
|
if not self.swap_dims:
|
||||||
self._swap_dims()
|
self._swap_dims()
|
||||||
@ -240,9 +261,21 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
|
|
||||||
attn_output = torch.empty_like(qkv[:, 0])
|
attn_output = torch.empty_like(qkv[:, 0])
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
qkv[:, 0], qkv[:, 1], qkv[:, 2], attn_output, cu_seqlens, cu_seqlens, max_s, max_s, 0.0,
|
qkv[:, 0],
|
||||||
|
qkv[:, 1],
|
||||||
|
qkv[:, 2],
|
||||||
|
attn_output,
|
||||||
|
cu_seqlens,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
max_s,
|
||||||
|
0.0,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
False, True, False, 0, None
|
False,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
query = qkv_rot[:, 0]
|
query = qkv_rot[:, 0]
|
||||||
@ -250,12 +283,21 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
|
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
query, layer_past[:, 0], layer_past[:, 1], attn_output,
|
query,
|
||||||
torch.arange(len(cu_seqlens), dtype=torch.int32).to(
|
layer_past[:, 0],
|
||||||
query.device
|
layer_past[:, 1],
|
||||||
), cu_seqlens, torch.tensor(1, dtype=torch.int32).to(query.device), max_s, 0.0,
|
attn_output,
|
||||||
|
torch.arange(len(cu_seqlens), dtype=torch.int32).to(query.device),
|
||||||
|
cu_seqlens,
|
||||||
|
torch.tensor(1, dtype=torch.int32).to(query.device),
|
||||||
|
max_s,
|
||||||
|
0.0,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
False, False, False, 0, None
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
@ -264,11 +306,11 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
class FlashMLP(nn.Module):
|
class FlashMLP(nn.Module):
|
||||||
def __init__(self, act, hidden_size, intermediate_size, process_group=None):
|
def __init__(self, act, hidden_size, intermediate_size, process_group=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert "gelu" in act
|
if "gelu" in act:
|
||||||
# if "gelu" in act:
|
act = "gelu_approx"
|
||||||
# act = "gelu_approx"
|
assert act in ["gelu_approx", "relu"]
|
||||||
# assert act in ["gelu_approx", "relu"]
|
self.is_gelu = act == "gelu_approx"
|
||||||
self.act = lambda x: F.gelu(x, approximate="tanh")
|
# self.act = lambda x: F.gelu(x, approximate="tanh")
|
||||||
|
|
||||||
if process_group is None:
|
if process_group is None:
|
||||||
self.dense_h_to_4h = nn.Linear(hidden_size, intermediate_size)
|
self.dense_h_to_4h = nn.Linear(hidden_size, intermediate_size)
|
||||||
@ -288,24 +330,34 @@ class FlashMLP(nn.Module):
|
|||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
hidden_states = self.dense_h_to_4h(hidden_states)
|
hidden_states, *rest = fused_dense_cuda.linear_act_forward(
|
||||||
hidden_states = self.act(hidden_states)
|
hidden_states,
|
||||||
hidden_states = self.dense_4h_to_h(hidden_states)
|
self.dense_h_to_4h.weight,
|
||||||
return hidden_states
|
self.dense_h_to_4h.bias,
|
||||||
|
self.is_gelu,
|
||||||
|
False,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
return self.dense_4h_to_h(hidden_states)
|
||||||
|
#
|
||||||
|
# hidden_states = self.dense_h_to_4h(hidden_states)
|
||||||
|
# hidden_states = self.act(hidden_states)
|
||||||
|
# hidden_states = self.dense_4h_to_h(hidden_states)
|
||||||
|
# return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class FlashNeoXLayer(nn.Module):
|
class FlashNeoXLayer(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_heads,
|
num_heads,
|
||||||
act,
|
act,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
intermediate_size,
|
intermediate_size,
|
||||||
rotary_pct,
|
rotary_pct,
|
||||||
rotary_emb_base,
|
rotary_emb_base,
|
||||||
layer_norm_eps,
|
layer_norm_eps,
|
||||||
use_parallel_residual,
|
use_parallel_residual,
|
||||||
process_group=None,
|
process_group=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.use_parallel_residual = use_parallel_residual
|
self.use_parallel_residual = use_parallel_residual
|
||||||
@ -317,51 +369,97 @@ class FlashNeoXLayer(nn.Module):
|
|||||||
self.mlp = FlashMLP(act, hidden_size, intermediate_size, process_group)
|
self.mlp = FlashMLP(act, hidden_size, intermediate_size, process_group)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
residual,
|
residual,
|
||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
prefill,
|
prefill,
|
||||||
):
|
):
|
||||||
if self.use_parallel_residual:
|
if self.use_parallel_residual:
|
||||||
attn_output = self.attention(
|
ln1_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
||||||
self.input_layernorm(hidden_states), position_ids, cu_seqlens, max_s, layer_past, prefill
|
hidden_states,
|
||||||
|
None,
|
||||||
|
self.input_layernorm.weight,
|
||||||
|
self.input_layernorm.bias,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
0.0,
|
||||||
|
self.input_layernorm.eps,
|
||||||
|
1.0,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
)
|
)
|
||||||
|
|
||||||
mlp_output = self.mlp(self.post_attention_layernorm(hidden_states))
|
attn_output = self.attention(
|
||||||
return mlp_output + attn_output + hidden_states, None
|
ln1_hidden_states, position_ids, cu_seqlens, max_s, layer_past, prefill
|
||||||
|
)
|
||||||
|
|
||||||
|
ln2_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
||||||
|
hidden_states,
|
||||||
|
None,
|
||||||
|
self.post_attention_layernorm.weight,
|
||||||
|
self.post_attention_layernorm.bias,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
0.0,
|
||||||
|
self.post_attention_layernorm.eps,
|
||||||
|
1.0,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
|
||||||
|
mlp_output = self.mlp(ln2_hidden_states)
|
||||||
|
return mlp_output + attn_output + hidden_states, None
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
||||||
hidden_states, residual = dropout_add_layer_norm(
|
|
||||||
hidden_states,
|
hidden_states,
|
||||||
residual,
|
residual,
|
||||||
self.input_layernorm.weight,
|
self.input_layernorm.weight,
|
||||||
self.input_layernorm.bias,
|
self.input_layernorm.bias,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
0.0,
|
0.0,
|
||||||
self.input_layernorm.eps,
|
self.input_layernorm.eps,
|
||||||
rowscale=None,
|
1.0,
|
||||||
prenorm=True,
|
0,
|
||||||
residual_in_fp32=True,
|
None,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = self.attention(
|
hidden_states = self.attention(
|
||||||
hidden_states, position_ids, cu_seqlens, max_s, layer_past, prefill
|
hidden_states, position_ids, cu_seqlens, max_s, layer_past, prefill
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, residual = dropout_add_layer_norm(
|
hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
residual,
|
residual,
|
||||||
self.post_attention_layernorm.weight,
|
self.post_attention_layernorm.weight,
|
||||||
self.post_attention_layernorm.bias,
|
self.post_attention_layernorm.bias,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
0.0,
|
0.0,
|
||||||
self.post_attention_layernorm.eps,
|
self.post_attention_layernorm.eps,
|
||||||
rowscale=None,
|
1.0,
|
||||||
prenorm=True,
|
0,
|
||||||
residual_in_fp32=True,
|
None,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
)
|
)
|
||||||
|
|
||||||
mlp_output = self.mlp(hidden_states)
|
mlp_output = self.mlp(hidden_states)
|
||||||
@ -421,12 +519,12 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||||||
self.num_heads = self.layers[0].attention.num_heads
|
self.num_heads = self.layers[0].attention.num_heads
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_s,
|
max_s,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
):
|
):
|
||||||
hidden_states = self.embed_in(input_ids)
|
hidden_states = self.embed_in(input_ids)
|
||||||
|
|
||||||
@ -483,12 +581,12 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_s,
|
max_s,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
):
|
):
|
||||||
hidden_states, present = self.gpt_neox(
|
hidden_states, present = self.gpt_neox(
|
||||||
input_ids, position_ids, cu_seqlens, max_s, past_key_values
|
input_ids, position_ids, cu_seqlens, max_s, past_key_values
|
||||||
|
Loading…
Reference in New Issue
Block a user