mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
fix galactica batching
This commit is contained in:
parent
a4782da22b
commit
bc4c6a406a
@ -156,31 +156,29 @@ class CausalLMBatch:
|
|||||||
past_keys = past_keys.view(batch.size, -1, *past_keys.shape[-2:])
|
past_keys = past_keys.view(batch.size, -1, *past_keys.shape[-2:])
|
||||||
past_values = past_values.view(batch.size, -1, *past_values.shape[-2:])
|
past_values = past_values.view(batch.size, -1, *past_values.shape[-2:])
|
||||||
|
|
||||||
_, num_heads, head_dim, padded_sequence_length = past_keys.shape
|
_, num_heads, padded_sequence_length, head_dim = past_values.shape
|
||||||
|
|
||||||
padded_past_keys_shape = (
|
|
||||||
total_batch_size,
|
|
||||||
num_heads,
|
|
||||||
head_dim,
|
|
||||||
max_sequence_length - 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# head_dim is last for BLOOM
|
|
||||||
if past_values.shape[-1] == head_dim:
|
|
||||||
past_values_head_dim_last = True
|
|
||||||
padded_past_values_shape = (
|
padded_past_values_shape = (
|
||||||
total_batch_size,
|
total_batch_size,
|
||||||
num_heads,
|
num_heads,
|
||||||
max_sequence_length - 1,
|
max_sequence_length - 1,
|
||||||
head_dim,
|
head_dim,
|
||||||
)
|
)
|
||||||
elif past_values.shape[-2] == head_dim:
|
|
||||||
past_values_head_dim_last = False
|
# seq_length is last for BLOOM
|
||||||
padded_past_values_shape = padded_past_keys_shape
|
if past_keys.shape[-2] == head_dim:
|
||||||
else:
|
past_keys_head_dim_last = False
|
||||||
raise ValueError(
|
padded_past_keys_shape = (
|
||||||
f"past_values shape {past_values.shape} is not valid"
|
total_batch_size,
|
||||||
|
num_heads,
|
||||||
|
head_dim,
|
||||||
|
max_sequence_length - 1,
|
||||||
)
|
)
|
||||||
|
elif past_keys.shape[-1] == head_dim:
|
||||||
|
past_keys_head_dim_last = True
|
||||||
|
padded_past_keys_shape = padded_past_values_shape
|
||||||
|
else:
|
||||||
|
raise ValueError(f"past_keys shape {past_keys.shape} is not valid")
|
||||||
|
|
||||||
# This will run only once per layer
|
# This will run only once per layer
|
||||||
if j == len(past_key_values):
|
if j == len(past_key_values):
|
||||||
@ -197,24 +195,24 @@ class CausalLMBatch:
|
|||||||
past_key_values.append((padded_past_keys, padded_past_values))
|
past_key_values.append((padded_past_keys, padded_past_values))
|
||||||
|
|
||||||
# We slice the past keys and values to remove the padding from previous batches
|
# We slice the past keys and values to remove the padding from previous batches
|
||||||
|
if past_keys_head_dim_last:
|
||||||
past_key_values[j][0][
|
past_key_values[j][0][
|
||||||
start_index:end_index, :, :, -(batch.max_sequence_length - 1) :
|
start_index:end_index,
|
||||||
|
:,
|
||||||
|
-(batch.max_sequence_length - 1) :,
|
||||||
|
:,
|
||||||
|
] = past_keys[:, :, -(batch.max_sequence_length - 1) :, :]
|
||||||
|
else:
|
||||||
|
past_key_values[j][0][
|
||||||
|
start_index:end_index,
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
-(batch.max_sequence_length - 1) :,
|
||||||
] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]
|
] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]
|
||||||
|
|
||||||
if past_values_head_dim_last:
|
|
||||||
past_key_values[j][1][
|
past_key_values[j][1][
|
||||||
start_index:end_index,
|
start_index:end_index, :, -(batch.max_sequence_length - 1) :, :
|
||||||
:,
|
|
||||||
-(batch.max_sequence_length - 1) :,
|
|
||||||
:,
|
|
||||||
] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
|
] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
|
||||||
else:
|
|
||||||
past_key_values[j][1][
|
|
||||||
start_index:end_index,
|
|
||||||
:,
|
|
||||||
:,
|
|
||||||
-(batch.max_sequence_length - 1) :,
|
|
||||||
] = past_values[:, :, :, -(batch.max_sequence_length - 1) :]
|
|
||||||
|
|
||||||
start_index += batch.size
|
start_index += batch.size
|
||||||
|
|
||||||
|
@ -204,6 +204,9 @@ class GalacticaSharded(Galactica):
|
|||||||
file, framework="pt", device=str(device) if not quantize else "cpu"
|
file, framework="pt", device=str(device) if not quantize else "cpu"
|
||||||
) as f:
|
) as f:
|
||||||
for name in f.keys():
|
for name in f.keys():
|
||||||
|
if name == "lm_head.weight":
|
||||||
|
continue
|
||||||
|
|
||||||
module_name, param_name = name.rsplit(".", 1)
|
module_name, param_name = name.rsplit(".", 1)
|
||||||
try:
|
try:
|
||||||
module = model.get_submodule(module_name)
|
module = model.get_submodule(module_name)
|
||||||
|
Loading…
Reference in New Issue
Block a user