mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
rollback to torch 2.1
This commit is contained in:
parent
af23c432e8
commit
a337182b43
@ -39,7 +39,7 @@ RUN cargo build --release
|
|||||||
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
|
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
|
||||||
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 as pytorch-install
|
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 as pytorch-install
|
||||||
|
|
||||||
ARG PYTORCH_VERSION=2.2.0
|
ARG PYTORCH_VERSION=2.1.1
|
||||||
ARG PYTHON_VERSION=3.10
|
ARG PYTHON_VERSION=3.10
|
||||||
# Keep in sync with `server/pyproject.toml
|
# Keep in sync with `server/pyproject.toml
|
||||||
ARG CUDA_VERSION=12.1
|
ARG CUDA_VERSION=12.1
|
||||||
|
@ -415,14 +415,14 @@ class CausalLMBatch(Batch):
|
|||||||
# We slice the keys to remove the padding from previous batches
|
# We slice the keys to remove the padding from previous batches
|
||||||
past_seq_len = batch.max_input_length - 1
|
past_seq_len = batch.max_input_length - 1
|
||||||
if batch.keys_head_dim_last:
|
if batch.keys_head_dim_last:
|
||||||
padded_past_keys[
|
padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = (
|
||||||
start_index:end_index, :, -past_seq_len:, :
|
past_keys[:, :, -past_seq_len:, :]
|
||||||
] = past_keys[:, :, -past_seq_len:, :]
|
)
|
||||||
else:
|
else:
|
||||||
# BLOOM case
|
# BLOOM case
|
||||||
padded_past_keys[
|
padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = (
|
||||||
start_index:end_index, :, :, -past_seq_len:
|
past_keys[:, :, :, -past_seq_len:]
|
||||||
] = past_keys[:, :, :, -past_seq_len:]
|
)
|
||||||
del past_keys
|
del past_keys
|
||||||
|
|
||||||
start_index = end_index
|
start_index = end_index
|
||||||
@ -440,9 +440,9 @@ class CausalLMBatch(Batch):
|
|||||||
end_index = start_index + len(batch)
|
end_index = start_index + len(batch)
|
||||||
# We slice the past values to remove the padding from previous batches
|
# We slice the past values to remove the padding from previous batches
|
||||||
past_seq_len = batch.max_input_length - 1
|
past_seq_len = batch.max_input_length - 1
|
||||||
padded_past_values[
|
padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
|
||||||
start_index:end_index, :, -past_seq_len:, :
|
past_values[:, :, -past_seq_len:, :]
|
||||||
] = past_values[:, :, -past_seq_len:, :]
|
)
|
||||||
del past_values
|
del past_values
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
|
@ -1017,9 +1017,9 @@ class FlashCausalLM(Model):
|
|||||||
# Copy batch.input_ids to prefill_token_indices
|
# Copy batch.input_ids to prefill_token_indices
|
||||||
if prefill_logprobs:
|
if prefill_logprobs:
|
||||||
if len(batch) > 1:
|
if len(batch) > 1:
|
||||||
prefill_tokens_indices[
|
prefill_tokens_indices[out_start_index : out_end_index - 1] = (
|
||||||
out_start_index : out_end_index - 1
|
batch.input_ids[start_index + 1 : start_index + out_length]
|
||||||
] = batch.input_ids[start_index + 1 : start_index + out_length]
|
)
|
||||||
else:
|
else:
|
||||||
# Set prefill_tokens_indices to the correct slice
|
# Set prefill_tokens_indices to the correct slice
|
||||||
prefill_tokens_indices = batch.input_ids[
|
prefill_tokens_indices = batch.input_ids[
|
||||||
|
Loading…
Reference in New Issue
Block a user