rollback to torch 2.1

This commit is contained in:
OlivierDehaene 2024-02-16 16:40:16 +01:00
parent af23c432e8
commit a337182b43
3 changed files with 13 additions and 13 deletions

View File

@ -39,7 +39,7 @@ RUN cargo build --release
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile # Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 as pytorch-install FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 as pytorch-install
ARG PYTORCH_VERSION=2.2.0 ARG PYTORCH_VERSION=2.1.1
ARG PYTHON_VERSION=3.10 ARG PYTHON_VERSION=3.10
# Keep in sync with `server/pyproject.toml # Keep in sync with `server/pyproject.toml
ARG CUDA_VERSION=12.1 ARG CUDA_VERSION=12.1

View File

@ -415,14 +415,14 @@ class CausalLMBatch(Batch):
# We slice the keys to remove the padding from previous batches # We slice the keys to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1 past_seq_len = batch.max_input_length - 1
if batch.keys_head_dim_last: if batch.keys_head_dim_last:
padded_past_keys[ padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = (
start_index:end_index, :, -past_seq_len:, : past_keys[:, :, -past_seq_len:, :]
] = past_keys[:, :, -past_seq_len:, :] )
else: else:
# BLOOM case # BLOOM case
padded_past_keys[ padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = (
start_index:end_index, :, :, -past_seq_len: past_keys[:, :, :, -past_seq_len:]
] = past_keys[:, :, :, -past_seq_len:] )
del past_keys del past_keys
start_index = end_index start_index = end_index
@ -440,9 +440,9 @@ class CausalLMBatch(Batch):
end_index = start_index + len(batch) end_index = start_index + len(batch)
# We slice the past values to remove the padding from previous batches # We slice the past values to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1 past_seq_len = batch.max_input_length - 1
padded_past_values[ padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
start_index:end_index, :, -past_seq_len:, : past_values[:, :, -past_seq_len:, :]
] = past_values[:, :, -past_seq_len:, :] )
del past_values del past_values
# Update values # Update values

View File

@ -1017,9 +1017,9 @@ class FlashCausalLM(Model):
# Copy batch.input_ids to prefill_token_indices # Copy batch.input_ids to prefill_token_indices
if prefill_logprobs: if prefill_logprobs:
if len(batch) > 1: if len(batch) > 1:
prefill_tokens_indices[ prefill_tokens_indices[out_start_index : out_end_index - 1] = (
out_start_index : out_end_index - 1 batch.input_ids[start_index + 1 : start_index + out_length]
] = batch.input_ids[start_index + 1 : start_index + out_length] )
else: else:
# Set prefill_tokens_indices to the correct slice # Set prefill_tokens_indices to the correct slice
prefill_tokens_indices = batch.input_ids[ prefill_tokens_indices = batch.input_ids[