From ab4037c640eb0519acea0090e91f4d916907b650 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 16 May 2023 21:24:53 +0200 Subject: [PATCH] fix naming --- .../models/causal_lm.py | 56 +++++++++--------- .../models/flash_causal_lm.py | 58 +++++++++---------- .../models/galactica.py | 14 +++-- .../models/seq2seq_lm.py | 56 +++++++++--------- 4 files changed, 93 insertions(+), 91 deletions(-) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 3a45ae06..9d8ae254 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -35,8 +35,8 @@ class CausalLMBatch(Batch): # Lengths of all generations present in the batch input_lengths: List[int] - offsets: List[int] - token_offsets: List[int] + prefix_offsets: List[int] + read_offsets: List[int] # Generation helpers next_token_choosers: List[NextTokenChooser] @@ -70,8 +70,8 @@ class CausalLMBatch(Batch): inputs = [] next_token_choosers = [] stopping_criterias = [] - offsets = [] - token_offsets = [] + prefix_offsets = [] + read_offsets = [] requests_idx_mapping = {} # Parse batch @@ -102,8 +102,8 @@ class CausalLMBatch(Batch): ).to(device) for _ in pb.requests: input_len = tokenized_inputs["input_ids"].shape[1] - offsets.append(0) - token_offsets.append(input_len) + prefix_offsets.append(0) + read_offsets.append(input_len) input_lengths = tokenized_inputs["attention_mask"].sum(1) max_input_length = input_lengths.max() @@ -132,8 +132,8 @@ class CausalLMBatch(Batch): past_key_values=None, all_input_ids=list(all_input_ids), input_lengths=input_lengths.tolist(), - offsets=offsets, - token_offsets=token_offsets, + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, max_input_length=max_input_length.item(), @@ -153,8 +153,8 @@ class CausalLMBatch(Batch): # New values after filtering requests_idx_mapping = {} input_lengths = [] - offsets = [] - token_offsets = [] + prefix_offsets = [] + read_offsets = [] all_input_ids = [] max_input_length = 0 @@ -169,8 +169,8 @@ class CausalLMBatch(Batch): requests_idx_mapping[r.id] = i keep_indices.append(idx) - offsets.append(self.offsets[idx]) - token_offsets.append(self.token_offsets[idx]) + prefix_offsets.append(self.prefix_offsets[idx]) + read_offsets.append(self.read_offsets[idx]) all_input_ids.append(self.all_input_ids[idx]) request_input_length = self.input_lengths[idx] @@ -227,8 +227,8 @@ class CausalLMBatch(Batch): self.position_ids = position_ids self.all_input_ids = all_input_ids self.input_lengths = input_lengths - self.offsets = offsets - self.token_offsets = token_offsets + self.prefix_offsets = prefix_offsets + self.read_offsets = read_offsets self.next_token_choosers = next_token_choosers self.stopping_criterias = stopping_criterias self.max_input_length = max_input_length @@ -253,8 +253,8 @@ class CausalLMBatch(Batch): requests = [] requests_idx_mapping = {} input_lengths = [] - offsets = [] - token_offsets = [] + prefix_offsets = [] + read_offsets = [] all_input_ids = [] next_token_choosers = [] stopping_criterias = [] @@ -272,8 +272,8 @@ class CausalLMBatch(Batch): for i, batch in enumerate(batches): requests.extend(batch.requests) input_lengths.extend(batch.input_lengths) - offsets.extend(batch.offsets) - token_offsets.extend(batch.token_offsets) + prefix_offsets.extend(batch.prefix_offsets) + read_offsets.extend(batch.read_offsets) all_input_ids.extend(batch.all_input_ids) next_token_choosers.extend(batch.next_token_choosers) stopping_criterias.extend(batch.stopping_criterias) @@ -430,8 +430,8 @@ class CausalLMBatch(Batch): past_key_values=past_key_values, all_input_ids=all_input_ids, input_lengths=input_lengths, - offsets=offsets, - token_offsets=token_offsets, + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, max_input_length=max_input_length, @@ -529,8 +529,8 @@ class CausalLM(Model): iterator = zip( batch.requests, batch.input_lengths, - batch.offsets, - batch.token_offsets, + batch.prefix_offsets, + batch.read_offsets, logits, batch.next_token_choosers, batch.stopping_criterias, @@ -541,8 +541,8 @@ class CausalLM(Model): for i, ( request, input_length, - offset, - token_offset, + prefix_offset, + read_offset, logits, next_token_chooser, stopping_criteria, @@ -560,8 +560,8 @@ class CausalLM(Model): # Generated token next_token_logprob = logprobs[-1, next_token_id] next_token_id_squeezed = next_token_id.squeeze() - next_token_text, offset, token_offset = self.decode_token( - all_input_ids[:, 0], offset, token_offset + next_token_text, prefix_offset, read_offset = self.decode_token( + all_input_ids[:, 0], prefix_offset, read_offset ) # Evaluate stopping criteria @@ -629,8 +629,8 @@ class CausalLM(Model): batch.input_ids[i, 0] = next_token_id batch.all_input_ids[i] = all_input_ids batch.input_lengths[i] = new_input_length - batch.offsets[i] = offset - batch.token_offsets[i] = token_offset + batch.prefix_offsets[i] = prefix_offset + batch.read_offsets[i] = read_offset batch.max_input_length = max(batch.max_input_length, new_input_length) # We finished all generations in the batch; there is no next batch diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 28376729..aee0480d 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -52,8 +52,8 @@ class FlashCausalLMBatch(Batch): # Lengths of all generations present in the batch input_lengths: List[int] - offsets: List[Optional[int]] - token_offsets: List[Optional[int]] + prefix_offsets: List[Optional[int]] + read_offsets: List[Optional[int]] # Generation helpers next_token_choosers: List[NextTokenChooser] @@ -82,8 +82,8 @@ class FlashCausalLMBatch(Batch): max_seqlen = 0 input_lengths = [] - offsets = [] - token_offsets = [] + prefix_offsets = [] + read_offsets = [] all_input_ids = [] requests_idx_mapping = {} @@ -108,8 +108,8 @@ class FlashCausalLMBatch(Batch): max_seqlen = max(max_seqlen, input_length) input_lengths.append(input_length) - offsets.append(0) - token_offsets.append(input_length) + prefix_offsets.append(0) + read_offsets.append(input_length) all_input_ids.append(tokenized_input) @@ -151,8 +151,8 @@ class FlashCausalLMBatch(Batch): max_seqlen=max_seqlen, past_key_values=None, input_lengths=input_lengths, - offsets=offsets, - token_offsets=token_offsets, + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=[], next_token_choosers=next_token_choosers, @@ -190,8 +190,8 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor = [] input_lengths = [] - offsets = [] - token_offsets = [] + prefix_offsets = [] + read_offsets = [] next_token_choosers = [] stopping_criterias = [] @@ -222,8 +222,8 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor.append(self.all_input_ids_tensor[idx]) input_lengths.append(request_input_length) - offsets.append(self.offsets[idx]) - token_offsets.append(self.token_offsets[idx]) + prefix_offsets.append(self.prefix_offsets[idx]) + read_offsets.append(self.read_offsets[idx]) next_token_choosers.append(self.next_token_choosers[idx]) @@ -269,8 +269,8 @@ class FlashCausalLMBatch(Batch): max_seqlen=max_seqlen, past_key_values=past_key_values, input_lengths=input_lengths, - offsets=offsets, - token_offsets=token_offsets, + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, next_token_choosers=next_token_choosers, @@ -302,8 +302,8 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor = [] input_lengths = [] - offsets = [] - token_offsets = [] + prefix_offsets = [] + read_offsets = [] next_token_choosers = [] stopping_criterias = [] @@ -347,8 +347,8 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor.extend(batch.all_input_ids_tensor) input_lengths.extend(batch.input_lengths) - offsets.extend(batch.offsets) - token_offsets.extend(batch.token_offsets) + prefix_offsets.extend(batch.prefix_offsets) + read_offsets.extend(batch.read_offsets) next_token_choosers.extend(batch.next_token_choosers) stopping_criterias.extend(batch.stopping_criterias) @@ -374,8 +374,8 @@ class FlashCausalLMBatch(Batch): max_seqlen=max_seqlen, past_key_values=past_key_values, input_lengths=input_lengths, - offsets=offsets, - token_offsets=token_offsets, + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, next_token_choosers=next_token_choosers, @@ -640,8 +640,8 @@ class FlashCausalLM(Model): iterator = zip( batch.requests, batch.input_lengths, - batch.offsets, - batch.token_offsets, + batch.prefix_offsets, + batch.read_offsets, batch.next_token_choosers, batch.stopping_criterias, batch.all_input_ids, @@ -654,8 +654,8 @@ class FlashCausalLM(Model): for i, ( request, input_length, - offset, - token_offset, + prefix_offset, + read_offset, next_token_chooser, stopping_criteria, all_input_ids, @@ -670,10 +670,10 @@ class FlashCausalLM(Model): all_input_ids.append(next_token_id) # Generated token - next_token_text, offset, token_offset = self.decode_token( + next_token_text, prefix_offset, read_offset = self.decode_token( all_input_ids, - offset, - token_offset, + prefix_offset, + read_offset, ) # Evaluate stopping criteria @@ -739,8 +739,8 @@ class FlashCausalLM(Model): # Update values batch.input_lengths[i] = new_input_length - batch.offsets[i] = offset - batch.token_offsets[i] = token_offset + batch.prefix_offsets[i] = prefix_offset + batch.read_offsets[i] = read_offset batch.all_input_ids[i] = all_input_ids batch.max_seqlen = batch.max_seqlen + 1 cumulative_length += input_length diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index b34489d8..24c37c19 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -94,8 +94,8 @@ class GalacticaCausalLMBatch(CausalLMBatch): inputs = [] next_token_choosers = [] stopping_criterias = [] - offsets = [] - token_offsets = [] + prefix_offsets = [] + read_offsets = [] requests_idx_mapping = {} # Parse batch @@ -106,8 +106,6 @@ class GalacticaCausalLMBatch(CausalLMBatch): requests_idx_mapping[r.id] = i # Add escape_custom_split_sequence to the CausalLMBatch logic inputs.append(escape_custom_split_sequence(r.inputs)) - offsets.append(None) - token_offsets.append(None) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer @@ -127,6 +125,10 @@ class GalacticaCausalLMBatch(CausalLMBatch): truncation=True, max_length=max_truncation, ).to(device) + for _ in pb.requests: + input_len = tokenized_inputs["input_ids"].shape[1] + prefix_offsets.append(0) + read_offsets.append(input_len) input_lengths = tokenized_inputs["attention_mask"].sum(1) max_input_length = input_lengths.max() @@ -155,8 +157,8 @@ class GalacticaCausalLMBatch(CausalLMBatch): past_key_values=None, all_input_ids=list(all_input_ids), input_lengths=input_lengths.tolist(), - offsets=offsets, - token_offsets=token_offsets, + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, max_input_length=max_input_length.item(), diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index ac7b9cdd..4f55b22f 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -42,8 +42,8 @@ class Seq2SeqLMBatch(Batch): # Lengths of all generations present in the batch input_lengths: List[int] decoder_input_lengths: List[int] - offsets: List[int] - token_offsets: List[int] + prefix_offsets: List[int] + read_offsets: List[int] # Generation helpers next_token_choosers: List[NextTokenChooser] @@ -79,8 +79,8 @@ class Seq2SeqLMBatch(Batch): stopping_criterias = [] decoder_input_lengths = [] - offsets = [] - token_offsets = [] + prefix_offsets = [] + read_offsets = [] requests_idx_mapping = {} # Parse batch @@ -122,8 +122,8 @@ class Seq2SeqLMBatch(Batch): .view(-1, 1) ) for _ in pb.requests: - offsets.append(0) - token_offsets.append(1) + prefix_offsets.append(0) + read_offsets.append(1) all_decoder_input_ids = decoder_input_ids.view(-1).split(1) max_tokens = len(inputs) * max_input_length + max_decode_tokens @@ -141,8 +141,8 @@ class Seq2SeqLMBatch(Batch): past_key_values=None, input_lengths=input_lengths.tolist(), decoder_input_lengths=decoder_input_lengths, - offsets=offsets, - token_offsets=token_offsets, + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, max_input_length=max_input_length.item(), @@ -166,8 +166,8 @@ class Seq2SeqLMBatch(Batch): requests_idx_mapping = {} input_lengths = [] decoder_input_lengths = [] - offsets = [] - token_offsets = [] + prefix_offsets = [] + read_offsets = [] all_decoder_input_ids = [] @@ -185,8 +185,8 @@ class Seq2SeqLMBatch(Batch): requests_idx_mapping[r.id] = i keep_indices.append(idx) - offsets.append(self.offsets[idx]) - token_offsets.append(self.token_offsets[idx]) + prefix_offsets.append(self.prefix_offsets[idx]) + read_offsets.append(self.read_offsets[idx]) all_decoder_input_ids.append(self.all_decoder_input_ids[idx]) @@ -249,8 +249,8 @@ class Seq2SeqLMBatch(Batch): self.all_decoder_input_ids = all_decoder_input_ids self.input_lengths = input_lengths self.decoder_input_lengths = decoder_input_lengths - self.offsets = offsets - self.token_offsets = token_offsets + self.prefix_offsets = prefix_offsets + self.read_offsets = read_offsets self.next_token_choosers = next_token_choosers self.stopping_criterias = stopping_criterias self.max_input_length = max_input_length @@ -284,8 +284,8 @@ class Seq2SeqLMBatch(Batch): all_decoder_input_ids = [] input_lengths = [] decoder_input_lengths = [] - offsets = [] - token_offsets = [] + prefix_offsets = [] + read_offsets = [] next_token_choosers = [] stopping_criterias = [] max_tokens = 0 @@ -307,8 +307,8 @@ class Seq2SeqLMBatch(Batch): all_decoder_input_ids.extend(batch.all_decoder_input_ids) input_lengths.extend(batch.input_lengths) decoder_input_lengths.extend(batch.decoder_input_lengths) - offsets.extend(batch.offsets) - token_offsets.extend(batch.token_offsets) + prefix_offsets.extend(batch.prefix_offsets) + read_offsets.extend(batch.read_offsets) next_token_choosers.extend(batch.next_token_choosers) stopping_criterias.extend(batch.stopping_criterias) @@ -483,8 +483,8 @@ class Seq2SeqLMBatch(Batch): past_key_values=past_key_values, input_lengths=input_lengths, decoder_input_lengths=decoder_input_lengths, - offsets=offsets, - token_offsets=token_offsets, + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, max_input_length=max_input_length, @@ -608,8 +608,8 @@ class Seq2SeqLM(Model): iterator = zip( batch.requests, batch.input_lengths, - batch.offsets, - batch.token_offsets, + batch.prefix_offsets, + batch.read_offsets, batch.decoder_input_lengths, logits, batch.next_token_choosers, @@ -621,8 +621,8 @@ class Seq2SeqLM(Model): for i, ( request, input_length, - offset, - token_offset, + prefix_offset, + read_offset, decoder_input_length, logits, next_token_chooser, @@ -643,8 +643,8 @@ class Seq2SeqLM(Model): # Generated token next_token_logprob = logprobs[-1, next_token_id] next_token_id_squeezed = next_token_id.squeeze() - next_token_text, offset, token_offset = self.decode_token( - all_decoder_input_ids, offset, token_offset + next_token_text, prefix_offset, read_offset = self.decode_token( + all_decoder_input_ids, prefix_offset, read_offset ) # Evaluate stopping criteria @@ -702,8 +702,8 @@ class Seq2SeqLM(Model): batch.all_decoder_input_ids[i] = all_decoder_input_ids batch.input_lengths[i] = input_length batch.decoder_input_lengths[i] = new_decoder_input_length - batch.offsets[i] = offset - batch.token_offsets[i] = token_offset + batch.prefix_offsets[i] = prefix_offset + batch.read_offsets[i] = read_offset batch.max_input_length = max(batch.max_input_length, input_length) batch.max_decoder_input_length = max( batch.max_decoder_input_length, new_decoder_input_length