Move input_ids to hpu and remove disposal of adapter_meta (#3237)

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi 2025-05-22 15:21:31 +08:00 committed by GitHub
parent e32528792c
commit 9e7e546923
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 136 additions and 107 deletions

View File

@ -90,6 +90,8 @@ class Seqlen:
def _async_h2d_tensor_copy(source, device="hpu"): def _async_h2d_tensor_copy(source, device="hpu"):
if source is None: if source is None:
return None return None
if source.device.type == "hpu":
return source
assert source.device.type == "cpu", "Source tensor is not present in host memory!" assert source.device.type == "cpu", "Source tensor is not present in host memory!"
target = torch.empty(source.shape, dtype=source.dtype, device=device) target = torch.empty(source.shape, dtype=source.dtype, device=device)
target.copy_(source, non_blocking=True) target.copy_(source, non_blocking=True)

View File

@ -634,21 +634,25 @@ class FlashCausalLMBatch(Batch):
# Index into tensors # Index into tensors
input_ids = self.input_ids[indices] input_ids = self.input_ids[indices]
position_ids = self.position_ids[indices] position_ids = self.position_ids[indices]
adapter_indices = self.adapter_meta.adapter_indices[indices]
input_lengths_tensor = self.input_lengths_tensor[indices] input_lengths_tensor = self.input_lengths_tensor[indices]
cache_lengths_tensor = self.cache_lengths_tensor[indices] cache_lengths_tensor = self.cache_lengths_tensor[indices]
# Move to GPU now that we have the whole tensor # Move to GPU now that we have the whole tensor
slot_indices = slot_indices.to(device) slot_indices = slot_indices.to(device)
if self.adapter_meta is not None:
adapter_segments, adapter_segment_indices = find_segments(adapter_indices) adapter_indices = self.adapter_meta.adapter_indices[indices]
adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32) adapter_segments, adapter_segment_indices = find_segments(
adapter_meta = AdapterBatchMetadata( adapter_indices
adapter_indices=adapter_indices, )
adapter_set=adapter_set, adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32)
adapter_segments=adapter_segments, adapter_meta = AdapterBatchMetadata(
segment_indices=adapter_segment_indices, adapter_indices=adapter_indices,
) adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
)
else:
adapter_meta = None
htorch.core.mark_step() htorch.core.mark_step()
return type(self)( return type(self)(
batch_id=self.batch_id, batch_id=self.batch_id,
@ -710,6 +714,7 @@ class FlashCausalLMBatch(Batch):
max_length = 0 max_length = 0
max_input_length = 0 max_input_length = 0
max_current_length = 0 max_current_length = 0
ADAPTER_TO_INDEX = get_adapter_to_index()
for b in batches: for b in batches:
total_batch_size += len(b) total_batch_size += len(b)
max_blocks = max(max_blocks, b.max_blocks) max_blocks = max(max_blocks, b.max_blocks)
@ -763,14 +768,15 @@ class FlashCausalLMBatch(Batch):
cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty( cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty(
total_batch_size total_batch_size
) )
total_indices_size = sum( if ADAPTER_TO_INDEX:
b.adapter_meta.adapter_indices.shape[0] for b in batches total_indices_size = sum(
) b.adapter_meta.adapter_indices.shape[0] for b in batches
adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty( )
total_indices_size adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(
) total_indices_size
adapter_segment_builder = SegmentConcatBuilder() )
adapter_set = set() adapter_segment_builder = SegmentConcatBuilder()
adapter_set = set()
prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty( prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty(
total_batch_size total_batch_size
@ -821,9 +827,7 @@ class FlashCausalLMBatch(Batch):
start_index = cumulative_batch_size start_index = cumulative_batch_size
end_index = cumulative_batch_size + valid_bsize end_index = cumulative_batch_size + valid_bsize
index = torch.tensor( index = torch.tensor(list(range(start_index, end_index)), device="cpu")
list(range(start_index, end_index)), device=batch.input_ids.device
)
top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor) top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor)
all_input_ids_tensor[ all_input_ids_tensor[
start_index:end_index, : batch.all_input_ids_tensor.shape[1] start_index:end_index, : batch.all_input_ids_tensor.shape[1]
@ -847,7 +851,9 @@ class FlashCausalLMBatch(Batch):
) )
if not prefilling: if not prefilling:
input_ids.index_copy_(0, index, batch.input_ids[:valid_bsize]) input_ids.index_copy_(
0, index.to(input_ids.device), batch.input_ids[:valid_bsize]
)
position_ids.index_copy_(0, index, batch.position_ids[:valid_bsize]) position_ids.index_copy_(0, index, batch.position_ids[:valid_bsize])
slot_indices.index_copy_( slot_indices.index_copy_(
0, index, batch.slot_indices + cumulative_slots 0, index, batch.slot_indices + cumulative_slots
@ -858,20 +864,21 @@ class FlashCausalLMBatch(Batch):
cache_lengths_tensor.index_copy_( cache_lengths_tensor.index_copy_(
0, index, batch.cache_lengths_tensor[:valid_bsize] 0, index, batch.cache_lengths_tensor[:valid_bsize]
) )
adapter_start_index = cumulative_adapter_indices_size if ADAPTER_TO_INDEX:
adapter_end_index = ( adapter_start_index = cumulative_adapter_indices_size
cumulative_adapter_indices_size adapter_end_index = (
+ batch.adapter_meta.adapter_indices.shape[0] cumulative_adapter_indices_size
) + batch.adapter_meta.adapter_indices.shape[0]
adapter_indices[adapter_start_index:adapter_end_index] = ( )
batch.adapter_meta.adapter_indices adapter_indices[adapter_start_index:adapter_end_index] = (
) batch.adapter_meta.adapter_indices
cumulative_adapter_indices_size = adapter_end_index )
adapter_set.update(batch.adapter_meta.adapter_set) cumulative_adapter_indices_size = adapter_end_index
adapter_segment_builder.concat( adapter_set.update(batch.adapter_meta.adapter_set)
batch.adapter_meta.adapter_segments, adapter_segment_builder.concat(
batch.adapter_meta.segment_indices, batch.adapter_meta.adapter_segments,
) batch.adapter_meta.segment_indices,
)
else: else:
if isinstance(batch.input_ids, torch.Tensor): if isinstance(batch.input_ids, torch.Tensor):
batch.input_ids = batch.input_ids.view(-1, 1).tolist() batch.input_ids = batch.input_ids.view(-1, 1).tolist()
@ -914,7 +921,7 @@ class FlashCausalLMBatch(Batch):
else: else:
speculative_ids = None speculative_ids = None
if adapter_segment_builder is not None: if ADAPTER_TO_INDEX and adapter_segment_builder is not None:
adapter_segments, adapter_segment_indices = adapter_segment_builder.build() adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
adapter_meta = AdapterBatchMetadata( adapter_meta = AdapterBatchMetadata(
adapter_indices=adapter_indices, adapter_indices=adapter_indices,
@ -961,7 +968,7 @@ class FlashCausalLMBatch(Batch):
num_blocks=num_blocks, num_blocks=num_blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
speculative_ids=speculative_ids, speculative_ids=speculative_ids,
adapter_meta=adapter_meta, adapter_meta=adapter_meta if ADAPTER_TO_INDEX else None,
hpu_attn_meta=None, hpu_attn_meta=None,
next_token_logits=None, next_token_logits=None,
speculative_logits=None, speculative_logits=None,
@ -1037,6 +1044,7 @@ class FlashCausalLMBatch(Batch):
# need extra pad to match warmup seq # need extra pad to match warmup seq
extra_pad = max_padded_input_len - self.max_input_length extra_pad = max_padded_input_len - self.max_input_length
extra_pad_bs = max_padded_bs - len(self) extra_pad_bs = max_padded_bs - len(self)
device = self.all_input_ids_tensor.device
if isinstance(self.input_ids, list) and len(self) > 1: if isinstance(self.input_ids, list) and len(self) > 1:
input_ids_padded_length = [] input_ids_padded_length = []
input_ids = [] input_ids = []
@ -1047,12 +1055,12 @@ class FlashCausalLMBatch(Batch):
input_ids.append(input_id) input_ids.append(input_id)
input_ids_padded_length.append(padded) input_ids_padded_length.append(padded)
input_ids = np.concatenate(input_ids, dtype=np.int64) input_ids = np.concatenate(input_ids, dtype=np.int64)
self.input_ids = torch.tensor(input_ids, dtype=torch.int64) self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
elif isinstance(self.input_ids, list): elif isinstance(self.input_ids, list):
input_ids = self.input_ids[0] input_ids = self.input_ids[0]
input_ids_padded_length.append(extra_pad) input_ids_padded_length.append(extra_pad)
input_ids = [0] * extra_pad + input_ids input_ids = [0] * extra_pad + input_ids
self.input_ids = torch.tensor(input_ids, dtype=torch.int64) self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
else: else:
self.input_ids = F.pad(self.input_ids, (extra_pad, 0), value=0) self.input_ids = F.pad(self.input_ids, (extra_pad, 0), value=0)
input_ids_padded_length.extend([extra_pad] * len(self)) input_ids_padded_length.extend([extra_pad] * len(self))
@ -1245,7 +1253,9 @@ class FlashCausalLMBatch(Batch):
self.slot_indices = slot_indices self.slot_indices = slot_indices
self.prefill_cu_outlens = prefill_cu_outlens self.prefill_cu_outlens = prefill_cu_outlens
self.prefill_cache_indices = torch.zeros_like(self.input_ids, dtype=torch.bool) self.prefill_cache_indices = torch.zeros_like(
self.input_ids, dtype=torch.bool, device="cpu"
)
self.prefill_cache_indices[prefill_cache_indices] = True self.prefill_cache_indices[prefill_cache_indices] = True
if all_prefill_logprobs: if all_prefill_logprobs:
@ -1301,21 +1311,24 @@ class FlashCausalLMBatch(Batch):
fsm_grammar_states, fsm_grammar_states,
) )
if adapter_set: if ADAPTER_TO_INDEX:
adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64) if adapter_set:
adapter_segments, adapter_segment_indices = find_segments(adapter_indices) adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64)
else: adapter_segments, adapter_segment_indices = find_segments(
adapter_indices = torch.zeros_like(self.input_ids) adapter_indices
adapter_segments = [0, len(adapter_indices)] )
adapter_segment_indices = [len(adapter_indices) - 1] else:
adapter_indices = torch.zeros_like(self.input_ids)
adapter_segments = [0, len(adapter_indices)]
adapter_segment_indices = [len(adapter_indices) - 1]
adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32) adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32)
self.adapter_meta = AdapterBatchMetadata( self.adapter_meta = AdapterBatchMetadata(
adapter_indices=adapter_indices, adapter_indices=adapter_indices,
adapter_set=adapter_set, adapter_set=adapter_set,
adapter_segments=adapter_segments, adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices, segment_indices=adapter_segment_indices,
) )
def __len__(self): def __len__(self):
return len(self.requests) return len(self.requests)
@ -1941,11 +1954,11 @@ class FlashCausalLM(Model):
# This makes sure the max_s for the decode pass is correct. # This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s) max_s = min(self.max_past(), max_s)
if batch.prefill_cache_indices is not None: if batch.prefill_cache_indices is not None:
slots_pad = torch.zeros_like(input_ids) slots_pad = torch.zeros_like(input_ids, device=slots.device)
slots_pad[batch.prefill_cache_indices] = slots slots_pad[batch.prefill_cache_indices] = slots
slots = slots_pad slots = slots_pad
else: else:
slots_pad = torch.zeros_like(input_ids) slots_pad = torch.zeros_like(input_ids, device=slots.device)
slots_pad[: slots.shape[0]] = slots slots_pad[: slots.shape[0]] = slots
slots = slots_pad slots = slots_pad
seqlen = Seqlen( seqlen = Seqlen(
@ -1965,7 +1978,7 @@ class FlashCausalLM(Model):
) )
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=_async_h2d_tensor_copy(input_ids), input_ids=input_ids,
position_ids=_async_h2d_tensor_copy(position_ids), position_ids=_async_h2d_tensor_copy(position_ids),
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill), cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
kv_cache=kv_cache, kv_cache=kv_cache,
@ -2059,15 +2072,16 @@ class FlashCausalLM(Model):
batch.position_ids = batch.position_ids[indices] batch.position_ids = batch.position_ids[indices]
batch.slot_indices = batch.slot_indices[indices[: len(batch)]] batch.slot_indices = batch.slot_indices[indices[: len(batch)]]
batch.adapter_meta.adapter_indices = ( if batch.adapter_meta is not None:
batch.adapter_meta.adapter_indices[indices] batch.adapter_meta.adapter_indices = (
) batch.adapter_meta.adapter_indices[indices]
)
# For each member of the batch # For each member of the batch
# Cumulative length # Cumulative length
accepted_ids = accepted_ids.cpu()
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
if batch.speculative_logits is not None: if batch.speculative_logits is not None:
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
for i in range(len(batch)): for i in range(len(batch)):
batch.all_input_ids_tensor[ batch.all_input_ids_tensor[
i, i,
@ -2076,6 +2090,20 @@ class FlashCausalLM(Model):
+ batch.input_lengths[i] + batch.input_lengths[i]
+ accepted_ids[i], + accepted_ids[i],
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]] ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
accepted_ids = accepted_ids.cpu()
if batch.position_ids.dim() == 2:
# Qwen2_vl case:
batch.position_ids += accepted_ids.unsqueeze(-1)
else:
batch.position_ids += accepted_ids
batch.cache_lengths_tensor += (
batch.input_lengths_tensor + accepted_ids - 1
)
batch.input_lengths_tensor = torch.ones_like(
batch.input_lengths_tensor
)
batch.slot_indices += accepted_ids[: len(batch)]
else: else:
index = batch.cache_lengths_tensor + batch.input_lengths_tensor index = batch.cache_lengths_tensor + batch.input_lengths_tensor
index = index.to(batch.all_input_ids_tensor.device) index = index.to(batch.all_input_ids_tensor.device)
@ -2088,22 +2116,18 @@ class FlashCausalLM(Model):
batch.all_input_ids_tensor.index_put_( batch.all_input_ids_tensor.index_put_(
(batch_idx, index.long()), next_input_ids (batch_idx, index.long()), next_input_ids
) )
next_input_ids = next_input_ids.cpu() batch.input_ids = next_input_ids
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1] batch.position_ids += 1
batch.cache_lengths_tensor += batch.input_lengths_tensor
batch.input_lengths_tensor = torch.ones_like(
batch.input_lengths_tensor
)
batch.slot_indices += 1
batch.speculative_ids = speculative_ids batch.speculative_ids = speculative_ids
if batch.position_ids.dim() == 2:
# Qwen2_vl case:
batch.position_ids += accepted_ids.unsqueeze(-1)
else:
batch.position_ids += accepted_ids
batch.cache_lengths_tensor += (
batch.input_lengths_tensor + accepted_ids - 1
)
batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor)
batch.slot_indices += accepted_ids[: len(batch)]
# Does a HPU <-> CPU sync internally # Does a HPU <-> CPU sync internally
if prefill: if prefill and batch.adapter_meta is not None:
# adjust segment lengths to account for all request lengths being 1 during decoding # adjust segment lengths to account for all request lengths being 1 during decoding
adapter_segments, _ = find_segments( adapter_segments, _ = find_segments(
batch.adapter_meta.adapter_indices batch.adapter_meta.adapter_indices
@ -2194,30 +2218,33 @@ class FlashCausalLM(Model):
prefill_logprobs = batch.prefill_next_token_indices is not None prefill_logprobs = batch.prefill_next_token_indices is not None
# Update adapter indices for speculative tokens (if present) # Update adapter indices for speculative tokens (if present)
adapter_meta = batch.adapter_meta adapter_meta = batch.adapter_meta
if batch.speculative_ids is not None: if adapter_meta is not None:
B, speculative_length = batch.speculative_ids.shape if batch.speculative_ids is not None:
new_length = speculative_length + 1 B, speculative_length = batch.speculative_ids.shape
adapter_indices = ( new_length = speculative_length + 1
adapter_meta.adapter_indices.unsqueeze(-1) adapter_indices = (
.expand(B, new_length) adapter_meta.adapter_indices.unsqueeze(-1)
.reshape(-1) .expand(B, new_length)
) .reshape(-1)
adapter_segments = adapter_meta.adapter_segments * new_length )
adapter_meta = AdapterBatchMetadata( adapter_segments = adapter_meta.adapter_segments * new_length
adapter_indices=adapter_indices, adapter_meta = AdapterBatchMetadata(
adapter_set=adapter_meta.adapter_set, adapter_indices=adapter_indices,
adapter_segments=adapter_segments, adapter_set=adapter_meta.adapter_set,
segment_indices=adapter_meta.segment_indices, adapter_segments=adapter_segments,
) segment_indices=adapter_meta.segment_indices,
)
# Assign pointers to adapter weights # Assign pointers to adapter weights
# TODO(travis): don't update this if indices haven't changed # TODO(travis): don't update this if indices haven't changed
adapter_data = AdapterBatchData.from_meta( adapter_data = AdapterBatchData.from_meta(
adapter_meta, adapter_meta,
self.layer_to_adapter_weights, self.layer_to_adapter_weights,
prefill, prefill,
batch.prefill_head_indices, batch.prefill_head_indices,
) )
else:
adapter_data = None
out, speculative_logits = self.forward(batch, adapter_data) out, speculative_logits = self.forward(batch, adapter_data)

View File

@ -627,11 +627,11 @@ class FlashVlmCausalLM(FlashCausalLM):
batch.prefilling, seqlen, batch_size batch.prefilling, seqlen, batch_size
) )
if batch.prefill_cache_indices is not None: if batch.prefill_cache_indices is not None:
slots_pad = torch.zeros_like(input_ids) slots_pad = torch.zeros_like(input_ids, device=slots.device)
slots_pad[batch.prefill_cache_indices] = slots slots_pad[batch.prefill_cache_indices] = slots
slots = slots_pad slots = slots_pad
else: else:
slots_pad = torch.zeros_like(input_ids) slots_pad = torch.zeros_like(input_ids, device=slots.device)
slots_pad[: slots.shape[0]] = slots slots_pad[: slots.shape[0]] = slots
slots = slots_pad slots = slots_pad
@ -639,7 +639,7 @@ class FlashVlmCausalLM(FlashCausalLM):
input_lengths=_async_h2d_tensor_copy(input_lengths), input_lengths=_async_h2d_tensor_copy(input_lengths),
) )
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=_async_h2d_tensor_copy(input_ids), input_ids=input_ids,
position_ids=_async_h2d_tensor_copy(position_ids), position_ids=_async_h2d_tensor_copy(position_ids),
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill), cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
kv_cache=kv_cache, kv_cache=kv_cache,

View File

@ -190,7 +190,7 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
input_ids = np.concatenate(batch.input_ids, dtype=np.int64) input_ids = np.concatenate(batch.input_ids, dtype=np.int64)
else: else:
input_ids = batch.input_ids[0] input_ids = batch.input_ids[0]
batch.input_ids = torch.tensor(input_ids, dtype=torch.int64) batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1) batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1)
@ -537,11 +537,11 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
) )
if batch.prefill_cache_indices is not None: if batch.prefill_cache_indices is not None:
slots_pad = torch.zeros_like(input_ids) slots_pad = torch.zeros_like(input_ids, device=slots.device)
slots_pad[batch.prefill_cache_indices] = slots slots_pad[batch.prefill_cache_indices] = slots
slots = slots_pad slots = slots_pad
else: else:
slots_pad = torch.zeros_like(input_ids) slots_pad = torch.zeros_like(input_ids, device=slots.device)
slots_pad[: slots.shape[0]] = slots slots_pad[: slots.shape[0]] = slots
slots = slots_pad slots = slots_pad
orig_bs = len(batch) orig_bs = len(batch)
@ -570,7 +570,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
input_lengths=_async_h2d_tensor_copy(input_lengths), input_lengths=_async_h2d_tensor_copy(input_lengths),
) )
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=_async_h2d_tensor_copy(input_ids), input_ids=input_ids,
position_ids=_async_h2d_tensor_copy(position_ids), position_ids=_async_h2d_tensor_copy(position_ids),
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill), cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
kv_cache=kv_cache, kv_cache=kv_cache,