mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 08:22:07 +00:00
Fix: Replace view() with reshape() in neox_modeling.py to resolve RuntimeError (#1155)
This commit is contained in:
parent
7402a355dc
commit
9179605e1e
@ -283,10 +283,10 @@ class GPTNeoXAttention(nn.Module):
|
|||||||
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
|
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
|
||||||
key_length = key.size(-2)
|
key_length = key.size(-2)
|
||||||
|
|
||||||
query = query.view(
|
query = query.reshape(
|
||||||
batch_size * num_attention_heads, query_length, attn_head_size
|
batch_size * num_attention_heads, query_length, attn_head_size
|
||||||
)
|
)
|
||||||
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
|
key = key.reshape(batch_size * num_attention_heads, key_length, attn_head_size)
|
||||||
attn_scores = torch.zeros(
|
attn_scores = torch.zeros(
|
||||||
1,
|
1,
|
||||||
dtype=query.dtype,
|
dtype=query.dtype,
|
||||||
|
Loading…
Reference in New Issue
Block a user