mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Fixing sealion support.
This commit is contained in:
parent
10fce5bffd
commit
29a4baea59
@ -348,15 +348,12 @@ class MultiheadAttention(nn.Module):
|
||||
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
|
||||
)
|
||||
if self.qk_ln:
|
||||
if weights.process_group.size() > 1:
|
||||
raise NotImplementedError("qk_ln is not supported for number of shards > 1")
|
||||
bias = not config.no_bias
|
||||
hidden_size = config.d_model
|
||||
head_dim = hidden_size // self.n_heads
|
||||
|
||||
norm_class = LPLayerNorm
|
||||
self.q_ln = norm_class(d_model, bias=bias, prefix=f"{prefix}.q_ln", weights=weights)
|
||||
self.k_ln = norm_class(self.n_heads * head_dim, prefix=f"{prefix}.k_ln", weights=weights)
|
||||
self.q_ln = LPLayerNorm(d_model, bias=bias, prefix=f"{prefix}.q_ln", weights=weights)
|
||||
self.k_ln = LPLayerNorm(self.n_heads * head_dim, prefix=f"{prefix}.k_ln", weights=weights)
|
||||
if self.attn_impl == "flash":
|
||||
self.attn_fn = flash_attn_fn
|
||||
elif self.attn_impl == "triton":
|
||||
@ -650,13 +647,15 @@ class MPTBlock(nn.Module):
|
||||
|
||||
|
||||
def _cast_if_autocast_enabled(tensor):
|
||||
if tensor.device.type == "cuda":
|
||||
dtype = torch.get_autocast_gpu_dtype()
|
||||
elif tensor.device.type == "cpu":
|
||||
dtype = torch.get_autocast_cpu_dtype()
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
return tensor.to(dtype=dtype)
|
||||
if torch.is_autocast_enabled():
|
||||
if tensor.device.type == "cuda":
|
||||
dtype = torch.get_autocast_gpu_dtype()
|
||||
elif tensor.device.type == "cpu":
|
||||
dtype = torch.get_autocast_cpu_dtype()
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
return tensor.to(dtype=dtype)
|
||||
return tensor
|
||||
|
||||
|
||||
class LPLayerNorm(torch.nn.LayerNorm):
|
||||
@ -680,8 +679,11 @@ class LPLayerNorm(torch.nn.LayerNorm):
|
||||
bias=bias,
|
||||
)
|
||||
if weights is not None:
|
||||
self.weight = nn.Parameter(weights.get_tensor(f"{prefix}.weight"))
|
||||
self.bias = nn.Parameter(weights.get_tensor(f"{prefix}.bias"))
|
||||
self.weight = nn.Parameter(weights.get_sharded(f"{prefix}.weight", dim=0))
|
||||
if bias:
|
||||
self.bias = nn.Parameter(weights.get_sharded(f"{prefix}.bias", dim=0))
|
||||
self.normalized_shape = self.weight.shape
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
module_device = x.device
|
||||
@ -837,7 +839,6 @@ class MPTModel(MPTPreTrainedModel):
|
||||
if self.config.init_config["verbose"] > 1:
|
||||
init_fn_name = self.config.init_config["name"]
|
||||
warnings.warn(f"Using {init_fn_name} initialization.")
|
||||
self.embedding_fraction = config.embedding_fraction
|
||||
|
||||
@torch.no_grad()
|
||||
def _attn_bias(
|
||||
@ -1027,11 +1028,6 @@ class MPTModel(MPTPreTrainedModel):
|
||||
)
|
||||
pos_emb = self.wpe(pos)
|
||||
x = tok_emb + pos_emb
|
||||
if self.embedding_fraction != 1:
|
||||
x = (
|
||||
x * self.embedding_fraction
|
||||
+ x.detach() * (1 - self.embedding_fraction)
|
||||
)
|
||||
(attn_bias, attention_mask) = self._attn_bias(
|
||||
device=x.device,
|
||||
dtype=torch.float32,
|
||||
|
Loading…
Reference in New Issue
Block a user