From 941d36f3fd5cff7aa3dd9692b0d6f6f6c62d5f99 Mon Sep 17 00:00:00 2001 From: Karol Damaszke Date: Tue, 27 Feb 2024 15:46:40 +0100 Subject: [PATCH] Enable deferred token generation (#44) (#75) Co-authored-by: Krzysztof Laskowski --- .../models/causal_lm.py | 198 +++++++++++------- server/text_generation_server/server.py | 10 +- 2 files changed, 125 insertions(+), 83 deletions(-) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index ab6b034d..d60a6144 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -207,6 +207,9 @@ class CausalLMBatch(Batch): input_length: int + logits = None + past = None + def to_pb(self) -> generate_pb2.CachedBatch: return generate_pb2.CachedBatch( id=self.batch_id, @@ -719,8 +722,100 @@ class CausalLM(Model): return outputs.logits, outputs.past_key_values @tracer.start_as_current_span("generate_token") - def generate_token(self, batch: CausalLMBatch) -> Tuple[List[Generation], Optional[CausalLMBatch]]: + def generate_token(self, batches: List[CausalLMBatch]) -> Tuple[List[Generation], Optional[CausalLMBatch]]: + # Results + generations: List[Generation] = [] + prev_batches = [] + requests_to_generate = [] + + # In order to pipeline any actions on CPU we perform the operation in 3 main stages: + # Stage 1. Collect next token ids of any previously started generations + for batch_id, batch in enumerate(batches): + if batch.logits is not None: + logits = batch.logits + past = batch.past + prefill = batch.past_key_values is None + if self.is_optimized_for_gaudi: + if prefill: + # no right padding for prefill + token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device) + else: + token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device) + else: + token_idx = None + + # Select next token + input_length = batch.input_length + if self.is_optimized_for_gaudi and logits.shape[-2] > 1: + next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser( + batch.input_ids[:, :token_idx], logits[:, input_length - 1: input_length, :].squeeze(-2) + ) + else: + next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser( + batch.input_ids[:, :token_idx], logits.squeeze(-2) + ) + batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( + batch.top_n_tokens, + batch.top_n_tokens_tensor, + logprobs, + ) + + prev_batches.append({ + 'next_token_ids': next_token_ids, + 'next_token_logprobs': next_token_logprobs, + }) + + for req_idx, req in enumerate(batch.requests): + requests_to_generate.append({ + 'req': req, + 'prev_req_idx': req.idx, + 'batch_id': batch_id, + 'seed': batch.next_token_chooser.seeds[req_idx], + 'do_sample': batch.next_token_chooser.do_sample[req_idx], + 'top_n_tokens': batch.top_n_tokens[req_idx], + 'top_token_ids': batch_top_token_ids[req_idx], + 'top_token_logprobs': batch_top_token_logprobs[req_idx], + }) + + htorch.core.mark_step() + + if token_idx is None: + batch.input_ids[:, 0] = next_token_ids[:, 0] + else: + batch.input_ids.index_copy_(1, token_idx.cpu(), next_token_ids.unsqueeze(1)) + + # Slice unused values from prefill, use it to store next token + if token_idx is None: + batch.input_ids = batch.input_ids[:, :1] + + # Update attention_mask as we added a new token to input_ids + if self.is_optimized_for_gaudi: + batch.attention_mask.index_fill_(1, token_idx, 1) + else: + batch.attention_mask[:, -batch.padding_right_offset] = 1 + + # Adjust lengths + batch.input_length += 1 + + # Update position_ids + if prefill: + batch.position_ids = batch.position_ids[:, token_idx - 1: token_idx] + 1 + else: + batch.position_ids += 1 + # Update past key values + if prefill: + batch.past_key_values = past + + htorch.core.mark_step() + + # Stage 2. Prepare new batch for speculative scheduling + if len(batches) > 1: + batch = self.batch_type.concatenate(batches, self.tokenizer.pad_token_id) + else: + batch = batches[0] + prefill = batch.past_key_values is None + # Check if we need to do any bookkeeping first if not prefill: batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id) @@ -729,10 +824,6 @@ class CausalLM(Model): dbg_trace( scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}') assert batch.right_padding > 0, 'No more room for next token!' - self.step = self.step + 1 - if self.hb_profer_started == True and self.step > self.profiling_warmup_steps + self.profiling_steps: - self.hb_profer.stop() - self.hb_profer_started = False if self.is_optimized_for_gaudi: if prefill: @@ -753,7 +844,7 @@ class CausalLM(Model): input_ids = batch.input_ids if prefill: - logits, past = self.forward( + batch.logits, batch.past = self.forward( input_ids, attention_mask, batch.position_ids, @@ -762,7 +853,7 @@ class CausalLM(Model): bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None ) else: - logits = self.forward( + batch.logits = self.forward( input_ids, attention_mask, batch.position_ids, @@ -771,46 +862,36 @@ class CausalLM(Model): bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None ) - # Results - generations: List[Generation] = [] - stopped = True - - # Select next token - input_length = batch.input_length - if self.is_optimized_for_gaudi and logits.shape[-2] > 1: - next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser( - batch.input_ids[:, :token_idx], logits[:, input_length - 1: input_length, :].squeeze(-2) - ) - else: - next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser( - batch.input_ids[:, :token_idx], logits.squeeze(-2) - ) - - batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( - batch.top_n_tokens, - batch.top_n_tokens_tensor, - logprobs, - ) - - next_token_logprobs = next_token_logprobs.tolist() - next_token_ids_cpu = next_token_ids.cpu() htorch.core.mark_step() - for req_idx, req in enumerate(batch.requests): - i = req.idx + # Stage 3. Finish and return previous generations + stopped = len(requests_to_generate) > 0 + for prev_batch in prev_batches: + prev_batch['next_token_logprobs'] = prev_batch['next_token_logprobs'].tolist() + prev_batch['next_token_ids_cpu'] = prev_batch['next_token_ids'].cpu() + htorch.core.mark_step() + + for req_data in requests_to_generate: + req = req_data['req'] + i = req_data['prev_req_idx'] + prev_batch_id = req_data['batch_id'] + assert len(prev_batches) > prev_batch_id + next_token_ids_cpu = prev_batches[prev_batch_id]['next_token_ids_cpu'] + next_token_logprobs = prev_batches[prev_batch_id]['next_token_logprobs'] + request = req.data input_length = req.input_length prefix_offset = req.prefix_offset read_offset = req.read_offset - do_sample = batch.next_token_chooser.do_sample[req_idx] - seed = batch.next_token_chooser.seeds[req_idx] + do_sample = req_data['do_sample'] + seed = req_data['seed'] stopping_criteria = req.stopping_criteria all_input_ids = req.all_input_ids - top_n_tokens = batch.top_n_tokens[req_idx] next_token_id = next_token_ids_cpu[i] next_token_logprob = next_token_logprobs[i] - top_token_ids = batch_top_token_ids[req_idx] - top_token_logprobs = batch_top_token_logprobs[req_idx] + top_n_tokens = req_data['top_n_tokens'] + top_token_ids = req_data['top_token_ids'] + top_token_logprobs = req_data['top_token_logprobs'] # Append next token to all tokens if self.is_optimized_for_gaudi: @@ -899,42 +980,9 @@ class CausalLM(Model): req.read_offset = read_offset htorch.core.mark_step() - if token_idx is None: - batch.input_ids[:, 0] = next_token_ids[:, 0] - else: - batch.input_ids.index_copy_(1, token_idx.cpu(), next_token_ids.unsqueeze(1)) + self.step = self.step + 1 + if self.hb_profer_started == True and self.step > self.profiling_warmup_steps + self.profiling_steps: + self.hb_profer.stop() + self.hb_profer_started = False - # We finished all generations in the batch; there is no next batch - if stopped: - if self.hb_profer_started == True: - self.hb_profer.step() - htorch.core.mark_step() - return generations, None - - # Slice unused values from prefill, use it to store next token - if token_idx is None: - batch.input_ids = batch.input_ids[:, :1] - - # Update attention_mask as we added a new token to input_ids - if self.is_optimized_for_gaudi: - batch.attention_mask.index_fill_(1, token_idx, 1) - else: - batch.attention_mask[:, -batch.padding_right_offset] = 1 - - # Adjust lengths - batch.input_length += 1 - - # Update position_ids - if prefill: - batch.position_ids = batch.position_ids[:, token_idx - 1: token_idx] + 1 - else: - batch.position_ids += 1 - # Update past key values - if prefill: - batch.past_key_values = past - - if self.hb_profer_started == True: - self.hb_profer.step() - htorch.core.mark_step() - - return generations, batch + return generations, batch if not stopped else None \ No newline at end of file diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 2f41ec94..dd64a93d 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -84,7 +84,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): with self.profiler.record_event("external", "prefill", {"batch_size": batch.input_ids.size(0)}): with self.profiler.record_event("internal", "generate_token"): - generations, next_batch = self.model.generate_token(batch) + generations, next_batch = self.model.generate_token([batch]) self.cache.set(next_batch) return generate_pb2.PrefillResponse( @@ -111,14 +111,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): if len(batches) == 0: raise ValueError("All batches are empty") - if len(batches) > 1: - with self.profiler.record_event("internal", "concatenate"): - batch = self.model.batch_type.concatenate(batches, self.model.tokenizer.pad_token_id) - else: - batch = batches[0] - with self.profiler.record_event("internal", "generate_token"): - generations, next_batch = self.model.generate_token(batch) + generations, next_batch = self.model.generate_token(batches) self.cache.set(next_batch) return generate_pb2.DecodeResponse(