mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Add Sealion MPT Support
This commit is contained in:
parent
3ccb3bb0b5
commit
df04b28bfc
@ -28,7 +28,6 @@ EPS = 1e-5
|
|||||||
|
|
||||||
|
|
||||||
def load_col(config, prefix, weights, bias):
|
def load_col(config, prefix, weights, bias):
|
||||||
assert bias == False, NotImplementedError
|
|
||||||
assert config.quantize != "gptq", NotImplementedError
|
assert config.quantize != "gptq", NotImplementedError
|
||||||
slice_ = weights._get_slice(f"{prefix}.weight")
|
slice_ = weights._get_slice(f"{prefix}.weight")
|
||||||
rank = weights.process_group.rank()
|
rank = weights.process_group.rank()
|
||||||
@ -45,7 +44,26 @@ def load_col(config, prefix, weights, bias):
|
|||||||
if weight.dtype != torch.int32:
|
if weight.dtype != torch.int32:
|
||||||
weight = weight.to(dtype=weights.dtype)
|
weight = weight.to(dtype=weights.dtype)
|
||||||
weight = weight.to(device=weights.device)
|
weight = weight.to(device=weights.device)
|
||||||
bias = None
|
|
||||||
|
if bias:
|
||||||
|
bias_slice_ = weights._get_slice(f"{prefix}.bias")
|
||||||
|
bias_rank = weights.process_group.rank()
|
||||||
|
bias_size = weights.process_group.size()
|
||||||
|
|
||||||
|
bias_h = bias_slice_.get_shape()
|
||||||
|
bias_h = bias_h[0]
|
||||||
|
bias_block_size = bias_h // bias_size
|
||||||
|
|
||||||
|
bias_q_part = bias_slice_[bias_rank * bias_block_size : (bias_rank + 1) * bias_block_size]
|
||||||
|
bias_k_part = bias_slice_[bias_h + bias_rank * bias_block_size : bias_h + (bias_rank + 1) * bias_block_size]
|
||||||
|
bias_v_part = bias_slice_[2 * bias_h + bias_rank * bias_block_size : 2 * bias_h + (bias_rank + 1) * bias_block_size]
|
||||||
|
|
||||||
|
bias = torch.cat([bias_q_part, bias_k_part, bias_v_part], dim=0)
|
||||||
|
if bias.dtype != torch.int32:
|
||||||
|
bias = bias.to(dtype=weights.dtype)
|
||||||
|
bias = bias.to(device=weights.device)
|
||||||
|
else:
|
||||||
|
bias = None
|
||||||
linear = get_linear(weight, bias, config.quantize)
|
linear = get_linear(weight, bias, config.quantize)
|
||||||
return TensorParallelColumnLinear(linear)
|
return TensorParallelColumnLinear(linear)
|
||||||
|
|
||||||
@ -330,7 +348,15 @@ class MultiheadAttention(nn.Module):
|
|||||||
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
|
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
|
||||||
)
|
)
|
||||||
if self.qk_ln:
|
if self.qk_ln:
|
||||||
raise NotImplementedError("qk_ln is not supported")
|
if weights.process_group.size() > 1:
|
||||||
|
raise NotImplementedError("qk_ln is not supported for number of shards > 1")
|
||||||
|
bias = not config.no_bias
|
||||||
|
hidden_size = config.d_model
|
||||||
|
head_dim = hidden_size // self.n_heads
|
||||||
|
|
||||||
|
norm_class = LPLayerNorm
|
||||||
|
self.q_ln = norm_class(d_model, bias=bias, prefix=f"{prefix}.q_ln", weights=weights)
|
||||||
|
self.k_ln = norm_class(self.n_heads * head_dim, prefix=f"{prefix}.k_ln", weights=weights)
|
||||||
if self.attn_impl == "flash":
|
if self.attn_impl == "flash":
|
||||||
self.attn_fn = flash_attn_fn
|
self.attn_fn = flash_attn_fn
|
||||||
elif self.attn_impl == "triton":
|
elif self.attn_impl == "triton":
|
||||||
@ -581,12 +607,20 @@ class MPTBlock(nn.Module):
|
|||||||
f"""Not implemented attn {config.attn_config["attn_type"]}"""
|
f"""Not implemented attn {config.attn_config["attn_type"]}"""
|
||||||
)
|
)
|
||||||
resid_pdrop = config.resid_pdrop
|
resid_pdrop = config.resid_pdrop
|
||||||
self.norm_1 = nn.LayerNorm.load_no_bias(
|
if config.no_bias:
|
||||||
prefix=f"{prefix}.norm_1", weights=weights, eps=EPS
|
self.norm_1 = nn.LayerNorm.load_no_bias(
|
||||||
)
|
prefix=f"{prefix}.norm_1", weights=weights, eps=EPS
|
||||||
self.norm_2 = nn.LayerNorm.load_no_bias(
|
)
|
||||||
prefix=f"{prefix}.norm_2", weights=weights, eps=EPS
|
self.norm_2 = nn.LayerNorm.load_no_bias(
|
||||||
)
|
prefix=f"{prefix}.norm_2", weights=weights, eps=EPS
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.norm_1 = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.norm_1", weights=weights, eps=EPS
|
||||||
|
)
|
||||||
|
self.norm_2 = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.norm_2", weights=weights, eps=EPS
|
||||||
|
)
|
||||||
self.attn = MultiheadAttention(config, prefix=f"{prefix}.attn", weights=weights)
|
self.attn = MultiheadAttention(config, prefix=f"{prefix}.attn", weights=weights)
|
||||||
self.ffn = MPTMLP(config, prefix=f"{prefix}.ffn", weights=weights)
|
self.ffn = MPTMLP(config, prefix=f"{prefix}.ffn", weights=weights)
|
||||||
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
|
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
|
||||||
@ -616,15 +650,13 @@ class MPTBlock(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def _cast_if_autocast_enabled(tensor):
|
def _cast_if_autocast_enabled(tensor):
|
||||||
if torch.is_autocast_enabled():
|
if tensor.device.type == "cuda":
|
||||||
if tensor.device.type == "cuda":
|
dtype = torch.get_autocast_gpu_dtype()
|
||||||
dtype = torch.get_autocast_gpu_dtype()
|
elif tensor.device.type == "cpu":
|
||||||
elif tensor.device.type == "cpu":
|
dtype = torch.get_autocast_cpu_dtype()
|
||||||
dtype = torch.get_autocast_cpu_dtype()
|
else:
|
||||||
else:
|
raise NotImplementedError()
|
||||||
raise NotImplementedError()
|
return tensor.to(dtype=dtype)
|
||||||
return tensor.to(dtype=dtype)
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
|
|
||||||
class LPLayerNorm(torch.nn.LayerNorm):
|
class LPLayerNorm(torch.nn.LayerNorm):
|
||||||
@ -635,6 +667,9 @@ class LPLayerNorm(torch.nn.LayerNorm):
|
|||||||
elementwise_affine=True,
|
elementwise_affine=True,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
|
bias: Optional[bool] = True,
|
||||||
|
prefix=None,
|
||||||
|
weights=None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
normalized_shape=normalized_shape,
|
normalized_shape=normalized_shape,
|
||||||
@ -642,7 +677,11 @@ class LPLayerNorm(torch.nn.LayerNorm):
|
|||||||
elementwise_affine=elementwise_affine,
|
elementwise_affine=elementwise_affine,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
bias=bias,
|
||||||
)
|
)
|
||||||
|
if weights is not None:
|
||||||
|
self.weight = nn.Parameter(weights.get_tensor(f"{prefix}.weight"))
|
||||||
|
self.bias = nn.Parameter(weights.get_tensor(f"{prefix}.bias"))
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
module_device = x.device
|
module_device = x.device
|
||||||
@ -755,20 +794,23 @@ class MPTModel(MPTPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.wte = TensorParallelEmbedding("transformer.wte", weights)
|
self.wte = TensorParallelEmbedding("transformer.wte", weights)
|
||||||
|
|
||||||
if not self.alibi:
|
if not self.alibi:
|
||||||
# self.wpe = torch.nn.Embedding(
|
self.wpe = TensorParallelEmbedding("transformer.wpe", weights)
|
||||||
# config.max_seq_len, config.d_model, device=config.init_device
|
|
||||||
# )
|
|
||||||
raise RuntimeError("no alibi no supported")
|
|
||||||
self.blocks = nn.ModuleList(
|
self.blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
MPTBlock(config, prefix=f"transformer.blocks.{i}", weights=weights)
|
MPTBlock(config, prefix=f"transformer.blocks.{i}", weights=weights)
|
||||||
for i in range(config.n_layers)
|
for i in range(config.n_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.norm_f = nn.LayerNorm.load_no_bias(
|
if config.no_bias:
|
||||||
prefix="transformer.norm_f", weights=weights, eps=EPS
|
self.norm_f = nn.LayerNorm.load_no_bias(
|
||||||
)
|
prefix="transformer.norm_f", weights=weights, eps=EPS
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.norm_f = nn.LayerNorm.load(
|
||||||
|
prefix="transformer.norm_f", weights=weights, eps=EPS
|
||||||
|
)
|
||||||
self.is_causal = not self.prefix_lm
|
self.is_causal = not self.prefix_lm
|
||||||
self._attn_bias_initialized = False
|
self._attn_bias_initialized = False
|
||||||
self.attn_bias = None
|
self.attn_bias = None
|
||||||
@ -787,13 +829,15 @@ class MPTModel(MPTPreTrainedModel):
|
|||||||
if config.verbose:
|
if config.verbose:
|
||||||
warnings.warn(f"Removing bias ({module.bias}) from {module}.")
|
warnings.warn(f"Removing bias ({module.bias}) from {module}.")
|
||||||
module.register_parameter("bias", None)
|
module.register_parameter("bias", None)
|
||||||
if config.verbose and config.verbose > 2:
|
if hasattr(self.config, "verbose"):
|
||||||
print(self)
|
if config.verbose and config.verbose > 2:
|
||||||
|
print(self)
|
||||||
if "verbose" not in self.config.init_config:
|
if "verbose" not in self.config.init_config:
|
||||||
self.config.init_config["verbose"] = self.config.verbose
|
self.config.init_config["verbose"] = self.config.verbose
|
||||||
if self.config.init_config["verbose"] > 1:
|
if self.config.init_config["verbose"] > 1:
|
||||||
init_fn_name = self.config.init_config["name"]
|
init_fn_name = self.config.init_config["name"]
|
||||||
warnings.warn(f"Using {init_fn_name} initialization.")
|
warnings.warn(f"Using {init_fn_name} initialization.")
|
||||||
|
self.embedding_fraction = config.embedding_fraction
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _attn_bias(
|
def _attn_bias(
|
||||||
@ -983,6 +1027,11 @@ class MPTModel(MPTPreTrainedModel):
|
|||||||
)
|
)
|
||||||
pos_emb = self.wpe(pos)
|
pos_emb = self.wpe(pos)
|
||||||
x = tok_emb + pos_emb
|
x = tok_emb + pos_emb
|
||||||
|
if self.embedding_fraction != 1:
|
||||||
|
x = (
|
||||||
|
x * self.embedding_fraction
|
||||||
|
+ x.detach() * (1 - self.embedding_fraction)
|
||||||
|
)
|
||||||
(attn_bias, attention_mask) = self._attn_bias(
|
(attn_bias, attention_mask) = self._attn_bias(
|
||||||
device=x.device,
|
device=x.device,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
|
Loading…
Reference in New Issue
Block a user