mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
fix galactica batching
This commit is contained in:
parent
a4782da22b
commit
bc4c6a406a
@ -156,31 +156,29 @@ class CausalLMBatch:
|
||||
past_keys = past_keys.view(batch.size, -1, *past_keys.shape[-2:])
|
||||
past_values = past_values.view(batch.size, -1, *past_values.shape[-2:])
|
||||
|
||||
_, num_heads, head_dim, padded_sequence_length = past_keys.shape
|
||||
_, num_heads, padded_sequence_length, head_dim = past_values.shape
|
||||
|
||||
padded_past_keys_shape = (
|
||||
padded_past_values_shape = (
|
||||
total_batch_size,
|
||||
num_heads,
|
||||
head_dim,
|
||||
max_sequence_length - 1,
|
||||
head_dim,
|
||||
)
|
||||
|
||||
# head_dim is last for BLOOM
|
||||
if past_values.shape[-1] == head_dim:
|
||||
past_values_head_dim_last = True
|
||||
padded_past_values_shape = (
|
||||
# seq_length is last for BLOOM
|
||||
if past_keys.shape[-2] == head_dim:
|
||||
past_keys_head_dim_last = False
|
||||
padded_past_keys_shape = (
|
||||
total_batch_size,
|
||||
num_heads,
|
||||
max_sequence_length - 1,
|
||||
head_dim,
|
||||
max_sequence_length - 1,
|
||||
)
|
||||
elif past_values.shape[-2] == head_dim:
|
||||
past_values_head_dim_last = False
|
||||
padded_past_values_shape = padded_past_keys_shape
|
||||
elif past_keys.shape[-1] == head_dim:
|
||||
past_keys_head_dim_last = True
|
||||
padded_past_keys_shape = padded_past_values_shape
|
||||
else:
|
||||
raise ValueError(
|
||||
f"past_values shape {past_values.shape} is not valid"
|
||||
)
|
||||
raise ValueError(f"past_keys shape {past_keys.shape} is not valid")
|
||||
|
||||
# This will run only once per layer
|
||||
if j == len(past_key_values):
|
||||
@ -197,24 +195,24 @@ class CausalLMBatch:
|
||||
past_key_values.append((padded_past_keys, padded_past_values))
|
||||
|
||||
# We slice the past keys and values to remove the padding from previous batches
|
||||
past_key_values[j][0][
|
||||
start_index:end_index, :, :, -(batch.max_sequence_length - 1) :
|
||||
] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]
|
||||
|
||||
if past_values_head_dim_last:
|
||||
past_key_values[j][1][
|
||||
if past_keys_head_dim_last:
|
||||
past_key_values[j][0][
|
||||
start_index:end_index,
|
||||
:,
|
||||
-(batch.max_sequence_length - 1) :,
|
||||
:,
|
||||
] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
|
||||
] = past_keys[:, :, -(batch.max_sequence_length - 1) :, :]
|
||||
else:
|
||||
past_key_values[j][1][
|
||||
past_key_values[j][0][
|
||||
start_index:end_index,
|
||||
:,
|
||||
:,
|
||||
-(batch.max_sequence_length - 1) :,
|
||||
] = past_values[:, :, :, -(batch.max_sequence_length - 1) :]
|
||||
] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]
|
||||
|
||||
past_key_values[j][1][
|
||||
start_index:end_index, :, -(batch.max_sequence_length - 1) :, :
|
||||
] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
|
||||
|
||||
start_index += batch.size
|
||||
|
||||
|
@ -204,6 +204,9 @@ class GalacticaSharded(Galactica):
|
||||
file, framework="pt", device=str(device) if not quantize else "cpu"
|
||||
) as f:
|
||||
for name in f.keys():
|
||||
if name == "lm_head.weight":
|
||||
continue
|
||||
|
||||
module_name, param_name = name.rsplit(".", 1)
|
||||
try:
|
||||
module = model.get_submodule(module_name)
|
||||
|
Loading…
Reference in New Issue
Block a user