From b52164d38ae164548e7dccc8aac5aaa214755f22 Mon Sep 17 00:00:00 2001 From: kaixuanliu Date: Thu, 30 Jan 2025 17:19:13 +0800 Subject: [PATCH] Complete padding of `CausalLMBatch` when there exists batch bucketing (#261) Signed-off-by: kaixuanliu --- .../models/causal_lm.py | 406 ++++++++++++------ 1 file changed, 267 insertions(+), 139 deletions(-) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 44662388..5b105224 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -53,21 +53,26 @@ from text_generation_server.utils.debug import dbg_trace from text_generation_server.utils.speculate import get_speculate tracer = trace.get_tracer(__name__) -MAX_TOTAL_TOKENS = int(os.environ.get('MAX_TOTAL_TOKENS', 2048)) -PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 256)) +MAX_TOTAL_TOKENS = int(os.environ.get("MAX_TOTAL_TOKENS", 2048)) +PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 256)) CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] -LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1)) -BATCH_BUCKET_SIZE= int(os.environ.get('BATCH_BUCKET_SIZE', 8)) -PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 2)) +LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1)) +BATCH_BUCKET_SIZE = int(os.environ.get("BATCH_BUCKET_SIZE", 8)) +PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get("PREFILL_BATCH_BUCKET_SIZE", 2)) + def torch_compile_for_eager(func): if LAZY_MODE == 1: return func - return torch.compile(func, backend="hpu_backend", options={"keep_input_mutations": True}) + return torch.compile( + func, backend="hpu_backend", options={"keep_input_mutations": True} + ) + def round_up(number, k): return (number + k - 1) // k * k + def to_tensor_indices(indices, device): return torch.tensor(indices, dtype=torch.long, device=device) @@ -96,9 +101,11 @@ def grouped_pad(tensor_groups, dims, values): for tensors, dim, value in zip(tensor_groups, dims, values): padding = MAX_TOTAL_TOKENS - tensors[0].size(dim) if dim is not None else 0 if padding > 0: - assert dim in [-1, -2], f'Only dims -1 and -2 are supported! {dim}' + assert dim in [-1, -2], f"Only dims -1 and -2 are supported! {dim}" pad_shape = (0, 0, 0, padding) if dim == -2 else (0, padding) - result = [torch.nn.functional.pad(t, pad_shape, value=value) for t in tensors] + result = [ + torch.nn.functional.pad(t, pad_shape, value=value) for t in tensors + ] else: result = [t for t in tensors] grouped_result.append(result) @@ -117,7 +124,10 @@ def roll(tensor, chunk, dim, merge_graphs): def grouped_roll(tensor_groups, chunk, dims, merge_graphs): - tensor_groups = [[roll(t, chunk, dim, merge_graphs) for t in tensors] for tensors, dim in zip(tensor_groups, dims)] + tensor_groups = [ + [roll(t, chunk, dim, merge_graphs) for t in tensors] + for tensors, dim in zip(tensor_groups, dims) + ] if merge_graphs: htorch.core.mark_step() return tensor_groups @@ -167,7 +177,10 @@ def extend_batch(tensors, target_bs, dim): def grouped_extend_batch(tensor_groups, target_bs, bs_dims): - tensor_groups = [extend_batch(tensors, target_bs, dim) for tensors, dim in zip(tensor_groups, bs_dims)] + tensor_groups = [ + extend_batch(tensors, target_bs, dim) + for tensors, dim in zip(tensor_groups, bs_dims) + ] return tensor_groups @@ -220,15 +233,20 @@ class CausalLMRequest: all_input_ids: torch.Tensor @classmethod - def from_pb(cls, idx: int, data: generate_pb2.Request, tokenizer: PreTrainedTokenizerBase): + def from_pb( + cls, idx: int, data: generate_pb2.Request, tokenizer: PreTrainedTokenizerBase + ): return cls( idx=idx, data=data, input_length=None, prefix_offset=None, read_offset=None, - stopping_criteria=StoppingCriteria.from_pb(data.stopping_parameters, tokenizer), - all_input_ids=None,) + stopping_criteria=StoppingCriteria.from_pb( + data.stopping_parameters, tokenizer + ), + all_input_ids=None, + ) def update_idx(self, new_idx): prev = self.idx @@ -289,7 +307,11 @@ class CausalLMBatch(Batch): # Very simple heuristic to determine whether we should merge tensors # this needs tuning for other models/scenarios small_bs = len(self.past_key_values) > self.batch_size - if not self.merged_kv_cache and small_bs and (pad_needed or shift_needed or expand_needed): + if ( + not self.merged_kv_cache + and small_bs + and (pad_needed or shift_needed or expand_needed) + ): past_keys, past_values = self.detach_kv_cache() past_keys = merge(past_keys) past_values = merge(past_values) @@ -309,7 +331,13 @@ class CausalLMBatch(Batch): seq_dim = -1 key_dim = -2 if self.keys_head_dim_last else -1 value_dim = -2 - tensors = [[self.input_ids], [self.attention_mask], [self.position_ids], past_keys, past_values] + tensors = [ + [self.input_ids], + [self.attention_mask], + [self.position_ids], + past_keys, + past_values, + ] # We don't need to align position_ids seq_dims = [seq_dim, seq_dim, None, key_dim, value_dim] bs_dims = [0, 0, 0] + ([1, 1] if self.merged_kv_cache else [0, 0]) @@ -350,13 +378,17 @@ class CausalLMBatch(Batch): dst_tensors, _, dst_dims = self.get_tensor_groups() free_indices_gen = self.free_indices_generator() for src_b in src_batches: - dst_indices = to_tensor_indices(src_b.update_indices(free_indices_gen), self.input_ids.device) + dst_indices = to_tensor_indices( + src_b.update_indices(free_indices_gen), self.input_ids.device + ) src_tensors, _, src_dims = src_b.get_tensor_groups() grouped_move(dst_tensors, dst_indices, src_tensors) self.set_tensor_groups(dst_tensors) @classmethod - def recombine(cls, batches: List["CausalLMBatch"], pad_token_id: int) -> "CausalLMBatch": + def recombine( + cls, batches: List["CausalLMBatch"], pad_token_id: int + ) -> "CausalLMBatch": if not all(b.past_key_values is not None for b in batches): raise ValueError("KV cache not allocated! Cannot recombine before prefill!") @@ -375,31 +407,39 @@ class CausalLMBatch(Batch): # For prefill there is a space allocated only for first token # Need to add padding to the max total tokens before first decode - moves_needed = [total_requests - len(b) if b.batch_size == new_bs else total_requests for b in batches] + moves_needed = [ + total_requests - len(b) if b.batch_size == new_bs else total_requests + for b in batches + ] dst_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0] - reshape = (batches[dst_batch_idx].batch_size < new_bs) + reshape = batches[dst_batch_idx].batch_size < new_bs # TODO: Add support for changing max seq len, i.e. due to output length bucketing # FIXME: max_seq_len for non optimized code if len(batches) > 1: - scenario = 'CONCAT' + scenario = "CONCAT" elif reshape: - scenario = 'RESHAPE' + scenario = "RESHAPE" elif cur_padding[dst_batch_idx] <= 0: - scenario = 'SHIFT' - offsets = [biggest_single_chunk(b.max_input_length - max_input_length) for b in batches] + scenario = "SHIFT" + offsets = [ + biggest_single_chunk(b.max_input_length - max_input_length) + for b in batches + ] max_input_length = max_input_length + offsets[dst_batch_idx] else: # Nothing to do return batches[0] dbg_trace( - scenario, f'bs:{[b.batch_size for b in batches]}->{new_bs}' - f' reqs:{[len(b) for b in batches]}' - f' offsets:{offsets}' - f' input_lengths:{input_lengths}' - f' cur_padding:{cur_padding}' - f' dst_batch:{dst_batch_idx}') + scenario, + f"bs:{[b.batch_size for b in batches]}->{new_bs}" + f" reqs:{[len(b) for b in batches]}" + f" offsets:{offsets}" + f" input_lengths:{input_lengths}" + f" cur_padding:{cur_padding}" + f" dst_batch:{dst_batch_idx}", + ) grouped_requests = [[req for req in batch.requests] for batch in batches] flat_requests = list(itertools.chain(*grouped_requests)) @@ -410,10 +450,15 @@ class CausalLMBatch(Batch): batches[i].realign(target_bs, offsets[i], pad_token_id) batches[i].split_kv_cache_if_needed(i == dst_batch_idx) batches[dst_batch_idx].expand_bs(new_bs) - batches[dst_batch_idx].move_data([batches[i] for i in range(len(batches)) if i != dst_batch_idx]) + batches[dst_batch_idx].move_data( + [batches[i] for i in range(len(batches)) if i != dst_batch_idx] + ) top_n_tokens = [r.data.top_n_tokens for r in flat_requests] - top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) + top_n_tokens.extend([-1] * (new_bs - total_requests)) + top_n_tokens_tensor = torch.tensor( + top_n_tokens, device=device, dtype=torch.int64 + ) parameters = [r.data.parameters for r in flat_requests] # append the dummy parameters for dummy requests @@ -424,7 +469,9 @@ class CausalLMBatch(Batch): fsm_grammar_states = [0] * batch_size for batch in batches: for i, req in enumerate(batch.requests): - fsm_grammar_states[req.idx] = batch.next_token_chooser.fsm_grammar_states[i] + fsm_grammar_states[req.idx] = ( + batch.next_token_chooser.fsm_grammar_states[i] + ) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( parameters, @@ -465,8 +512,11 @@ class CausalLMBatch(Batch): dtype: torch.dtype, device: torch.device, ) -> "CausalLMBatch": - dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}') - requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)] + dbg_trace("FROM_PB", f"num_reqs:{len(pb.requests)}") + requests = [ + CausalLMRequest.from_pb(idx, req, tokenizer) + for idx, req in enumerate(pb.requests) + ] inputs = [] top_n_tokens = [] @@ -476,10 +526,10 @@ class CausalLMBatch(Batch): inputs.append(concat_text_chunks(r.input_chunks.chunks)) top_n_tokens.append(r.top_n_tokens) max_truncation = max(max_truncation, r.truncate) - + max_input_length = max_truncation if max_input_length < PAD_SEQUENCE_TO_MULTIPLE_OF: - max_input_length = PAD_SEQUENCE_TO_MULTIPLE_OF + max_input_length = PAD_SEQUENCE_TO_MULTIPLE_OF max_new_tokens = max(r.stopping_criteria.max_new_tokens for r in requests) # TODO: by tokenizing all inputs at once we loose information on actual input lengths @@ -501,7 +551,7 @@ class CausalLMBatch(Batch): ) tokenized_inputs = tokenizer( - inputs+dummy_inputs, + inputs + dummy_inputs, return_tensors="pt", padding="longest", return_token_type_ids=False, @@ -514,7 +564,9 @@ class CausalLMBatch(Batch): bucket_size = max_input_length left_padding = max_input_length - input_len if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0: - assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length" + assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, ( + "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length" + ) rounded_seq_len = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF) if rounded_seq_len <= max_input_length: bucket_size = rounded_seq_len - 1 @@ -547,7 +599,8 @@ class CausalLMBatch(Batch): position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) - + old_bs = len(requests) + top_n_tokens.extend([-1] * (new_bs - old_bs)) top_n_tokens_tensor = torch.tensor( top_n_tokens, device=device, dtype=torch.int64 ) @@ -568,14 +621,16 @@ class CausalLMBatch(Batch): @tracer.start_as_current_span("filter") def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]: - dbg_trace('FILTER', f'num_reqs:{len(self.requests)} -> {len(request_ids)}') + dbg_trace("FILTER", f"num_reqs:{len(self.requests)} -> {len(request_ids)}") request_ids = set(request_ids) self.requests = [req for req in self.requests if req.data.id in request_ids] return self @classmethod @tracer.start_as_current_span("concatenate") - def concatenate(cls, batches: List["CausalLMBatch"], pad_token_id: int = 0) -> "CausalLMBatch": + def concatenate( + cls, batches: List["CausalLMBatch"], pad_token_id: int = 0 + ) -> "CausalLMBatch": return cls.recombine(batches, pad_token_id) def __len__(self): @@ -618,9 +673,7 @@ class CausalLM(Model): tokenizer_class=AutoTokenizer, config_class=AutoConfig, batch_class=CausalLMBatch, - ): - if speculator: raise RuntimeError("Speculator decoding is not enabled for AutoModel") @@ -646,18 +699,14 @@ class CausalLM(Model): htorch.core.hpu_set_env() if world_size > 1: - model = self.get_deepspeed_model( - model_id, dtype, revision - ) + model = self.get_deepspeed_model(model_id, dtype, revision) model = hq_env.prepare_model_for_quantization(model) else: get_repo_root(model_id) # Check support for rope scaling model_kwargs = {} - config = AutoConfig.from_pretrained( - model_id - ) + config = AutoConfig.from_pretrained(model_id) if hasattr(config, "rope_scaling"): model_kwargs["rope_scaling"] = self.get_rope_scaling() @@ -666,26 +715,34 @@ class CausalLM(Model): revision=revision, torch_dtype=dtype, trust_remote_code=trust_remote_code, - **model_kwargs + **model_kwargs, ) model = hq_env.prepare_model_for_quantization(model) model = model.eval().to(device) - self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1 + self.enable_hpu_graph = ( + os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1 + ) self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true" - if model.config.model_type not in ["gpt_bigcode"]: # gpt_bigcode/starcoderbase-3b skips remove_kv_cache_from_output() + if model.config.model_type not in [ + "gpt_bigcode" + ]: # gpt_bigcode/starcoderbase-3b skips remove_kv_cache_from_output() model = remove_kv_cache_from_output(model) if self.enable_hpu_graph: from habana_frameworks.torch.hpu import wrap_in_hpu_graph + model = wrap_in_hpu_graph(model, disable_tensor_cache=True) else: if LAZY_MODE == 0: # It is said that "keep_input_mutations" is safe for inference to be done - dbg_trace( - "TORCH COMPILE", f'Torch compiling of model') - model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True}) + dbg_trace("TORCH COMPILE", f"Torch compiling of model") + model.model = torch.compile( + model.model, + backend="hpu_backend", + options={"keep_input_mutations": True}, + ) model = hq_env.setup_quantization(model) @@ -714,8 +771,14 @@ class CausalLM(Model): "return_dict": True, } - - if model.config.model_type in ["llama", "mistral", "starcoder2", "qwen2", "falcon", "gpt_bigcode"]: + if model.config.model_type in [ + "llama", + "mistral", + "starcoder2", + "qwen2", + "falcon", + "gpt_bigcode", + ]: if model.config.model_type not in ["falcon", "gpt_bigcode"]: self.kwargs["attn_softmax_bf16"] = True @@ -740,11 +803,15 @@ class CausalLM(Model): ) # Create profiler - ranks_to_profile = [int(val) for val in os.getenv("PROF_RANKS", "0").split(',')] + ranks_to_profile = [int(val) for val in os.getenv("PROF_RANKS", "0").split(",")] record_shapes = os.getenv("PROF_RECORD_SHAPES", "false").lower() == "true" output_dir = os.getenv("PROF_PATH", "/tmp/hpu_profile") - self.profiling_warmup_steps = int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in ranks_to_profile else 0 - self.profiling_steps = int(os.getenv("PROF_STEP", "0")) if rank in ranks_to_profile else 0 + self.profiling_warmup_steps = ( + int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in ranks_to_profile else 0 + ) + self.profiling_steps = ( + int(os.getenv("PROF_STEP", "0")) if rank in ranks_to_profile else 0 + ) self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0")) if self.profiling_steps > 0: self.hb_profiler = HabanaProfile( @@ -752,7 +819,7 @@ class CausalLM(Model): warmup=self.profiling_warmup_steps, active=self.profiling_steps, output_dir=output_dir, - record_shapes=record_shapes + record_shapes=record_shapes, ) self.hb_profiler.start() else: @@ -760,23 +827,20 @@ class CausalLM(Model): self.step = 0 def get_deepspeed_model( - self, - model_id: str, - dtype: torch.dtype, - revision: Optional[str] = None + self, model_id: str, dtype: torch.dtype, revision: Optional[str] = None ) -> torch.nn.Module: import deepspeed from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu world_size, rank, local_rank = initialize_distributed_hpu() - model_kwargs = { - "revision": revision - } + model_kwargs = {"revision": revision} # Initialize process(es) for DeepSpeed deepspeed.init_distributed(dist_backend="hccl") logger.info( - "DeepSpeed is enabled. world_size {} rank {} local_rank {}".format(world_size, rank, local_rank) + "DeepSpeed is enabled. world_size {} rank {} local_rank {}".format( + world_size, rank, local_rank + ) ) config = AutoConfig.from_pretrained(model_id, **model_kwargs) load_to_meta = model_on_meta(config) @@ -794,7 +858,9 @@ class CausalLM(Model): get_repo_root(model_id, local_rank=os.getenv("LOCAL_RANK")) # TODO: revisit placement on CPU when auto-injection is possible with deepspeed.OnDevice(dtype=dtype, device="cpu"): - model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, **model_kwargs) + model = AutoModelForCausalLM.from_pretrained( + model_id, torch_dtype=dtype, **model_kwargs + ) model = model.eval() # Initialize the model @@ -817,16 +883,16 @@ class CausalLM(Model): return None rope_factor = float(os.getenv("ROPE_FACTOR", 1.0)) - return { - 'type': rope_scaling, 'factor': float(rope_factor) - } + return {"type": rope_scaling, "factor": float(rope_factor)} @property def batch_type(self) -> Type[CausalLMBatch]: return CausalLMBatch def decode(self, generated_ids: List[int]) -> str: - return self.tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) + return self.tokenizer.decode( + generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) def decode_token( self, @@ -835,7 +901,9 @@ class CausalLM(Model): read_offset: int = 0, ) -> Tuple[str, int, int]: if is_tokenizer_transparent(self.tokenizer): - new_text = self.tokenizer.decode(all_input_ids[read_offset:], skip_special_tokens=False) + new_text = self.tokenizer.decode( + all_input_ids[read_offset:], skip_special_tokens=False + ) return new_text, read_offset, len(all_input_ids) else: return super().decode_token(all_input_ids, prefix_offset, read_offset) @@ -858,7 +926,7 @@ class CausalLM(Model): } # Optimum Habana got "lazy_mode" key-val only supported for llama type of models - if self.model.config.model_type == "llama" : + if self.model.config.model_type == "llama": kwargs["lazy_mode"] = LAZY_MODE == 1 if self.has_position_ids: @@ -869,7 +937,9 @@ class CausalLM(Model): kwargs.update(self.kwargs) - if past_key_values is not None and self.model.config.model_type not in ["gpt_bigcode"]: + if past_key_values is not None and self.model.config.model_type not in [ + "gpt_bigcode" + ]: return self.model.forward(**kwargs) else: outputs = self.model.forward(**kwargs) @@ -896,18 +966,26 @@ class CausalLM(Model): token_idx_scalar = batch.attention_mask.shape[-1] - 1 token_idx = torch.tensor(token_idx_scalar).to(self.device) else: - token_idx_scalar = batch.attention_mask.shape[-1] - batch.right_padding + token_idx_scalar = ( + batch.attention_mask.shape[-1] - batch.right_padding + ) token_idx = torch.tensor(token_idx_scalar).to(self.device) # Select next token input_length = batch.input_length if logits.shape[-2] > 1: - next_token_ids, next_token_logprobs, logprobs, _, _ = batch.next_token_chooser( - batch.input_ids, logits[:, input_length - 1: input_length, :].squeeze(-2), self.speculate + next_token_ids, next_token_logprobs, logprobs, _, _ = ( + batch.next_token_chooser( + batch.input_ids, + logits[:, input_length - 1 : input_length, :].squeeze(-2), + self.speculate, + ) ) else: - next_token_ids, next_token_logprobs, logprobs, _, _ = batch.next_token_chooser( - batch.input_ids, logits.squeeze(-2), self.speculate + next_token_ids, next_token_logprobs, logprobs, _, _ = ( + batch.next_token_chooser( + batch.input_ids, logits.squeeze(-2), self.speculate + ) ) # Speculation is not active for causal accepted_ids = torch.ones_like(batch.input_ids)[:, 0] @@ -918,24 +996,29 @@ class CausalLM(Model): accepted_ids, ) - prev_batches.append({ - 'next_token_ids': next_token_ids, - 'next_token_logprobs': next_token_logprobs, - }) + prev_batches.append( + { + "next_token_ids": next_token_ids, + "next_token_logprobs": next_token_logprobs, + } + ) for req_idx, req in enumerate(batch.requests): - requests_to_generate.append({ - 'req': req, - 'prev_req_idx': req.idx, - 'batch_id': batch_id, - 'seed': batch.next_token_chooser.seeds[req_idx], - 'do_sample': batch.next_token_chooser.do_sample[req_idx], - 'top_n_tokens': batch.top_n_tokens[req_idx], - 'top_token_ids': batch_top_token_ids[req_idx], - 'top_token_logprobs': batch_top_token_logprobs[req_idx], - 'grammar_state': batch.next_token_chooser.fsm_grammar_states[req.idx], - - }) + requests_to_generate.append( + { + "req": req, + "prev_req_idx": req.idx, + "batch_id": batch_id, + "seed": batch.next_token_chooser.seeds[req_idx], + "do_sample": batch.next_token_chooser.do_sample[req_idx], + "top_n_tokens": batch.top_n_tokens[req_idx], + "top_token_ids": batch_top_token_ids[req_idx], + "top_token_logprobs": batch_top_token_logprobs[req_idx], + "grammar_state": batch.next_token_chooser.fsm_grammar_states[ + req.idx + ], + } + ) htorch.core.mark_step() @@ -950,7 +1033,9 @@ class CausalLM(Model): # Update position_ids if prefill: - batch.position_ids = torch.index_select(batch.position_ids, 1, token_idx - 1) + 1 + batch.position_ids = ( + torch.index_select(batch.position_ids, 1, token_idx - 1) + 1 + ) else: batch.position_ids += 1 # Update past key values @@ -971,13 +1056,19 @@ class CausalLM(Model): if not prefill: batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id) - scenario = 'PREFILL' if prefill else 'GENERATE' - if self.enable_hpu_graph and self.limit_hpu_graph and round_up(batch.batch_size, BATCH_BUCKET_SIZE) != self.prev_bs: + scenario = "PREFILL" if prefill else "GENERATE" + if ( + self.enable_hpu_graph + and self.limit_hpu_graph + and round_up(batch.batch_size, BATCH_BUCKET_SIZE) != self.prev_bs + ): self.model.clear_cache() self.prev_bs = round_up(batch.batch_size, BATCH_BUCKET_SIZE) dbg_trace( - scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}') - assert batch.right_padding > 0, 'No more room for next token!' + scenario, + f"bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}", + ) + assert batch.right_padding > 0, "No more room for next token!" # Execute batch if prefill: @@ -989,14 +1080,18 @@ class CausalLM(Model): batch.position_ids, token_idx, batch.past_key_values, - bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None, + bypass_hpu_graph=prefill and self.limit_hpu_graph + if self.enable_hpu_graph + else None, ) elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]): # Don't schedule next forward if max_new_tokens for all requests equals 1 # - we've already generated the first and only needed token in the prefill phase pass else: - token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device) + token_idx = torch.tensor( + batch.attention_mask.shape[-1] - batch.right_padding + ).to(self.device) input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1) logits = self.forward( input_ids, @@ -1004,7 +1099,9 @@ class CausalLM(Model): batch.position_ids, token_idx, batch.past_key_values, - bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None, + bypass_hpu_graph=prefill and self.limit_hpu_graph + if self.enable_hpu_graph + else None, ) if self.model.config.model_type in ["gpt_bigcode"]: batch.logits, batch.past = logits @@ -1018,40 +1115,45 @@ class CausalLM(Model): # Stage 3. Finish and return previous generations stopped = len(requests_to_generate) > 0 for prev_batch in prev_batches: - prev_batch['next_token_logprobs'] = prev_batch['next_token_logprobs'].tolist() - prev_batch['next_token_ids_cpu'] = prev_batch['next_token_ids'].cpu() + prev_batch["next_token_logprobs"] = prev_batch[ + "next_token_logprobs" + ].tolist() + prev_batch["next_token_ids_cpu"] = prev_batch["next_token_ids"].cpu() htorch.core.mark_step() for req_data in requests_to_generate: - req = req_data['req'] - i = req_data['prev_req_idx'] - prev_batch_id = req_data['batch_id'] + req = req_data["req"] + i = req_data["prev_req_idx"] + prev_batch_id = req_data["batch_id"] assert len(prev_batches) > prev_batch_id - next_token_ids_cpu = prev_batches[prev_batch_id]['next_token_ids_cpu'] - next_token_logprobs = prev_batches[prev_batch_id]['next_token_logprobs'] + next_token_ids_cpu = prev_batches[prev_batch_id]["next_token_ids_cpu"] + next_token_logprobs = prev_batches[prev_batch_id]["next_token_logprobs"] request = req.data input_length = req.input_length prefix_offset = req.prefix_offset read_offset = req.read_offset - do_sample = req_data['do_sample'] - seed = req_data['seed'] + do_sample = req_data["do_sample"] + seed = req_data["seed"] stopping_criteria = req.stopping_criteria all_input_ids = req.all_input_ids next_token_id = next_token_ids_cpu[i] next_token_logprob = next_token_logprobs[i] - top_n_tokens = req_data['top_n_tokens'] - top_token_ids = req_data['top_token_ids'] - top_token_logprobs = req_data['top_token_logprobs'] - grammar_state = req_data['grammar_state'] + top_n_tokens = req_data["top_n_tokens"] + top_token_ids = req_data["top_token_ids"] + top_token_logprobs = req_data["top_token_logprobs"] + grammar_state = req_data["grammar_state"] # Append next token to all tokens all_input_ids[input_length] = next_token_id new_input_length = input_length + 1 # Generated token - if is_tokenizer_transparent(self.tokenizer) and len(stopping_criteria.stop_sequence_criterias) == 0: - next_token_text = '' + if ( + is_tokenizer_transparent(self.tokenizer) + and len(stopping_criteria.stop_sequence_criterias) == 0 + ): + next_token_text = "" else: next_token_text, prefix_offset, read_offset = self.decode_token( all_input_ids[0:new_input_length, 0], prefix_offset, read_offset @@ -1075,7 +1177,11 @@ class CausalLM(Model): output_text = None else: output_text = self.decode( - all_input_ids[new_input_length - stopping_criteria.current_tokens: new_input_length, 0] + all_input_ids[ + new_input_length + - stopping_criteria.current_tokens : new_input_length, + 0, + ] ) generated_text = GeneratedText( output_text, @@ -1090,7 +1196,7 @@ class CausalLM(Model): if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: # Remove generated token to only have prefill and add nan for first prompt token prefill_logprobs = [float("nan")] + next_token_logprobs - prefill_token_ids = all_input_ids[0: new_input_length - 1] + prefill_token_ids = all_input_ids[0 : new_input_length - 1] prefill_texts = self.tokenizer.batch_decode( prefill_token_ids, clean_up_tokenization_spaces=False, @@ -1159,7 +1265,12 @@ class CausalLM(Model): htorch.core.mark_step() self.step = self.step + 1 if self.hb_profiler is not None: - if self.step > self.profiling_wait_steps + self.profiling_warmup_steps + self.profiling_steps: + if ( + self.step + > self.profiling_wait_steps + + self.profiling_warmup_steps + + self.profiling_steps + ): self.hb_profiler.stop() else: self.hb_profiler.step() @@ -1178,11 +1289,12 @@ class CausalLM(Model): return self.batch_type.from_pb(batch, self.tokenizer, self.dtype, self.device) - def warmup(self, request) -> None: MAX_TOTAL_TOKENS = request.max_total_tokens MAX_BATCH_TOTAL_TOKENS = request.max_batch_total_tokens - batch = self.batch_type.from_pb(request.batch, self.tokenizer, self.dtype, self.device) + batch = self.batch_type.from_pb( + request.batch, self.tokenizer, self.dtype, self.device + ) max_prefill_batch_size = batch.input_ids.shape[0] try: # max prefill batch size warmup @@ -1199,14 +1311,21 @@ class CausalLM(Model): max_input_length = request.max_input_length prefill_batch_size_list = [batch for batch in range(PREFILL_BATCH_BUCKET_SIZE, max_prefill_batch_size, PREFILL_BATCH_BUCKET_SIZE)] prefill_batch_size_list.append(max_prefill_batch_size) - prefill_seqlen_list = [seq for seq in range(PAD_SEQUENCE_TO_MULTIPLE_OF, max_input_length, PAD_SEQUENCE_TO_MULTIPLE_OF)] + prefill_seqlen_list = [ + seq + for seq in range( + PAD_SEQUENCE_TO_MULTIPLE_OF, + max_input_length, + PAD_SEQUENCE_TO_MULTIPLE_OF, + ) + ] prefill_seqlen_list.append(max_input_length) prefill_batch_size_list.sort(reverse=True) prefill_seqlen_list.sort(reverse=True) try: for batch_size in prefill_batch_size_list: for seq_len in prefill_seqlen_list: - batch = self.generate_warmup_batch(request, seq_len-1, batch_size) + batch = self.generate_warmup_batch(request, seq_len - 1, batch_size) _, prefill_batch, _ = self.generate_token([batch]) except: prefill_batch_size_list.sort() @@ -1227,24 +1346,33 @@ class CausalLM(Model): f"Memory stats: {mem_stats} " ) - #warmup decode batch size + # warmup decode batch size max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS) max_decode_batch_size = round_up(max_decode_batch_size, BATCH_BUCKET_SIZE) - decode_batch_size_list = [i for i in range(BATCH_BUCKET_SIZE, max_decode_batch_size, BATCH_BUCKET_SIZE)] + decode_batch_size_list = [ + i + for i in range(BATCH_BUCKET_SIZE, max_decode_batch_size, BATCH_BUCKET_SIZE) + ] decode_batch_size_list.append(max_decode_batch_size) decode_batch_size_list.sort(reverse=True) try: for batch_size in decode_batch_size_list: - batches= [] - iters = math.floor(batch_size/max_prefill_batch_size) + batches = [] + iters = math.floor(batch_size / max_prefill_batch_size) for i in range(iters): - batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, max_prefill_batch_size) + batch = self.generate_warmup_batch( + request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, max_prefill_batch_size + ) _, prefill_batch, _ = self.generate_token([batch]) batches.append(prefill_batch) if batch_size % max_prefill_batch_size != 0: - batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, batch_size % max_prefill_batch_size) + batch = self.generate_warmup_batch( + request, + PAD_SEQUENCE_TO_MULTIPLE_OF - 1, + batch_size % max_prefill_batch_size, + ) _, prefill_batch, _ = self.generate_token([batch]) batches.append(prefill_batch) @@ -1254,10 +1382,10 @@ class CausalLM(Model): batches.clear() except: - raise RuntimeError( - f"Not enough memory to warmup decode batch_sizes({decode_batch_size_list})." - f"You need to decrease `--max-batch-total-tokens`" - ) + raise RuntimeError( + f"Not enough memory to warmup decode batch_sizes({decode_batch_size_list})." + f"You need to decrease `--max-batch-total-tokens`" + ) decode_batch_size_list.sort() MAX_BATCH_TOTAL_TOKENS = MAX_TOTAL_TOKENS * decode_batch_size_list[-1] @@ -1268,4 +1396,4 @@ class CausalLM(Model): f"Memory stats: {mem_stats} " ) - return MAX_BATCH_TOTAL_TOKENS \ No newline at end of file + return MAX_BATCH_TOTAL_TOKENS