From bc4c6a406abaef039bbd92d47dd46a38fc20513f Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 1 Dec 2022 18:54:53 +0100 Subject: [PATCH] fix galactica batching --- server/text_generation/models/causal_lm.py | 44 +++++++++++----------- server/text_generation/models/galactica.py | 3 ++ 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index dee5a19e..ca8ea575 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -156,31 +156,29 @@ class CausalLMBatch: past_keys = past_keys.view(batch.size, -1, *past_keys.shape[-2:]) past_values = past_values.view(batch.size, -1, *past_values.shape[-2:]) - _, num_heads, head_dim, padded_sequence_length = past_keys.shape + _, num_heads, padded_sequence_length, head_dim = past_values.shape - padded_past_keys_shape = ( + padded_past_values_shape = ( total_batch_size, num_heads, - head_dim, max_sequence_length - 1, + head_dim, ) - # head_dim is last for BLOOM - if past_values.shape[-1] == head_dim: - past_values_head_dim_last = True - padded_past_values_shape = ( + # seq_length is last for BLOOM + if past_keys.shape[-2] == head_dim: + past_keys_head_dim_last = False + padded_past_keys_shape = ( total_batch_size, num_heads, - max_sequence_length - 1, head_dim, + max_sequence_length - 1, ) - elif past_values.shape[-2] == head_dim: - past_values_head_dim_last = False - padded_past_values_shape = padded_past_keys_shape + elif past_keys.shape[-1] == head_dim: + past_keys_head_dim_last = True + padded_past_keys_shape = padded_past_values_shape else: - raise ValueError( - f"past_values shape {past_values.shape} is not valid" - ) + raise ValueError(f"past_keys shape {past_keys.shape} is not valid") # This will run only once per layer if j == len(past_key_values): @@ -197,24 +195,24 @@ class CausalLMBatch: past_key_values.append((padded_past_keys, padded_past_values)) # We slice the past keys and values to remove the padding from previous batches - past_key_values[j][0][ - start_index:end_index, :, :, -(batch.max_sequence_length - 1) : - ] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :] - - if past_values_head_dim_last: - past_key_values[j][1][ + if past_keys_head_dim_last: + past_key_values[j][0][ start_index:end_index, :, -(batch.max_sequence_length - 1) :, :, - ] = past_values[:, :, -(batch.max_sequence_length - 1) :, :] + ] = past_keys[:, :, -(batch.max_sequence_length - 1) :, :] else: - past_key_values[j][1][ + past_key_values[j][0][ start_index:end_index, :, :, -(batch.max_sequence_length - 1) :, - ] = past_values[:, :, :, -(batch.max_sequence_length - 1) :] + ] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :] + + past_key_values[j][1][ + start_index:end_index, :, -(batch.max_sequence_length - 1) :, : + ] = past_values[:, :, -(batch.max_sequence_length - 1) :, :] start_index += batch.size diff --git a/server/text_generation/models/galactica.py b/server/text_generation/models/galactica.py index 81aac649..abc3c36c 100644 --- a/server/text_generation/models/galactica.py +++ b/server/text_generation/models/galactica.py @@ -204,6 +204,9 @@ class GalacticaSharded(Galactica): file, framework="pt", device=str(device) if not quantize else "cpu" ) as f: for name in f.keys(): + if name == "lm_head.weight": + continue + module_name, param_name = name.rsplit(".", 1) try: module = model.get_submodule(module_name)