mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
rollback to torch 2.1
This commit is contained in:
parent
af23c432e8
commit
a337182b43
@ -39,7 +39,7 @@ RUN cargo build --release
|
||||
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
|
||||
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 as pytorch-install
|
||||
|
||||
ARG PYTORCH_VERSION=2.2.0
|
||||
ARG PYTORCH_VERSION=2.1.1
|
||||
ARG PYTHON_VERSION=3.10
|
||||
# Keep in sync with `server/pyproject.toml
|
||||
ARG CUDA_VERSION=12.1
|
||||
|
@ -415,14 +415,14 @@ class CausalLMBatch(Batch):
|
||||
# We slice the keys to remove the padding from previous batches
|
||||
past_seq_len = batch.max_input_length - 1
|
||||
if batch.keys_head_dim_last:
|
||||
padded_past_keys[
|
||||
start_index:end_index, :, -past_seq_len:, :
|
||||
] = past_keys[:, :, -past_seq_len:, :]
|
||||
padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = (
|
||||
past_keys[:, :, -past_seq_len:, :]
|
||||
)
|
||||
else:
|
||||
# BLOOM case
|
||||
padded_past_keys[
|
||||
start_index:end_index, :, :, -past_seq_len:
|
||||
] = past_keys[:, :, :, -past_seq_len:]
|
||||
padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = (
|
||||
past_keys[:, :, :, -past_seq_len:]
|
||||
)
|
||||
del past_keys
|
||||
|
||||
start_index = end_index
|
||||
@ -440,9 +440,9 @@ class CausalLMBatch(Batch):
|
||||
end_index = start_index + len(batch)
|
||||
# We slice the past values to remove the padding from previous batches
|
||||
past_seq_len = batch.max_input_length - 1
|
||||
padded_past_values[
|
||||
start_index:end_index, :, -past_seq_len:, :
|
||||
] = past_values[:, :, -past_seq_len:, :]
|
||||
padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
|
||||
past_values[:, :, -past_seq_len:, :]
|
||||
)
|
||||
del past_values
|
||||
|
||||
# Update values
|
||||
|
@ -1017,9 +1017,9 @@ class FlashCausalLM(Model):
|
||||
# Copy batch.input_ids to prefill_token_indices
|
||||
if prefill_logprobs:
|
||||
if len(batch) > 1:
|
||||
prefill_tokens_indices[
|
||||
out_start_index : out_end_index - 1
|
||||
] = batch.input_ids[start_index + 1 : start_index + out_length]
|
||||
prefill_tokens_indices[out_start_index : out_end_index - 1] = (
|
||||
batch.input_ids[start_index + 1 : start_index + out_length]
|
||||
)
|
||||
else:
|
||||
# Set prefill_tokens_indices to the correct slice
|
||||
prefill_tokens_indices = batch.input_ids[
|
||||
|
Loading…
Reference in New Issue
Block a user