M********

This commit is contained in:
Ubuntu 2023-05-24 20:28:54 +00:00 committed by Nicolas Patry
parent 55045be42f
commit c471e46cf8
2 changed files with 8 additions and 9 deletions

View File

@ -148,7 +148,7 @@ def get_model(
) )
elif model_type == "gpt_neox": elif model_type == "gpt_neox":
if FLASH_ATTENTION and False: if FLASH_ATTENTION:
return FlashNeoXSharded( return FlashNeoXSharded(
model_id, model_id,
revision, revision,

View File

@ -139,7 +139,9 @@ class T5DenseActDense(nn.Module):
hidden_states = hidden_states.to(dtype=self.wo_cast[0]) hidden_states = hidden_states.to(dtype=self.wo_cast[0])
hidden_states = self.wo(hidden_states) hidden_states = self.wo(hidden_states)
hidden_states = hidden_states.to(dtype=self.wo_cast[1]) # XXX: Recasting is already done within the layer norm.
# Casting back to float16 here modifies results
# hidden_states = hidden_states.to(dtype=self.wo_cast[1])
return hidden_states return hidden_states
@ -182,7 +184,9 @@ class T5DenseGatedActDense(nn.Module):
hidden_states = hidden_states.to(dtype=self.wo_cast[0]) hidden_states = hidden_states.to(dtype=self.wo_cast[0])
hidden_states = self.wo(hidden_states) hidden_states = self.wo(hidden_states)
hidden_states = hidden_states.to(dtype=self.wo_cast[1]) # XXX: Recasting is already done within the layer norm.
# Casting back to float16 here modifies results
# hidden_states = hidden_states.to(dtype=self.wo_cast[1])
return hidden_states return hidden_states
@ -350,6 +354,7 @@ class T5Attention(nn.Module):
# Input is (batch_size, seq_length, dim) # Input is (batch_size, seq_length, dim)
# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
batch_size, seq_length = hidden_states.shape[:2] batch_size, seq_length = hidden_states.shape[:2]
real_seq_length = seq_length real_seq_length = seq_length
@ -841,10 +846,6 @@ class T5Stack(T5PreTrainedModel):
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
from safetensors.torch import save_file
save_file({"inputs_embeds": inputs_embeds}, f"inputs_embeds_{self.rank}_layer.safetensors")
batch_size, seq_length = input_shape batch_size, seq_length = input_shape
# required mask seq length can be calculated via length of past # required mask seq length can be calculated via length of past
@ -941,8 +942,6 @@ class T5Stack(T5PreTrainedModel):
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
from safetensors.torch import save_file
save_file({"layer": layer_outputs[0]}, f"layer_outputs_{self.rank}_layer_{i}.safetensors")
# layer_outputs is a tuple with: # layer_outputs is a tuple with:
# hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)