mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Fixing sealion support.
This commit is contained in:
parent
10fce5bffd
commit
29a4baea59
@ -348,15 +348,12 @@ class MultiheadAttention(nn.Module):
|
|||||||
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
|
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
|
||||||
)
|
)
|
||||||
if self.qk_ln:
|
if self.qk_ln:
|
||||||
if weights.process_group.size() > 1:
|
|
||||||
raise NotImplementedError("qk_ln is not supported for number of shards > 1")
|
|
||||||
bias = not config.no_bias
|
bias = not config.no_bias
|
||||||
hidden_size = config.d_model
|
hidden_size = config.d_model
|
||||||
head_dim = hidden_size // self.n_heads
|
head_dim = hidden_size // self.n_heads
|
||||||
|
|
||||||
norm_class = LPLayerNorm
|
self.q_ln = LPLayerNorm(d_model, bias=bias, prefix=f"{prefix}.q_ln", weights=weights)
|
||||||
self.q_ln = norm_class(d_model, bias=bias, prefix=f"{prefix}.q_ln", weights=weights)
|
self.k_ln = LPLayerNorm(self.n_heads * head_dim, prefix=f"{prefix}.k_ln", weights=weights)
|
||||||
self.k_ln = norm_class(self.n_heads * head_dim, prefix=f"{prefix}.k_ln", weights=weights)
|
|
||||||
if self.attn_impl == "flash":
|
if self.attn_impl == "flash":
|
||||||
self.attn_fn = flash_attn_fn
|
self.attn_fn = flash_attn_fn
|
||||||
elif self.attn_impl == "triton":
|
elif self.attn_impl == "triton":
|
||||||
@ -650,6 +647,7 @@ class MPTBlock(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def _cast_if_autocast_enabled(tensor):
|
def _cast_if_autocast_enabled(tensor):
|
||||||
|
if torch.is_autocast_enabled():
|
||||||
if tensor.device.type == "cuda":
|
if tensor.device.type == "cuda":
|
||||||
dtype = torch.get_autocast_gpu_dtype()
|
dtype = torch.get_autocast_gpu_dtype()
|
||||||
elif tensor.device.type == "cpu":
|
elif tensor.device.type == "cpu":
|
||||||
@ -657,6 +655,7 @@ def _cast_if_autocast_enabled(tensor):
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
return tensor.to(dtype=dtype)
|
return tensor.to(dtype=dtype)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
class LPLayerNorm(torch.nn.LayerNorm):
|
class LPLayerNorm(torch.nn.LayerNorm):
|
||||||
@ -680,8 +679,11 @@ class LPLayerNorm(torch.nn.LayerNorm):
|
|||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
if weights is not None:
|
if weights is not None:
|
||||||
self.weight = nn.Parameter(weights.get_tensor(f"{prefix}.weight"))
|
self.weight = nn.Parameter(weights.get_sharded(f"{prefix}.weight", dim=0))
|
||||||
self.bias = nn.Parameter(weights.get_tensor(f"{prefix}.bias"))
|
if bias:
|
||||||
|
self.bias = nn.Parameter(weights.get_sharded(f"{prefix}.bias", dim=0))
|
||||||
|
self.normalized_shape = self.weight.shape
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
module_device = x.device
|
module_device = x.device
|
||||||
@ -837,7 +839,6 @@ class MPTModel(MPTPreTrainedModel):
|
|||||||
if self.config.init_config["verbose"] > 1:
|
if self.config.init_config["verbose"] > 1:
|
||||||
init_fn_name = self.config.init_config["name"]
|
init_fn_name = self.config.init_config["name"]
|
||||||
warnings.warn(f"Using {init_fn_name} initialization.")
|
warnings.warn(f"Using {init_fn_name} initialization.")
|
||||||
self.embedding_fraction = config.embedding_fraction
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _attn_bias(
|
def _attn_bias(
|
||||||
@ -1027,11 +1028,6 @@ class MPTModel(MPTPreTrainedModel):
|
|||||||
)
|
)
|
||||||
pos_emb = self.wpe(pos)
|
pos_emb = self.wpe(pos)
|
||||||
x = tok_emb + pos_emb
|
x = tok_emb + pos_emb
|
||||||
if self.embedding_fraction != 1:
|
|
||||||
x = (
|
|
||||||
x * self.embedding_fraction
|
|
||||||
+ x.detach() * (1 - self.embedding_fraction)
|
|
||||||
)
|
|
||||||
(attn_bias, attention_mask) = self._attn_bias(
|
(attn_bias, attention_mask) = self._attn_bias(
|
||||||
device=x.device,
|
device=x.device,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
|
Loading…
Reference in New Issue
Block a user