diff --git a/Dockerfile b/Dockerfile index 73892494..e79372a3 100644 --- a/Dockerfile +++ b/Dockerfile @@ -39,7 +39,7 @@ RUN cargo build --release # Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 as pytorch-install -ARG PYTORCH_VERSION=2.2.0 +ARG PYTORCH_VERSION=2.1.1 ARG PYTHON_VERSION=3.10 # Keep in sync with `server/pyproject.toml ARG CUDA_VERSION=12.1 diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 3de45921..a0f0c9e8 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -415,14 +415,14 @@ class CausalLMBatch(Batch): # We slice the keys to remove the padding from previous batches past_seq_len = batch.max_input_length - 1 if batch.keys_head_dim_last: - padded_past_keys[ - start_index:end_index, :, -past_seq_len:, : - ] = past_keys[:, :, -past_seq_len:, :] + padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = ( + past_keys[:, :, -past_seq_len:, :] + ) else: # BLOOM case - padded_past_keys[ - start_index:end_index, :, :, -past_seq_len: - ] = past_keys[:, :, :, -past_seq_len:] + padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = ( + past_keys[:, :, :, -past_seq_len:] + ) del past_keys start_index = end_index @@ -440,9 +440,9 @@ class CausalLMBatch(Batch): end_index = start_index + len(batch) # We slice the past values to remove the padding from previous batches past_seq_len = batch.max_input_length - 1 - padded_past_values[ - start_index:end_index, :, -past_seq_len:, : - ] = past_values[:, :, -past_seq_len:, :] + padded_past_values[start_index:end_index, :, -past_seq_len:, :] = ( + past_values[:, :, -past_seq_len:, :] + ) del past_values # Update values diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 5168a33d..b8d0be22 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1017,9 +1017,9 @@ class FlashCausalLM(Model): # Copy batch.input_ids to prefill_token_indices if prefill_logprobs: if len(batch) > 1: - prefill_tokens_indices[ - out_start_index : out_end_index - 1 - ] = batch.input_ids[start_index + 1 : start_index + out_length] + prefill_tokens_indices[out_start_index : out_end_index - 1] = ( + batch.input_ids[start_index + 1 : start_index + out_length] + ) else: # Set prefill_tokens_indices to the correct slice prefill_tokens_indices = batch.input_ids[