From 9179605e1eaa5532e81552e7e7fc92ec32652592 Mon Sep 17 00:00:00 2001 From: Mario928 <88029051+Mario928@users.noreply.github.com> Date: Thu, 19 Oct 2023 15:24:26 +0530 Subject: [PATCH] Fix: Replace view() with reshape() in neox_modeling.py to resolve RuntimeError (#1155) --- .../models/custom_modeling/neox_modeling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/neox_modeling.py b/server/text_generation_server/models/custom_modeling/neox_modeling.py index 24ba6796..dbcefbae 100644 --- a/server/text_generation_server/models/custom_modeling/neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/neox_modeling.py @@ -283,10 +283,10 @@ class GPTNeoXAttention(nn.Module): batch_size, num_attention_heads, query_length, attn_head_size = query.size() key_length = key.size(-2) - query = query.view( + query = query.reshape( batch_size * num_attention_heads, query_length, attn_head_size ) - key = key.view(batch_size * num_attention_heads, key_length, attn_head_size) + key = key.reshape(batch_size * num_attention_heads, key_length, attn_head_size) attn_scores = torch.zeros( 1, dtype=query.dtype,