diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 7cd49239..1472c2a6 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -34,6 +34,7 @@ class CausalLMBatch(Batch): # Lengths of all generations present in the batch input_lengths: List[int] + offsets: List[Optional[int]] # Generation helpers next_token_choosers: List[NextTokenChooser] @@ -64,12 +65,14 @@ class CausalLMBatch(Batch): inputs = [] next_token_choosers = [] stopping_criterias = [] + offsets = [] # Parse batch max_truncation = 0 padding_right_offset = 0 for r in pb.requests: inputs.append(r.inputs) + offsets.append(None) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer @@ -113,6 +116,7 @@ class CausalLMBatch(Batch): past_key_values=None, all_input_ids=all_input_ids, input_lengths=input_lengths.tolist(), + offsets=offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, size=pb.size, @@ -135,6 +139,7 @@ class CausalLMBatch(Batch): # Batch attributes requests = [] input_lengths = [] + offsets = [] all_input_ids = [] next_token_choosers = [] stopping_criterias = [] @@ -151,6 +156,7 @@ class CausalLMBatch(Batch): for i, batch in enumerate(batches): requests.extend(batch.requests) input_lengths.extend(batch.input_lengths) + offsets.extend(batch.offsets) all_input_ids.extend(batch.all_input_ids) next_token_choosers.extend(batch.next_token_choosers) stopping_criterias.extend(batch.stopping_criterias) @@ -264,6 +270,7 @@ class CausalLMBatch(Batch): past_key_values=past_key_values, all_input_ids=all_input_ids, input_lengths=input_lengths, + offsets=offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, size=total_batch_size, @@ -350,6 +357,7 @@ class CausalLM(Model): # New values for next forward next_batch_input_lengths = [] + next_batch_offsets = [] next_batch_input_ids = [] next_batch_all_input_ids = [] @@ -364,6 +372,7 @@ class CausalLM(Model): iterator = zip( batch.requests, batch.input_lengths, + batch.offsets, logits, batch.next_token_choosers, batch.stopping_criterias, @@ -374,6 +383,7 @@ class CausalLM(Model): for i, ( request, input_length, + offset, logits, next_token_chooser, stopping_criteria, @@ -391,10 +401,7 @@ class CausalLM(Model): # Generated token next_token_logprob = logprobs[-1, next_token_id] next_token_id_squeezed = next_token_id.squeeze() - next_token_text = self.decode_token( - all_input_ids[-2, 0], - next_token_id_squeezed, - ) + next_token_text, offset = self.decode_token(all_input_ids[:, 0], offset) # Evaluate stopping criteria stop, reason = stopping_criteria( @@ -424,6 +431,7 @@ class CausalLM(Model): next_batch_all_input_ids.append(all_input_ids) next_batch_size += 1 next_batch_input_lengths.append(new_input_length) + next_batch_offsets.append(offset) next_batch_max_input_length = max( next_batch_max_input_length, new_input_length ) @@ -507,6 +515,7 @@ class CausalLM(Model): past_key_values=next_batch_past_key_values, all_input_ids=next_batch_all_input_ids, input_lengths=next_batch_input_lengths, + offsets=next_batch_offsets, next_token_choosers=next_batch_next_token_choosers, stopping_criterias=next_batch_stopping_criterias, size=next_batch_size, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 2aeac7b5..ef28ac4d 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -44,6 +44,7 @@ class FlashCausalLMBatch(Batch): # Lengths of all generations present in the batch input_lengths: List[int] + offsets: List[Optional[int]] # Generation helpers next_token_choosers: List[NextTokenChooser] @@ -67,6 +68,7 @@ class FlashCausalLMBatch(Batch): max_seqlen = 0 input_lengths = [] + offsets = [] all_input_ids = [] all_input_ids_tensor = [] @@ -84,6 +86,7 @@ class FlashCausalLMBatch(Batch): input_length = len(tokenized_input) max_seqlen = max(max_seqlen, input_length) input_lengths.append(input_length) + offsets.append(None) all_input_ids.append(tokenized_input) tokenized_input = torch.tensor(tokenized_input, device=device) @@ -120,6 +123,7 @@ class FlashCausalLMBatch(Batch): max_seqlen=max_seqlen, past_key_values=None, input_lengths=input_lengths, + offsets=offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, next_token_choosers=next_token_choosers, @@ -132,6 +136,7 @@ class FlashCausalLMBatch(Batch): # Batch attributes requests = [] input_lengths = [] + offsets = [] all_input_ids = [] all_input_ids_tensor = [] next_token_choosers = [] @@ -150,6 +155,7 @@ class FlashCausalLMBatch(Batch): for i, batch in enumerate(batches): requests.extend(batch.requests) input_lengths.extend(batch.input_lengths) + offsets.extend(batch.offsets) all_input_ids.extend(batch.all_input_ids) all_input_ids_tensor.extend(batch.all_input_ids_tensor) next_token_choosers.extend(batch.next_token_choosers) @@ -279,6 +285,7 @@ class FlashCausalLM(Model): next_batch_max_seqlen = 0 next_batch_past_key_values = [] next_batch_input_lengths = [] + next_batch_offsets = [] next_batch_all_input_ids = [] next_batch_all_input_ids_tensor = [] @@ -292,6 +299,7 @@ class FlashCausalLM(Model): iterator = zip( batch.requests, batch.input_lengths, + batch.offsets, batch.next_token_choosers, batch.stopping_criterias, batch.all_input_ids, @@ -302,6 +310,7 @@ class FlashCausalLM(Model): for i, ( request, input_length, + offset, next_token_chooser, stopping_criteria, all_input_ids, @@ -334,9 +343,8 @@ class FlashCausalLM(Model): # Generated token next_token_logprob = logprobs[-1, next_token_id_item] - next_token_text = self.decode_token( - all_input_ids[-2], - next_token_id_item, + next_token_text, offset = self.decode_token( + all_input_ids[-(stopping_criteria.current_tokens + 1) :], offset ) # Evaluate stopping criteria @@ -377,6 +385,7 @@ class FlashCausalLM(Model): next_batch_cu_seqlens[-1] + new_input_length ) next_batch_input_lengths.append(new_input_length) + next_batch_offsets.append(offset) next_batch_all_input_ids.append(all_input_ids) next_batch_all_input_ids_tensor.append(all_input_ids_tensor) next_batch_max_seqlen = max(next_batch_max_seqlen, new_input_length) @@ -453,6 +462,7 @@ class FlashCausalLM(Model): max_seqlen=next_batch_max_seqlen, past_key_values=next_batch_past_key_values, input_lengths=next_batch_input_lengths, + offsets=next_batch_offsets, all_input_ids=next_batch_all_input_ids, all_input_ids_tensor=next_batch_all_input_ids_tensor, next_token_choosers=next_batch_next_token_choosers, diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index f997ab1a..f1d3e8a6 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -93,23 +93,21 @@ class GalacticaCausalLMBatch(CausalLMBatch): inputs = [] next_token_choosers = [] stopping_criterias = [] - input_lengths = [] + offsets = [] # Parse batch max_truncation = 0 - max_sequence_length = 0 padding_right_offset = 0 for r in pb.requests: # Add escape_custom_split_sequence to the CausalLMBatch logic inputs.append(escape_custom_split_sequence(r.inputs)) - input_lengths.append(r.input_length) + offsets.append(None) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer ) stopping_criterias.append(stopping_criteria) max_truncation = max(max_truncation, r.truncate) - max_sequence_length = max(max_sequence_length, r.input_length) padding_right_offset = max( padding_right_offset, stopping_criteria.max_new_tokens ) @@ -123,13 +121,17 @@ class GalacticaCausalLMBatch(CausalLMBatch): truncation=True, max_length=max_truncation, ).to(device) + + input_lengths = tokenized_inputs["attention_mask"].sum(1) + max_input_length = input_lengths.max() + input_ids = tokenized_inputs["input_ids"] # Allocate maximum attention_mask attention_mask = input_ids.new_zeros( - (pb.size, max_sequence_length + padding_right_offset) + (pb.size, max_input_length + padding_right_offset) ) # Copy tokenizer attention_mask into fully allocated attention_mask - attention_mask[:, :max_sequence_length] = tokenized_inputs["attention_mask"] + attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"] position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) @@ -144,10 +146,11 @@ class GalacticaCausalLMBatch(CausalLMBatch): past_key_values=None, all_input_ids=all_input_ids, input_lengths=input_lengths, + offsets=offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, size=pb.size, - max_sequence_length=max_sequence_length, + max_input_length=max_input_length, padding_right_offset=padding_right_offset, ) diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 5a917f5b..cd3ac6a5 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -24,16 +24,26 @@ class Model(ABC): def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]: raise NotImplementedError - def decode_token(self, previous_token_id: int, token_id: int) -> str: + def decode_token( + self, all_input_ids: List[int], offset: Optional[int] = None + ) -> Tuple[str, Optional[int]]: """Hack to hopefully support generate_stream for the maximum number of tokenizers""" - # Decode previous token and previous token + token + + # Decode all token minus last one and all tokens results = self.tokenizer.batch_decode( - [[previous_token_id], [previous_token_id, token_id]], + [all_input_ids[:-1], all_input_ids], skip_special_tokens=False, ) - if results[0] and results[0][0] == " " and results[1][0] != " ": - results[0] = results[0].lstrip() + # default offset is only the last token + if offset is None: + offset = len(results[0]) - # slice to remove previous token - return results[1][len(results[0]): ] + # get text + text = results[1][offset:] + + # if text is utf-8 + if text and text[-1] != "�": + return text, None + else: + return "", offset diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 72f694c3..99bfa991 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -38,6 +38,7 @@ class Seq2SeqLMBatch(Batch): # Lengths of all generations present in the batch input_lengths: List[int] decoder_input_lengths: List[int] + offsets: List[Optional[int]] # Generation helpers next_token_choosers: List[NextTokenChooser] @@ -71,6 +72,7 @@ class Seq2SeqLMBatch(Batch): decoder_input_ids = [] decoder_input_lengths = [] + offsets = [] # Parse batch max_truncation = 0 @@ -80,6 +82,7 @@ class Seq2SeqLMBatch(Batch): # Decoder sequence only contains the bos_token decoder_input_ids.append(tokenizer.bos_token_id) decoder_input_lengths.append(1) + offsets.append(None) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer @@ -117,6 +120,7 @@ class Seq2SeqLMBatch(Batch): past_key_values=None, input_lengths=input_lengths.tolist(), decoder_input_lengths=decoder_input_lengths, + offsets=offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, size=len(pb.requests), @@ -147,6 +151,7 @@ class Seq2SeqLMBatch(Batch): requests = [] input_lengths = [] decoder_input_lengths = [] + offsets = [] next_token_choosers = [] stopping_criterias = [] @@ -166,6 +171,7 @@ class Seq2SeqLMBatch(Batch): requests.extend(batch.requests) input_lengths.extend(batch.input_lengths) decoder_input_lengths.extend(batch.decoder_input_lengths) + offsets.extend(batch.offsets) next_token_choosers.extend(batch.next_token_choosers) stopping_criterias.extend(batch.stopping_criterias) @@ -303,6 +309,7 @@ class Seq2SeqLMBatch(Batch): past_key_values=past_key_values, input_lengths=input_lengths, decoder_input_lengths=decoder_input_lengths, + offsets=offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, size=total_batch_size, @@ -422,6 +429,7 @@ class Seq2SeqLM(Model): # New values for next forward next_batch_input_lengths = [] + next_batch_offsets = [] next_batch_decoder_input_ids = [] next_batch_decoder_input_lengths = [] @@ -437,6 +445,7 @@ class Seq2SeqLM(Model): iterator = zip( batch.requests, batch.input_lengths, + batch.offsets, batch.decoder_input_lengths, logits, batch.next_token_choosers, @@ -448,6 +457,7 @@ class Seq2SeqLM(Model): for i, ( request, input_length, + offset, decoder_input_length, logits, next_token_chooser, @@ -466,10 +476,7 @@ class Seq2SeqLM(Model): # Generated token next_token_logprob = logprobs[-1, next_token_id] next_token_id_squeezed = next_token_id.squeeze() - next_token_text = self.decode_token( - decoder_input_ids[-2], - next_token_id_squeezed, - ) + next_token_text, offset = self.decode_token(decoder_input_ids, offset) # Evaluate stopping criteria stop, reason = stopping_criteria(next_token_id, next_token_text) @@ -496,6 +503,7 @@ class Seq2SeqLM(Model): next_batch_size += 1 next_batch_input_lengths.append(input_length) next_batch_decoder_input_lengths.append(new_decoder_input_length) + next_batch_offsets.append(offset) next_batch_max_input_length = max( next_batch_max_input_length, input_length ) @@ -581,6 +589,7 @@ class Seq2SeqLM(Model): past_key_values=next_batch_past_key_values, input_lengths=next_batch_input_lengths, decoder_input_lengths=next_batch_decoder_input_lengths, + offsets=next_batch_offsets, next_token_choosers=next_batch_next_token_choosers, stopping_criterias=next_batch_stopping_criterias, size=next_batch_size,