mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
feat(server): cleanup flash neox loading (#139)
This commit is contained in:
parent
d6a93fe992
commit
678b2f3900
@ -450,8 +450,6 @@ class FlashNeoX(Model):
|
|||||||
next_batch_input_ids = next_batch_input_ids[0].view(1)
|
next_batch_input_ids = next_batch_input_ids[0].view(1)
|
||||||
next_batch_past_key_values = next_batch_past_key_values[0]
|
next_batch_past_key_values = next_batch_past_key_values[0]
|
||||||
|
|
||||||
print(next_batch_input_ids.shape)
|
|
||||||
|
|
||||||
next_batch = FlashNeoXBatch(
|
next_batch = FlashNeoXBatch(
|
||||||
batch_id=batch.batch_id,
|
batch_id=batch.batch_id,
|
||||||
requests=next_batch_requests,
|
requests=next_batch_requests,
|
||||||
@ -507,6 +505,7 @@ class FlashNeoXSharded(FlashNeoX):
|
|||||||
rank=self.rank,
|
rank=self.rank,
|
||||||
world_size=self.world_size,
|
world_size=self.world_size,
|
||||||
)
|
)
|
||||||
|
model.post_load_weights()
|
||||||
self.model = model.eval().to(dtype)
|
self.model = model.eval().to(dtype)
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(FlashNeoX, self).__init__(
|
super(FlashNeoX, self).__init__(
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
@ -24,13 +26,11 @@ class FastLinear(nn.Linear):
|
|||||||
dtype=None,
|
dtype=None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
|
super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
|
||||||
self.swap_dims = True
|
|
||||||
|
def transpose_weight(self):
|
||||||
|
self.weight = nn.Parameter(self.weight.T)
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
if self.swap_dims:
|
|
||||||
self.weight = nn.Parameter(self.weight.T)
|
|
||||||
self.swap_dims = False
|
|
||||||
|
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
return torch.addmm(self.bias, input, self.weight)
|
return torch.addmm(self.bias, input, self.weight)
|
||||||
return torch.matmul(input, self.weight)
|
return torch.matmul(input, self.weight)
|
||||||
@ -120,6 +120,10 @@ class TensorParallelEmbedding(nn.Embedding):
|
|||||||
self.min_id = self.tp_rank * block_size
|
self.min_id = self.tp_rank * block_size
|
||||||
self.max_id = (self.tp_rank + 1) * block_size
|
self.max_id = (self.tp_rank + 1) * block_size
|
||||||
|
|
||||||
|
# Additional entry that will map to zero
|
||||||
|
# Used for masking
|
||||||
|
self.null_idx = block_size
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
block_size,
|
block_size,
|
||||||
embedding_dim,
|
embedding_dim,
|
||||||
@ -133,15 +137,19 @@ class TensorParallelEmbedding(nn.Embedding):
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def add_null_idx(self):
|
||||||
|
"""Additional 0 entry used for masking"""
|
||||||
|
self.weight = nn.Parameter(F.pad(self.weight, (0, 0, 0, 1)))
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
# `0` if input is in the correct interval, else `1`
|
# default all out of bounds values to `self.null_idx` that will then be mapped to 0
|
||||||
input_mask = torch.logical_or(self.min_id > input, input >= self.max_id)
|
|
||||||
# translate for [0, self.max_id - self.min_id[
|
# translate for [0, self.max_id - self.min_id[
|
||||||
input = input - self.min_id
|
input = torch.where(
|
||||||
# default all out of bounds values to `0`
|
(self.min_id > input) | (input >= self.max_id),
|
||||||
input[input_mask] = 0
|
self.null_idx,
|
||||||
|
input - self.min_id,
|
||||||
|
)
|
||||||
out = super().forward(input)
|
out = super().forward(input)
|
||||||
out[input_mask] = 0.0
|
|
||||||
torch.distributed.all_reduce(out, group=self.process_group)
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -214,11 +222,9 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
hidden_size,
|
hidden_size,
|
||||||
process_group=process_group,
|
process_group=process_group,
|
||||||
)
|
)
|
||||||
self.swap_dims = True
|
|
||||||
|
|
||||||
# TODO: remove and swap dims when loading weights
|
def shuffle_qkv_dims(self):
|
||||||
def _swap_dims(self):
|
"""Swap dims to avoid an additional permute"""
|
||||||
"""Swap dims for the first inference to avoid an additional permute"""
|
|
||||||
self.query_key_value.weight = torch.nn.Parameter(
|
self.query_key_value.weight = torch.nn.Parameter(
|
||||||
self.query_key_value.weight.view(
|
self.query_key_value.weight.view(
|
||||||
self.num_heads, 3, self.head_size, self.hidden_size
|
self.num_heads, 3, self.head_size, self.hidden_size
|
||||||
@ -231,7 +237,6 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
.permute(1, 0, 2)
|
.permute(1, 0, 2)
|
||||||
.reshape(-1)
|
.reshape(-1)
|
||||||
)
|
)
|
||||||
self.swap_dims = False
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -244,9 +249,6 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
layer_past_present_indices,
|
layer_past_present_indices,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
):
|
):
|
||||||
if self.swap_dims:
|
|
||||||
self._swap_dims()
|
|
||||||
|
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
||||||
qkv_rot = self.rotary_emb(qkv, cos, sin)
|
qkv_rot = self.rotary_emb(qkv, cos, sin)
|
||||||
@ -329,7 +331,6 @@ class FlashMLP(nn.Module):
|
|||||||
hidden_size,
|
hidden_size,
|
||||||
process_group=process_group,
|
process_group=process_group,
|
||||||
)
|
)
|
||||||
self.heuristic = "auto"
|
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
@ -531,6 +532,25 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||||||
self.head_size = self.layers[0].attention.head_size
|
self.head_size = self.layers[0].attention.head_size
|
||||||
self.num_heads = self.layers[0].attention.num_heads
|
self.num_heads = self.layers[0].attention.num_heads
|
||||||
|
|
||||||
|
def post_load_weights(self):
|
||||||
|
if isinstance(self.embed_in, TensorParallelEmbedding):
|
||||||
|
self.embed_in.add_null_idx()
|
||||||
|
for layer in self.layers:
|
||||||
|
layer: FlashNeoXLayer
|
||||||
|
layer.attention.shuffle_qkv_dims()
|
||||||
|
layer.attention.query_key_value.transpose_weight()
|
||||||
|
layer.attention.dense.transpose_weight()
|
||||||
|
layer.mlp.dense_h_to_4h.transpose_weight()
|
||||||
|
layer.mlp.dense_4h_to_h.transpose_weight()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
|
model = super(FlashGPTNeoXModel, cls).from_pretrained(
|
||||||
|
pretrained_model_name_or_path, *model_args, **kwargs
|
||||||
|
)
|
||||||
|
model.post_load_weights()
|
||||||
|
return model
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
@ -627,6 +647,18 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
|||||||
config.hidden_size, config.vocab_size, bias=False
|
config.hidden_size, config.vocab_size, bias=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def post_load_weights(self):
|
||||||
|
self.gpt_neox.post_load_weights()
|
||||||
|
self.embed_out.transpose_weight()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
|
model = super(FlashGPTNeoXForCausalLM, cls).from_pretrained(
|
||||||
|
pretrained_model_name_or_path, *model_args, **kwargs
|
||||||
|
)
|
||||||
|
model.post_load_weights()
|
||||||
|
return model
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
|
Loading…
Reference in New Issue
Block a user