mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 08:22:07 +00:00
Fix: Replace view() with reshape() in neox_modeling.py to resolve RuntimeError (#1155)
This commit is contained in:
parent
7402a355dc
commit
9179605e1e
@ -283,10 +283,10 @@ class GPTNeoXAttention(nn.Module):
|
||||
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
|
||||
key_length = key.size(-2)
|
||||
|
||||
query = query.view(
|
||||
query = query.reshape(
|
||||
batch_size * num_attention_heads, query_length, attn_head_size
|
||||
)
|
||||
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
|
||||
key = key.reshape(batch_size * num_attention_heads, key_length, attn_head_size)
|
||||
attn_scores = torch.zeros(
|
||||
1,
|
||||
dtype=query.dtype,
|
||||
|
Loading…
Reference in New Issue
Block a user