M********

This commit is contained in:
Ubuntu 2023-05-24 20:28:54 +00:00 committed by Nicolas Patry
parent 55045be42f
commit c471e46cf8
2 changed files with 8 additions and 9 deletions

View File

@ -148,7 +148,7 @@ def get_model(
)
elif model_type == "gpt_neox":
if FLASH_ATTENTION and False:
if FLASH_ATTENTION:
return FlashNeoXSharded(
model_id,
revision,

View File

@ -139,7 +139,9 @@ class T5DenseActDense(nn.Module):
hidden_states = hidden_states.to(dtype=self.wo_cast[0])
hidden_states = self.wo(hidden_states)
hidden_states = hidden_states.to(dtype=self.wo_cast[1])
# XXX: Recasting is already done within the layer norm.
# Casting back to float16 here modifies results
# hidden_states = hidden_states.to(dtype=self.wo_cast[1])
return hidden_states
@ -182,7 +184,9 @@ class T5DenseGatedActDense(nn.Module):
hidden_states = hidden_states.to(dtype=self.wo_cast[0])
hidden_states = self.wo(hidden_states)
hidden_states = hidden_states.to(dtype=self.wo_cast[1])
# XXX: Recasting is already done within the layer norm.
# Casting back to float16 here modifies results
# hidden_states = hidden_states.to(dtype=self.wo_cast[1])
return hidden_states
@ -350,6 +354,7 @@ class T5Attention(nn.Module):
# Input is (batch_size, seq_length, dim)
# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
batch_size, seq_length = hidden_states.shape[:2]
real_seq_length = seq_length
@ -841,10 +846,6 @@ class T5Stack(T5PreTrainedModel):
inputs_embeds = self.embed_tokens(input_ids)
from safetensors.torch import save_file
save_file({"inputs_embeds": inputs_embeds}, f"inputs_embeds_{self.rank}_layer.safetensors")
batch_size, seq_length = input_shape
# required mask seq length can be calculated via length of past
@ -941,8 +942,6 @@ class T5Stack(T5PreTrainedModel):
use_cache=use_cache,
output_attentions=output_attentions,
)
from safetensors.torch import save_file
save_file({"layer": layer_outputs[0]}, f"layer_outputs_{self.rank}_layer_{i}.safetensors")
# layer_outputs is a tuple with:
# hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)