mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-24 12:32:11 +00:00
Move input_ids to hpu and remove disposal of adapter_meta (#3237)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
e32528792c
commit
9e7e546923
@ -90,6 +90,8 @@ class Seqlen:
|
|||||||
def _async_h2d_tensor_copy(source, device="hpu"):
|
def _async_h2d_tensor_copy(source, device="hpu"):
|
||||||
if source is None:
|
if source is None:
|
||||||
return None
|
return None
|
||||||
|
if source.device.type == "hpu":
|
||||||
|
return source
|
||||||
assert source.device.type == "cpu", "Source tensor is not present in host memory!"
|
assert source.device.type == "cpu", "Source tensor is not present in host memory!"
|
||||||
target = torch.empty(source.shape, dtype=source.dtype, device=device)
|
target = torch.empty(source.shape, dtype=source.dtype, device=device)
|
||||||
target.copy_(source, non_blocking=True)
|
target.copy_(source, non_blocking=True)
|
||||||
|
@ -634,14 +634,16 @@ class FlashCausalLMBatch(Batch):
|
|||||||
# Index into tensors
|
# Index into tensors
|
||||||
input_ids = self.input_ids[indices]
|
input_ids = self.input_ids[indices]
|
||||||
position_ids = self.position_ids[indices]
|
position_ids = self.position_ids[indices]
|
||||||
adapter_indices = self.adapter_meta.adapter_indices[indices]
|
|
||||||
input_lengths_tensor = self.input_lengths_tensor[indices]
|
input_lengths_tensor = self.input_lengths_tensor[indices]
|
||||||
cache_lengths_tensor = self.cache_lengths_tensor[indices]
|
cache_lengths_tensor = self.cache_lengths_tensor[indices]
|
||||||
|
|
||||||
# Move to GPU now that we have the whole tensor
|
# Move to GPU now that we have the whole tensor
|
||||||
slot_indices = slot_indices.to(device)
|
slot_indices = slot_indices.to(device)
|
||||||
|
if self.adapter_meta is not None:
|
||||||
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
|
adapter_indices = self.adapter_meta.adapter_indices[indices]
|
||||||
|
adapter_segments, adapter_segment_indices = find_segments(
|
||||||
|
adapter_indices
|
||||||
|
)
|
||||||
adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32)
|
adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32)
|
||||||
adapter_meta = AdapterBatchMetadata(
|
adapter_meta = AdapterBatchMetadata(
|
||||||
adapter_indices=adapter_indices,
|
adapter_indices=adapter_indices,
|
||||||
@ -649,6 +651,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
adapter_segments=adapter_segments,
|
adapter_segments=adapter_segments,
|
||||||
segment_indices=adapter_segment_indices,
|
segment_indices=adapter_segment_indices,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
adapter_meta = None
|
||||||
htorch.core.mark_step()
|
htorch.core.mark_step()
|
||||||
return type(self)(
|
return type(self)(
|
||||||
batch_id=self.batch_id,
|
batch_id=self.batch_id,
|
||||||
@ -710,6 +714,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
max_length = 0
|
max_length = 0
|
||||||
max_input_length = 0
|
max_input_length = 0
|
||||||
max_current_length = 0
|
max_current_length = 0
|
||||||
|
ADAPTER_TO_INDEX = get_adapter_to_index()
|
||||||
for b in batches:
|
for b in batches:
|
||||||
total_batch_size += len(b)
|
total_batch_size += len(b)
|
||||||
max_blocks = max(max_blocks, b.max_blocks)
|
max_blocks = max(max_blocks, b.max_blocks)
|
||||||
@ -763,6 +768,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty(
|
cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty(
|
||||||
total_batch_size
|
total_batch_size
|
||||||
)
|
)
|
||||||
|
if ADAPTER_TO_INDEX:
|
||||||
total_indices_size = sum(
|
total_indices_size = sum(
|
||||||
b.adapter_meta.adapter_indices.shape[0] for b in batches
|
b.adapter_meta.adapter_indices.shape[0] for b in batches
|
||||||
)
|
)
|
||||||
@ -821,9 +827,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
start_index = cumulative_batch_size
|
start_index = cumulative_batch_size
|
||||||
end_index = cumulative_batch_size + valid_bsize
|
end_index = cumulative_batch_size + valid_bsize
|
||||||
|
|
||||||
index = torch.tensor(
|
index = torch.tensor(list(range(start_index, end_index)), device="cpu")
|
||||||
list(range(start_index, end_index)), device=batch.input_ids.device
|
|
||||||
)
|
|
||||||
top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor)
|
top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor)
|
||||||
all_input_ids_tensor[
|
all_input_ids_tensor[
|
||||||
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
|
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
|
||||||
@ -847,7 +851,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not prefilling:
|
if not prefilling:
|
||||||
input_ids.index_copy_(0, index, batch.input_ids[:valid_bsize])
|
input_ids.index_copy_(
|
||||||
|
0, index.to(input_ids.device), batch.input_ids[:valid_bsize]
|
||||||
|
)
|
||||||
position_ids.index_copy_(0, index, batch.position_ids[:valid_bsize])
|
position_ids.index_copy_(0, index, batch.position_ids[:valid_bsize])
|
||||||
slot_indices.index_copy_(
|
slot_indices.index_copy_(
|
||||||
0, index, batch.slot_indices + cumulative_slots
|
0, index, batch.slot_indices + cumulative_slots
|
||||||
@ -858,6 +864,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
cache_lengths_tensor.index_copy_(
|
cache_lengths_tensor.index_copy_(
|
||||||
0, index, batch.cache_lengths_tensor[:valid_bsize]
|
0, index, batch.cache_lengths_tensor[:valid_bsize]
|
||||||
)
|
)
|
||||||
|
if ADAPTER_TO_INDEX:
|
||||||
adapter_start_index = cumulative_adapter_indices_size
|
adapter_start_index = cumulative_adapter_indices_size
|
||||||
adapter_end_index = (
|
adapter_end_index = (
|
||||||
cumulative_adapter_indices_size
|
cumulative_adapter_indices_size
|
||||||
@ -914,7 +921,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
else:
|
else:
|
||||||
speculative_ids = None
|
speculative_ids = None
|
||||||
|
|
||||||
if adapter_segment_builder is not None:
|
if ADAPTER_TO_INDEX and adapter_segment_builder is not None:
|
||||||
adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
|
adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
|
||||||
adapter_meta = AdapterBatchMetadata(
|
adapter_meta = AdapterBatchMetadata(
|
||||||
adapter_indices=adapter_indices,
|
adapter_indices=adapter_indices,
|
||||||
@ -961,7 +968,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
num_blocks=num_blocks,
|
num_blocks=num_blocks,
|
||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
speculative_ids=speculative_ids,
|
speculative_ids=speculative_ids,
|
||||||
adapter_meta=adapter_meta,
|
adapter_meta=adapter_meta if ADAPTER_TO_INDEX else None,
|
||||||
hpu_attn_meta=None,
|
hpu_attn_meta=None,
|
||||||
next_token_logits=None,
|
next_token_logits=None,
|
||||||
speculative_logits=None,
|
speculative_logits=None,
|
||||||
@ -1037,6 +1044,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
# need extra pad to match warmup seq
|
# need extra pad to match warmup seq
|
||||||
extra_pad = max_padded_input_len - self.max_input_length
|
extra_pad = max_padded_input_len - self.max_input_length
|
||||||
extra_pad_bs = max_padded_bs - len(self)
|
extra_pad_bs = max_padded_bs - len(self)
|
||||||
|
device = self.all_input_ids_tensor.device
|
||||||
if isinstance(self.input_ids, list) and len(self) > 1:
|
if isinstance(self.input_ids, list) and len(self) > 1:
|
||||||
input_ids_padded_length = []
|
input_ids_padded_length = []
|
||||||
input_ids = []
|
input_ids = []
|
||||||
@ -1047,12 +1055,12 @@ class FlashCausalLMBatch(Batch):
|
|||||||
input_ids.append(input_id)
|
input_ids.append(input_id)
|
||||||
input_ids_padded_length.append(padded)
|
input_ids_padded_length.append(padded)
|
||||||
input_ids = np.concatenate(input_ids, dtype=np.int64)
|
input_ids = np.concatenate(input_ids, dtype=np.int64)
|
||||||
self.input_ids = torch.tensor(input_ids, dtype=torch.int64)
|
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||||
elif isinstance(self.input_ids, list):
|
elif isinstance(self.input_ids, list):
|
||||||
input_ids = self.input_ids[0]
|
input_ids = self.input_ids[0]
|
||||||
input_ids_padded_length.append(extra_pad)
|
input_ids_padded_length.append(extra_pad)
|
||||||
input_ids = [0] * extra_pad + input_ids
|
input_ids = [0] * extra_pad + input_ids
|
||||||
self.input_ids = torch.tensor(input_ids, dtype=torch.int64)
|
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||||
else:
|
else:
|
||||||
self.input_ids = F.pad(self.input_ids, (extra_pad, 0), value=0)
|
self.input_ids = F.pad(self.input_ids, (extra_pad, 0), value=0)
|
||||||
input_ids_padded_length.extend([extra_pad] * len(self))
|
input_ids_padded_length.extend([extra_pad] * len(self))
|
||||||
@ -1245,7 +1253,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
self.slot_indices = slot_indices
|
self.slot_indices = slot_indices
|
||||||
|
|
||||||
self.prefill_cu_outlens = prefill_cu_outlens
|
self.prefill_cu_outlens = prefill_cu_outlens
|
||||||
self.prefill_cache_indices = torch.zeros_like(self.input_ids, dtype=torch.bool)
|
self.prefill_cache_indices = torch.zeros_like(
|
||||||
|
self.input_ids, dtype=torch.bool, device="cpu"
|
||||||
|
)
|
||||||
self.prefill_cache_indices[prefill_cache_indices] = True
|
self.prefill_cache_indices[prefill_cache_indices] = True
|
||||||
|
|
||||||
if all_prefill_logprobs:
|
if all_prefill_logprobs:
|
||||||
@ -1301,9 +1311,12 @@ class FlashCausalLMBatch(Batch):
|
|||||||
fsm_grammar_states,
|
fsm_grammar_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if ADAPTER_TO_INDEX:
|
||||||
if adapter_set:
|
if adapter_set:
|
||||||
adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64)
|
adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64)
|
||||||
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
|
adapter_segments, adapter_segment_indices = find_segments(
|
||||||
|
adapter_indices
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
adapter_indices = torch.zeros_like(self.input_ids)
|
adapter_indices = torch.zeros_like(self.input_ids)
|
||||||
adapter_segments = [0, len(adapter_indices)]
|
adapter_segments = [0, len(adapter_indices)]
|
||||||
@ -1941,11 +1954,11 @@ class FlashCausalLM(Model):
|
|||||||
# This makes sure the max_s for the decode pass is correct.
|
# This makes sure the max_s for the decode pass is correct.
|
||||||
max_s = min(self.max_past(), max_s)
|
max_s = min(self.max_past(), max_s)
|
||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
slots_pad = torch.zeros_like(input_ids)
|
slots_pad = torch.zeros_like(input_ids, device=slots.device)
|
||||||
slots_pad[batch.prefill_cache_indices] = slots
|
slots_pad[batch.prefill_cache_indices] = slots
|
||||||
slots = slots_pad
|
slots = slots_pad
|
||||||
else:
|
else:
|
||||||
slots_pad = torch.zeros_like(input_ids)
|
slots_pad = torch.zeros_like(input_ids, device=slots.device)
|
||||||
slots_pad[: slots.shape[0]] = slots
|
slots_pad[: slots.shape[0]] = slots
|
||||||
slots = slots_pad
|
slots = slots_pad
|
||||||
seqlen = Seqlen(
|
seqlen = Seqlen(
|
||||||
@ -1965,7 +1978,7 @@ class FlashCausalLM(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=_async_h2d_tensor_copy(input_ids),
|
input_ids=input_ids,
|
||||||
position_ids=_async_h2d_tensor_copy(position_ids),
|
position_ids=_async_h2d_tensor_copy(position_ids),
|
||||||
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
|
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
@ -2059,15 +2072,16 @@ class FlashCausalLM(Model):
|
|||||||
batch.position_ids = batch.position_ids[indices]
|
batch.position_ids = batch.position_ids[indices]
|
||||||
|
|
||||||
batch.slot_indices = batch.slot_indices[indices[: len(batch)]]
|
batch.slot_indices = batch.slot_indices[indices[: len(batch)]]
|
||||||
|
if batch.adapter_meta is not None:
|
||||||
batch.adapter_meta.adapter_indices = (
|
batch.adapter_meta.adapter_indices = (
|
||||||
batch.adapter_meta.adapter_indices[indices]
|
batch.adapter_meta.adapter_indices[indices]
|
||||||
)
|
)
|
||||||
# For each member of the batch
|
# For each member of the batch
|
||||||
# Cumulative length
|
# Cumulative length
|
||||||
accepted_ids = accepted_ids.cpu()
|
|
||||||
|
if batch.speculative_logits is not None:
|
||||||
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
|
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
|
||||||
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
|
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
|
||||||
if batch.speculative_logits is not None:
|
|
||||||
for i in range(len(batch)):
|
for i in range(len(batch)):
|
||||||
batch.all_input_ids_tensor[
|
batch.all_input_ids_tensor[
|
||||||
i,
|
i,
|
||||||
@ -2076,6 +2090,20 @@ class FlashCausalLM(Model):
|
|||||||
+ batch.input_lengths[i]
|
+ batch.input_lengths[i]
|
||||||
+ accepted_ids[i],
|
+ accepted_ids[i],
|
||||||
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
|
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
|
||||||
|
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
|
||||||
|
accepted_ids = accepted_ids.cpu()
|
||||||
|
if batch.position_ids.dim() == 2:
|
||||||
|
# Qwen2_vl case:
|
||||||
|
batch.position_ids += accepted_ids.unsqueeze(-1)
|
||||||
|
else:
|
||||||
|
batch.position_ids += accepted_ids
|
||||||
|
batch.cache_lengths_tensor += (
|
||||||
|
batch.input_lengths_tensor + accepted_ids - 1
|
||||||
|
)
|
||||||
|
batch.input_lengths_tensor = torch.ones_like(
|
||||||
|
batch.input_lengths_tensor
|
||||||
|
)
|
||||||
|
batch.slot_indices += accepted_ids[: len(batch)]
|
||||||
else:
|
else:
|
||||||
index = batch.cache_lengths_tensor + batch.input_lengths_tensor
|
index = batch.cache_lengths_tensor + batch.input_lengths_tensor
|
||||||
index = index.to(batch.all_input_ids_tensor.device)
|
index = index.to(batch.all_input_ids_tensor.device)
|
||||||
@ -2088,22 +2116,18 @@ class FlashCausalLM(Model):
|
|||||||
batch.all_input_ids_tensor.index_put_(
|
batch.all_input_ids_tensor.index_put_(
|
||||||
(batch_idx, index.long()), next_input_ids
|
(batch_idx, index.long()), next_input_ids
|
||||||
)
|
)
|
||||||
next_input_ids = next_input_ids.cpu()
|
batch.input_ids = next_input_ids
|
||||||
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
|
batch.position_ids += 1
|
||||||
batch.speculative_ids = speculative_ids
|
batch.cache_lengths_tensor += batch.input_lengths_tensor
|
||||||
if batch.position_ids.dim() == 2:
|
batch.input_lengths_tensor = torch.ones_like(
|
||||||
# Qwen2_vl case:
|
batch.input_lengths_tensor
|
||||||
batch.position_ids += accepted_ids.unsqueeze(-1)
|
|
||||||
else:
|
|
||||||
batch.position_ids += accepted_ids
|
|
||||||
batch.cache_lengths_tensor += (
|
|
||||||
batch.input_lengths_tensor + accepted_ids - 1
|
|
||||||
)
|
)
|
||||||
batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor)
|
batch.slot_indices += 1
|
||||||
batch.slot_indices += accepted_ids[: len(batch)]
|
|
||||||
|
batch.speculative_ids = speculative_ids
|
||||||
|
|
||||||
# Does a HPU <-> CPU sync internally
|
# Does a HPU <-> CPU sync internally
|
||||||
if prefill:
|
if prefill and batch.adapter_meta is not None:
|
||||||
# adjust segment lengths to account for all request lengths being 1 during decoding
|
# adjust segment lengths to account for all request lengths being 1 during decoding
|
||||||
adapter_segments, _ = find_segments(
|
adapter_segments, _ = find_segments(
|
||||||
batch.adapter_meta.adapter_indices
|
batch.adapter_meta.adapter_indices
|
||||||
@ -2194,6 +2218,7 @@ class FlashCausalLM(Model):
|
|||||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||||
# Update adapter indices for speculative tokens (if present)
|
# Update adapter indices for speculative tokens (if present)
|
||||||
adapter_meta = batch.adapter_meta
|
adapter_meta = batch.adapter_meta
|
||||||
|
if adapter_meta is not None:
|
||||||
if batch.speculative_ids is not None:
|
if batch.speculative_ids is not None:
|
||||||
B, speculative_length = batch.speculative_ids.shape
|
B, speculative_length = batch.speculative_ids.shape
|
||||||
new_length = speculative_length + 1
|
new_length = speculative_length + 1
|
||||||
@ -2218,6 +2243,8 @@ class FlashCausalLM(Model):
|
|||||||
prefill,
|
prefill,
|
||||||
batch.prefill_head_indices,
|
batch.prefill_head_indices,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
adapter_data = None
|
||||||
|
|
||||||
out, speculative_logits = self.forward(batch, adapter_data)
|
out, speculative_logits = self.forward(batch, adapter_data)
|
||||||
|
|
||||||
|
@ -627,11 +627,11 @@ class FlashVlmCausalLM(FlashCausalLM):
|
|||||||
batch.prefilling, seqlen, batch_size
|
batch.prefilling, seqlen, batch_size
|
||||||
)
|
)
|
||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
slots_pad = torch.zeros_like(input_ids)
|
slots_pad = torch.zeros_like(input_ids, device=slots.device)
|
||||||
slots_pad[batch.prefill_cache_indices] = slots
|
slots_pad[batch.prefill_cache_indices] = slots
|
||||||
slots = slots_pad
|
slots = slots_pad
|
||||||
else:
|
else:
|
||||||
slots_pad = torch.zeros_like(input_ids)
|
slots_pad = torch.zeros_like(input_ids, device=slots.device)
|
||||||
slots_pad[: slots.shape[0]] = slots
|
slots_pad[: slots.shape[0]] = slots
|
||||||
slots = slots_pad
|
slots = slots_pad
|
||||||
|
|
||||||
@ -639,7 +639,7 @@ class FlashVlmCausalLM(FlashCausalLM):
|
|||||||
input_lengths=_async_h2d_tensor_copy(input_lengths),
|
input_lengths=_async_h2d_tensor_copy(input_lengths),
|
||||||
)
|
)
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=_async_h2d_tensor_copy(input_ids),
|
input_ids=input_ids,
|
||||||
position_ids=_async_h2d_tensor_copy(position_ids),
|
position_ids=_async_h2d_tensor_copy(position_ids),
|
||||||
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
|
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
|
@ -190,7 +190,7 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
|
|||||||
input_ids = np.concatenate(batch.input_ids, dtype=np.int64)
|
input_ids = np.concatenate(batch.input_ids, dtype=np.int64)
|
||||||
else:
|
else:
|
||||||
input_ids = batch.input_ids[0]
|
input_ids = batch.input_ids[0]
|
||||||
batch.input_ids = torch.tensor(input_ids, dtype=torch.int64)
|
batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||||
|
|
||||||
batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1)
|
batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1)
|
||||||
|
|
||||||
@ -537,11 +537,11 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
slots_pad = torch.zeros_like(input_ids)
|
slots_pad = torch.zeros_like(input_ids, device=slots.device)
|
||||||
slots_pad[batch.prefill_cache_indices] = slots
|
slots_pad[batch.prefill_cache_indices] = slots
|
||||||
slots = slots_pad
|
slots = slots_pad
|
||||||
else:
|
else:
|
||||||
slots_pad = torch.zeros_like(input_ids)
|
slots_pad = torch.zeros_like(input_ids, device=slots.device)
|
||||||
slots_pad[: slots.shape[0]] = slots
|
slots_pad[: slots.shape[0]] = slots
|
||||||
slots = slots_pad
|
slots = slots_pad
|
||||||
orig_bs = len(batch)
|
orig_bs = len(batch)
|
||||||
@ -570,7 +570,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
input_lengths=_async_h2d_tensor_copy(input_lengths),
|
input_lengths=_async_h2d_tensor_copy(input_lengths),
|
||||||
)
|
)
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=_async_h2d_tensor_copy(input_ids),
|
input_ids=input_ids,
|
||||||
position_ids=_async_h2d_tensor_copy(position_ids),
|
position_ids=_async_h2d_tensor_copy(position_ids),
|
||||||
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
|
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
|
Loading…
Reference in New Issue
Block a user