mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-24 12:32:11 +00:00
Move input_ids to hpu and remove disposal of adapter_meta (#3237)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
e32528792c
commit
9e7e546923
@ -90,6 +90,8 @@ class Seqlen:
|
||||
def _async_h2d_tensor_copy(source, device="hpu"):
|
||||
if source is None:
|
||||
return None
|
||||
if source.device.type == "hpu":
|
||||
return source
|
||||
assert source.device.type == "cpu", "Source tensor is not present in host memory!"
|
||||
target = torch.empty(source.shape, dtype=source.dtype, device=device)
|
||||
target.copy_(source, non_blocking=True)
|
||||
|
@ -634,21 +634,25 @@ class FlashCausalLMBatch(Batch):
|
||||
# Index into tensors
|
||||
input_ids = self.input_ids[indices]
|
||||
position_ids = self.position_ids[indices]
|
||||
adapter_indices = self.adapter_meta.adapter_indices[indices]
|
||||
input_lengths_tensor = self.input_lengths_tensor[indices]
|
||||
cache_lengths_tensor = self.cache_lengths_tensor[indices]
|
||||
|
||||
# Move to GPU now that we have the whole tensor
|
||||
slot_indices = slot_indices.to(device)
|
||||
|
||||
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
|
||||
adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32)
|
||||
adapter_meta = AdapterBatchMetadata(
|
||||
adapter_indices=adapter_indices,
|
||||
adapter_set=adapter_set,
|
||||
adapter_segments=adapter_segments,
|
||||
segment_indices=adapter_segment_indices,
|
||||
)
|
||||
if self.adapter_meta is not None:
|
||||
adapter_indices = self.adapter_meta.adapter_indices[indices]
|
||||
adapter_segments, adapter_segment_indices = find_segments(
|
||||
adapter_indices
|
||||
)
|
||||
adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32)
|
||||
adapter_meta = AdapterBatchMetadata(
|
||||
adapter_indices=adapter_indices,
|
||||
adapter_set=adapter_set,
|
||||
adapter_segments=adapter_segments,
|
||||
segment_indices=adapter_segment_indices,
|
||||
)
|
||||
else:
|
||||
adapter_meta = None
|
||||
htorch.core.mark_step()
|
||||
return type(self)(
|
||||
batch_id=self.batch_id,
|
||||
@ -710,6 +714,7 @@ class FlashCausalLMBatch(Batch):
|
||||
max_length = 0
|
||||
max_input_length = 0
|
||||
max_current_length = 0
|
||||
ADAPTER_TO_INDEX = get_adapter_to_index()
|
||||
for b in batches:
|
||||
total_batch_size += len(b)
|
||||
max_blocks = max(max_blocks, b.max_blocks)
|
||||
@ -763,14 +768,15 @@ class FlashCausalLMBatch(Batch):
|
||||
cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty(
|
||||
total_batch_size
|
||||
)
|
||||
total_indices_size = sum(
|
||||
b.adapter_meta.adapter_indices.shape[0] for b in batches
|
||||
)
|
||||
adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(
|
||||
total_indices_size
|
||||
)
|
||||
adapter_segment_builder = SegmentConcatBuilder()
|
||||
adapter_set = set()
|
||||
if ADAPTER_TO_INDEX:
|
||||
total_indices_size = sum(
|
||||
b.adapter_meta.adapter_indices.shape[0] for b in batches
|
||||
)
|
||||
adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(
|
||||
total_indices_size
|
||||
)
|
||||
adapter_segment_builder = SegmentConcatBuilder()
|
||||
adapter_set = set()
|
||||
|
||||
prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty(
|
||||
total_batch_size
|
||||
@ -821,9 +827,7 @@ class FlashCausalLMBatch(Batch):
|
||||
start_index = cumulative_batch_size
|
||||
end_index = cumulative_batch_size + valid_bsize
|
||||
|
||||
index = torch.tensor(
|
||||
list(range(start_index, end_index)), device=batch.input_ids.device
|
||||
)
|
||||
index = torch.tensor(list(range(start_index, end_index)), device="cpu")
|
||||
top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor)
|
||||
all_input_ids_tensor[
|
||||
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
|
||||
@ -847,7 +851,9 @@ class FlashCausalLMBatch(Batch):
|
||||
)
|
||||
|
||||
if not prefilling:
|
||||
input_ids.index_copy_(0, index, batch.input_ids[:valid_bsize])
|
||||
input_ids.index_copy_(
|
||||
0, index.to(input_ids.device), batch.input_ids[:valid_bsize]
|
||||
)
|
||||
position_ids.index_copy_(0, index, batch.position_ids[:valid_bsize])
|
||||
slot_indices.index_copy_(
|
||||
0, index, batch.slot_indices + cumulative_slots
|
||||
@ -858,20 +864,21 @@ class FlashCausalLMBatch(Batch):
|
||||
cache_lengths_tensor.index_copy_(
|
||||
0, index, batch.cache_lengths_tensor[:valid_bsize]
|
||||
)
|
||||
adapter_start_index = cumulative_adapter_indices_size
|
||||
adapter_end_index = (
|
||||
cumulative_adapter_indices_size
|
||||
+ batch.adapter_meta.adapter_indices.shape[0]
|
||||
)
|
||||
adapter_indices[adapter_start_index:adapter_end_index] = (
|
||||
batch.adapter_meta.adapter_indices
|
||||
)
|
||||
cumulative_adapter_indices_size = adapter_end_index
|
||||
adapter_set.update(batch.adapter_meta.adapter_set)
|
||||
adapter_segment_builder.concat(
|
||||
batch.adapter_meta.adapter_segments,
|
||||
batch.adapter_meta.segment_indices,
|
||||
)
|
||||
if ADAPTER_TO_INDEX:
|
||||
adapter_start_index = cumulative_adapter_indices_size
|
||||
adapter_end_index = (
|
||||
cumulative_adapter_indices_size
|
||||
+ batch.adapter_meta.adapter_indices.shape[0]
|
||||
)
|
||||
adapter_indices[adapter_start_index:adapter_end_index] = (
|
||||
batch.adapter_meta.adapter_indices
|
||||
)
|
||||
cumulative_adapter_indices_size = adapter_end_index
|
||||
adapter_set.update(batch.adapter_meta.adapter_set)
|
||||
adapter_segment_builder.concat(
|
||||
batch.adapter_meta.adapter_segments,
|
||||
batch.adapter_meta.segment_indices,
|
||||
)
|
||||
else:
|
||||
if isinstance(batch.input_ids, torch.Tensor):
|
||||
batch.input_ids = batch.input_ids.view(-1, 1).tolist()
|
||||
@ -914,7 +921,7 @@ class FlashCausalLMBatch(Batch):
|
||||
else:
|
||||
speculative_ids = None
|
||||
|
||||
if adapter_segment_builder is not None:
|
||||
if ADAPTER_TO_INDEX and adapter_segment_builder is not None:
|
||||
adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
|
||||
adapter_meta = AdapterBatchMetadata(
|
||||
adapter_indices=adapter_indices,
|
||||
@ -961,7 +968,7 @@ class FlashCausalLMBatch(Batch):
|
||||
num_blocks=num_blocks,
|
||||
max_blocks=max_blocks,
|
||||
speculative_ids=speculative_ids,
|
||||
adapter_meta=adapter_meta,
|
||||
adapter_meta=adapter_meta if ADAPTER_TO_INDEX else None,
|
||||
hpu_attn_meta=None,
|
||||
next_token_logits=None,
|
||||
speculative_logits=None,
|
||||
@ -1037,6 +1044,7 @@ class FlashCausalLMBatch(Batch):
|
||||
# need extra pad to match warmup seq
|
||||
extra_pad = max_padded_input_len - self.max_input_length
|
||||
extra_pad_bs = max_padded_bs - len(self)
|
||||
device = self.all_input_ids_tensor.device
|
||||
if isinstance(self.input_ids, list) and len(self) > 1:
|
||||
input_ids_padded_length = []
|
||||
input_ids = []
|
||||
@ -1047,12 +1055,12 @@ class FlashCausalLMBatch(Batch):
|
||||
input_ids.append(input_id)
|
||||
input_ids_padded_length.append(padded)
|
||||
input_ids = np.concatenate(input_ids, dtype=np.int64)
|
||||
self.input_ids = torch.tensor(input_ids, dtype=torch.int64)
|
||||
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||
elif isinstance(self.input_ids, list):
|
||||
input_ids = self.input_ids[0]
|
||||
input_ids_padded_length.append(extra_pad)
|
||||
input_ids = [0] * extra_pad + input_ids
|
||||
self.input_ids = torch.tensor(input_ids, dtype=torch.int64)
|
||||
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||
else:
|
||||
self.input_ids = F.pad(self.input_ids, (extra_pad, 0), value=0)
|
||||
input_ids_padded_length.extend([extra_pad] * len(self))
|
||||
@ -1245,7 +1253,9 @@ class FlashCausalLMBatch(Batch):
|
||||
self.slot_indices = slot_indices
|
||||
|
||||
self.prefill_cu_outlens = prefill_cu_outlens
|
||||
self.prefill_cache_indices = torch.zeros_like(self.input_ids, dtype=torch.bool)
|
||||
self.prefill_cache_indices = torch.zeros_like(
|
||||
self.input_ids, dtype=torch.bool, device="cpu"
|
||||
)
|
||||
self.prefill_cache_indices[prefill_cache_indices] = True
|
||||
|
||||
if all_prefill_logprobs:
|
||||
@ -1301,21 +1311,24 @@ class FlashCausalLMBatch(Batch):
|
||||
fsm_grammar_states,
|
||||
)
|
||||
|
||||
if adapter_set:
|
||||
adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64)
|
||||
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
|
||||
else:
|
||||
adapter_indices = torch.zeros_like(self.input_ids)
|
||||
adapter_segments = [0, len(adapter_indices)]
|
||||
adapter_segment_indices = [len(adapter_indices) - 1]
|
||||
if ADAPTER_TO_INDEX:
|
||||
if adapter_set:
|
||||
adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64)
|
||||
adapter_segments, adapter_segment_indices = find_segments(
|
||||
adapter_indices
|
||||
)
|
||||
else:
|
||||
adapter_indices = torch.zeros_like(self.input_ids)
|
||||
adapter_segments = [0, len(adapter_indices)]
|
||||
adapter_segment_indices = [len(adapter_indices) - 1]
|
||||
|
||||
adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32)
|
||||
self.adapter_meta = AdapterBatchMetadata(
|
||||
adapter_indices=adapter_indices,
|
||||
adapter_set=adapter_set,
|
||||
adapter_segments=adapter_segments,
|
||||
segment_indices=adapter_segment_indices,
|
||||
)
|
||||
adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32)
|
||||
self.adapter_meta = AdapterBatchMetadata(
|
||||
adapter_indices=adapter_indices,
|
||||
adapter_set=adapter_set,
|
||||
adapter_segments=adapter_segments,
|
||||
segment_indices=adapter_segment_indices,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.requests)
|
||||
@ -1941,11 +1954,11 @@ class FlashCausalLM(Model):
|
||||
# This makes sure the max_s for the decode pass is correct.
|
||||
max_s = min(self.max_past(), max_s)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
slots_pad = torch.zeros_like(input_ids)
|
||||
slots_pad = torch.zeros_like(input_ids, device=slots.device)
|
||||
slots_pad[batch.prefill_cache_indices] = slots
|
||||
slots = slots_pad
|
||||
else:
|
||||
slots_pad = torch.zeros_like(input_ids)
|
||||
slots_pad = torch.zeros_like(input_ids, device=slots.device)
|
||||
slots_pad[: slots.shape[0]] = slots
|
||||
slots = slots_pad
|
||||
seqlen = Seqlen(
|
||||
@ -1965,7 +1978,7 @@ class FlashCausalLM(Model):
|
||||
)
|
||||
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=_async_h2d_tensor_copy(input_ids),
|
||||
input_ids=input_ids,
|
||||
position_ids=_async_h2d_tensor_copy(position_ids),
|
||||
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
|
||||
kv_cache=kv_cache,
|
||||
@ -2059,15 +2072,16 @@ class FlashCausalLM(Model):
|
||||
batch.position_ids = batch.position_ids[indices]
|
||||
|
||||
batch.slot_indices = batch.slot_indices[indices[: len(batch)]]
|
||||
batch.adapter_meta.adapter_indices = (
|
||||
batch.adapter_meta.adapter_indices[indices]
|
||||
)
|
||||
if batch.adapter_meta is not None:
|
||||
batch.adapter_meta.adapter_indices = (
|
||||
batch.adapter_meta.adapter_indices[indices]
|
||||
)
|
||||
# For each member of the batch
|
||||
# Cumulative length
|
||||
accepted_ids = accepted_ids.cpu()
|
||||
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
|
||||
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
|
||||
|
||||
if batch.speculative_logits is not None:
|
||||
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
|
||||
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
|
||||
for i in range(len(batch)):
|
||||
batch.all_input_ids_tensor[
|
||||
i,
|
||||
@ -2076,6 +2090,20 @@ class FlashCausalLM(Model):
|
||||
+ batch.input_lengths[i]
|
||||
+ accepted_ids[i],
|
||||
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
|
||||
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
|
||||
accepted_ids = accepted_ids.cpu()
|
||||
if batch.position_ids.dim() == 2:
|
||||
# Qwen2_vl case:
|
||||
batch.position_ids += accepted_ids.unsqueeze(-1)
|
||||
else:
|
||||
batch.position_ids += accepted_ids
|
||||
batch.cache_lengths_tensor += (
|
||||
batch.input_lengths_tensor + accepted_ids - 1
|
||||
)
|
||||
batch.input_lengths_tensor = torch.ones_like(
|
||||
batch.input_lengths_tensor
|
||||
)
|
||||
batch.slot_indices += accepted_ids[: len(batch)]
|
||||
else:
|
||||
index = batch.cache_lengths_tensor + batch.input_lengths_tensor
|
||||
index = index.to(batch.all_input_ids_tensor.device)
|
||||
@ -2088,22 +2116,18 @@ class FlashCausalLM(Model):
|
||||
batch.all_input_ids_tensor.index_put_(
|
||||
(batch_idx, index.long()), next_input_ids
|
||||
)
|
||||
next_input_ids = next_input_ids.cpu()
|
||||
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
|
||||
batch.input_ids = next_input_ids
|
||||
batch.position_ids += 1
|
||||
batch.cache_lengths_tensor += batch.input_lengths_tensor
|
||||
batch.input_lengths_tensor = torch.ones_like(
|
||||
batch.input_lengths_tensor
|
||||
)
|
||||
batch.slot_indices += 1
|
||||
|
||||
batch.speculative_ids = speculative_ids
|
||||
if batch.position_ids.dim() == 2:
|
||||
# Qwen2_vl case:
|
||||
batch.position_ids += accepted_ids.unsqueeze(-1)
|
||||
else:
|
||||
batch.position_ids += accepted_ids
|
||||
batch.cache_lengths_tensor += (
|
||||
batch.input_lengths_tensor + accepted_ids - 1
|
||||
)
|
||||
batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor)
|
||||
batch.slot_indices += accepted_ids[: len(batch)]
|
||||
|
||||
# Does a HPU <-> CPU sync internally
|
||||
if prefill:
|
||||
if prefill and batch.adapter_meta is not None:
|
||||
# adjust segment lengths to account for all request lengths being 1 during decoding
|
||||
adapter_segments, _ = find_segments(
|
||||
batch.adapter_meta.adapter_indices
|
||||
@ -2194,30 +2218,33 @@ class FlashCausalLM(Model):
|
||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||
# Update adapter indices for speculative tokens (if present)
|
||||
adapter_meta = batch.adapter_meta
|
||||
if batch.speculative_ids is not None:
|
||||
B, speculative_length = batch.speculative_ids.shape
|
||||
new_length = speculative_length + 1
|
||||
adapter_indices = (
|
||||
adapter_meta.adapter_indices.unsqueeze(-1)
|
||||
.expand(B, new_length)
|
||||
.reshape(-1)
|
||||
)
|
||||
adapter_segments = adapter_meta.adapter_segments * new_length
|
||||
adapter_meta = AdapterBatchMetadata(
|
||||
adapter_indices=adapter_indices,
|
||||
adapter_set=adapter_meta.adapter_set,
|
||||
adapter_segments=adapter_segments,
|
||||
segment_indices=adapter_meta.segment_indices,
|
||||
)
|
||||
if adapter_meta is not None:
|
||||
if batch.speculative_ids is not None:
|
||||
B, speculative_length = batch.speculative_ids.shape
|
||||
new_length = speculative_length + 1
|
||||
adapter_indices = (
|
||||
adapter_meta.adapter_indices.unsqueeze(-1)
|
||||
.expand(B, new_length)
|
||||
.reshape(-1)
|
||||
)
|
||||
adapter_segments = adapter_meta.adapter_segments * new_length
|
||||
adapter_meta = AdapterBatchMetadata(
|
||||
adapter_indices=adapter_indices,
|
||||
adapter_set=adapter_meta.adapter_set,
|
||||
adapter_segments=adapter_segments,
|
||||
segment_indices=adapter_meta.segment_indices,
|
||||
)
|
||||
|
||||
# Assign pointers to adapter weights
|
||||
# TODO(travis): don't update this if indices haven't changed
|
||||
adapter_data = AdapterBatchData.from_meta(
|
||||
adapter_meta,
|
||||
self.layer_to_adapter_weights,
|
||||
prefill,
|
||||
batch.prefill_head_indices,
|
||||
)
|
||||
# Assign pointers to adapter weights
|
||||
# TODO(travis): don't update this if indices haven't changed
|
||||
adapter_data = AdapterBatchData.from_meta(
|
||||
adapter_meta,
|
||||
self.layer_to_adapter_weights,
|
||||
prefill,
|
||||
batch.prefill_head_indices,
|
||||
)
|
||||
else:
|
||||
adapter_data = None
|
||||
|
||||
out, speculative_logits = self.forward(batch, adapter_data)
|
||||
|
||||
|
@ -627,11 +627,11 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
batch.prefilling, seqlen, batch_size
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
slots_pad = torch.zeros_like(input_ids)
|
||||
slots_pad = torch.zeros_like(input_ids, device=slots.device)
|
||||
slots_pad[batch.prefill_cache_indices] = slots
|
||||
slots = slots_pad
|
||||
else:
|
||||
slots_pad = torch.zeros_like(input_ids)
|
||||
slots_pad = torch.zeros_like(input_ids, device=slots.device)
|
||||
slots_pad[: slots.shape[0]] = slots
|
||||
slots = slots_pad
|
||||
|
||||
@ -639,7 +639,7 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
input_lengths=_async_h2d_tensor_copy(input_lengths),
|
||||
)
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=_async_h2d_tensor_copy(input_ids),
|
||||
input_ids=input_ids,
|
||||
position_ids=_async_h2d_tensor_copy(position_ids),
|
||||
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
|
||||
kv_cache=kv_cache,
|
||||
|
@ -190,7 +190,7 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
|
||||
input_ids = np.concatenate(batch.input_ids, dtype=np.int64)
|
||||
else:
|
||||
input_ids = batch.input_ids[0]
|
||||
batch.input_ids = torch.tensor(input_ids, dtype=torch.int64)
|
||||
batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||
|
||||
batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1)
|
||||
|
||||
@ -537,11 +537,11 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
)
|
||||
|
||||
if batch.prefill_cache_indices is not None:
|
||||
slots_pad = torch.zeros_like(input_ids)
|
||||
slots_pad = torch.zeros_like(input_ids, device=slots.device)
|
||||
slots_pad[batch.prefill_cache_indices] = slots
|
||||
slots = slots_pad
|
||||
else:
|
||||
slots_pad = torch.zeros_like(input_ids)
|
||||
slots_pad = torch.zeros_like(input_ids, device=slots.device)
|
||||
slots_pad[: slots.shape[0]] = slots
|
||||
slots = slots_pad
|
||||
orig_bs = len(batch)
|
||||
@ -570,7 +570,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
input_lengths=_async_h2d_tensor_copy(input_lengths),
|
||||
)
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=_async_h2d_tensor_copy(input_ids),
|
||||
input_ids=input_ids,
|
||||
position_ids=_async_h2d_tensor_copy(position_ids),
|
||||
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
|
||||
kv_cache=kv_cache,
|
||||
|
Loading…
Reference in New Issue
Block a user