rollback to torch 2.1

This commit is contained in:
OlivierDehaene 2024-02-16 16:40:16 +01:00
parent af23c432e8
commit a337182b43
3 changed files with 13 additions and 13 deletions

View File

@ -39,7 +39,7 @@ RUN cargo build --release
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 as pytorch-install
ARG PYTORCH_VERSION=2.2.0
ARG PYTORCH_VERSION=2.1.1
ARG PYTHON_VERSION=3.10
# Keep in sync with `server/pyproject.toml
ARG CUDA_VERSION=12.1

View File

@ -415,14 +415,14 @@ class CausalLMBatch(Batch):
# We slice the keys to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1
if batch.keys_head_dim_last:
padded_past_keys[
start_index:end_index, :, -past_seq_len:, :
] = past_keys[:, :, -past_seq_len:, :]
padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = (
past_keys[:, :, -past_seq_len:, :]
)
else:
# BLOOM case
padded_past_keys[
start_index:end_index, :, :, -past_seq_len:
] = past_keys[:, :, :, -past_seq_len:]
padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = (
past_keys[:, :, :, -past_seq_len:]
)
del past_keys
start_index = end_index
@ -440,9 +440,9 @@ class CausalLMBatch(Batch):
end_index = start_index + len(batch)
# We slice the past values to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1
padded_past_values[
start_index:end_index, :, -past_seq_len:, :
] = past_values[:, :, -past_seq_len:, :]
padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
past_values[:, :, -past_seq_len:, :]
)
del past_values
# Update values

View File

@ -1017,9 +1017,9 @@ class FlashCausalLM(Model):
# Copy batch.input_ids to prefill_token_indices
if prefill_logprobs:
if len(batch) > 1:
prefill_tokens_indices[
out_start_index : out_end_index - 1
] = batch.input_ids[start_index + 1 : start_index + out_length]
prefill_tokens_indices[out_start_index : out_end_index - 1] = (
batch.input_ids[start_index + 1 : start_index + out_length]
)
else:
# Set prefill_tokens_indices to the correct slice
prefill_tokens_indices = batch.input_ids[