diff --git a/server/text_generation_server/models/custom_modeling/neox_modeling.py b/server/text_generation_server/models/custom_modeling/neox_modeling.py index 24ba6796..dbcefbae 100644 --- a/server/text_generation_server/models/custom_modeling/neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/neox_modeling.py @@ -283,10 +283,10 @@ class GPTNeoXAttention(nn.Module): batch_size, num_attention_heads, query_length, attn_head_size = query.size() key_length = key.size(-2) - query = query.view( + query = query.reshape( batch_size * num_attention_heads, query_length, attn_head_size ) - key = key.view(batch_size * num_attention_heads, key_length, attn_head_size) + key = key.reshape(batch_size * num_attention_heads, key_length, attn_head_size) attn_scores = torch.zeros( 1, dtype=query.dtype,