mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
better implem
This commit is contained in:
parent
a8446a5a31
commit
f2f78e17d1
@ -37,7 +37,6 @@ class CausalLMBatch(Batch):
|
|||||||
# Metadata used for padding
|
# Metadata used for padding
|
||||||
size: int
|
size: int
|
||||||
max_sequence_length: int
|
max_sequence_length: int
|
||||||
max_potential_length: int
|
|
||||||
padding_right_offset: int
|
padding_right_offset: int
|
||||||
|
|
||||||
# Past metadata
|
# Past metadata
|
||||||
@ -64,7 +63,6 @@ class CausalLMBatch(Batch):
|
|||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
max_sequence_length = 0
|
max_sequence_length = 0
|
||||||
max_potential_length = 0
|
|
||||||
padding_right_offset = 0
|
padding_right_offset = 0
|
||||||
for r in pb.requests:
|
for r in pb.requests:
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
@ -75,10 +73,9 @@ class CausalLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
max_sequence_length = max(max_sequence_length, r.input_length)
|
max_sequence_length = max(max_sequence_length, r.input_length)
|
||||||
potential_length = r.input_length + stopping_criteria.max_new_tokens
|
padding_right_offset = max(
|
||||||
if max_potential_length < potential_length:
|
padding_right_offset, stopping_criteria.max_new_tokens
|
||||||
max_potential_length = potential_length
|
)
|
||||||
padding_right_offset = stopping_criteria.max_new_tokens
|
|
||||||
|
|
||||||
tokenized_inputs = tokenizer(
|
tokenized_inputs = tokenizer(
|
||||||
inputs,
|
inputs,
|
||||||
@ -89,7 +86,9 @@ class CausalLMBatch(Batch):
|
|||||||
|
|
||||||
input_ids = tokenized_inputs["input_ids"]
|
input_ids = tokenized_inputs["input_ids"]
|
||||||
# Allocate maximum attention_mask
|
# Allocate maximum attention_mask
|
||||||
attention_mask = input_ids.new_zeros((pb.size, max_potential_length))
|
attention_mask = input_ids.new_zeros(
|
||||||
|
(pb.size, max_sequence_length + padding_right_offset)
|
||||||
|
)
|
||||||
# Copy tokenizer attention_mask into fully allocated attention_mask
|
# Copy tokenizer attention_mask into fully allocated attention_mask
|
||||||
attention_mask[:, :max_sequence_length] = tokenized_inputs["attention_mask"]
|
attention_mask[:, :max_sequence_length] = tokenized_inputs["attention_mask"]
|
||||||
|
|
||||||
@ -110,7 +109,6 @@ class CausalLMBatch(Batch):
|
|||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
size=pb.size,
|
size=pb.size,
|
||||||
max_sequence_length=max_sequence_length,
|
max_sequence_length=max_sequence_length,
|
||||||
max_potential_length=max_potential_length,
|
|
||||||
padding_right_offset=padding_right_offset,
|
padding_right_offset=padding_right_offset,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -120,14 +118,11 @@ class CausalLMBatch(Batch):
|
|||||||
# Used for padding
|
# Used for padding
|
||||||
total_batch_size = 0
|
total_batch_size = 0
|
||||||
max_sequence_length = 0
|
max_sequence_length = 0
|
||||||
max_potential_length = 0
|
|
||||||
padding_right_offset = 0
|
padding_right_offset = 0
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
total_batch_size += batch.size
|
total_batch_size += batch.size
|
||||||
max_sequence_length = max(max_sequence_length, batch.max_sequence_length)
|
max_sequence_length = max(max_sequence_length, batch.max_sequence_length)
|
||||||
if max_potential_length < batch.max_potential_length:
|
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
|
||||||
max_potential_length = batch.max_potential_length
|
|
||||||
padding_right_offset = batch.padding_right_offset
|
|
||||||
|
|
||||||
# Batch attributes
|
# Batch attributes
|
||||||
requests = []
|
requests = []
|
||||||
@ -170,21 +165,21 @@ class CausalLMBatch(Batch):
|
|||||||
# Create padded tensor
|
# Create padded tensor
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = batch.attention_mask.new_zeros(
|
attention_mask = batch.attention_mask.new_zeros(
|
||||||
(total_batch_size, max_potential_length),
|
(total_batch_size, max_sequence_length + padding_right_offset),
|
||||||
)
|
)
|
||||||
|
|
||||||
# We need to slice the attention mask to remove padding from previous steps
|
# We need to slice the attention mask to remove padding from previous steps
|
||||||
# and to remove unused allocated space
|
# and to remove unused allocated space
|
||||||
|
left_offset = max_sequence_length - batch.max_sequence_length
|
||||||
|
batch_left_offset = (
|
||||||
|
batch.attention_mask.shape[1] - batch.max_sequence_length - batch.padding_right_offset
|
||||||
|
)
|
||||||
attention_mask[
|
attention_mask[
|
||||||
start_index:end_index,
|
start_index:end_index,
|
||||||
-(
|
left_offset:-padding_right_offset,
|
||||||
batch.max_sequence_length + padding_right_offset
|
|
||||||
) : -padding_right_offset,
|
|
||||||
] = batch.attention_mask[
|
] = batch.attention_mask[
|
||||||
:,
|
:,
|
||||||
-(
|
batch_left_offset : -batch.padding_right_offset,
|
||||||
batch.max_sequence_length + batch.padding_right_offset
|
|
||||||
) : -batch.padding_right_offset,
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# Create empty tensor
|
# Create empty tensor
|
||||||
@ -263,7 +258,6 @@ class CausalLMBatch(Batch):
|
|||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
size=total_batch_size,
|
size=total_batch_size,
|
||||||
max_sequence_length=max_sequence_length,
|
max_sequence_length=max_sequence_length,
|
||||||
max_potential_length=max_potential_length,
|
|
||||||
padding_right_offset=padding_right_offset,
|
padding_right_offset=padding_right_offset,
|
||||||
keys_head_dim_last=batches[0].keys_head_dim_last,
|
keys_head_dim_last=batches[0].keys_head_dim_last,
|
||||||
)
|
)
|
||||||
@ -332,10 +326,7 @@ class CausalLM(Model):
|
|||||||
self, batch: CausalLMBatch
|
self, batch: CausalLMBatch
|
||||||
) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
|
) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
|
||||||
# slice the attention mask to the correct shape
|
# slice the attention mask to the correct shape
|
||||||
if batch.padding_right_offset != 0:
|
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
||||||
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
|
||||||
else:
|
|
||||||
attention_mask = batch.attention_mask
|
|
||||||
|
|
||||||
logits, past = self.forward(
|
logits, past = self.forward(
|
||||||
batch.input_ids,
|
batch.input_ids,
|
||||||
@ -355,7 +346,6 @@ class CausalLM(Model):
|
|||||||
# Metadata
|
# Metadata
|
||||||
next_batch_size = 0
|
next_batch_size = 0
|
||||||
next_batch_max_sequence_length = 0
|
next_batch_max_sequence_length = 0
|
||||||
next_batch_max_potential_length = 0
|
|
||||||
|
|
||||||
# Results
|
# Results
|
||||||
generations: List[Generation] = []
|
generations: List[Generation] = []
|
||||||
@ -428,13 +418,6 @@ class CausalLM(Model):
|
|||||||
next_batch_max_sequence_length = max(
|
next_batch_max_sequence_length = max(
|
||||||
next_batch_max_sequence_length, new_input_length
|
next_batch_max_sequence_length, new_input_length
|
||||||
)
|
)
|
||||||
# potential length is input_length + max_new_tokens but we need to remove generated tokens
|
|
||||||
next_batch_max_potential_length = max(
|
|
||||||
next_batch_max_potential_length,
|
|
||||||
new_input_length
|
|
||||||
+ stopping_criteria.max_new_tokens
|
|
||||||
- stopping_criteria.current_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if stopping_criteria.current_tokens == 1:
|
if stopping_criteria.current_tokens == 1:
|
||||||
@ -518,7 +501,6 @@ class CausalLM(Model):
|
|||||||
stopping_criterias=next_batch_stopping_criterias,
|
stopping_criterias=next_batch_stopping_criterias,
|
||||||
size=next_batch_size,
|
size=next_batch_size,
|
||||||
max_sequence_length=next_batch_max_sequence_length,
|
max_sequence_length=next_batch_max_sequence_length,
|
||||||
max_potential_length=next_batch_max_potential_length,
|
|
||||||
padding_right_offset=batch.padding_right_offset - 1,
|
padding_right_offset=batch.padding_right_offset - 1,
|
||||||
keys_head_dim_last=batch.keys_head_dim_last,
|
keys_head_dim_last=batch.keys_head_dim_last,
|
||||||
)
|
)
|
||||||
|
@ -106,12 +106,10 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Tokenize batch
|
# Tokenize batch
|
||||||
pad_to_multiple_of = 8 if device.type == "cuda" else None
|
|
||||||
tokenized_inputs = tokenizer(
|
tokenized_inputs = tokenizer(
|
||||||
inputs,
|
inputs,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
padding=True,
|
padding=True,
|
||||||
pad_to_multiple_of=pad_to_multiple_of,
|
|
||||||
return_token_type_ids=False,
|
return_token_type_ids=False,
|
||||||
).to(device)
|
).to(device)
|
||||||
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
||||||
|
@ -42,7 +42,6 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
size: int
|
size: int
|
||||||
max_input_length: int
|
max_input_length: int
|
||||||
max_decoder_input_length: int
|
max_decoder_input_length: int
|
||||||
max_potential_length: int
|
|
||||||
padding_right_offset: int
|
padding_right_offset: int
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.Batch:
|
def to_pb(self) -> generate_pb2.Batch:
|
||||||
@ -71,7 +70,6 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
max_input_length = 0
|
max_input_length = 0
|
||||||
max_potential_length = 0
|
|
||||||
padding_right_offset = 0
|
padding_right_offset = 0
|
||||||
for r in pb.requests:
|
for r in pb.requests:
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
@ -85,18 +83,15 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
max_input_length = max(max_input_length, r.input_length)
|
max_input_length = max(max_input_length, r.input_length)
|
||||||
if max_potential_length < stopping_criteria.max_new_tokens + 1:
|
padding_right_offset = max(
|
||||||
# +1 because we have the bos token
|
padding_right_offset, stopping_criteria.max_new_tokens
|
||||||
max_potential_length = stopping_criteria.max_new_tokens + 1
|
)
|
||||||
padding_right_offset = stopping_criteria.max_new_tokens
|
|
||||||
|
|
||||||
# Tokenize batch
|
# Tokenize batch
|
||||||
pad_to_multiple_of = 8 if device.type == "cuda" else None
|
|
||||||
tokenized_inputs = tokenizer(
|
tokenized_inputs = tokenizer(
|
||||||
inputs,
|
inputs,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
padding=True,
|
padding=True,
|
||||||
pad_to_multiple_of=pad_to_multiple_of,
|
|
||||||
return_token_type_ids=False,
|
return_token_type_ids=False,
|
||||||
).to(device)
|
).to(device)
|
||||||
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
|
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
|
||||||
@ -118,7 +113,6 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
size=len(pb.requests),
|
size=len(pb.requests),
|
||||||
max_input_length=max(input_lengths),
|
max_input_length=max(input_lengths),
|
||||||
max_decoder_input_length=1,
|
max_decoder_input_length=1,
|
||||||
max_potential_length=max_potential_length,
|
|
||||||
padding_right_offset=padding_right_offset,
|
padding_right_offset=padding_right_offset,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -131,7 +125,6 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
total_batch_size = 0
|
total_batch_size = 0
|
||||||
max_input_length = 0
|
max_input_length = 0
|
||||||
max_decoder_input_length = 0
|
max_decoder_input_length = 0
|
||||||
max_potential_length = 0
|
|
||||||
padding_right_offset = 0
|
padding_right_offset = 0
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
total_batch_size += batch.size
|
total_batch_size += batch.size
|
||||||
@ -139,9 +132,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
max_decoder_input_length = max(
|
max_decoder_input_length = max(
|
||||||
max_decoder_input_length, batch.max_decoder_input_length
|
max_decoder_input_length, batch.max_decoder_input_length
|
||||||
)
|
)
|
||||||
if max_potential_length < batch.max_potential_length:
|
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
|
||||||
max_potential_length = batch.max_potential_length
|
|
||||||
padding_right_offset = batch.padding_right_offset
|
|
||||||
|
|
||||||
# Batch attributes
|
# Batch attributes
|
||||||
requests = []
|
requests = []
|
||||||
@ -200,29 +191,28 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
if decoder_attention_mask is None:
|
if decoder_attention_mask is None:
|
||||||
# As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here
|
# As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here
|
||||||
decoder_attention_mask = batch.attention_mask.new_zeros(
|
decoder_attention_mask = batch.attention_mask.new_zeros(
|
||||||
(total_batch_size, max_potential_length),
|
(total_batch_size, max_decoder_input_length + padding_right_offset),
|
||||||
)
|
)
|
||||||
# If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
|
# If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
|
||||||
# this batch. All generations are of length `batch.max_decoder_input_length`.
|
# this batch. All generations are of length `batch.max_decoder_input_length`.
|
||||||
|
left_offset = max_decoder_input_length - batch.max_decoder_input_length
|
||||||
if batch.decoder_attention_mask is None:
|
if batch.decoder_attention_mask is None:
|
||||||
decoder_attention_mask[
|
decoder_attention_mask[
|
||||||
start_index:end_index,
|
start_index:end_index,
|
||||||
-(
|
left_offset:-padding_right_offset,
|
||||||
batch.max_decoder_input_length + padding_right_offset
|
|
||||||
) : -padding_right_offset,
|
|
||||||
] = 1
|
] = 1
|
||||||
# If it exists, we need to index
|
# If it exists, we need to index
|
||||||
else:
|
else:
|
||||||
|
batch_left_offset = (
|
||||||
|
batch.decoder_attention_mask.shape[1]
|
||||||
|
- batch.max_decoder_input_length - batch.padding_right_offset
|
||||||
|
)
|
||||||
decoder_attention_mask[
|
decoder_attention_mask[
|
||||||
start_index:end_index,
|
start_index:end_index,
|
||||||
-(
|
left_offset:-padding_right_offset,
|
||||||
batch.max_decoder_input_length + padding_right_offset
|
|
||||||
) : -padding_right_offset,
|
|
||||||
] = batch.decoder_attention_mask[
|
] = batch.decoder_attention_mask[
|
||||||
:,
|
:,
|
||||||
-(
|
batch_left_offset : -batch.padding_right_offset,
|
||||||
batch.max_decoder_input_length + batch.padding_right_offset
|
|
||||||
) : -batch.padding_right_offset,
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# Create padded tensor
|
# Create padded tensor
|
||||||
@ -308,7 +298,6 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
size=total_batch_size,
|
size=total_batch_size,
|
||||||
max_input_length=max_input_length,
|
max_input_length=max_input_length,
|
||||||
max_decoder_input_length=max_decoder_input_length,
|
max_decoder_input_length=max_decoder_input_length,
|
||||||
max_potential_length=max_potential_length,
|
|
||||||
padding_right_offset=padding_right_offset,
|
padding_right_offset=padding_right_offset,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -387,12 +376,9 @@ class Seq2SeqLM(Model):
|
|||||||
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
|
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
|
||||||
if batch.decoder_attention_mask is not None:
|
if batch.decoder_attention_mask is not None:
|
||||||
# slice to the correct shape
|
# slice to the correct shape
|
||||||
if batch.padding_right_offset != 0:
|
decoder_attention_mask = batch.decoder_attention_mask[
|
||||||
decoder_attention_mask = batch.decoder_attention_mask[
|
:, : -batch.padding_right_offset
|
||||||
:, : -batch.padding_right_offset
|
]
|
||||||
]
|
|
||||||
else:
|
|
||||||
decoder_attention_mask = batch.decoder_attention_mask
|
|
||||||
else:
|
else:
|
||||||
decoder_attention_mask = None
|
decoder_attention_mask = None
|
||||||
|
|
||||||
@ -431,7 +417,6 @@ class Seq2SeqLM(Model):
|
|||||||
next_batch_size = 0
|
next_batch_size = 0
|
||||||
next_batch_max_input_length = 0
|
next_batch_max_input_length = 0
|
||||||
next_batch_max_decoder_input_length = 0
|
next_batch_max_decoder_input_length = 0
|
||||||
next_batch_max_potential_length = 0
|
|
||||||
|
|
||||||
# Finished requests
|
# Finished requests
|
||||||
generations: List[Generation] = []
|
generations: List[Generation] = []
|
||||||
@ -506,11 +491,6 @@ class Seq2SeqLM(Model):
|
|||||||
next_batch_max_decoder_input_length = max(
|
next_batch_max_decoder_input_length = max(
|
||||||
next_batch_max_decoder_input_length, new_decoder_input_length
|
next_batch_max_decoder_input_length, new_decoder_input_length
|
||||||
)
|
)
|
||||||
# +1 because of the bos token
|
|
||||||
next_batch_max_potential_length = max(
|
|
||||||
next_batch_max_potential_length,
|
|
||||||
stopping_criteria.max_new_tokens + 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if stopping_criteria.current_tokens == 1:
|
if stopping_criteria.current_tokens == 1:
|
||||||
@ -598,7 +578,6 @@ class Seq2SeqLM(Model):
|
|||||||
size=next_batch_size,
|
size=next_batch_size,
|
||||||
max_input_length=next_batch_max_input_length,
|
max_input_length=next_batch_max_input_length,
|
||||||
max_decoder_input_length=next_batch_max_decoder_input_length,
|
max_decoder_input_length=next_batch_max_decoder_input_length,
|
||||||
max_potential_length=next_batch_max_potential_length,
|
|
||||||
padding_right_offset=batch.padding_right_offset - 1,
|
padding_right_offset=batch.padding_right_offset - 1,
|
||||||
)
|
)
|
||||||
return generations, next_batch
|
return generations, next_batch
|
||||||
|
Loading…
Reference in New Issue
Block a user