mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
fix(server): fix reshaping of bloom past_key_values in concatenate() (#252)
Introduced in #214 Fixes #249
This commit is contained in:
parent
db2b4e0754
commit
b4cf832c40
@ -335,7 +335,7 @@ class CausalLMBatch(Batch):
|
|||||||
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
|
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
|
||||||
for layer in batch.past_key_values
|
for layer in batch.past_key_values
|
||||||
]
|
]
|
||||||
elif batch.past_key_values[0][0].shape == 3:
|
elif len(batch.past_key_values[0][0].shape) == 3:
|
||||||
for layer in batch.past_key_values:
|
for layer in batch.past_key_values:
|
||||||
for k, t in enumerate(layer):
|
for k, t in enumerate(layer):
|
||||||
layer[k] = t.view(len(batch), -1, *t.shape[-2:])
|
layer[k] = t.view(len(batch), -1, *t.shape[-2:])
|
||||||
|
Loading…
Reference in New Issue
Block a user