diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 1f862c9e..f0427c20 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -148,7 +148,7 @@ def get_model( ) elif model_type == "gpt_neox": - if FLASH_ATTENTION and False: + if FLASH_ATTENTION: return FlashNeoXSharded( model_id, revision, diff --git a/server/text_generation_server/models/custom_modeling/t5_modeling.py b/server/text_generation_server/models/custom_modeling/t5_modeling.py index afc04311..a4e6249b 100644 --- a/server/text_generation_server/models/custom_modeling/t5_modeling.py +++ b/server/text_generation_server/models/custom_modeling/t5_modeling.py @@ -139,7 +139,9 @@ class T5DenseActDense(nn.Module): hidden_states = hidden_states.to(dtype=self.wo_cast[0]) hidden_states = self.wo(hidden_states) - hidden_states = hidden_states.to(dtype=self.wo_cast[1]) + # XXX: Recasting is already done within the layer norm. + # Casting back to float16 here modifies results + # hidden_states = hidden_states.to(dtype=self.wo_cast[1]) return hidden_states @@ -182,7 +184,9 @@ class T5DenseGatedActDense(nn.Module): hidden_states = hidden_states.to(dtype=self.wo_cast[0]) hidden_states = self.wo(hidden_states) - hidden_states = hidden_states.to(dtype=self.wo_cast[1]) + # XXX: Recasting is already done within the layer norm. + # Casting back to float16 here modifies results + # hidden_states = hidden_states.to(dtype=self.wo_cast[1]) return hidden_states @@ -350,6 +354,7 @@ class T5Attention(nn.Module): # Input is (batch_size, seq_length, dim) # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + batch_size, seq_length = hidden_states.shape[:2] real_seq_length = seq_length @@ -841,10 +846,6 @@ class T5Stack(T5PreTrainedModel): inputs_embeds = self.embed_tokens(input_ids) - from safetensors.torch import save_file - save_file({"inputs_embeds": inputs_embeds}, f"inputs_embeds_{self.rank}_layer.safetensors") - - batch_size, seq_length = input_shape # required mask seq length can be calculated via length of past @@ -941,8 +942,6 @@ class T5Stack(T5PreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, ) - from safetensors.torch import save_file - save_file({"layer": layer_outputs[0]}, f"layer_outputs_{self.rank}_layer_{i}.safetensors") # layer_outputs is a tuple with: # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)