diff --git a/backends/gaudi/server/text_generation_server/layers/attention/__init__.py b/backends/gaudi/server/text_generation_server/layers/attention/__init__.py index 9ba9f6e0..89a43d65 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/__init__.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/__init__.py @@ -3,6 +3,7 @@ from .common import ( HPUPagedAttentionMetadata, trim_attn_metadata, trim_seqlen_metadata, + _async_h2d_tensor_copy, ) from .hpu import ( @@ -25,4 +26,5 @@ __all__ = [ "HPUPagedAttentionMetadata", "trim_seqlen_metadata", "trim_attn_metadata", + "_async_h2d_tensor_copy", ] diff --git a/backends/gaudi/server/text_generation_server/layers/attention/common.py b/backends/gaudi/server/text_generation_server/layers/attention/common.py index 34c77040..9bd738fc 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/common.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/common.py @@ -75,42 +75,27 @@ def trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object: @dataclass class Seqlen: input_lengths: torch.Tensor - cache_lengths: torch.Tensor - cu_seqlen_q: Optional[torch.Tensor] - cu_seqlen_k: Optional[torch.Tensor] def __init__( self, input_lengths, - cache_lengths, - cu_seqlen_q=None, ): self.input_lengths = input_lengths - self.cache_lengths = cache_lengths - device = self.input_lengths.device - shape = self.input_lengths.shape - if cu_seqlen_q is None: - cu_seqlen_q = torch.arange( - shape[0] + 1, - device=device, - dtype=torch.int32, - ) - cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32) - - # cuda graphs don't like this and this is necessary to clamp within mistral - # Although FA2 might not want the clamping - # cu_seqlen_k[0] = 0 - total = self.input_lengths + self.cache_lengths - torch.cumsum(total, -1, out=cu_seqlen_k[1:]) - - self.cu_seqlen_q = cu_seqlen_q - self.cu_seqlen_k = cu_seqlen_k def clamp(self, max): # Flash decoding doesn't need to clamp return self +def _async_h2d_tensor_copy(source, device="hpu"): + if source is None: + return None + assert source.device.type == "cpu", "Source tensor is not present in host memory!" + target = torch.empty(source.shape, dtype=source.dtype, device=device) + target.copy_(source, non_blocking=True) + return target + + def trim_seqlen_metadata(metadata: Seqlen) -> object: # NOTE(kzawora): To anyone working on this in the future: # Trimming metadata is required when using HPUGraphs. @@ -137,9 +122,6 @@ def trim_seqlen_metadata(metadata: Seqlen) -> object: "TrimmedSeqlen", [ "input_lengths", - "cache_lengths", - "cu_seqlen_q", - "cu_seqlen_k", ], ) return attention_metadata diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index ecedd4aa..ad585172 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -36,6 +36,7 @@ from text_generation_server.utils import ( initialize_torch_distributed, weight_files, Weights, + pad_next_token_chooser_parameters, ) from text_generation_server.models.types import ( Batch, @@ -56,6 +57,7 @@ from text_generation_server.layers.attention import ( HPUPagedAttentionMetadata, trim_attn_metadata, trim_seqlen_metadata, + _async_h2d_tensor_copy, ) from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils.dist import MEMORY_FRACTION @@ -141,18 +143,23 @@ def prepare_for_decode( block_groups = pad_list(block_groups, block_bucket_size, -1) block_usage = pad_list(block_usage, block_bucket_size, 1) - block_list = torch.tensor(block_list, dtype=torch.int, device=device) - block_groups = torch.tensor(block_groups, dtype=torch.int, device=device) - block_usage = torch.tensor(block_usage, dtype=dtype, device=device) - block_mapping = torch.nn.functional.one_hot(block_groups, num_classes=batch_size) + block_list = torch.tensor(block_list, dtype=torch.int, device="cpu") + block_groups = torch.tensor(block_groups, dtype=torch.int, device="cpu") + block_usage = torch.tensor(block_usage, dtype=dtype, device="cpu") + block_list_device = _async_h2d_tensor_copy(block_list) + block_groups_device = _async_h2d_tensor_copy(block_groups) + block_usage_device = _async_h2d_tensor_copy(block_usage) + block_mapping = torch.nn.functional.one_hot( + block_groups_device, num_classes=batch_size + ) mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0) mask = mask >= block_usage.unsqueeze(-1) attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf) return trim_attn_metadata( HPUPagedAttentionMetadata( - block_list=block_list, - block_groups=block_groups, - block_usage=block_usage, + block_list=block_list_device, + block_groups=block_groups_device, + block_usage=block_usage_device, block_mapping=block_mapping.to(dtype), attn_bias=attn_bias, ) @@ -248,6 +255,7 @@ class FlashCausalLMBatch(Batch): next_token_logits: Optional[torch.Tensor] speculative_logits: Optional[torch.Tensor] + valid_indices: Optional[List[int]] def to_pb(self) -> generate_pb2.CachedBatch: return generate_pb2.CachedBatch( @@ -417,32 +425,23 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor[i, : len(input_ids)] = input_ids # Create tensors on device - all_input_ids_tensor = torch.tensor( - all_input_ids_tensor, dtype=torch.int64, device=device - ) + all_input_ids_tensor = torch.tensor(all_input_ids_tensor, dtype=torch.int64) - top_n_tokens_tensor = torch.tensor( - top_n_tokens, device=device, dtype=torch.int64 - ) + top_n_tokens_tensor = torch.tensor(top_n_tokens, dtype=torch.int64) - block_tables_ragged = torch.tensor( - block_tables_ragged, device=device, dtype=torch.int32 - ) - cu_blocks = torch.tensor(cu_blocks, device=device, dtype=torch.int64) + block_tables_ragged = torch.tensor(block_tables_ragged, dtype=torch.int32) + cu_blocks = torch.tensor(cu_blocks, dtype=torch.int64) block_tables_tensor = torch.empty( (len(block_tables), max_blocks), - device=device, dtype=torch.int32, ) for i, request_blocks in enumerate(block_tables): block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) - prompt_lengths_tensor = torch.tensor( - prompt_lengths, dtype=torch.int32, device=device - ) + prompt_lengths_tensor = torch.tensor(prompt_lengths, dtype=torch.int32) - slots = torch.tensor(slots, dtype=torch.int64, device=device) + slots = torch.tensor(slots, dtype=torch.int64) cu_slots = torch.tensor(cu_slots, dtype=torch.int64) return cls( @@ -488,6 +487,7 @@ class FlashCausalLMBatch(Batch): hpu_attn_meta=None, next_token_logits=None, speculative_logits=None, + valid_indices=None, ) @classmethod @@ -519,9 +519,7 @@ class FlashCausalLMBatch(Batch): indices = [] # slots to keep after filtering - slot_filtering_indices = torch.zeros( - self.slots.shape[0], dtype=torch.bool, device=device - ) + slot_filtering_indices = torch.zeros(self.slots.shape[0], dtype=torch.bool) # Create on CPU to only move to GPU once instead of at every copy slot_indices = torch.empty(len(request_ids), dtype=torch.int64) @@ -544,7 +542,6 @@ class FlashCausalLMBatch(Batch): prefill_logprob_tokens = [] stopping_criterias = [] - top_n_tokens = [] adapter_set = set() num_blocks = 0 @@ -582,7 +579,6 @@ class FlashCausalLMBatch(Batch): stopping_criteria = self.stopping_criterias[idx] stopping_criterias.append(stopping_criteria) - top_n_tokens.append(self.top_n_tokens[idx]) prefill_logprob_tokens.append(self.prefill_logprob_tokens[idx]) ADAPTER_TO_INDEX = get_adapter_to_index() @@ -614,19 +610,7 @@ class FlashCausalLMBatch(Batch): max_blocks = max(max_blocks, len(request_block_table)) max_slots = max(max_slots, slot_length) - all_input_ids_tensor = self.all_input_ids_tensor[indices] - next_token_logits = self.next_token_logits[indices] - speculative_logits = ( - self.speculative_logits[indices] - if self.speculative_logits is not None - else None - ) block_tables_tensor = self.block_tables_tensor[indices] - next_token_chooser = self.next_token_chooser.filter(indices) - top_n_tokens_tensor = self.top_n_tokens_tensor[indices] - speculative_ids = ( - self.speculative_ids[indices] if self.speculative_ids is not None else None - ) prompt_lengths_tensor = self.prompt_lengths_tensor[indices] cu_slots = torch.tensor(cu_slots, dtype=torch.int64) @@ -652,16 +636,14 @@ class FlashCausalLMBatch(Batch): slot_indices = slot_indices.to(device) adapter_segments, adapter_segment_indices = find_segments(adapter_indices) - adapter_segments = torch.tensor( - adapter_segments, dtype=torch.int32, device=device - ) + adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32) adapter_meta = AdapterBatchMetadata( adapter_indices=adapter_indices, adapter_set=adapter_set, adapter_segments=adapter_segments, segment_indices=adapter_segment_indices, ) - + htorch.core.mark_step() return type(self)( batch_id=self.batch_id, requests=requests, @@ -692,18 +674,19 @@ class FlashCausalLMBatch(Batch): prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, - all_input_ids_tensor=all_input_ids_tensor, - next_token_chooser=next_token_chooser, + all_input_ids_tensor=self.all_input_ids_tensor, + next_token_chooser=self.next_token_chooser, stopping_criterias=stopping_criterias, - top_n_tokens=top_n_tokens, - top_n_tokens_tensor=top_n_tokens_tensor, + top_n_tokens=self.top_n_tokens, + top_n_tokens_tensor=self.top_n_tokens_tensor, num_blocks=num_blocks, max_blocks=max_blocks, - speculative_ids=speculative_ids, + speculative_ids=self.speculative_ids, adapter_meta=adapter_meta, hpu_attn_meta=None, - next_token_logits=next_token_logits, - speculative_logits=speculative_logits, + valid_indices=indices, + next_token_logits=self.next_token_logits, + speculative_logits=self.speculative_logits, ) @classmethod @@ -820,6 +803,7 @@ class FlashCausalLMBatch(Batch): for i, batch in enumerate(batches): requests.extend(batch.requests) + valid_bsize = len(batch) if i == 0: requests_idx_mapping = batch.requests_idx_mapping @@ -829,16 +813,15 @@ class FlashCausalLMBatch(Batch): requests_idx_mapping[k] = v + cumulative_batch_size start_index = cumulative_batch_size - end_index = cumulative_batch_size + len(batch) + end_index = cumulative_batch_size + valid_bsize - # Copy tensors (HPU) index = torch.tensor( list(range(start_index, end_index)), device=batch.input_ids.device ) top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor) all_input_ids_tensor[ start_index:end_index, : batch.all_input_ids_tensor.shape[1] - ] = batch.all_input_ids_tensor[:, :max_length] + ] = batch.all_input_ids_tensor[:valid_bsize, :max_length] block_tables_tensor[ start_index:end_index, : batch.block_tables_tensor.shape[1] @@ -847,19 +830,28 @@ class FlashCausalLMBatch(Batch): slots_start_index = cumulative_slots slots_end_index = cumulative_slots + len(batch.slots) - slots[slots_start_index:slots_end_index] = batch.slots + slot_index = torch.tensor( + list(range(slots_start_index, slots_end_index)), + device=batch.slots.device, + ) + + slots.index_copy_(0, slot_index, batch.slots) cu_slots[start_index + 1 : end_index + 1] = ( batch.cu_slots[1:] + cumulative_slots ) if not prefilling: - input_ids.index_copy_(0, index, batch.input_ids) - position_ids.index_copy_(0, index, batch.position_ids) + input_ids.index_copy_(0, index, batch.input_ids[:valid_bsize]) + position_ids.index_copy_(0, index, batch.position_ids[:valid_bsize]) slot_indices.index_copy_( 0, index, batch.slot_indices + cumulative_slots ) - input_lengths_tensor.index_copy_(0, index, batch.input_lengths_tensor) - cache_lengths_tensor.index_copy_(0, index, batch.cache_lengths_tensor) + input_lengths_tensor.index_copy_( + 0, index, batch.input_lengths_tensor[:valid_bsize] + ) + cache_lengths_tensor.index_copy_( + 0, index, batch.cache_lengths_tensor[:valid_bsize] + ) adapter_start_index = cumulative_adapter_indices_size adapter_end_index = ( cumulative_adapter_indices_size @@ -967,6 +959,7 @@ class FlashCausalLMBatch(Batch): hpu_attn_meta=None, next_token_logits=None, speculative_logits=None, + valid_indices=None, ) def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx): @@ -982,27 +975,53 @@ class FlashCausalLMBatch(Batch): padded_bs = self.input_ids.shape[0] slots = self.slots[self.slot_indices] extra_pad = padded_bs - self.input_ids.shape[0] - if extra_pad != 0: - slots = F.pad(slots, (0, extra_pad), value=0) - block_tables.extend([[0]] * extra_pad) self.hpu_attn_meta = prepare_for_decode( dtype, use_contiguous_pa, - self.block_tables_tensor.device, - slots.cpu(), + "hpu", + slots, block_tables, padded_bs, bucketing_ctx, ) + self.input_ids = F.pad(self.input_ids, (0, extra_pad), value=0) + self.position_ids = F.pad(self.position_ids, (0, extra_pad), value=1) + self.input_lengths_tensor = F.pad( + self.input_lengths_tensor, (0, extra_pad), value=0 + ) + self.cache_lengths_tensor = F.pad( + self.cache_lengths_tensor, (0, extra_pad), value=0 + ) + self.all_input_ids_tensor = F.pad( + self.all_input_ids_tensor, + (0, 0, 0, extra_pad), + value=0, + ) + next_token_chooser_parameters = [] + next_token_chooser_parameters.extend([r.parameters for r in self.requests]) + pad_next_token_chooser_parameters(next_token_chooser_parameters, padded_bs) + # update past grammar states + fsm_grammar_states = [0] * padded_bs - def prepare_for_prefill(self, max_padded_input_len): + for i, req in enumerate(self.requests): + fsm_grammar_states[i] = self.next_token_chooser.fsm_grammar_states[i] + + self.next_token_chooser = HeterogeneousNextTokenChooser.from_pb( + next_token_chooser_parameters, + self.next_token_chooser.dtype, + self.next_token_chooser.device, + self.next_token_chooser.tokenizer, + fsm_grammar_states, + ) + + def prepare_for_prefill(self, max_padded_input_len, max_padded_bs): # Prepare values if we need to continue prefilling # Speculation must be ignored while we prefill even with chunking # it simplifies everything assert self.speculative_ids is None - device = self.block_tables_tensor.device + # device = self.block_tables_tensor.device # hpu does not support varlen for prefill, use sdpa instead. so need to pad input_tensor, position # padding to left to work with sliding window @@ -1011,6 +1030,7 @@ class FlashCausalLMBatch(Batch): input_ids_padded_length = [] # need extra pad to match warmup seq extra_pad = max_padded_input_len - self.max_input_length + extra_pad_bs = max_padded_bs - len(self) if isinstance(self.input_ids, list) and len(self) > 1: input_ids_padded_length = [] input_ids = [] @@ -1021,24 +1041,32 @@ class FlashCausalLMBatch(Batch): input_ids.append(input_id) input_ids_padded_length.append(padded) input_ids = np.concatenate(input_ids, dtype=np.int64) - self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + self.input_ids = torch.tensor(input_ids, dtype=torch.int64) elif isinstance(self.input_ids, list): input_ids = self.input_ids[0] input_ids_padded_length.append(extra_pad) input_ids = [0] * extra_pad + input_ids - self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + self.input_ids = torch.tensor(input_ids, dtype=torch.int64) else: self.input_ids = F.pad(self.input_ids, (extra_pad, 0), value=0) - input_ids_padded_length.append(extra_pad) + input_ids_padded_length.extend([extra_pad] * len(self)) - self.input_lengths_tensor = torch.tensor( - self.input_lengths, dtype=torch.int32, device=device + self.input_ids = F.pad( + self.input_ids, (0, extra_pad_bs * max_padded_input_len), value=0 ) - cu_seqlen_prefill = self.input_lengths_tensor.new_zeros(len(self) + 1) + + self.input_lengths_tensor = torch.tensor(self.input_lengths, dtype=torch.int32) + + self.input_lengths_tensor = F.pad( + self.input_lengths_tensor, (0, extra_pad_bs), value=0 + ) + + cu_seqlen_prefill = self.input_lengths_tensor.new_zeros(max_padded_bs + 1) torch.cumsum(self.input_lengths_tensor, out=cu_seqlen_prefill[1:], dim=0) self.cu_seqlen_prefill = cu_seqlen_prefill.to(torch.int32) - self.cache_lengths_tensor = torch.tensor( - self.cache_lengths, dtype=torch.int32, device=device + self.cache_lengths_tensor = torch.tensor(self.cache_lengths, dtype=torch.int32) + self.cache_lengths_tensor = F.pad( + self.cache_lengths_tensor, (0, extra_pad_bs), value=0 ) sliding_window = get_sliding_windows() @@ -1171,7 +1199,7 @@ class FlashCausalLMBatch(Batch): torch.arange( cumulative_length, cumulative_length + input_length, - dtype=torch.int64, + dtype=torch.int32, ) ) prefill_next_token_indices.append( @@ -1182,7 +1210,7 @@ class FlashCausalLMBatch(Batch): prefill_head_indices.append( torch.tensor( [cumulative_length + input_length - 1], - dtype=torch.int64, + dtype=torch.int32, ) ) prefill_next_token_indices.append(prefill_out_cumulative_length) @@ -1204,12 +1232,15 @@ class FlashCausalLMBatch(Batch): slot_indices = slot_indices[0] prefill_cache_indices = prefill_cache_indices[0] - self.position_ids = position_ids.to(device) - self.slot_indices = slot_indices.to(device) + self.position_ids = position_ids + self.position_ids = F.pad( + self.position_ids, (0, extra_pad_bs * max_padded_input_len), value=1 + ) + self.slot_indices = slot_indices self.prefill_cu_outlens = prefill_cu_outlens self.prefill_cache_indices = torch.zeros_like(self.input_ids, dtype=torch.bool) - self.prefill_cache_indices[prefill_cache_indices.to(device)] = True + self.prefill_cache_indices[prefill_cache_indices] = True if all_prefill_logprobs: prefill_head_indices = None @@ -1218,16 +1249,19 @@ class FlashCausalLMBatch(Batch): prefill_head_indices = self.cu_seqlen_prefill[1:] - 1 prefill_next_token_indices = None else: - prefill_head_indices = torch.cat(prefill_head_indices).to(device) + prefill_head_indices = torch.cat(prefill_head_indices) prefill_next_token_indices = torch.tensor( - prefill_next_token_indices, dtype=torch.int64, device=device + prefill_next_token_indices, dtype=torch.int64 ) self.prefill_head_indices = prefill_head_indices self.prefill_next_token_indices = prefill_next_token_indices input_ids_padded_length_tensor = torch.cumsum( - torch.tensor(input_ids_padded_length, dtype=torch.int64, device=device), + torch.tensor(input_ids_padded_length, dtype=torch.int32), dim=-1, + ).to(torch.int32) + input_ids_padded_length_tensor = F.pad( + input_ids_padded_length_tensor, (0, extra_pad_bs), value=0 ) if self.prefill_head_indices is not None: self.prefill_head_indices = ( @@ -1239,19 +1273,37 @@ class FlashCausalLMBatch(Batch): self.prefill_next_token_indices + input_ids_padded_length_tensor ) + self.all_input_ids_tensor = F.pad( + self.all_input_ids_tensor, + (0, 0, 0, extra_pad_bs), + value=0, + ) + next_token_chooser_parameters = [] + next_token_chooser_parameters.extend([r.parameters for r in self.requests]) + pad_next_token_chooser_parameters(next_token_chooser_parameters, max_padded_bs) + # update past grammar states + fsm_grammar_states = [0] * max_padded_bs + + for i, req in enumerate(self.requests): + fsm_grammar_states[i] = self.next_token_chooser.fsm_grammar_states[i] + + self.next_token_chooser = HeterogeneousNextTokenChooser.from_pb( + next_token_chooser_parameters, + self.next_token_chooser.dtype, + self.next_token_chooser.device, + self.next_token_chooser.tokenizer, + fsm_grammar_states, + ) + if adapter_set: - adapter_indices = torch.cat(adapter_indices_list).to( - dtype=torch.int64, device=device - ) + adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64) adapter_segments, adapter_segment_indices = find_segments(adapter_indices) else: adapter_indices = torch.zeros_like(self.input_ids) adapter_segments = [0, len(adapter_indices)] adapter_segment_indices = [len(adapter_indices) - 1] - adapter_segments = torch.tensor( - adapter_segments, dtype=torch.int32, device=device - ) + adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32) self.adapter_meta = AdapterBatchMetadata( adapter_indices=adapter_indices, adapter_set=adapter_set, @@ -1392,6 +1444,9 @@ class FlashCausalLM(Model): self.use_contiguous_pa = ( os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true" ) + self.limit_hpu_graphs = ( + os.environ.get("LIMIT_HPU_GRAPHS", "false").lower() == "true" + ) super().__init__( model_id=model_id, model=model, @@ -1509,8 +1564,17 @@ class FlashCausalLM(Model): self.device, ) + max_num_seqs = int(os.getenv("MAX_BATCH_SIZE", 128)) + if os.getenv("VLLM_PROMPT_SEQ_BUCKET_MAX") is None: + os.environ["VLLM_PROMPT_SEQ_BUCKET_MAX"] = str(max_input_tokens) + if os.getenv("VLLM_DECODE_BLOCK_BUCKET_MAX") is None: + max_total_blocks = ( + math.ceil(max_total_tokens / BLOCK_SIZE) * max_num_seqs + 1 + ) + os.environ["VLLM_DECODE_BLOCK_BUCKET_MAX"] = str(max_total_blocks) + self.bucketing_ctx = HPUBucketingContext( - os.getenv("DECODE_MAX_BS", 128), # self.max_num_seqs, #TODO + max_num_seqs, os.getenv("PREFILL_MAX_BS", 64), # self.max_num_prefill_seqs, #TODO BLOCK_SIZE, num_blocks * BLOCK_SIZE, @@ -1536,6 +1600,7 @@ class FlashCausalLM(Model): log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}") for index in range(warmup_times): self.warmup_prefill(seq_len, batch_size, batch) + self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) for i, (batch_size, block_num) in enumerate( reversed(self.bucketing_ctx.decode_buckets) @@ -1552,62 +1617,51 @@ class FlashCausalLM(Model): def warmup_prefill( self, prompt_len: int, batch_size: int, batch: FlashCausalLMBatch ): - input_ids = torch.zeros( - prompt_len, dtype=batch.input_ids.dtype, device=self.device - ).repeat(batch_size) - position_ids = torch.arange( - prompt_len, dtype=batch.position_ids.dtype, device=self.device - ).repeat(batch_size) + input_ids = torch.zeros(prompt_len, dtype=batch.input_ids.dtype).repeat( + batch_size + ) + position_ids = torch.arange(prompt_len, dtype=batch.position_ids.dtype).repeat( + batch_size + ) max_bt = (prompt_len // BLOCK_SIZE + 1) * batch_size - block_tables = torch.arange( - max_bt, dtype=torch.int32, device=self.device - ).reshape(batch_size, -1) + block_tables = torch.arange(max_bt, dtype=torch.int32).reshape(batch_size, -1) slot_acc = [] for i in range(batch_size): slots = [] for b in block_tables[i]: slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)) slot_acc.extend(slots[:prompt_len]) - slots = torch.tensor(slot_acc, dtype=batch.slots.dtype, device=self.device) + slots = torch.tensor(slot_acc, dtype=batch.slots.dtype) - input_lengths = ( - torch.ones(batch_size, dtype=torch.int32, device=self.device) * prompt_len - ) - cache_lengths_tensor = torch.zeros( - batch_size, dtype=torch.int32, device=self.device - ) - cu_seqlen_prefill = torch.zeros( - batch_size + 1, device=self.device, dtype=torch.int32 - ) + input_lengths = torch.ones(batch_size, dtype=torch.int32) * prompt_len + cu_seqlen_prefill = torch.zeros(batch_size + 1, dtype=torch.int32) torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:]) seqlen = Seqlen( - input_lengths=input_lengths, - cache_lengths=cache_lengths_tensor, - cu_seqlen_q=cu_seqlen_prefill, + input_lengths=_async_h2d_tensor_copy(input_lengths), ) lm_head_indices = input_lengths - 1 + kwargs = {} + if htorch.utils.internal.is_lazy(): + kwargs["bypass_hpu_graphs"] = self.limit_hpu_graphs # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, + input_ids=_async_h2d_tensor_copy(input_ids), + position_ids=_async_h2d_tensor_copy(position_ids), + cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill), kv_cache=self.kv_cache, - slots=slots, + slots=_async_h2d_tensor_copy(slots), seqlen=trim_seqlen_metadata(seqlen), - lm_head_indices=lm_head_indices, + lm_head_indices=_async_h2d_tensor_copy(lm_head_indices), adapter_data=None, hpu_attention_meta=None, + **kwargs, ) def warmup_decode(self, batch_size: int, block_num: int, batch: FlashCausalLMBatch): - input_ids = torch.zeros( - batch_size, dtype=batch.input_ids.dtype, device=self.device - ) - position_ids = torch.arange( - batch_size, dtype=batch.position_ids.dtype, device=self.device - ) + input_ids = torch.zeros(batch_size, dtype=batch.input_ids.dtype) + position_ids = torch.arange(batch_size, dtype=batch.position_ids.dtype) blocks = [block_num // batch_size for _ in range(batch_size)] blocks[0] += block_num % batch_size past_len = [] @@ -1622,19 +1676,12 @@ class FlashCausalLM(Model): block_tables.append(block_array) past_len.append(blocks[i] * BLOCK_SIZE - 1) start_idx += blocks[i] - input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device) - cache_lengths_tensor = torch.tensor( - past_len, dtype=torch.int32, device=self.device - ) - cu_seqlen_prefill = torch.zeros( - batch_size + 1, device=self.device, dtype=torch.int32 - ) + input_lengths = torch.ones(batch_size, dtype=torch.int32) + cu_seqlen_prefill = torch.zeros(batch_size + 1, dtype=torch.int32) torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:]) seqlen = Seqlen( - input_lengths=input_lengths, - cache_lengths=cache_lengths_tensor, - cu_seqlen_q=cu_seqlen_prefill, + input_lengths=_async_h2d_tensor_copy(input_lengths), ) hpu_attention_meta = prepare_for_decode( @@ -1646,18 +1693,22 @@ class FlashCausalLM(Model): batch_size, bucketing_ctx=None, ) - slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device) + slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype) + kwargs = {} + if htorch.utils.internal.is_lazy(): + kwargs["bypass_hpu_graphs"] = False # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( - input_ids=input_ids, - position_ids=position_ids, + input_ids=_async_h2d_tensor_copy(input_ids), + position_ids=_async_h2d_tensor_copy(position_ids), cu_seqlen_prefill=None, kv_cache=self.kv_cache, - slots=slots_tensor, + slots=_async_h2d_tensor_copy(slots_tensor), seqlen=trim_seqlen_metadata(seqlen), lm_head_indices=None, adapter_data=None, hpu_attention_meta=hpu_attention_meta, + **kwargs, ) def forward( @@ -1699,9 +1750,6 @@ class FlashCausalLM(Model): input_lengths = ( input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int ).view(-1) - cache_lengths_tensor = ( - batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length) - ).reshape(-1) # Add Copy the block tables for all members block_tables = ( @@ -1722,7 +1770,6 @@ class FlashCausalLM(Model): block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor - cache_lengths_tensor = batch.cache_lengths_tensor max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices @@ -1735,80 +1782,34 @@ class FlashCausalLM(Model): slots_pad = torch.zeros_like(input_ids) slots_pad[batch.prefill_cache_indices] = slots slots = slots_pad - if self.bucketing_ctx is not None: - if batch.prefilling: - padded_bs = self.bucketing_ctx.get_padded_prompt_batch_size( - input_lengths.shape[0] - ) - else: - padded_bs = self.bucketing_ctx.get_padded_decode_batch_size( - input_lengths.shape[0] - ) else: - padded_bs = input_lengths.shape[0] - orig_bs = input_lengths.shape[0] - if padded_bs != input_lengths.shape[0]: - padded_input_lengths = F.pad( - input_lengths, - (0, padded_bs - orig_bs), - value=0, - ) - padded_cache_lengths_tensor = F.pad( - cache_lengths_tensor, - (0, padded_bs - orig_bs), - value=0, - ) - if cu_seqlen_prefill is not None: - cu_seqlen_prefill = torch.zeros( - padded_bs + 1, device=self.device, dtype=torch.int32 - ) - torch.cumsum(padded_input_lengths, -1, out=cu_seqlen_prefill[1:]) - seqlen = Seqlen( - input_lengths=padded_input_lengths, - cache_lengths=padded_cache_lengths_tensor, - cu_seqlen_q=cu_seqlen_prefill, - ) - input_seq = input_ids.view(orig_bs, -1) - input_ids = F.pad( - input_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0 - ) - position_ids = F.pad( - position_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=1 - ) - slots = F.pad( - slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0 - ) - if lm_head_indices is not None: - lm_head_indices = F.pad( - lm_head_indices, (0, padded_bs - orig_bs), value=0 - ) - else: - seqlen = Seqlen( - input_lengths=input_lengths, - cache_lengths=cache_lengths_tensor, - cu_seqlen_q=cu_seqlen_prefill, - ) + slots_pad = torch.zeros_like(input_ids) + slots_pad[: slots.shape[0]] = slots + slots = slots_pad + seqlen = Seqlen( + input_lengths=_async_h2d_tensor_copy(input_lengths), + ) kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = batch.prefilling + kwargs["bypass_hpu_graphs"] = ( + batch.prefilling if self.limit_hpu_graphs else False + ) logits, speculative_logits = self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, + input_ids=_async_h2d_tensor_copy(input_ids), + position_ids=_async_h2d_tensor_copy(position_ids), + cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill), kv_cache=kv_cache, - slots=slots, + slots=_async_h2d_tensor_copy(slots), seqlen=trim_seqlen_metadata(seqlen), - lm_head_indices=lm_head_indices, + lm_head_indices=_async_h2d_tensor_copy(lm_head_indices), # TODO not support adapter now, need the add in the future adapter_data=None, hpu_attention_meta=batch.hpu_attn_meta, **kwargs, ) - return logits[:orig_bs], ( - speculative_logits[:orig_bs] if speculative_logits is not None else None - ) + return logits, speculative_logits @tracer.start_as_current_span("generate_token") def generate_token( @@ -1817,6 +1818,7 @@ class FlashCausalLM(Model): # In order to pipeline any actions on CPU we perform the operation in 3 main stages: # Stage 1. Collect next token ids of any previously started generations + start = time.time_ns() prev_batches = [] requests_to_generate = [] for batch_id, batch in enumerate(batches): @@ -1834,7 +1836,9 @@ class FlashCausalLM(Model): accepted_ids, speculative_ids, ) = batch.next_token_chooser( - batch.all_input_ids_tensor[:, : batch.max_current_length], + _async_h2d_tensor_copy( + batch.all_input_ids_tensor[:, : batch.max_current_length] + ), batch.next_token_logits, speculate, batch.speculative_ids, @@ -1843,10 +1847,39 @@ class FlashCausalLM(Model): batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch.top_n_tokens, - batch.top_n_tokens_tensor, + _async_h2d_tensor_copy(batch.top_n_tokens_tensor), logprobs, accepted_ids, ) + if batch.valid_indices is not None: + next_input_ids = next_input_ids.cpu() + next_token_logprobs = next_token_logprobs.cpu() + accepted_ids = accepted_ids.cpu() + batch.all_input_ids_tensor = batch.all_input_ids_tensor[ + batch.valid_indices + ] + next_input_ids = next_input_ids[batch.valid_indices] + next_token_logprobs = next_token_logprobs[batch.valid_indices] + accepted_ids = accepted_ids[batch.valid_indices] + if speculative_ids is not None: + speculative_ids = speculative_ids[batch.valid_indices] + batch.top_n_tokens_tensor = batch.top_n_tokens_tensor[ + batch.valid_indices + ] + top_n_tokens = [] + batch_top_token_ids_v = [] + batch_top_token_logprobs_v = [] + for i in batch.valid_indices: + top_n_tokens.append(batch.top_n_tokens[i]) + batch_top_token_ids_v.append(batch_top_token_ids[i]) + batch_top_token_logprobs_v.append(batch_top_token_logprobs[i]) + batch_top_token_ids = batch_top_token_ids_v + batch_top_token_logprobs = batch_top_token_logprobs_v + batch.top_n_tokens = top_n_tokens + batch.next_token_chooser = batch.next_token_chooser.filter( + batch.valid_indices + ) + batch.valid_indices = None # Since we are done prefilling, all the tensors that were concatenating values for all the requests # instantly become of shape [BATCH_SIZE] @@ -1860,14 +1893,16 @@ class FlashCausalLM(Model): else: batch.position_ids = batch.position_ids[indices] - batch.slot_indices = batch.slot_indices[indices] + batch.slot_indices = batch.slot_indices[indices[: len(batch)]] batch.adapter_meta.adapter_indices = ( batch.adapter_meta.adapter_indices[indices] ) # For each member of the batch # Cumulative length + accepted_ids = accepted_ids.cpu() cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1) torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:]) + next_input_ids = next_input_ids.cpu() if batch.speculative_logits is not None: for i in range(len(batch)): batch.all_input_ids_tensor[ @@ -1879,16 +1914,16 @@ class FlashCausalLM(Model): ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]] else: index = batch.cache_lengths_tensor + batch.input_lengths_tensor + index = index.to(batch.all_input_ids_tensor) batch_idx = torch.arange( 0, batch.all_input_ids_tensor.shape[0], dtype=torch.long, - device=batch.input_lengths_tensor.device, + device=batch.all_input_ids_tensor.device, ) batch.all_input_ids_tensor.index_put_( (batch_idx, index.long()), next_input_ids ) - batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1] batch.speculative_ids = speculative_ids if batch.position_ids.dim() == 2: @@ -1900,7 +1935,7 @@ class FlashCausalLM(Model): batch.input_lengths_tensor + accepted_ids - 1 ) batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor) - batch.slot_indices += accepted_ids + batch.slot_indices += accepted_ids[: len(batch)] # Does a HPU <-> CPU sync internally if prefill: @@ -1921,8 +1956,6 @@ class FlashCausalLM(Model): } ) idx = len(prev_batches) - 1 - if batch.speculative_logits is not None: - accepted_ids_cpu = accepted_ids.cpu() for req_idx, req in enumerate(batch.requests): new_input_length = 1 @@ -1930,7 +1963,7 @@ class FlashCausalLM(Model): new_cache_length = ( batch.cache_lengths[req_idx] + batch.input_lengths[req_idx] - + accepted_ids_cpu[req_idx] + + accepted_ids[req_idx] - 1 ) else: @@ -1978,15 +2011,17 @@ class FlashCausalLM(Model): batch = self.batch_type.concatenate(batches) else: batch = batches[0] - start = time.time_ns() prefill = batch.prefilling if prefill: if self.bucketing_ctx is not None: batch.prepare_for_prefill( - self.bucketing_ctx.get_padded_prompt_seq_len(batch.max_input_length) + self.bucketing_ctx.get_padded_prompt_seq_len( + batch.max_input_length + ), + self.bucketing_ctx.get_padded_prompt_batch_size(len(batch)), ) else: - batch.prepare_for_prefill(batch.max_input_length) + batch.prepare_for_prefill(batch.max_input_length, len(batch)) else: batch.prepare_for_decode( self.dtype, self.use_contiguous_pa, self.bucketing_ctx @@ -2037,14 +2072,15 @@ class FlashCausalLM(Model): batch.speculative_logits = speculative_logits # HPU->CPU sync + htorch.core.mark_step() + start_decode = time.time_ns() for prev_batch in prev_batches: prev_batch["next_token_logprobs"] = prev_batch[ "next_token_logprobs" ].tolist() prev_batch["next_token_ids"] = prev_batch["next_token_ids"].tolist() prev_batch["accepted_ids"] = prev_batch["accepted_ids"].tolist() - - start_decode = time.time_ns() + htorch.core.mark_step() # Stage 3. Finish and return previous generations # Results generations: List[Generation] = [] @@ -2186,7 +2222,7 @@ class FlashCausalLM(Model): batch.prefix_offsets[i] = prefix_offset batch.read_offsets[i] = read_offset batch.all_input_ids[i] = all_input_ids - + htorch.core.mark_step() if stopped: # No need to return a batch if we know that all requests stopped forward_ns = start_decode - start diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index c885816b..1776b219 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -17,12 +17,15 @@ from text_generation_server.models.globals import PREFIX_CACHING, BLOCK_SIZE from loguru import logger from text_generation_server.utils.log import log_master from transformers import AutoProcessor -from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata +from text_generation_server.layers.attention import ( + Seqlen, + trim_seqlen_metadata, + _async_h2d_tensor_copy, +) import habana_frameworks.torch as htorch from text_generation_server.utils.import_utils import ( synchronize, ) -import torch.nn.functional as F tracer = trace.get_tracer(__name__) @@ -383,12 +386,8 @@ class FlashVlmCausalLM(FlashCausalLM): def warmup_decode( self, batch_size: int, block_num: int, batch: FlashVlmCausalLMBatch ): - input_ids = torch.zeros( - batch_size, dtype=batch.input_ids.dtype, device=self.device - ) - position_ids = torch.arange( - batch_size, dtype=batch.position_ids.dtype, device=self.device - ) + input_ids = torch.zeros(batch_size, dtype=batch.input_ids.dtype) + position_ids = torch.arange(batch_size, dtype=batch.position_ids.dtype) if batch.position_ids is not None and batch.position_ids.dim() == 2: # qwen2_vl and qwen2_5_vl case position_ids = position_ids.unsqueeze(-1).repeat( @@ -408,19 +407,10 @@ class FlashVlmCausalLM(FlashCausalLM): block_tables.append(block_array) past_len.append(blocks[i] * BLOCK_SIZE - 1) start_idx += blocks[i] - input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device) - cache_lengths_tensor = torch.tensor( - past_len, dtype=torch.int32, device=self.device - ) - cu_seqlen_prefill = torch.zeros( - batch_size + 1, device=self.device, dtype=torch.int32 - ) - torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:]) + input_lengths = torch.ones(batch_size, dtype=torch.int32) seqlen = Seqlen( - input_lengths=input_lengths, - cache_lengths=cache_lengths_tensor, - cu_seqlen_q=cu_seqlen_prefill, + input_lengths=_async_h2d_tensor_copy(input_lengths), ) hpu_attention_meta = prepare_for_decode( @@ -432,14 +422,14 @@ class FlashVlmCausalLM(FlashCausalLM): batch_size, bucketing_ctx=None, ) - slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device) + slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype) # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( - input_ids=input_ids, - position_ids=position_ids, + input_ids=_async_h2d_tensor_copy(input_ids), + position_ids=_async_h2d_tensor_copy(position_ids), cu_seqlen_prefill=None, kv_cache=self.kv_cache, - slots=slots_tensor, + slots=_async_h2d_tensor_copy(slots_tensor), seqlen=trim_seqlen_metadata(seqlen), hpu_attention_meta=hpu_attention_meta, lm_head_indices=None, @@ -498,9 +488,6 @@ class FlashVlmCausalLM(FlashCausalLM): input_lengths = ( input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int ).view(-1) - cache_lengths_tensor = ( - batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length) - ).reshape(-1) # Add Copy the block tables for all members block_tables = ( @@ -521,7 +508,6 @@ class FlashVlmCausalLM(FlashCausalLM): block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor - cache_lengths_tensor = batch.cache_lengths_tensor max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices @@ -546,78 +532,23 @@ class FlashVlmCausalLM(FlashCausalLM): slots_pad = torch.zeros_like(input_ids) slots_pad[batch.prefill_cache_indices] = slots slots = slots_pad - if self.bucketing_ctx is not None: - if batch.prefilling: - padded_bs = self.bucketing_ctx.get_padded_prompt_batch_size( - input_lengths.shape[0] - ) - else: - padded_bs = self.bucketing_ctx.get_padded_decode_batch_size( - input_lengths.shape[0] - ) else: - padded_bs = input_lengths.shape[0] - orig_bs = input_lengths.shape[0] - if padded_bs != input_lengths.shape[0]: - padded_input_lengths = F.pad( - input_lengths, - (0, padded_bs - orig_bs), - value=0, - ) - padded_cache_lengths_tensor = F.pad( - cache_lengths_tensor, - (0, padded_bs - orig_bs), - value=0, - ) - if cu_seqlen_prefill is not None: - cu_seqlen_prefill = torch.zeros( - padded_bs + 1, device=self.device, dtype=torch.int32 - ) - torch.cumsum(padded_input_lengths, -1, out=cu_seqlen_prefill[1:]) - seqlen = Seqlen( - input_lengths=padded_input_lengths, - cache_lengths=padded_cache_lengths_tensor, - cu_seqlen_q=cu_seqlen_prefill, - ) - input_seq = input_ids.view(orig_bs, -1) - input_ids = F.pad( - input_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0 - ) - if position_ids.dim() == 2: - # qwen2_vl and qwen2_5_vl case - position_ids = F.pad( - position_ids, - (0, 0, 0, (padded_bs - orig_bs) * input_seq.shape[-1]), - value=1, - ) - else: - position_ids = F.pad( - position_ids, - (0, (padded_bs - orig_bs) * input_seq.shape[-1]), - value=1, - ) - slots = F.pad( - slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0 - ) - if lm_head_indices is not None: - lm_head_indices = F.pad( - lm_head_indices, (0, padded_bs - orig_bs), value=0 - ) - else: - seqlen = Seqlen( - input_lengths=input_lengths, - cache_lengths=cache_lengths_tensor, - cu_seqlen_q=cu_seqlen_prefill, - ) + slots_pad = torch.zeros_like(input_ids) + slots_pad[: slots.shape[0]] = slots + slots = slots_pad + + seqlen = Seqlen( + input_lengths=_async_h2d_tensor_copy(input_lengths), + ) logits, speculative_logits = self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, + input_ids=_async_h2d_tensor_copy(input_ids), + position_ids=_async_h2d_tensor_copy(position_ids), + cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill), kv_cache=kv_cache, - slots=slots, + slots=_async_h2d_tensor_copy(slots), seqlen=trim_seqlen_metadata(seqlen), hpu_attention_meta=batch.hpu_attn_meta, - lm_head_indices=lm_head_indices, + lm_head_indices=_async_h2d_tensor_copy(lm_head_indices), pixel_values=batch.pixel_values, pixel_attention_mask=batch.pixel_attention_mask, image_sizes=batch.image_sizes, @@ -632,6 +563,4 @@ class FlashVlmCausalLM(FlashCausalLM): batch.image_sizes = None if batch.image_grid_thw is not None: batch.image_grid_thw = None - return logits[:orig_bs], ( - speculative_logits[:orig_bs] if speculative_logits is not None else None - ) + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index c1ea36f2..5de9bca8 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -19,7 +19,11 @@ from text_generation_server.models.flash_vlm_causal_lm import ( FlashVlmCausalLM, ) from text_generation_server.pb import generate_pb2 -from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata +from text_generation_server.layers.attention import ( + Seqlen, + trim_seqlen_metadata, + _async_h2d_tensor_copy, +) import habana_frameworks.torch as htorch from loguru import logger from text_generation_server.models.globals import BLOCK_SIZE @@ -183,7 +187,7 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch): input_ids = np.concatenate(batch.input_ids, dtype=np.int64) else: input_ids = batch.input_ids[0] - batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + batch.input_ids = torch.tensor(input_ids, dtype=torch.int64) batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1) @@ -206,33 +210,26 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch): def generate_cross_attention_states( - cross_attention_states, image_indices, seqlen, pad_seq_len, prefilling + cross_attention_states, image_indices, input_lengths, pad_seq_len, prefilling ): if cross_attention_states is None: return None, None, None - device = cross_attention_states.device indices_list = [] if prefilling: for i in image_indices: - indices_list.append( - torch.arange(pad_seq_len * i, pad_seq_len * (i + 1), device=device) - ) + indices_list.append(torch.arange(pad_seq_len * i, pad_seq_len * (i + 1))) indices = torch.cat(indices_list, dim=0) else: indices = image_indices[:] - return indices, seqlen.input_lengths.index_select(0, image_indices) + return indices, input_lengths.index_select(0, image_indices) class FlashMllamaCausalLM(FlashVlmCausalLM): def warmup_decode( self, batch_size: int, block_num: int, batch: FlashMllamaCausalLMBatch ): - input_ids = torch.zeros( - batch_size, dtype=batch.input_ids.dtype, device=self.device - ) - position_ids = torch.arange( - batch_size, dtype=batch.position_ids.dtype, device=self.device - ) + input_ids = torch.zeros(batch_size, dtype=batch.input_ids.dtype) + position_ids = torch.arange(batch_size, dtype=batch.position_ids.dtype) blocks = [block_num // batch_size for _ in range(batch_size)] blocks[0] += block_num % batch_size past_len = [] @@ -247,19 +244,10 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): block_tables.append(block_array) past_len.append(blocks[i] * BLOCK_SIZE - 1) start_idx += blocks[i] - input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device) - cache_lengths_tensor = torch.tensor( - past_len, dtype=torch.int32, device=self.device - ) - cu_seqlen_prefill = torch.zeros( - batch_size + 1, device=self.device, dtype=torch.int32 - ) - torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:]) + input_lengths = torch.ones(batch_size, dtype=torch.int32) seqlen = Seqlen( - input_lengths=input_lengths, - cache_lengths=cache_lengths_tensor, - cu_seqlen_q=cu_seqlen_prefill, + input_lengths=_async_h2d_tensor_copy(input_lengths), ) hpu_attention_meta = prepare_for_decode( @@ -272,87 +260,86 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): bucketing_ctx=None, ) # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. - image_indices = torch.tensor(batch.image_indices, device=self.device) + image_indices = torch.tensor(batch.image_indices) image_indices = image_indices.repeat(batch_size) cross_attention_states = batch.cross_attention_states.repeat(batch_size, 1, 1) indices, cross_attention_len = generate_cross_attention_states( - cross_attention_states, image_indices, seqlen, 1, False + cross_attention_states, image_indices, input_lengths, 1, False ) - slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device) + slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype) self.model.forward( - input_ids=input_ids, - position_ids=position_ids, + input_ids=_async_h2d_tensor_copy(input_ids), + position_ids=_async_h2d_tensor_copy(position_ids), cu_seqlen_prefill=None, kv_cache=self.kv_cache, - slots=slots_tensor, + slots=_async_h2d_tensor_copy(slots_tensor), seqlen=trim_seqlen_metadata(seqlen), hpu_attention_meta=hpu_attention_meta, lm_head_indices=None, adapter_data=None, cross_attention_states=cross_attention_states, - indices=indices, - cross_attention_len=cross_attention_len, + indices=_async_h2d_tensor_copy(indices), + cross_attention_len=_async_h2d_tensor_copy(cross_attention_len), ) def warmup_prefill( self, prompt_len: int, batch_size: int, batch: FlashMllamaCausalLMBatch ): - input_ids = torch.zeros( - prompt_len, dtype=batch.input_ids.dtype, device=self.device - ).repeat(batch_size) - position_ids = torch.arange( - prompt_len, dtype=batch.position_ids.dtype, device=self.device - ).repeat(batch_size) + input_ids = torch.zeros(prompt_len, dtype=batch.input_ids.dtype).repeat( + batch_size + ) + position_ids = torch.arange(prompt_len, dtype=batch.position_ids.dtype).repeat( + batch_size + ) max_bt = (prompt_len // BLOCK_SIZE + 1) * batch_size - block_tables = torch.arange( - max_bt, dtype=torch.int32, device=self.device - ).reshape(batch_size, -1) + block_tables = torch.arange(max_bt, dtype=torch.int32).reshape(batch_size, -1) slot_acc = [] for i in range(batch_size): slots = [] for b in block_tables[i]: slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)) slot_acc.extend(slots[:prompt_len]) - slots = torch.tensor(slot_acc, dtype=batch.slots.dtype, device=self.device) + slots = torch.tensor(slot_acc, dtype=batch.slots.dtype) input_lengths = ( - torch.ones(batch_size, dtype=torch.int32, device=self.device) * prompt_len - ) - cache_lengths_tensor = torch.zeros( - batch_size, dtype=torch.int32, device=self.device - ) - cu_seqlen_prefill = torch.zeros( - batch_size + 1, device=self.device, dtype=torch.int32 + torch.ones( + batch_size, + dtype=torch.int32, + ) + * prompt_len ) + cu_seqlen_prefill = torch.zeros(batch_size + 1, dtype=torch.int32) torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:]) - seqlen = Seqlen( - input_lengths=input_lengths, - cache_lengths=cache_lengths_tensor, - cu_seqlen_q=cu_seqlen_prefill, - ) lm_head_indices = input_lengths - 1 # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. - image_indices = torch.tensor(batch.image_indices, device=self.device) + image_indices = torch.tensor(batch.image_indices) image_indices = image_indices.repeat(batch_size) cross_attention_states = batch.cross_attention_states.repeat(batch_size, 1, 1) indices, cross_attention_len = generate_cross_attention_states( - cross_attention_states, image_indices, seqlen, prompt_len, True + cross_attention_states, image_indices, input_lengths, prompt_len, True ) + seqlen = Seqlen( + input_lengths=_async_h2d_tensor_copy(input_lengths), + ) + kwargs = {} + if htorch.utils.internal.is_lazy(): + kwargs["bypass_hpu_graphs"] = self.limit_hpu_graphs self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, + input_ids=_async_h2d_tensor_copy(input_ids), + position_ids=_async_h2d_tensor_copy(position_ids), + cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill), kv_cache=self.kv_cache, - slots=slots, + slots=_async_h2d_tensor_copy(slots), seqlen=trim_seqlen_metadata(seqlen), hpu_attention_meta=None, - lm_head_indices=lm_head_indices, + lm_head_indices=_async_h2d_tensor_copy(lm_head_indices), adapter_data=None, cross_attention_states=cross_attention_states, - indices=indices, - cross_attention_len=cross_attention_len, + indices=_async_h2d_tensor_copy(indices), + cross_attention_len=_async_h2d_tensor_copy(cross_attention_len), + **kwargs, ) def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch): @@ -410,9 +397,6 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): input_lengths = ( input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int ).view(-1) - cache_lengths_tensor = ( - batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length) - ).reshape(-1) # Add Copy the block tables for all members block_tables = ( @@ -433,7 +417,6 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor - cache_lengths_tensor = batch.cache_lengths_tensor max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices @@ -455,100 +438,58 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = batch.prefilling + kwargs["bypass_hpu_graphs"] = ( + batch.prefilling if self.limit_hpu_graphs else False + ) if batch.prefill_cache_indices is not None: slots_pad = torch.zeros_like(input_ids) slots_pad[batch.prefill_cache_indices] = slots slots = slots_pad - - if self.bucketing_ctx is not None: - if batch.prefilling: - padded_bs = self.bucketing_ctx.get_padded_prompt_batch_size( - input_lengths.shape[0] - ) - else: - padded_bs = self.bucketing_ctx.get_padded_decode_batch_size( - input_lengths.shape[0] - ) else: - padded_bs = input_lengths.shape[0] - orig_bs = input_lengths.shape[0] - padded_input_len = input_ids.view(orig_bs, -1).shape[-1] - image_indices = torch.tensor(batch.image_indices, device=self.device) - if padded_bs != input_lengths.shape[0]: - padded_input_lengths = F.pad( - input_lengths, - (0, padded_bs - orig_bs), + slots_pad = torch.zeros_like(input_ids) + slots_pad[: slots.shape[0]] = slots + slots = slots_pad + orig_bs = len(batch) + padded_bs = batch.input_lengths_tensor.shape[0] + padded_input_len = input_ids.view(padded_bs, -1).shape[-1] + image_indices = torch.tensor(batch.image_indices) + + if cross_attention_states is not None: + cross_attention_states = F.pad( + cross_attention_states, + (0, 0, 0, 0, 0, (padded_bs - orig_bs)), value=0, ) - padded_cache_lengths_tensor = F.pad( - cache_lengths_tensor, - (0, padded_bs - orig_bs), - value=0, - ) - if cu_seqlen_prefill is not None: - cu_seqlen_prefill = torch.zeros( - padded_bs + 1, device=self.device, dtype=torch.int32 - ) - torch.cumsum(padded_input_lengths, -1, out=cu_seqlen_prefill[1:]) - seqlen = Seqlen( - input_lengths=padded_input_lengths, - cache_lengths=padded_cache_lengths_tensor, - cu_seqlen_q=cu_seqlen_prefill, - ) - - input_ids = F.pad( - input_ids, (0, (padded_bs - orig_bs) * padded_input_len), value=0 - ) - position_ids = F.pad( - position_ids, (0, (padded_bs - orig_bs) * padded_input_len), value=1 - ) - slots = F.pad(slots, (0, (padded_bs - orig_bs) * padded_input_len), value=0) - if lm_head_indices is not None: - lm_head_indices = F.pad( - lm_head_indices, (0, padded_bs - orig_bs), value=0 - ) - if cross_attention_states is not None: - cross_attention_states = F.pad( - cross_attention_states, - (0, 0, 0, 0, 0, (padded_bs - orig_bs)), - value=0, - ) - if len(image_indices) != 0: - pad_indices = torch.arange(orig_bs, padded_bs, device=self.device) - image_indices = torch.cat((image_indices, pad_indices), dim=0) - else: - seqlen = Seqlen( - input_lengths=input_lengths, - cache_lengths=cache_lengths_tensor, - cu_seqlen_q=cu_seqlen_prefill, - ) + if len(image_indices) != 0: + pad_indices = torch.arange(orig_bs, padded_bs) + image_indices = torch.cat((image_indices, pad_indices), dim=0) indices, cross_attention_len = generate_cross_attention_states( cross_attention_states, image_indices, - seqlen, + input_lengths, padded_input_len, batch.prefilling, ) + seqlen = Seqlen( + input_lengths=_async_h2d_tensor_copy(input_lengths), + ) logits, speculative_logits = self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, + input_ids=_async_h2d_tensor_copy(input_ids), + position_ids=_async_h2d_tensor_copy(position_ids), + cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill), kv_cache=kv_cache, - slots=slots, + slots=_async_h2d_tensor_copy(slots), seqlen=trim_seqlen_metadata(seqlen), hpu_attention_meta=batch.hpu_attn_meta, - lm_head_indices=lm_head_indices, + lm_head_indices=_async_h2d_tensor_copy(lm_head_indices), # TODO list adapter_data=None, cross_attention_states=cross_attention_states, - indices=indices, - cross_attention_len=cross_attention_len, + indices=_async_h2d_tensor_copy(indices), + cross_attention_len=_async_h2d_tensor_copy(cross_attention_len), **kwargs, ) if batch.pixel_values is not None: batch.pixel_values = None - return logits[:orig_bs], ( - speculative_logits[:orig_bs] if speculative_logits is not None else None - ) + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/utils/tokens.py b/backends/gaudi/server/text_generation_server/utils/tokens.py index 9c44ba15..9f5ffb87 100644 --- a/backends/gaudi/server/text_generation_server/utils/tokens.py +++ b/backends/gaudi/server/text_generation_server/utils/tokens.py @@ -552,8 +552,13 @@ def pad_next_token_chooser_parameters( class Sampling: def __init__(self, seed: int, device: str = "cpu"): - self.generator = torch.Generator("cpu") - self.generator.manual_seed(seed) + if device in ["hpu", torch.device("hpu")]: + import habana_frameworks.torch.hpu.random as htrandom + + self.generator = htrandom.default_generators[0].manual_seed(seed) + else: + self.generator = torch.Generator("cpu") + self.generator.manual_seed(seed) self.seed = seed def __call__(self, logits):