mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
M********
This commit is contained in:
parent
55045be42f
commit
c471e46cf8
@ -148,7 +148,7 @@ def get_model(
|
||||
)
|
||||
|
||||
elif model_type == "gpt_neox":
|
||||
if FLASH_ATTENTION and False:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashNeoXSharded(
|
||||
model_id,
|
||||
revision,
|
||||
|
@ -139,7 +139,9 @@ class T5DenseActDense(nn.Module):
|
||||
|
||||
hidden_states = hidden_states.to(dtype=self.wo_cast[0])
|
||||
hidden_states = self.wo(hidden_states)
|
||||
hidden_states = hidden_states.to(dtype=self.wo_cast[1])
|
||||
# XXX: Recasting is already done within the layer norm.
|
||||
# Casting back to float16 here modifies results
|
||||
# hidden_states = hidden_states.to(dtype=self.wo_cast[1])
|
||||
return hidden_states
|
||||
|
||||
|
||||
@ -182,7 +184,9 @@ class T5DenseGatedActDense(nn.Module):
|
||||
|
||||
hidden_states = hidden_states.to(dtype=self.wo_cast[0])
|
||||
hidden_states = self.wo(hidden_states)
|
||||
hidden_states = hidden_states.to(dtype=self.wo_cast[1])
|
||||
# XXX: Recasting is already done within the layer norm.
|
||||
# Casting back to float16 here modifies results
|
||||
# hidden_states = hidden_states.to(dtype=self.wo_cast[1])
|
||||
return hidden_states
|
||||
|
||||
|
||||
@ -350,6 +354,7 @@ class T5Attention(nn.Module):
|
||||
# Input is (batch_size, seq_length, dim)
|
||||
# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
|
||||
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
|
||||
|
||||
batch_size, seq_length = hidden_states.shape[:2]
|
||||
|
||||
real_seq_length = seq_length
|
||||
@ -841,10 +846,6 @@ class T5Stack(T5PreTrainedModel):
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
|
||||
from safetensors.torch import save_file
|
||||
save_file({"inputs_embeds": inputs_embeds}, f"inputs_embeds_{self.rank}_layer.safetensors")
|
||||
|
||||
|
||||
batch_size, seq_length = input_shape
|
||||
|
||||
# required mask seq length can be calculated via length of past
|
||||
@ -941,8 +942,6 @@ class T5Stack(T5PreTrainedModel):
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
from safetensors.torch import save_file
|
||||
save_file({"layer": layer_outputs[0]}, f"layer_outputs_{self.rank}_layer_{i}.safetensors")
|
||||
|
||||
# layer_outputs is a tuple with:
|
||||
# hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
|
||||
|
Loading…
Reference in New Issue
Block a user