mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
better implem
This commit is contained in:
parent
a8446a5a31
commit
f2f78e17d1
@ -37,7 +37,6 @@ class CausalLMBatch(Batch):
|
||||
# Metadata used for padding
|
||||
size: int
|
||||
max_sequence_length: int
|
||||
max_potential_length: int
|
||||
padding_right_offset: int
|
||||
|
||||
# Past metadata
|
||||
@ -64,7 +63,6 @@ class CausalLMBatch(Batch):
|
||||
|
||||
# Parse batch
|
||||
max_sequence_length = 0
|
||||
max_potential_length = 0
|
||||
padding_right_offset = 0
|
||||
for r in pb.requests:
|
||||
inputs.append(r.inputs)
|
||||
@ -75,10 +73,9 @@ class CausalLMBatch(Batch):
|
||||
)
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
max_sequence_length = max(max_sequence_length, r.input_length)
|
||||
potential_length = r.input_length + stopping_criteria.max_new_tokens
|
||||
if max_potential_length < potential_length:
|
||||
max_potential_length = potential_length
|
||||
padding_right_offset = stopping_criteria.max_new_tokens
|
||||
padding_right_offset = max(
|
||||
padding_right_offset, stopping_criteria.max_new_tokens
|
||||
)
|
||||
|
||||
tokenized_inputs = tokenizer(
|
||||
inputs,
|
||||
@ -89,7 +86,9 @@ class CausalLMBatch(Batch):
|
||||
|
||||
input_ids = tokenized_inputs["input_ids"]
|
||||
# Allocate maximum attention_mask
|
||||
attention_mask = input_ids.new_zeros((pb.size, max_potential_length))
|
||||
attention_mask = input_ids.new_zeros(
|
||||
(pb.size, max_sequence_length + padding_right_offset)
|
||||
)
|
||||
# Copy tokenizer attention_mask into fully allocated attention_mask
|
||||
attention_mask[:, :max_sequence_length] = tokenized_inputs["attention_mask"]
|
||||
|
||||
@ -110,7 +109,6 @@ class CausalLMBatch(Batch):
|
||||
stopping_criterias=stopping_criterias,
|
||||
size=pb.size,
|
||||
max_sequence_length=max_sequence_length,
|
||||
max_potential_length=max_potential_length,
|
||||
padding_right_offset=padding_right_offset,
|
||||
)
|
||||
|
||||
@ -120,14 +118,11 @@ class CausalLMBatch(Batch):
|
||||
# Used for padding
|
||||
total_batch_size = 0
|
||||
max_sequence_length = 0
|
||||
max_potential_length = 0
|
||||
padding_right_offset = 0
|
||||
for batch in batches:
|
||||
total_batch_size += batch.size
|
||||
max_sequence_length = max(max_sequence_length, batch.max_sequence_length)
|
||||
if max_potential_length < batch.max_potential_length:
|
||||
max_potential_length = batch.max_potential_length
|
||||
padding_right_offset = batch.padding_right_offset
|
||||
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
|
||||
|
||||
# Batch attributes
|
||||
requests = []
|
||||
@ -170,21 +165,21 @@ class CausalLMBatch(Batch):
|
||||
# Create padded tensor
|
||||
if attention_mask is None:
|
||||
attention_mask = batch.attention_mask.new_zeros(
|
||||
(total_batch_size, max_potential_length),
|
||||
(total_batch_size, max_sequence_length + padding_right_offset),
|
||||
)
|
||||
|
||||
# We need to slice the attention mask to remove padding from previous steps
|
||||
# and to remove unused allocated space
|
||||
left_offset = max_sequence_length - batch.max_sequence_length
|
||||
batch_left_offset = (
|
||||
batch.attention_mask.shape[1] - batch.max_sequence_length - batch.padding_right_offset
|
||||
)
|
||||
attention_mask[
|
||||
start_index:end_index,
|
||||
-(
|
||||
batch.max_sequence_length + padding_right_offset
|
||||
) : -padding_right_offset,
|
||||
left_offset:-padding_right_offset,
|
||||
] = batch.attention_mask[
|
||||
:,
|
||||
-(
|
||||
batch.max_sequence_length + batch.padding_right_offset
|
||||
) : -batch.padding_right_offset,
|
||||
batch_left_offset : -batch.padding_right_offset,
|
||||
]
|
||||
|
||||
# Create empty tensor
|
||||
@ -263,7 +258,6 @@ class CausalLMBatch(Batch):
|
||||
stopping_criterias=stopping_criterias,
|
||||
size=total_batch_size,
|
||||
max_sequence_length=max_sequence_length,
|
||||
max_potential_length=max_potential_length,
|
||||
padding_right_offset=padding_right_offset,
|
||||
keys_head_dim_last=batches[0].keys_head_dim_last,
|
||||
)
|
||||
@ -332,10 +326,7 @@ class CausalLM(Model):
|
||||
self, batch: CausalLMBatch
|
||||
) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
|
||||
# slice the attention mask to the correct shape
|
||||
if batch.padding_right_offset != 0:
|
||||
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
||||
else:
|
||||
attention_mask = batch.attention_mask
|
||||
|
||||
logits, past = self.forward(
|
||||
batch.input_ids,
|
||||
@ -355,7 +346,6 @@ class CausalLM(Model):
|
||||
# Metadata
|
||||
next_batch_size = 0
|
||||
next_batch_max_sequence_length = 0
|
||||
next_batch_max_potential_length = 0
|
||||
|
||||
# Results
|
||||
generations: List[Generation] = []
|
||||
@ -428,13 +418,6 @@ class CausalLM(Model):
|
||||
next_batch_max_sequence_length = max(
|
||||
next_batch_max_sequence_length, new_input_length
|
||||
)
|
||||
# potential length is input_length + max_new_tokens but we need to remove generated tokens
|
||||
next_batch_max_potential_length = max(
|
||||
next_batch_max_potential_length,
|
||||
new_input_length
|
||||
+ stopping_criteria.max_new_tokens
|
||||
- stopping_criteria.current_tokens,
|
||||
)
|
||||
|
||||
# Prefill
|
||||
if stopping_criteria.current_tokens == 1:
|
||||
@ -518,7 +501,6 @@ class CausalLM(Model):
|
||||
stopping_criterias=next_batch_stopping_criterias,
|
||||
size=next_batch_size,
|
||||
max_sequence_length=next_batch_max_sequence_length,
|
||||
max_potential_length=next_batch_max_potential_length,
|
||||
padding_right_offset=batch.padding_right_offset - 1,
|
||||
keys_head_dim_last=batch.keys_head_dim_last,
|
||||
)
|
||||
|
@ -106,12 +106,10 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||
)
|
||||
|
||||
# Tokenize batch
|
||||
pad_to_multiple_of = 8 if device.type == "cuda" else None
|
||||
tokenized_inputs = tokenizer(
|
||||
inputs,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_token_type_ids=False,
|
||||
).to(device)
|
||||
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
||||
|
@ -42,7 +42,6 @@ class Seq2SeqLMBatch(Batch):
|
||||
size: int
|
||||
max_input_length: int
|
||||
max_decoder_input_length: int
|
||||
max_potential_length: int
|
||||
padding_right_offset: int
|
||||
|
||||
def to_pb(self) -> generate_pb2.Batch:
|
||||
@ -71,7 +70,6 @@ class Seq2SeqLMBatch(Batch):
|
||||
|
||||
# Parse batch
|
||||
max_input_length = 0
|
||||
max_potential_length = 0
|
||||
padding_right_offset = 0
|
||||
for r in pb.requests:
|
||||
inputs.append(r.inputs)
|
||||
@ -85,18 +83,15 @@ class Seq2SeqLMBatch(Batch):
|
||||
)
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
max_input_length = max(max_input_length, r.input_length)
|
||||
if max_potential_length < stopping_criteria.max_new_tokens + 1:
|
||||
# +1 because we have the bos token
|
||||
max_potential_length = stopping_criteria.max_new_tokens + 1
|
||||
padding_right_offset = stopping_criteria.max_new_tokens
|
||||
padding_right_offset = max(
|
||||
padding_right_offset, stopping_criteria.max_new_tokens
|
||||
)
|
||||
|
||||
# Tokenize batch
|
||||
pad_to_multiple_of = 8 if device.type == "cuda" else None
|
||||
tokenized_inputs = tokenizer(
|
||||
inputs,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_token_type_ids=False,
|
||||
).to(device)
|
||||
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
|
||||
@ -118,7 +113,6 @@ class Seq2SeqLMBatch(Batch):
|
||||
size=len(pb.requests),
|
||||
max_input_length=max(input_lengths),
|
||||
max_decoder_input_length=1,
|
||||
max_potential_length=max_potential_length,
|
||||
padding_right_offset=padding_right_offset,
|
||||
)
|
||||
|
||||
@ -131,7 +125,6 @@ class Seq2SeqLMBatch(Batch):
|
||||
total_batch_size = 0
|
||||
max_input_length = 0
|
||||
max_decoder_input_length = 0
|
||||
max_potential_length = 0
|
||||
padding_right_offset = 0
|
||||
for batch in batches:
|
||||
total_batch_size += batch.size
|
||||
@ -139,9 +132,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
max_decoder_input_length = max(
|
||||
max_decoder_input_length, batch.max_decoder_input_length
|
||||
)
|
||||
if max_potential_length < batch.max_potential_length:
|
||||
max_potential_length = batch.max_potential_length
|
||||
padding_right_offset = batch.padding_right_offset
|
||||
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
|
||||
|
||||
# Batch attributes
|
||||
requests = []
|
||||
@ -200,29 +191,28 @@ class Seq2SeqLMBatch(Batch):
|
||||
if decoder_attention_mask is None:
|
||||
# As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here
|
||||
decoder_attention_mask = batch.attention_mask.new_zeros(
|
||||
(total_batch_size, max_potential_length),
|
||||
(total_batch_size, max_decoder_input_length + padding_right_offset),
|
||||
)
|
||||
# If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
|
||||
# this batch. All generations are of length `batch.max_decoder_input_length`.
|
||||
left_offset = max_decoder_input_length - batch.max_decoder_input_length
|
||||
if batch.decoder_attention_mask is None:
|
||||
decoder_attention_mask[
|
||||
start_index:end_index,
|
||||
-(
|
||||
batch.max_decoder_input_length + padding_right_offset
|
||||
) : -padding_right_offset,
|
||||
left_offset:-padding_right_offset,
|
||||
] = 1
|
||||
# If it exists, we need to index
|
||||
else:
|
||||
batch_left_offset = (
|
||||
batch.decoder_attention_mask.shape[1]
|
||||
- batch.max_decoder_input_length - batch.padding_right_offset
|
||||
)
|
||||
decoder_attention_mask[
|
||||
start_index:end_index,
|
||||
-(
|
||||
batch.max_decoder_input_length + padding_right_offset
|
||||
) : -padding_right_offset,
|
||||
left_offset:-padding_right_offset,
|
||||
] = batch.decoder_attention_mask[
|
||||
:,
|
||||
-(
|
||||
batch.max_decoder_input_length + batch.padding_right_offset
|
||||
) : -batch.padding_right_offset,
|
||||
batch_left_offset : -batch.padding_right_offset,
|
||||
]
|
||||
|
||||
# Create padded tensor
|
||||
@ -308,7 +298,6 @@ class Seq2SeqLMBatch(Batch):
|
||||
size=total_batch_size,
|
||||
max_input_length=max_input_length,
|
||||
max_decoder_input_length=max_decoder_input_length,
|
||||
max_potential_length=max_potential_length,
|
||||
padding_right_offset=padding_right_offset,
|
||||
)
|
||||
|
||||
@ -387,12 +376,9 @@ class Seq2SeqLM(Model):
|
||||
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
|
||||
if batch.decoder_attention_mask is not None:
|
||||
# slice to the correct shape
|
||||
if batch.padding_right_offset != 0:
|
||||
decoder_attention_mask = batch.decoder_attention_mask[
|
||||
:, : -batch.padding_right_offset
|
||||
]
|
||||
else:
|
||||
decoder_attention_mask = batch.decoder_attention_mask
|
||||
else:
|
||||
decoder_attention_mask = None
|
||||
|
||||
@ -431,7 +417,6 @@ class Seq2SeqLM(Model):
|
||||
next_batch_size = 0
|
||||
next_batch_max_input_length = 0
|
||||
next_batch_max_decoder_input_length = 0
|
||||
next_batch_max_potential_length = 0
|
||||
|
||||
# Finished requests
|
||||
generations: List[Generation] = []
|
||||
@ -506,11 +491,6 @@ class Seq2SeqLM(Model):
|
||||
next_batch_max_decoder_input_length = max(
|
||||
next_batch_max_decoder_input_length, new_decoder_input_length
|
||||
)
|
||||
# +1 because of the bos token
|
||||
next_batch_max_potential_length = max(
|
||||
next_batch_max_potential_length,
|
||||
stopping_criteria.max_new_tokens + 1,
|
||||
)
|
||||
|
||||
# Prefill
|
||||
if stopping_criteria.current_tokens == 1:
|
||||
@ -598,7 +578,6 @@ class Seq2SeqLM(Model):
|
||||
size=next_batch_size,
|
||||
max_input_length=next_batch_max_input_length,
|
||||
max_decoder_input_length=next_batch_max_decoder_input_length,
|
||||
max_potential_length=next_batch_max_potential_length,
|
||||
padding_right_offset=batch.padding_right_offset - 1,
|
||||
)
|
||||
return generations, next_batch
|
||||
|
Loading…
Reference in New Issue
Block a user