From b5233f9c3cb5b937ddd5e2a0713b77301fca415a Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 5 Apr 2023 13:47:25 +0200 Subject: [PATCH] better decode --- benchmark/src/generation.rs | 6 ++++-- k6/load_test.js | 9 +++------ .../text_generation_server/models/causal_lm.py | 16 +++++++++++++++- .../models/flash_causal_lm.py | 18 ++++++++++++++++-- .../text_generation_server/models/galactica.py | 3 +++ server/text_generation_server/models/model.py | 17 +++++++++++------ .../models/seq2seq_lm.py | 16 +++++++++++++++- 7 files changed, 67 insertions(+), 18 deletions(-) diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index 3a6316ab..dde429a5 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -75,7 +75,7 @@ async fn generate_runs( // Warmups on batch size for _ in 0..warmups { let (_, decode_batch) = - prefill(sequence.clone(), b, decode_length, &mut client).await?; + prefill(sequence.clone(), sequence_length, b, decode_length, &mut client).await?; let _ = decode(decode_batch, &mut client).await?; // Send warmup message run_sender.send(Ok(Message::Warmup)).await.unwrap_or(()); @@ -83,7 +83,7 @@ async fn generate_runs( for _ in 0..n_runs { let (prefill, decode_batch) = - prefill(sequence.clone(), b, decode_length, &mut client).await?; + prefill(sequence.clone(), sequence_length, b, decode_length, &mut client).await?; // Send prefill message run_sender .send(Ok(Message::Prefill(prefill))) @@ -110,6 +110,7 @@ async fn generate_runs( // Run a prefill step async fn prefill( sequence: String, + sequence_length: u32, batch_size: u32, decode_length: u32, client: &mut ShardedClient, @@ -119,6 +120,7 @@ async fn prefill( .map(|id| Request { id: id.into(), inputs: sequence.clone(), + truncate: sequence_length, parameters: Some(NextTokenChooserParameters { temperature: 1.0, top_k: 0, diff --git a/k6/load_test.js b/k6/load_test.js index 516b5666..6fae74c8 100644 --- a/k6/load_test.js +++ b/k6/load_test.js @@ -7,9 +7,6 @@ export const options = { {duration: '2m', target: 100}, {duration: '1m', target: 0}, ], - hosts: { - 'text-generation-inference.huggingface.co': '127.0.0.1:3000', - }, }; const SLEEP_DURATION = 1; @@ -29,7 +26,7 @@ function greedy_example(inputs, max_new_tokens, name) { name: name } }; - return http.post('http://text-generation-inference.huggingface.co/generate', body, params); + return http.post('https://open-assistant.ngrok.io/generate', body, params); } function sample_example(inputs, max_new_tokens, name) { @@ -50,7 +47,7 @@ function sample_example(inputs, max_new_tokens, name) { name: name } }; - return http.post('http://text-generation-inference.huggingface.co/generate', body, params); + return http.post('https://open-assistant.ngrok.io/generate', body, params); } export default function () { @@ -95,4 +92,4 @@ export default function () { 'is status 200': (r) => r.status === 200, }); sleep(SLEEP_DURATION); -} \ No newline at end of file +} diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 1472c2a6..2e077ca6 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -35,6 +35,7 @@ class CausalLMBatch(Batch): # Lengths of all generations present in the batch input_lengths: List[int] offsets: List[Optional[int]] + token_offsets: List[Optional[int]] # Generation helpers next_token_choosers: List[NextTokenChooser] @@ -66,6 +67,7 @@ class CausalLMBatch(Batch): next_token_choosers = [] stopping_criterias = [] offsets = [] + token_offsets = [] # Parse batch max_truncation = 0 @@ -73,6 +75,7 @@ class CausalLMBatch(Batch): for r in pb.requests: inputs.append(r.inputs) offsets.append(None) + token_offsets.append(None) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer @@ -117,6 +120,7 @@ class CausalLMBatch(Batch): all_input_ids=all_input_ids, input_lengths=input_lengths.tolist(), offsets=offsets, + token_offsets=token_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, size=pb.size, @@ -140,6 +144,7 @@ class CausalLMBatch(Batch): requests = [] input_lengths = [] offsets = [] + token_offsets = [] all_input_ids = [] next_token_choosers = [] stopping_criterias = [] @@ -157,6 +162,7 @@ class CausalLMBatch(Batch): requests.extend(batch.requests) input_lengths.extend(batch.input_lengths) offsets.extend(batch.offsets) + token_offsets.extend(batch.token_offsets) all_input_ids.extend(batch.all_input_ids) next_token_choosers.extend(batch.next_token_choosers) stopping_criterias.extend(batch.stopping_criterias) @@ -271,6 +277,7 @@ class CausalLMBatch(Batch): all_input_ids=all_input_ids, input_lengths=input_lengths, offsets=offsets, + token_offsets=token_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, size=total_batch_size, @@ -358,6 +365,7 @@ class CausalLM(Model): # New values for next forward next_batch_input_lengths = [] next_batch_offsets = [] + next_batch_token_offsets = [] next_batch_input_ids = [] next_batch_all_input_ids = [] @@ -373,6 +381,7 @@ class CausalLM(Model): batch.requests, batch.input_lengths, batch.offsets, + batch.token_offsets, logits, batch.next_token_choosers, batch.stopping_criterias, @@ -384,6 +393,7 @@ class CausalLM(Model): request, input_length, offset, + token_offset, logits, next_token_chooser, stopping_criteria, @@ -401,7 +411,9 @@ class CausalLM(Model): # Generated token next_token_logprob = logprobs[-1, next_token_id] next_token_id_squeezed = next_token_id.squeeze() - next_token_text, offset = self.decode_token(all_input_ids[:, 0], offset) + next_token_text, offset, token_offset = self.decode_token( + all_input_ids[:, 0], offset, token_offset + ) # Evaluate stopping criteria stop, reason = stopping_criteria( @@ -432,6 +444,7 @@ class CausalLM(Model): next_batch_size += 1 next_batch_input_lengths.append(new_input_length) next_batch_offsets.append(offset) + next_batch_token_offsets.append(token_offset) next_batch_max_input_length = max( next_batch_max_input_length, new_input_length ) @@ -516,6 +529,7 @@ class CausalLM(Model): all_input_ids=next_batch_all_input_ids, input_lengths=next_batch_input_lengths, offsets=next_batch_offsets, + token_offsets=next_batch_token_offsets, next_token_choosers=next_batch_next_token_choosers, stopping_criterias=next_batch_stopping_criterias, size=next_batch_size, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 5f0e46da..61ebe3ec 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -45,6 +45,7 @@ class FlashCausalLMBatch(Batch): # Lengths of all generations present in the batch input_lengths: List[int] offsets: List[Optional[int]] + token_offsets: List[Optional[int]] # Generation helpers next_token_choosers: List[NextTokenChooser] @@ -69,6 +70,7 @@ class FlashCausalLMBatch(Batch): input_lengths = [] offsets = [] + token_offsets = [] all_input_ids = [] all_input_ids_tensor = [] @@ -87,6 +89,7 @@ class FlashCausalLMBatch(Batch): max_seqlen = max(max_seqlen, input_length) input_lengths.append(input_length) offsets.append(None) + token_offsets.append(None) all_input_ids.append(tokenized_input) tokenized_input = torch.tensor(tokenized_input, device=device) @@ -124,6 +127,7 @@ class FlashCausalLMBatch(Batch): past_key_values=None, input_lengths=input_lengths, offsets=offsets, + token_offsets=token_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, next_token_choosers=next_token_choosers, @@ -137,6 +141,7 @@ class FlashCausalLMBatch(Batch): requests = [] input_lengths = [] offsets = [] + token_offsets = [] all_input_ids = [] all_input_ids_tensor = [] next_token_choosers = [] @@ -156,6 +161,7 @@ class FlashCausalLMBatch(Batch): requests.extend(batch.requests) input_lengths.extend(batch.input_lengths) offsets.extend(batch.offsets) + token_offsets.extend(batch.token_offsets) all_input_ids.extend(batch.all_input_ids) all_input_ids_tensor.extend(batch.all_input_ids_tensor) next_token_choosers.extend(batch.next_token_choosers) @@ -189,6 +195,7 @@ class FlashCausalLMBatch(Batch): past_key_values=past_key_values, input_lengths=input_lengths, offsets=offsets, + token_offsets=token_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, next_token_choosers=next_token_choosers, @@ -287,6 +294,7 @@ class FlashCausalLM(Model): next_batch_past_key_values = [] next_batch_input_lengths = [] next_batch_offsets = [] + next_batch_token_offsets = [] next_batch_all_input_ids = [] next_batch_all_input_ids_tensor = [] @@ -301,6 +309,7 @@ class FlashCausalLM(Model): batch.requests, batch.input_lengths, batch.offsets, + batch.token_offsets, batch.next_token_choosers, batch.stopping_criterias, batch.all_input_ids, @@ -312,6 +321,7 @@ class FlashCausalLM(Model): request, input_length, offset, + token_offset, next_token_chooser, stopping_criteria, all_input_ids, @@ -344,8 +354,10 @@ class FlashCausalLM(Model): # Generated token next_token_logprob = logprobs[-1, next_token_id_item] - next_token_text, offset = self.decode_token( - all_input_ids[-(stopping_criteria.current_tokens + 1) :], offset + next_token_text, offset, token_offset = self.decode_token( + all_input_ids[-(stopping_criteria.current_tokens + 1) :], + offset, + token_offset, ) # Evaluate stopping criteria @@ -387,6 +399,7 @@ class FlashCausalLM(Model): ) next_batch_input_lengths.append(new_input_length) next_batch_offsets.append(offset) + next_batch_token_offsets.append(token_offset) next_batch_all_input_ids.append(all_input_ids) next_batch_all_input_ids_tensor.append(all_input_ids_tensor) next_batch_max_seqlen = max(next_batch_max_seqlen, new_input_length) @@ -464,6 +477,7 @@ class FlashCausalLM(Model): past_key_values=next_batch_past_key_values, input_lengths=next_batch_input_lengths, offsets=next_batch_offsets, + token_offsets=next_batch_token_offsets, all_input_ids=next_batch_all_input_ids, all_input_ids_tensor=next_batch_all_input_ids_tensor, next_token_choosers=next_batch_next_token_choosers, diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index f1d3e8a6..f1090f63 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -94,6 +94,7 @@ class GalacticaCausalLMBatch(CausalLMBatch): next_token_choosers = [] stopping_criterias = [] offsets = [] + token_offsets = [] # Parse batch max_truncation = 0 @@ -102,6 +103,7 @@ class GalacticaCausalLMBatch(CausalLMBatch): # Add escape_custom_split_sequence to the CausalLMBatch logic inputs.append(escape_custom_split_sequence(r.inputs)) offsets.append(None) + token_offsets.append(None) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer @@ -147,6 +149,7 @@ class GalacticaCausalLMBatch(CausalLMBatch): all_input_ids=all_input_ids, input_lengths=input_lengths, offsets=offsets, + token_offsets=token_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, size=pb.size, diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index cd3ac6a5..6ef0112d 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -25,13 +25,18 @@ class Model(ABC): raise NotImplementedError def decode_token( - self, all_input_ids: List[int], offset: Optional[int] = None - ) -> Tuple[str, Optional[int]]: + self, + all_input_ids: List[int], + offset: Optional[int] = None, + token_offset: Optional[int] = None, + ) -> Tuple[str, Optional[int], Optional[int]]: """Hack to hopefully support generate_stream for the maximum number of tokenizers""" + if token_offset is None: + token_offset = len(all_input_ids) - 5 - # Decode all token minus last one and all tokens + # Decode token_offset token minus last one and token_offset tokens results = self.tokenizer.batch_decode( - [all_input_ids[:-1], all_input_ids], + [all_input_ids[token_offset:-1], all_input_ids[token_offset:]], skip_special_tokens=False, ) @@ -44,6 +49,6 @@ class Model(ABC): # if text is utf-8 if text and text[-1] != "�": - return text, None + return text, None, None else: - return "", offset + return "", offset, token_offset diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 99bfa991..134ea681 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -39,6 +39,7 @@ class Seq2SeqLMBatch(Batch): input_lengths: List[int] decoder_input_lengths: List[int] offsets: List[Optional[int]] + token_offsets: List[Optional[int]] # Generation helpers next_token_choosers: List[NextTokenChooser] @@ -73,6 +74,7 @@ class Seq2SeqLMBatch(Batch): decoder_input_ids = [] decoder_input_lengths = [] offsets = [] + token_offsets = [] # Parse batch max_truncation = 0 @@ -83,6 +85,7 @@ class Seq2SeqLMBatch(Batch): decoder_input_ids.append(tokenizer.bos_token_id) decoder_input_lengths.append(1) offsets.append(None) + token_offsets.append(None) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer @@ -121,6 +124,7 @@ class Seq2SeqLMBatch(Batch): input_lengths=input_lengths.tolist(), decoder_input_lengths=decoder_input_lengths, offsets=offsets, + token_offsets=token_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, size=len(pb.requests), @@ -152,6 +156,7 @@ class Seq2SeqLMBatch(Batch): input_lengths = [] decoder_input_lengths = [] offsets = [] + token_offsets = [] next_token_choosers = [] stopping_criterias = [] @@ -172,6 +177,7 @@ class Seq2SeqLMBatch(Batch): input_lengths.extend(batch.input_lengths) decoder_input_lengths.extend(batch.decoder_input_lengths) offsets.extend(batch.offsets) + token_offsets.extend(batch.token_offsets) next_token_choosers.extend(batch.next_token_choosers) stopping_criterias.extend(batch.stopping_criterias) @@ -310,6 +316,7 @@ class Seq2SeqLMBatch(Batch): input_lengths=input_lengths, decoder_input_lengths=decoder_input_lengths, offsets=offsets, + token_offsets=token_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, size=total_batch_size, @@ -430,6 +437,7 @@ class Seq2SeqLM(Model): # New values for next forward next_batch_input_lengths = [] next_batch_offsets = [] + next_batch_token_offsets = [] next_batch_decoder_input_ids = [] next_batch_decoder_input_lengths = [] @@ -446,6 +454,7 @@ class Seq2SeqLM(Model): batch.requests, batch.input_lengths, batch.offsets, + batch.token_offsets, batch.decoder_input_lengths, logits, batch.next_token_choosers, @@ -458,6 +467,7 @@ class Seq2SeqLM(Model): request, input_length, offset, + token_offset, decoder_input_length, logits, next_token_chooser, @@ -476,7 +486,9 @@ class Seq2SeqLM(Model): # Generated token next_token_logprob = logprobs[-1, next_token_id] next_token_id_squeezed = next_token_id.squeeze() - next_token_text, offset = self.decode_token(decoder_input_ids, offset) + next_token_text, offset, token_offset = self.decode_token( + decoder_input_ids, offset, token_offset + ) # Evaluate stopping criteria stop, reason = stopping_criteria(next_token_id, next_token_text) @@ -504,6 +516,7 @@ class Seq2SeqLM(Model): next_batch_input_lengths.append(input_length) next_batch_decoder_input_lengths.append(new_decoder_input_length) next_batch_offsets.append(offset) + next_batch_token_offsets.append(token_offset) next_batch_max_input_length = max( next_batch_max_input_length, input_length ) @@ -590,6 +603,7 @@ class Seq2SeqLM(Model): input_lengths=next_batch_input_lengths, decoder_input_lengths=next_batch_decoder_input_lengths, offsets=next_batch_offsets, + token_offsets=next_batch_token_offsets, next_token_choosers=next_batch_next_token_choosers, stopping_criterias=next_batch_stopping_criterias, size=next_batch_size,