Move input_ids to hpu and remove disposal of adapter_meta (#3237)

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi 2025-05-22 15:21:31 +08:00 committed by GitHub
parent e32528792c
commit 9e7e546923
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 136 additions and 107 deletions

View File

@ -90,6 +90,8 @@ class Seqlen:
def _async_h2d_tensor_copy(source, device="hpu"): def _async_h2d_tensor_copy(source, device="hpu"):
if source is None: if source is None:
return None return None
if source.device.type == "hpu":
return source
assert source.device.type == "cpu", "Source tensor is not present in host memory!" assert source.device.type == "cpu", "Source tensor is not present in host memory!"
target = torch.empty(source.shape, dtype=source.dtype, device=device) target = torch.empty(source.shape, dtype=source.dtype, device=device)
target.copy_(source, non_blocking=True) target.copy_(source, non_blocking=True)

View File

@ -634,14 +634,16 @@ class FlashCausalLMBatch(Batch):
# Index into tensors # Index into tensors
input_ids = self.input_ids[indices] input_ids = self.input_ids[indices]
position_ids = self.position_ids[indices] position_ids = self.position_ids[indices]
adapter_indices = self.adapter_meta.adapter_indices[indices]
input_lengths_tensor = self.input_lengths_tensor[indices] input_lengths_tensor = self.input_lengths_tensor[indices]
cache_lengths_tensor = self.cache_lengths_tensor[indices] cache_lengths_tensor = self.cache_lengths_tensor[indices]
# Move to GPU now that we have the whole tensor # Move to GPU now that we have the whole tensor
slot_indices = slot_indices.to(device) slot_indices = slot_indices.to(device)
if self.adapter_meta is not None:
adapter_segments, adapter_segment_indices = find_segments(adapter_indices) adapter_indices = self.adapter_meta.adapter_indices[indices]
adapter_segments, adapter_segment_indices = find_segments(
adapter_indices
)
adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32) adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32)
adapter_meta = AdapterBatchMetadata( adapter_meta = AdapterBatchMetadata(
adapter_indices=adapter_indices, adapter_indices=adapter_indices,
@ -649,6 +651,8 @@ class FlashCausalLMBatch(Batch):
adapter_segments=adapter_segments, adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices, segment_indices=adapter_segment_indices,
) )
else:
adapter_meta = None
htorch.core.mark_step() htorch.core.mark_step()
return type(self)( return type(self)(
batch_id=self.batch_id, batch_id=self.batch_id,
@ -710,6 +714,7 @@ class FlashCausalLMBatch(Batch):
max_length = 0 max_length = 0
max_input_length = 0 max_input_length = 0
max_current_length = 0 max_current_length = 0
ADAPTER_TO_INDEX = get_adapter_to_index()
for b in batches: for b in batches:
total_batch_size += len(b) total_batch_size += len(b)
max_blocks = max(max_blocks, b.max_blocks) max_blocks = max(max_blocks, b.max_blocks)
@ -763,6 +768,7 @@ class FlashCausalLMBatch(Batch):
cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty( cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty(
total_batch_size total_batch_size
) )
if ADAPTER_TO_INDEX:
total_indices_size = sum( total_indices_size = sum(
b.adapter_meta.adapter_indices.shape[0] for b in batches b.adapter_meta.adapter_indices.shape[0] for b in batches
) )
@ -821,9 +827,7 @@ class FlashCausalLMBatch(Batch):
start_index = cumulative_batch_size start_index = cumulative_batch_size
end_index = cumulative_batch_size + valid_bsize end_index = cumulative_batch_size + valid_bsize
index = torch.tensor( index = torch.tensor(list(range(start_index, end_index)), device="cpu")
list(range(start_index, end_index)), device=batch.input_ids.device
)
top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor) top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor)
all_input_ids_tensor[ all_input_ids_tensor[
start_index:end_index, : batch.all_input_ids_tensor.shape[1] start_index:end_index, : batch.all_input_ids_tensor.shape[1]
@ -847,7 +851,9 @@ class FlashCausalLMBatch(Batch):
) )
if not prefilling: if not prefilling:
input_ids.index_copy_(0, index, batch.input_ids[:valid_bsize]) input_ids.index_copy_(
0, index.to(input_ids.device), batch.input_ids[:valid_bsize]
)
position_ids.index_copy_(0, index, batch.position_ids[:valid_bsize]) position_ids.index_copy_(0, index, batch.position_ids[:valid_bsize])
slot_indices.index_copy_( slot_indices.index_copy_(
0, index, batch.slot_indices + cumulative_slots 0, index, batch.slot_indices + cumulative_slots
@ -858,6 +864,7 @@ class FlashCausalLMBatch(Batch):
cache_lengths_tensor.index_copy_( cache_lengths_tensor.index_copy_(
0, index, batch.cache_lengths_tensor[:valid_bsize] 0, index, batch.cache_lengths_tensor[:valid_bsize]
) )
if ADAPTER_TO_INDEX:
adapter_start_index = cumulative_adapter_indices_size adapter_start_index = cumulative_adapter_indices_size
adapter_end_index = ( adapter_end_index = (
cumulative_adapter_indices_size cumulative_adapter_indices_size
@ -914,7 +921,7 @@ class FlashCausalLMBatch(Batch):
else: else:
speculative_ids = None speculative_ids = None
if adapter_segment_builder is not None: if ADAPTER_TO_INDEX and adapter_segment_builder is not None:
adapter_segments, adapter_segment_indices = adapter_segment_builder.build() adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
adapter_meta = AdapterBatchMetadata( adapter_meta = AdapterBatchMetadata(
adapter_indices=adapter_indices, adapter_indices=adapter_indices,
@ -961,7 +968,7 @@ class FlashCausalLMBatch(Batch):
num_blocks=num_blocks, num_blocks=num_blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
speculative_ids=speculative_ids, speculative_ids=speculative_ids,
adapter_meta=adapter_meta, adapter_meta=adapter_meta if ADAPTER_TO_INDEX else None,
hpu_attn_meta=None, hpu_attn_meta=None,
next_token_logits=None, next_token_logits=None,
speculative_logits=None, speculative_logits=None,
@ -1037,6 +1044,7 @@ class FlashCausalLMBatch(Batch):
# need extra pad to match warmup seq # need extra pad to match warmup seq
extra_pad = max_padded_input_len - self.max_input_length extra_pad = max_padded_input_len - self.max_input_length
extra_pad_bs = max_padded_bs - len(self) extra_pad_bs = max_padded_bs - len(self)
device = self.all_input_ids_tensor.device
if isinstance(self.input_ids, list) and len(self) > 1: if isinstance(self.input_ids, list) and len(self) > 1:
input_ids_padded_length = [] input_ids_padded_length = []
input_ids = [] input_ids = []
@ -1047,12 +1055,12 @@ class FlashCausalLMBatch(Batch):
input_ids.append(input_id) input_ids.append(input_id)
input_ids_padded_length.append(padded) input_ids_padded_length.append(padded)
input_ids = np.concatenate(input_ids, dtype=np.int64) input_ids = np.concatenate(input_ids, dtype=np.int64)
self.input_ids = torch.tensor(input_ids, dtype=torch.int64) self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
elif isinstance(self.input_ids, list): elif isinstance(self.input_ids, list):
input_ids = self.input_ids[0] input_ids = self.input_ids[0]
input_ids_padded_length.append(extra_pad) input_ids_padded_length.append(extra_pad)
input_ids = [0] * extra_pad + input_ids input_ids = [0] * extra_pad + input_ids
self.input_ids = torch.tensor(input_ids, dtype=torch.int64) self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
else: else:
self.input_ids = F.pad(self.input_ids, (extra_pad, 0), value=0) self.input_ids = F.pad(self.input_ids, (extra_pad, 0), value=0)
input_ids_padded_length.extend([extra_pad] * len(self)) input_ids_padded_length.extend([extra_pad] * len(self))
@ -1245,7 +1253,9 @@ class FlashCausalLMBatch(Batch):
self.slot_indices = slot_indices self.slot_indices = slot_indices
self.prefill_cu_outlens = prefill_cu_outlens self.prefill_cu_outlens = prefill_cu_outlens
self.prefill_cache_indices = torch.zeros_like(self.input_ids, dtype=torch.bool) self.prefill_cache_indices = torch.zeros_like(
self.input_ids, dtype=torch.bool, device="cpu"
)
self.prefill_cache_indices[prefill_cache_indices] = True self.prefill_cache_indices[prefill_cache_indices] = True
if all_prefill_logprobs: if all_prefill_logprobs:
@ -1301,9 +1311,12 @@ class FlashCausalLMBatch(Batch):
fsm_grammar_states, fsm_grammar_states,
) )
if ADAPTER_TO_INDEX:
if adapter_set: if adapter_set:
adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64) adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64)
adapter_segments, adapter_segment_indices = find_segments(adapter_indices) adapter_segments, adapter_segment_indices = find_segments(
adapter_indices
)
else: else:
adapter_indices = torch.zeros_like(self.input_ids) adapter_indices = torch.zeros_like(self.input_ids)
adapter_segments = [0, len(adapter_indices)] adapter_segments = [0, len(adapter_indices)]
@ -1941,11 +1954,11 @@ class FlashCausalLM(Model):
# This makes sure the max_s for the decode pass is correct. # This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s) max_s = min(self.max_past(), max_s)
if batch.prefill_cache_indices is not None: if batch.prefill_cache_indices is not None:
slots_pad = torch.zeros_like(input_ids) slots_pad = torch.zeros_like(input_ids, device=slots.device)
slots_pad[batch.prefill_cache_indices] = slots slots_pad[batch.prefill_cache_indices] = slots
slots = slots_pad slots = slots_pad
else: else:
slots_pad = torch.zeros_like(input_ids) slots_pad = torch.zeros_like(input_ids, device=slots.device)
slots_pad[: slots.shape[0]] = slots slots_pad[: slots.shape[0]] = slots
slots = slots_pad slots = slots_pad
seqlen = Seqlen( seqlen = Seqlen(
@ -1965,7 +1978,7 @@ class FlashCausalLM(Model):
) )
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=_async_h2d_tensor_copy(input_ids), input_ids=input_ids,
position_ids=_async_h2d_tensor_copy(position_ids), position_ids=_async_h2d_tensor_copy(position_ids),
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill), cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
kv_cache=kv_cache, kv_cache=kv_cache,
@ -2059,15 +2072,16 @@ class FlashCausalLM(Model):
batch.position_ids = batch.position_ids[indices] batch.position_ids = batch.position_ids[indices]
batch.slot_indices = batch.slot_indices[indices[: len(batch)]] batch.slot_indices = batch.slot_indices[indices[: len(batch)]]
if batch.adapter_meta is not None:
batch.adapter_meta.adapter_indices = ( batch.adapter_meta.adapter_indices = (
batch.adapter_meta.adapter_indices[indices] batch.adapter_meta.adapter_indices[indices]
) )
# For each member of the batch # For each member of the batch
# Cumulative length # Cumulative length
accepted_ids = accepted_ids.cpu()
if batch.speculative_logits is not None:
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1) cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:]) torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
if batch.speculative_logits is not None:
for i in range(len(batch)): for i in range(len(batch)):
batch.all_input_ids_tensor[ batch.all_input_ids_tensor[
i, i,
@ -2076,6 +2090,20 @@ class FlashCausalLM(Model):
+ batch.input_lengths[i] + batch.input_lengths[i]
+ accepted_ids[i], + accepted_ids[i],
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]] ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
accepted_ids = accepted_ids.cpu()
if batch.position_ids.dim() == 2:
# Qwen2_vl case:
batch.position_ids += accepted_ids.unsqueeze(-1)
else:
batch.position_ids += accepted_ids
batch.cache_lengths_tensor += (
batch.input_lengths_tensor + accepted_ids - 1
)
batch.input_lengths_tensor = torch.ones_like(
batch.input_lengths_tensor
)
batch.slot_indices += accepted_ids[: len(batch)]
else: else:
index = batch.cache_lengths_tensor + batch.input_lengths_tensor index = batch.cache_lengths_tensor + batch.input_lengths_tensor
index = index.to(batch.all_input_ids_tensor.device) index = index.to(batch.all_input_ids_tensor.device)
@ -2088,22 +2116,18 @@ class FlashCausalLM(Model):
batch.all_input_ids_tensor.index_put_( batch.all_input_ids_tensor.index_put_(
(batch_idx, index.long()), next_input_ids (batch_idx, index.long()), next_input_ids
) )
next_input_ids = next_input_ids.cpu() batch.input_ids = next_input_ids
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1] batch.position_ids += 1
batch.speculative_ids = speculative_ids batch.cache_lengths_tensor += batch.input_lengths_tensor
if batch.position_ids.dim() == 2: batch.input_lengths_tensor = torch.ones_like(
# Qwen2_vl case: batch.input_lengths_tensor
batch.position_ids += accepted_ids.unsqueeze(-1)
else:
batch.position_ids += accepted_ids
batch.cache_lengths_tensor += (
batch.input_lengths_tensor + accepted_ids - 1
) )
batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor) batch.slot_indices += 1
batch.slot_indices += accepted_ids[: len(batch)]
batch.speculative_ids = speculative_ids
# Does a HPU <-> CPU sync internally # Does a HPU <-> CPU sync internally
if prefill: if prefill and batch.adapter_meta is not None:
# adjust segment lengths to account for all request lengths being 1 during decoding # adjust segment lengths to account for all request lengths being 1 during decoding
adapter_segments, _ = find_segments( adapter_segments, _ = find_segments(
batch.adapter_meta.adapter_indices batch.adapter_meta.adapter_indices
@ -2194,6 +2218,7 @@ class FlashCausalLM(Model):
prefill_logprobs = batch.prefill_next_token_indices is not None prefill_logprobs = batch.prefill_next_token_indices is not None
# Update adapter indices for speculative tokens (if present) # Update adapter indices for speculative tokens (if present)
adapter_meta = batch.adapter_meta adapter_meta = batch.adapter_meta
if adapter_meta is not None:
if batch.speculative_ids is not None: if batch.speculative_ids is not None:
B, speculative_length = batch.speculative_ids.shape B, speculative_length = batch.speculative_ids.shape
new_length = speculative_length + 1 new_length = speculative_length + 1
@ -2218,6 +2243,8 @@ class FlashCausalLM(Model):
prefill, prefill,
batch.prefill_head_indices, batch.prefill_head_indices,
) )
else:
adapter_data = None
out, speculative_logits = self.forward(batch, adapter_data) out, speculative_logits = self.forward(batch, adapter_data)

View File

@ -627,11 +627,11 @@ class FlashVlmCausalLM(FlashCausalLM):
batch.prefilling, seqlen, batch_size batch.prefilling, seqlen, batch_size
) )
if batch.prefill_cache_indices is not None: if batch.prefill_cache_indices is not None:
slots_pad = torch.zeros_like(input_ids) slots_pad = torch.zeros_like(input_ids, device=slots.device)
slots_pad[batch.prefill_cache_indices] = slots slots_pad[batch.prefill_cache_indices] = slots
slots = slots_pad slots = slots_pad
else: else:
slots_pad = torch.zeros_like(input_ids) slots_pad = torch.zeros_like(input_ids, device=slots.device)
slots_pad[: slots.shape[0]] = slots slots_pad[: slots.shape[0]] = slots
slots = slots_pad slots = slots_pad
@ -639,7 +639,7 @@ class FlashVlmCausalLM(FlashCausalLM):
input_lengths=_async_h2d_tensor_copy(input_lengths), input_lengths=_async_h2d_tensor_copy(input_lengths),
) )
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=_async_h2d_tensor_copy(input_ids), input_ids=input_ids,
position_ids=_async_h2d_tensor_copy(position_ids), position_ids=_async_h2d_tensor_copy(position_ids),
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill), cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
kv_cache=kv_cache, kv_cache=kv_cache,

View File

@ -190,7 +190,7 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
input_ids = np.concatenate(batch.input_ids, dtype=np.int64) input_ids = np.concatenate(batch.input_ids, dtype=np.int64)
else: else:
input_ids = batch.input_ids[0] input_ids = batch.input_ids[0]
batch.input_ids = torch.tensor(input_ids, dtype=torch.int64) batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1) batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1)
@ -537,11 +537,11 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
) )
if batch.prefill_cache_indices is not None: if batch.prefill_cache_indices is not None:
slots_pad = torch.zeros_like(input_ids) slots_pad = torch.zeros_like(input_ids, device=slots.device)
slots_pad[batch.prefill_cache_indices] = slots slots_pad[batch.prefill_cache_indices] = slots
slots = slots_pad slots = slots_pad
else: else:
slots_pad = torch.zeros_like(input_ids) slots_pad = torch.zeros_like(input_ids, device=slots.device)
slots_pad[: slots.shape[0]] = slots slots_pad[: slots.shape[0]] = slots
slots = slots_pad slots = slots_pad
orig_bs = len(batch) orig_bs = len(batch)
@ -570,7 +570,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
input_lengths=_async_h2d_tensor_copy(input_lengths), input_lengths=_async_h2d_tensor_copy(input_lengths),
) )
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=_async_h2d_tensor_copy(input_ids), input_ids=input_ids,
position_ids=_async_h2d_tensor_copy(position_ids), position_ids=_async_h2d_tensor_copy(position_ids),
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill), cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
kv_cache=kv_cache, kv_cache=kv_cache,