forward and tokenize chooser use the same shape (#3196)

* forward and tokenize chooser use the same shape
concate or filter happened to cpu tensor to avoid dynamic shape in hpu

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* use hpu set seed

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi 2025-05-06 16:49:32 +08:00 committed by GitHub
parent 51a0b9d11c
commit 533eee50dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 376 additions and 481 deletions

View File

@ -3,6 +3,7 @@ from .common import (
HPUPagedAttentionMetadata,
trim_attn_metadata,
trim_seqlen_metadata,
_async_h2d_tensor_copy,
)
from .hpu import (
@ -25,4 +26,5 @@ __all__ = [
"HPUPagedAttentionMetadata",
"trim_seqlen_metadata",
"trim_attn_metadata",
"_async_h2d_tensor_copy",
]

View File

@ -75,42 +75,27 @@ def trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object:
@dataclass
class Seqlen:
input_lengths: torch.Tensor
cache_lengths: torch.Tensor
cu_seqlen_q: Optional[torch.Tensor]
cu_seqlen_k: Optional[torch.Tensor]
def __init__(
self,
input_lengths,
cache_lengths,
cu_seqlen_q=None,
):
self.input_lengths = input_lengths
self.cache_lengths = cache_lengths
device = self.input_lengths.device
shape = self.input_lengths.shape
if cu_seqlen_q is None:
cu_seqlen_q = torch.arange(
shape[0] + 1,
device=device,
dtype=torch.int32,
)
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
# cuda graphs don't like this and this is necessary to clamp within mistral
# Although FA2 might not want the clamping
# cu_seqlen_k[0] = 0
total = self.input_lengths + self.cache_lengths
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
self.cu_seqlen_q = cu_seqlen_q
self.cu_seqlen_k = cu_seqlen_k
def clamp(self, max):
# Flash decoding doesn't need to clamp
return self
def _async_h2d_tensor_copy(source, device="hpu"):
if source is None:
return None
assert source.device.type == "cpu", "Source tensor is not present in host memory!"
target = torch.empty(source.shape, dtype=source.dtype, device=device)
target.copy_(source, non_blocking=True)
return target
def trim_seqlen_metadata(metadata: Seqlen) -> object:
# NOTE(kzawora): To anyone working on this in the future:
# Trimming metadata is required when using HPUGraphs.
@ -137,9 +122,6 @@ def trim_seqlen_metadata(metadata: Seqlen) -> object:
"TrimmedSeqlen",
[
"input_lengths",
"cache_lengths",
"cu_seqlen_q",
"cu_seqlen_k",
],
)
return attention_metadata

View File

@ -36,6 +36,7 @@ from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
pad_next_token_chooser_parameters,
)
from text_generation_server.models.types import (
Batch,
@ -56,6 +57,7 @@ from text_generation_server.layers.attention import (
HPUPagedAttentionMetadata,
trim_attn_metadata,
trim_seqlen_metadata,
_async_h2d_tensor_copy,
)
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION
@ -141,18 +143,23 @@ def prepare_for_decode(
block_groups = pad_list(block_groups, block_bucket_size, -1)
block_usage = pad_list(block_usage, block_bucket_size, 1)
block_list = torch.tensor(block_list, dtype=torch.int, device=device)
block_groups = torch.tensor(block_groups, dtype=torch.int, device=device)
block_usage = torch.tensor(block_usage, dtype=dtype, device=device)
block_mapping = torch.nn.functional.one_hot(block_groups, num_classes=batch_size)
block_list = torch.tensor(block_list, dtype=torch.int, device="cpu")
block_groups = torch.tensor(block_groups, dtype=torch.int, device="cpu")
block_usage = torch.tensor(block_usage, dtype=dtype, device="cpu")
block_list_device = _async_h2d_tensor_copy(block_list)
block_groups_device = _async_h2d_tensor_copy(block_groups)
block_usage_device = _async_h2d_tensor_copy(block_usage)
block_mapping = torch.nn.functional.one_hot(
block_groups_device, num_classes=batch_size
)
mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0)
mask = mask >= block_usage.unsqueeze(-1)
attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)
return trim_attn_metadata(
HPUPagedAttentionMetadata(
block_list=block_list,
block_groups=block_groups,
block_usage=block_usage,
block_list=block_list_device,
block_groups=block_groups_device,
block_usage=block_usage_device,
block_mapping=block_mapping.to(dtype),
attn_bias=attn_bias,
)
@ -248,6 +255,7 @@ class FlashCausalLMBatch(Batch):
next_token_logits: Optional[torch.Tensor]
speculative_logits: Optional[torch.Tensor]
valid_indices: Optional[List[int]]
def to_pb(self) -> generate_pb2.CachedBatch:
return generate_pb2.CachedBatch(
@ -417,32 +425,23 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor[i, : len(input_ids)] = input_ids
# Create tensors on device
all_input_ids_tensor = torch.tensor(
all_input_ids_tensor, dtype=torch.int64, device=device
)
all_input_ids_tensor = torch.tensor(all_input_ids_tensor, dtype=torch.int64)
top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
top_n_tokens_tensor = torch.tensor(top_n_tokens, dtype=torch.int64)
block_tables_ragged = torch.tensor(
block_tables_ragged, device=device, dtype=torch.int32
)
cu_blocks = torch.tensor(cu_blocks, device=device, dtype=torch.int64)
block_tables_ragged = torch.tensor(block_tables_ragged, dtype=torch.int32)
cu_blocks = torch.tensor(cu_blocks, dtype=torch.int64)
block_tables_tensor = torch.empty(
(len(block_tables), max_blocks),
device=device,
dtype=torch.int32,
)
for i, request_blocks in enumerate(block_tables):
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
prompt_lengths_tensor = torch.tensor(
prompt_lengths, dtype=torch.int32, device=device
)
prompt_lengths_tensor = torch.tensor(prompt_lengths, dtype=torch.int32)
slots = torch.tensor(slots, dtype=torch.int64, device=device)
slots = torch.tensor(slots, dtype=torch.int64)
cu_slots = torch.tensor(cu_slots, dtype=torch.int64)
return cls(
@ -488,6 +487,7 @@ class FlashCausalLMBatch(Batch):
hpu_attn_meta=None,
next_token_logits=None,
speculative_logits=None,
valid_indices=None,
)
@classmethod
@ -519,9 +519,7 @@ class FlashCausalLMBatch(Batch):
indices = []
# slots to keep after filtering
slot_filtering_indices = torch.zeros(
self.slots.shape[0], dtype=torch.bool, device=device
)
slot_filtering_indices = torch.zeros(self.slots.shape[0], dtype=torch.bool)
# Create on CPU to only move to GPU once instead of at every copy
slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
@ -544,7 +542,6 @@ class FlashCausalLMBatch(Batch):
prefill_logprob_tokens = []
stopping_criterias = []
top_n_tokens = []
adapter_set = set()
num_blocks = 0
@ -582,7 +579,6 @@ class FlashCausalLMBatch(Batch):
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(self.top_n_tokens[idx])
prefill_logprob_tokens.append(self.prefill_logprob_tokens[idx])
ADAPTER_TO_INDEX = get_adapter_to_index()
@ -614,19 +610,7 @@ class FlashCausalLMBatch(Batch):
max_blocks = max(max_blocks, len(request_block_table))
max_slots = max(max_slots, slot_length)
all_input_ids_tensor = self.all_input_ids_tensor[indices]
next_token_logits = self.next_token_logits[indices]
speculative_logits = (
self.speculative_logits[indices]
if self.speculative_logits is not None
else None
)
block_tables_tensor = self.block_tables_tensor[indices]
next_token_chooser = self.next_token_chooser.filter(indices)
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
speculative_ids = (
self.speculative_ids[indices] if self.speculative_ids is not None else None
)
prompt_lengths_tensor = self.prompt_lengths_tensor[indices]
cu_slots = torch.tensor(cu_slots, dtype=torch.int64)
@ -652,16 +636,14 @@ class FlashCausalLMBatch(Batch):
slot_indices = slot_indices.to(device)
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
adapter_segments = torch.tensor(
adapter_segments, dtype=torch.int32, device=device
)
adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32)
adapter_meta = AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
)
htorch.core.mark_step()
return type(self)(
batch_id=self.batch_id,
requests=requests,
@ -692,18 +674,19 @@ class FlashCausalLMBatch(Batch):
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser,
all_input_ids_tensor=self.all_input_ids_tensor,
next_token_chooser=self.next_token_chooser,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
top_n_tokens=self.top_n_tokens,
top_n_tokens_tensor=self.top_n_tokens_tensor,
num_blocks=num_blocks,
max_blocks=max_blocks,
speculative_ids=speculative_ids,
speculative_ids=self.speculative_ids,
adapter_meta=adapter_meta,
hpu_attn_meta=None,
next_token_logits=next_token_logits,
speculative_logits=speculative_logits,
valid_indices=indices,
next_token_logits=self.next_token_logits,
speculative_logits=self.speculative_logits,
)
@classmethod
@ -820,6 +803,7 @@ class FlashCausalLMBatch(Batch):
for i, batch in enumerate(batches):
requests.extend(batch.requests)
valid_bsize = len(batch)
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
@ -829,16 +813,15 @@ class FlashCausalLMBatch(Batch):
requests_idx_mapping[k] = v + cumulative_batch_size
start_index = cumulative_batch_size
end_index = cumulative_batch_size + len(batch)
end_index = cumulative_batch_size + valid_bsize
# Copy tensors (HPU)
index = torch.tensor(
list(range(start_index, end_index)), device=batch.input_ids.device
)
top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor)
all_input_ids_tensor[
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
] = batch.all_input_ids_tensor[:, :max_length]
] = batch.all_input_ids_tensor[:valid_bsize, :max_length]
block_tables_tensor[
start_index:end_index, : batch.block_tables_tensor.shape[1]
@ -847,19 +830,28 @@ class FlashCausalLMBatch(Batch):
slots_start_index = cumulative_slots
slots_end_index = cumulative_slots + len(batch.slots)
slots[slots_start_index:slots_end_index] = batch.slots
slot_index = torch.tensor(
list(range(slots_start_index, slots_end_index)),
device=batch.slots.device,
)
slots.index_copy_(0, slot_index, batch.slots)
cu_slots[start_index + 1 : end_index + 1] = (
batch.cu_slots[1:] + cumulative_slots
)
if not prefilling:
input_ids.index_copy_(0, index, batch.input_ids)
position_ids.index_copy_(0, index, batch.position_ids)
input_ids.index_copy_(0, index, batch.input_ids[:valid_bsize])
position_ids.index_copy_(0, index, batch.position_ids[:valid_bsize])
slot_indices.index_copy_(
0, index, batch.slot_indices + cumulative_slots
)
input_lengths_tensor.index_copy_(0, index, batch.input_lengths_tensor)
cache_lengths_tensor.index_copy_(0, index, batch.cache_lengths_tensor)
input_lengths_tensor.index_copy_(
0, index, batch.input_lengths_tensor[:valid_bsize]
)
cache_lengths_tensor.index_copy_(
0, index, batch.cache_lengths_tensor[:valid_bsize]
)
adapter_start_index = cumulative_adapter_indices_size
adapter_end_index = (
cumulative_adapter_indices_size
@ -967,6 +959,7 @@ class FlashCausalLMBatch(Batch):
hpu_attn_meta=None,
next_token_logits=None,
speculative_logits=None,
valid_indices=None,
)
def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx):
@ -982,27 +975,53 @@ class FlashCausalLMBatch(Batch):
padded_bs = self.input_ids.shape[0]
slots = self.slots[self.slot_indices]
extra_pad = padded_bs - self.input_ids.shape[0]
if extra_pad != 0:
slots = F.pad(slots, (0, extra_pad), value=0)
block_tables.extend([[0]] * extra_pad)
self.hpu_attn_meta = prepare_for_decode(
dtype,
use_contiguous_pa,
self.block_tables_tensor.device,
slots.cpu(),
"hpu",
slots,
block_tables,
padded_bs,
bucketing_ctx,
)
self.input_ids = F.pad(self.input_ids, (0, extra_pad), value=0)
self.position_ids = F.pad(self.position_ids, (0, extra_pad), value=1)
self.input_lengths_tensor = F.pad(
self.input_lengths_tensor, (0, extra_pad), value=0
)
self.cache_lengths_tensor = F.pad(
self.cache_lengths_tensor, (0, extra_pad), value=0
)
self.all_input_ids_tensor = F.pad(
self.all_input_ids_tensor,
(0, 0, 0, extra_pad),
value=0,
)
next_token_chooser_parameters = []
next_token_chooser_parameters.extend([r.parameters for r in self.requests])
pad_next_token_chooser_parameters(next_token_chooser_parameters, padded_bs)
# update past grammar states
fsm_grammar_states = [0] * padded_bs
def prepare_for_prefill(self, max_padded_input_len):
for i, req in enumerate(self.requests):
fsm_grammar_states[i] = self.next_token_chooser.fsm_grammar_states[i]
self.next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters,
self.next_token_chooser.dtype,
self.next_token_chooser.device,
self.next_token_chooser.tokenizer,
fsm_grammar_states,
)
def prepare_for_prefill(self, max_padded_input_len, max_padded_bs):
# Prepare values if we need to continue prefilling
# Speculation must be ignored while we prefill even with chunking
# it simplifies everything
assert self.speculative_ids is None
device = self.block_tables_tensor.device
# device = self.block_tables_tensor.device
# hpu does not support varlen for prefill, use sdpa instead. so need to pad input_tensor, position
# padding to left to work with sliding window
@ -1011,6 +1030,7 @@ class FlashCausalLMBatch(Batch):
input_ids_padded_length = []
# need extra pad to match warmup seq
extra_pad = max_padded_input_len - self.max_input_length
extra_pad_bs = max_padded_bs - len(self)
if isinstance(self.input_ids, list) and len(self) > 1:
input_ids_padded_length = []
input_ids = []
@ -1021,24 +1041,32 @@ class FlashCausalLMBatch(Batch):
input_ids.append(input_id)
input_ids_padded_length.append(padded)
input_ids = np.concatenate(input_ids, dtype=np.int64)
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
self.input_ids = torch.tensor(input_ids, dtype=torch.int64)
elif isinstance(self.input_ids, list):
input_ids = self.input_ids[0]
input_ids_padded_length.append(extra_pad)
input_ids = [0] * extra_pad + input_ids
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
self.input_ids = torch.tensor(input_ids, dtype=torch.int64)
else:
self.input_ids = F.pad(self.input_ids, (extra_pad, 0), value=0)
input_ids_padded_length.append(extra_pad)
input_ids_padded_length.extend([extra_pad] * len(self))
self.input_lengths_tensor = torch.tensor(
self.input_lengths, dtype=torch.int32, device=device
self.input_ids = F.pad(
self.input_ids, (0, extra_pad_bs * max_padded_input_len), value=0
)
cu_seqlen_prefill = self.input_lengths_tensor.new_zeros(len(self) + 1)
self.input_lengths_tensor = torch.tensor(self.input_lengths, dtype=torch.int32)
self.input_lengths_tensor = F.pad(
self.input_lengths_tensor, (0, extra_pad_bs), value=0
)
cu_seqlen_prefill = self.input_lengths_tensor.new_zeros(max_padded_bs + 1)
torch.cumsum(self.input_lengths_tensor, out=cu_seqlen_prefill[1:], dim=0)
self.cu_seqlen_prefill = cu_seqlen_prefill.to(torch.int32)
self.cache_lengths_tensor = torch.tensor(
self.cache_lengths, dtype=torch.int32, device=device
self.cache_lengths_tensor = torch.tensor(self.cache_lengths, dtype=torch.int32)
self.cache_lengths_tensor = F.pad(
self.cache_lengths_tensor, (0, extra_pad_bs), value=0
)
sliding_window = get_sliding_windows()
@ -1171,7 +1199,7 @@ class FlashCausalLMBatch(Batch):
torch.arange(
cumulative_length,
cumulative_length + input_length,
dtype=torch.int64,
dtype=torch.int32,
)
)
prefill_next_token_indices.append(
@ -1182,7 +1210,7 @@ class FlashCausalLMBatch(Batch):
prefill_head_indices.append(
torch.tensor(
[cumulative_length + input_length - 1],
dtype=torch.int64,
dtype=torch.int32,
)
)
prefill_next_token_indices.append(prefill_out_cumulative_length)
@ -1204,12 +1232,15 @@ class FlashCausalLMBatch(Batch):
slot_indices = slot_indices[0]
prefill_cache_indices = prefill_cache_indices[0]
self.position_ids = position_ids.to(device)
self.slot_indices = slot_indices.to(device)
self.position_ids = position_ids
self.position_ids = F.pad(
self.position_ids, (0, extra_pad_bs * max_padded_input_len), value=1
)
self.slot_indices = slot_indices
self.prefill_cu_outlens = prefill_cu_outlens
self.prefill_cache_indices = torch.zeros_like(self.input_ids, dtype=torch.bool)
self.prefill_cache_indices[prefill_cache_indices.to(device)] = True
self.prefill_cache_indices[prefill_cache_indices] = True
if all_prefill_logprobs:
prefill_head_indices = None
@ -1218,16 +1249,19 @@ class FlashCausalLMBatch(Batch):
prefill_head_indices = self.cu_seqlen_prefill[1:] - 1
prefill_next_token_indices = None
else:
prefill_head_indices = torch.cat(prefill_head_indices).to(device)
prefill_head_indices = torch.cat(prefill_head_indices)
prefill_next_token_indices = torch.tensor(
prefill_next_token_indices, dtype=torch.int64, device=device
prefill_next_token_indices, dtype=torch.int64
)
self.prefill_head_indices = prefill_head_indices
self.prefill_next_token_indices = prefill_next_token_indices
input_ids_padded_length_tensor = torch.cumsum(
torch.tensor(input_ids_padded_length, dtype=torch.int64, device=device),
torch.tensor(input_ids_padded_length, dtype=torch.int32),
dim=-1,
).to(torch.int32)
input_ids_padded_length_tensor = F.pad(
input_ids_padded_length_tensor, (0, extra_pad_bs), value=0
)
if self.prefill_head_indices is not None:
self.prefill_head_indices = (
@ -1239,19 +1273,37 @@ class FlashCausalLMBatch(Batch):
self.prefill_next_token_indices + input_ids_padded_length_tensor
)
if adapter_set:
adapter_indices = torch.cat(adapter_indices_list).to(
dtype=torch.int64, device=device
self.all_input_ids_tensor = F.pad(
self.all_input_ids_tensor,
(0, 0, 0, extra_pad_bs),
value=0,
)
next_token_chooser_parameters = []
next_token_chooser_parameters.extend([r.parameters for r in self.requests])
pad_next_token_chooser_parameters(next_token_chooser_parameters, max_padded_bs)
# update past grammar states
fsm_grammar_states = [0] * max_padded_bs
for i, req in enumerate(self.requests):
fsm_grammar_states[i] = self.next_token_chooser.fsm_grammar_states[i]
self.next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters,
self.next_token_chooser.dtype,
self.next_token_chooser.device,
self.next_token_chooser.tokenizer,
fsm_grammar_states,
)
if adapter_set:
adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64)
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
else:
adapter_indices = torch.zeros_like(self.input_ids)
adapter_segments = [0, len(adapter_indices)]
adapter_segment_indices = [len(adapter_indices) - 1]
adapter_segments = torch.tensor(
adapter_segments, dtype=torch.int32, device=device
)
adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32)
self.adapter_meta = AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_set,
@ -1392,6 +1444,9 @@ class FlashCausalLM(Model):
self.use_contiguous_pa = (
os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true"
)
self.limit_hpu_graphs = (
os.environ.get("LIMIT_HPU_GRAPHS", "false").lower() == "true"
)
super().__init__(
model_id=model_id,
model=model,
@ -1509,8 +1564,17 @@ class FlashCausalLM(Model):
self.device,
)
max_num_seqs = int(os.getenv("MAX_BATCH_SIZE", 128))
if os.getenv("VLLM_PROMPT_SEQ_BUCKET_MAX") is None:
os.environ["VLLM_PROMPT_SEQ_BUCKET_MAX"] = str(max_input_tokens)
if os.getenv("VLLM_DECODE_BLOCK_BUCKET_MAX") is None:
max_total_blocks = (
math.ceil(max_total_tokens / BLOCK_SIZE) * max_num_seqs + 1
)
os.environ["VLLM_DECODE_BLOCK_BUCKET_MAX"] = str(max_total_blocks)
self.bucketing_ctx = HPUBucketingContext(
os.getenv("DECODE_MAX_BS", 128), # self.max_num_seqs, #TODO
max_num_seqs,
os.getenv("PREFILL_MAX_BS", 64), # self.max_num_prefill_seqs, #TODO
BLOCK_SIZE,
num_blocks * BLOCK_SIZE,
@ -1536,6 +1600,7 @@ class FlashCausalLM(Model):
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
for index in range(warmup_times):
self.warmup_prefill(seq_len, batch_size, batch)
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
for i, (batch_size, block_num) in enumerate(
reversed(self.bucketing_ctx.decode_buckets)
@ -1552,62 +1617,51 @@ class FlashCausalLM(Model):
def warmup_prefill(
self, prompt_len: int, batch_size: int, batch: FlashCausalLMBatch
):
input_ids = torch.zeros(
prompt_len, dtype=batch.input_ids.dtype, device=self.device
).repeat(batch_size)
position_ids = torch.arange(
prompt_len, dtype=batch.position_ids.dtype, device=self.device
).repeat(batch_size)
input_ids = torch.zeros(prompt_len, dtype=batch.input_ids.dtype).repeat(
batch_size
)
position_ids = torch.arange(prompt_len, dtype=batch.position_ids.dtype).repeat(
batch_size
)
max_bt = (prompt_len // BLOCK_SIZE + 1) * batch_size
block_tables = torch.arange(
max_bt, dtype=torch.int32, device=self.device
).reshape(batch_size, -1)
block_tables = torch.arange(max_bt, dtype=torch.int32).reshape(batch_size, -1)
slot_acc = []
for i in range(batch_size):
slots = []
for b in block_tables[i]:
slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))
slot_acc.extend(slots[:prompt_len])
slots = torch.tensor(slot_acc, dtype=batch.slots.dtype, device=self.device)
slots = torch.tensor(slot_acc, dtype=batch.slots.dtype)
input_lengths = (
torch.ones(batch_size, dtype=torch.int32, device=self.device) * prompt_len
)
cache_lengths_tensor = torch.zeros(
batch_size, dtype=torch.int32, device=self.device
)
cu_seqlen_prefill = torch.zeros(
batch_size + 1, device=self.device, dtype=torch.int32
)
input_lengths = torch.ones(batch_size, dtype=torch.int32) * prompt_len
cu_seqlen_prefill = torch.zeros(batch_size + 1, dtype=torch.int32)
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
seqlen = Seqlen(
input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
input_lengths=_async_h2d_tensor_copy(input_lengths),
)
lm_head_indices = input_lengths - 1
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = self.limit_hpu_graphs
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
input_ids=_async_h2d_tensor_copy(input_ids),
position_ids=_async_h2d_tensor_copy(position_ids),
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
kv_cache=self.kv_cache,
slots=slots,
slots=_async_h2d_tensor_copy(slots),
seqlen=trim_seqlen_metadata(seqlen),
lm_head_indices=lm_head_indices,
lm_head_indices=_async_h2d_tensor_copy(lm_head_indices),
adapter_data=None,
hpu_attention_meta=None,
**kwargs,
)
def warmup_decode(self, batch_size: int, block_num: int, batch: FlashCausalLMBatch):
input_ids = torch.zeros(
batch_size, dtype=batch.input_ids.dtype, device=self.device
)
position_ids = torch.arange(
batch_size, dtype=batch.position_ids.dtype, device=self.device
)
input_ids = torch.zeros(batch_size, dtype=batch.input_ids.dtype)
position_ids = torch.arange(batch_size, dtype=batch.position_ids.dtype)
blocks = [block_num // batch_size for _ in range(batch_size)]
blocks[0] += block_num % batch_size
past_len = []
@ -1622,19 +1676,12 @@ class FlashCausalLM(Model):
block_tables.append(block_array)
past_len.append(blocks[i] * BLOCK_SIZE - 1)
start_idx += blocks[i]
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
cache_lengths_tensor = torch.tensor(
past_len, dtype=torch.int32, device=self.device
)
cu_seqlen_prefill = torch.zeros(
batch_size + 1, device=self.device, dtype=torch.int32
)
input_lengths = torch.ones(batch_size, dtype=torch.int32)
cu_seqlen_prefill = torch.zeros(batch_size + 1, dtype=torch.int32)
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
seqlen = Seqlen(
input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
input_lengths=_async_h2d_tensor_copy(input_lengths),
)
hpu_attention_meta = prepare_for_decode(
@ -1646,18 +1693,22 @@ class FlashCausalLM(Model):
batch_size,
bucketing_ctx=None,
)
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype)
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = False
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
input_ids=_async_h2d_tensor_copy(input_ids),
position_ids=_async_h2d_tensor_copy(position_ids),
cu_seqlen_prefill=None,
kv_cache=self.kv_cache,
slots=slots_tensor,
slots=_async_h2d_tensor_copy(slots_tensor),
seqlen=trim_seqlen_metadata(seqlen),
lm_head_indices=None,
adapter_data=None,
hpu_attention_meta=hpu_attention_meta,
**kwargs,
)
def forward(
@ -1699,9 +1750,6 @@ class FlashCausalLM(Model):
input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
cache_lengths_tensor = (
batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
).reshape(-1)
# Add Copy the block tables for all members
block_tables = (
@ -1722,7 +1770,6 @@ class FlashCausalLM(Model):
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
cache_lengths_tensor = batch.cache_lengths_tensor
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
@ -1735,80 +1782,34 @@ class FlashCausalLM(Model):
slots_pad = torch.zeros_like(input_ids)
slots_pad[batch.prefill_cache_indices] = slots
slots = slots_pad
if self.bucketing_ctx is not None:
if batch.prefilling:
padded_bs = self.bucketing_ctx.get_padded_prompt_batch_size(
input_lengths.shape[0]
)
else:
padded_bs = self.bucketing_ctx.get_padded_decode_batch_size(
input_lengths.shape[0]
)
else:
padded_bs = input_lengths.shape[0]
orig_bs = input_lengths.shape[0]
if padded_bs != input_lengths.shape[0]:
padded_input_lengths = F.pad(
input_lengths,
(0, padded_bs - orig_bs),
value=0,
)
padded_cache_lengths_tensor = F.pad(
cache_lengths_tensor,
(0, padded_bs - orig_bs),
value=0,
)
if cu_seqlen_prefill is not None:
cu_seqlen_prefill = torch.zeros(
padded_bs + 1, device=self.device, dtype=torch.int32
)
torch.cumsum(padded_input_lengths, -1, out=cu_seqlen_prefill[1:])
slots_pad = torch.zeros_like(input_ids)
slots_pad[: slots.shape[0]] = slots
slots = slots_pad
seqlen = Seqlen(
input_lengths=padded_input_lengths,
cache_lengths=padded_cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
)
input_seq = input_ids.view(orig_bs, -1)
input_ids = F.pad(
input_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0
)
position_ids = F.pad(
position_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=1
)
slots = F.pad(
slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0
)
if lm_head_indices is not None:
lm_head_indices = F.pad(
lm_head_indices, (0, padded_bs - orig_bs), value=0
)
else:
seqlen = Seqlen(
input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
input_lengths=_async_h2d_tensor_copy(input_lengths),
)
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = batch.prefilling
kwargs["bypass_hpu_graphs"] = (
batch.prefilling if self.limit_hpu_graphs else False
)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
input_ids=_async_h2d_tensor_copy(input_ids),
position_ids=_async_h2d_tensor_copy(position_ids),
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
kv_cache=kv_cache,
slots=slots,
slots=_async_h2d_tensor_copy(slots),
seqlen=trim_seqlen_metadata(seqlen),
lm_head_indices=lm_head_indices,
lm_head_indices=_async_h2d_tensor_copy(lm_head_indices),
# TODO not support adapter now, need the add in the future
adapter_data=None,
hpu_attention_meta=batch.hpu_attn_meta,
**kwargs,
)
return logits[:orig_bs], (
speculative_logits[:orig_bs] if speculative_logits is not None else None
)
return logits, speculative_logits
@tracer.start_as_current_span("generate_token")
def generate_token(
@ -1817,6 +1818,7 @@ class FlashCausalLM(Model):
# In order to pipeline any actions on CPU we perform the operation in 3 main stages:
# Stage 1. Collect next token ids of any previously started generations
start = time.time_ns()
prev_batches = []
requests_to_generate = []
for batch_id, batch in enumerate(batches):
@ -1834,7 +1836,9 @@ class FlashCausalLM(Model):
accepted_ids,
speculative_ids,
) = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_current_length],
_async_h2d_tensor_copy(
batch.all_input_ids_tensor[:, : batch.max_current_length]
),
batch.next_token_logits,
speculate,
batch.speculative_ids,
@ -1843,10 +1847,39 @@ class FlashCausalLM(Model):
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens,
batch.top_n_tokens_tensor,
_async_h2d_tensor_copy(batch.top_n_tokens_tensor),
logprobs,
accepted_ids,
)
if batch.valid_indices is not None:
next_input_ids = next_input_ids.cpu()
next_token_logprobs = next_token_logprobs.cpu()
accepted_ids = accepted_ids.cpu()
batch.all_input_ids_tensor = batch.all_input_ids_tensor[
batch.valid_indices
]
next_input_ids = next_input_ids[batch.valid_indices]
next_token_logprobs = next_token_logprobs[batch.valid_indices]
accepted_ids = accepted_ids[batch.valid_indices]
if speculative_ids is not None:
speculative_ids = speculative_ids[batch.valid_indices]
batch.top_n_tokens_tensor = batch.top_n_tokens_tensor[
batch.valid_indices
]
top_n_tokens = []
batch_top_token_ids_v = []
batch_top_token_logprobs_v = []
for i in batch.valid_indices:
top_n_tokens.append(batch.top_n_tokens[i])
batch_top_token_ids_v.append(batch_top_token_ids[i])
batch_top_token_logprobs_v.append(batch_top_token_logprobs[i])
batch_top_token_ids = batch_top_token_ids_v
batch_top_token_logprobs = batch_top_token_logprobs_v
batch.top_n_tokens = top_n_tokens
batch.next_token_chooser = batch.next_token_chooser.filter(
batch.valid_indices
)
batch.valid_indices = None
# Since we are done prefilling, all the tensors that were concatenating values for all the requests
# instantly become of shape [BATCH_SIZE]
@ -1860,14 +1893,16 @@ class FlashCausalLM(Model):
else:
batch.position_ids = batch.position_ids[indices]
batch.slot_indices = batch.slot_indices[indices]
batch.slot_indices = batch.slot_indices[indices[: len(batch)]]
batch.adapter_meta.adapter_indices = (
batch.adapter_meta.adapter_indices[indices]
)
# For each member of the batch
# Cumulative length
accepted_ids = accepted_ids.cpu()
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
next_input_ids = next_input_ids.cpu()
if batch.speculative_logits is not None:
for i in range(len(batch)):
batch.all_input_ids_tensor[
@ -1879,16 +1914,16 @@ class FlashCausalLM(Model):
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
else:
index = batch.cache_lengths_tensor + batch.input_lengths_tensor
index = index.to(batch.all_input_ids_tensor)
batch_idx = torch.arange(
0,
batch.all_input_ids_tensor.shape[0],
dtype=torch.long,
device=batch.input_lengths_tensor.device,
device=batch.all_input_ids_tensor.device,
)
batch.all_input_ids_tensor.index_put_(
(batch_idx, index.long()), next_input_ids
)
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
batch.speculative_ids = speculative_ids
if batch.position_ids.dim() == 2:
@ -1900,7 +1935,7 @@ class FlashCausalLM(Model):
batch.input_lengths_tensor + accepted_ids - 1
)
batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor)
batch.slot_indices += accepted_ids
batch.slot_indices += accepted_ids[: len(batch)]
# Does a HPU <-> CPU sync internally
if prefill:
@ -1921,8 +1956,6 @@ class FlashCausalLM(Model):
}
)
idx = len(prev_batches) - 1
if batch.speculative_logits is not None:
accepted_ids_cpu = accepted_ids.cpu()
for req_idx, req in enumerate(batch.requests):
new_input_length = 1
@ -1930,7 +1963,7 @@ class FlashCausalLM(Model):
new_cache_length = (
batch.cache_lengths[req_idx]
+ batch.input_lengths[req_idx]
+ accepted_ids_cpu[req_idx]
+ accepted_ids[req_idx]
- 1
)
else:
@ -1978,15 +2011,17 @@ class FlashCausalLM(Model):
batch = self.batch_type.concatenate(batches)
else:
batch = batches[0]
start = time.time_ns()
prefill = batch.prefilling
if prefill:
if self.bucketing_ctx is not None:
batch.prepare_for_prefill(
self.bucketing_ctx.get_padded_prompt_seq_len(batch.max_input_length)
self.bucketing_ctx.get_padded_prompt_seq_len(
batch.max_input_length
),
self.bucketing_ctx.get_padded_prompt_batch_size(len(batch)),
)
else:
batch.prepare_for_prefill(batch.max_input_length)
batch.prepare_for_prefill(batch.max_input_length, len(batch))
else:
batch.prepare_for_decode(
self.dtype, self.use_contiguous_pa, self.bucketing_ctx
@ -2037,14 +2072,15 @@ class FlashCausalLM(Model):
batch.speculative_logits = speculative_logits
# HPU->CPU sync
htorch.core.mark_step()
start_decode = time.time_ns()
for prev_batch in prev_batches:
prev_batch["next_token_logprobs"] = prev_batch[
"next_token_logprobs"
].tolist()
prev_batch["next_token_ids"] = prev_batch["next_token_ids"].tolist()
prev_batch["accepted_ids"] = prev_batch["accepted_ids"].tolist()
start_decode = time.time_ns()
htorch.core.mark_step()
# Stage 3. Finish and return previous generations
# Results
generations: List[Generation] = []
@ -2186,7 +2222,7 @@ class FlashCausalLM(Model):
batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset
batch.all_input_ids[i] = all_input_ids
htorch.core.mark_step()
if stopped:
# No need to return a batch if we know that all requests stopped
forward_ns = start_decode - start

View File

@ -17,12 +17,15 @@ from text_generation_server.models.globals import PREFIX_CACHING, BLOCK_SIZE
from loguru import logger
from text_generation_server.utils.log import log_master
from transformers import AutoProcessor
from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata
from text_generation_server.layers.attention import (
Seqlen,
trim_seqlen_metadata,
_async_h2d_tensor_copy,
)
import habana_frameworks.torch as htorch
from text_generation_server.utils.import_utils import (
synchronize,
)
import torch.nn.functional as F
tracer = trace.get_tracer(__name__)
@ -383,12 +386,8 @@ class FlashVlmCausalLM(FlashCausalLM):
def warmup_decode(
self, batch_size: int, block_num: int, batch: FlashVlmCausalLMBatch
):
input_ids = torch.zeros(
batch_size, dtype=batch.input_ids.dtype, device=self.device
)
position_ids = torch.arange(
batch_size, dtype=batch.position_ids.dtype, device=self.device
)
input_ids = torch.zeros(batch_size, dtype=batch.input_ids.dtype)
position_ids = torch.arange(batch_size, dtype=batch.position_ids.dtype)
if batch.position_ids is not None and batch.position_ids.dim() == 2:
# qwen2_vl and qwen2_5_vl case
position_ids = position_ids.unsqueeze(-1).repeat(
@ -408,19 +407,10 @@ class FlashVlmCausalLM(FlashCausalLM):
block_tables.append(block_array)
past_len.append(blocks[i] * BLOCK_SIZE - 1)
start_idx += blocks[i]
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
cache_lengths_tensor = torch.tensor(
past_len, dtype=torch.int32, device=self.device
)
cu_seqlen_prefill = torch.zeros(
batch_size + 1, device=self.device, dtype=torch.int32
)
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
input_lengths = torch.ones(batch_size, dtype=torch.int32)
seqlen = Seqlen(
input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
input_lengths=_async_h2d_tensor_copy(input_lengths),
)
hpu_attention_meta = prepare_for_decode(
@ -432,14 +422,14 @@ class FlashVlmCausalLM(FlashCausalLM):
batch_size,
bucketing_ctx=None,
)
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
input_ids=_async_h2d_tensor_copy(input_ids),
position_ids=_async_h2d_tensor_copy(position_ids),
cu_seqlen_prefill=None,
kv_cache=self.kv_cache,
slots=slots_tensor,
slots=_async_h2d_tensor_copy(slots_tensor),
seqlen=trim_seqlen_metadata(seqlen),
hpu_attention_meta=hpu_attention_meta,
lm_head_indices=None,
@ -498,9 +488,6 @@ class FlashVlmCausalLM(FlashCausalLM):
input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
cache_lengths_tensor = (
batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
).reshape(-1)
# Add Copy the block tables for all members
block_tables = (
@ -521,7 +508,6 @@ class FlashVlmCausalLM(FlashCausalLM):
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
cache_lengths_tensor = batch.cache_lengths_tensor
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
@ -546,78 +532,23 @@ class FlashVlmCausalLM(FlashCausalLM):
slots_pad = torch.zeros_like(input_ids)
slots_pad[batch.prefill_cache_indices] = slots
slots = slots_pad
if self.bucketing_ctx is not None:
if batch.prefilling:
padded_bs = self.bucketing_ctx.get_padded_prompt_batch_size(
input_lengths.shape[0]
)
else:
padded_bs = self.bucketing_ctx.get_padded_decode_batch_size(
input_lengths.shape[0]
)
else:
padded_bs = input_lengths.shape[0]
orig_bs = input_lengths.shape[0]
if padded_bs != input_lengths.shape[0]:
padded_input_lengths = F.pad(
input_lengths,
(0, padded_bs - orig_bs),
value=0,
)
padded_cache_lengths_tensor = F.pad(
cache_lengths_tensor,
(0, padded_bs - orig_bs),
value=0,
)
if cu_seqlen_prefill is not None:
cu_seqlen_prefill = torch.zeros(
padded_bs + 1, device=self.device, dtype=torch.int32
)
torch.cumsum(padded_input_lengths, -1, out=cu_seqlen_prefill[1:])
slots_pad = torch.zeros_like(input_ids)
slots_pad[: slots.shape[0]] = slots
slots = slots_pad
seqlen = Seqlen(
input_lengths=padded_input_lengths,
cache_lengths=padded_cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
)
input_seq = input_ids.view(orig_bs, -1)
input_ids = F.pad(
input_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0
)
if position_ids.dim() == 2:
# qwen2_vl and qwen2_5_vl case
position_ids = F.pad(
position_ids,
(0, 0, 0, (padded_bs - orig_bs) * input_seq.shape[-1]),
value=1,
)
else:
position_ids = F.pad(
position_ids,
(0, (padded_bs - orig_bs) * input_seq.shape[-1]),
value=1,
)
slots = F.pad(
slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0
)
if lm_head_indices is not None:
lm_head_indices = F.pad(
lm_head_indices, (0, padded_bs - orig_bs), value=0
)
else:
seqlen = Seqlen(
input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
input_lengths=_async_h2d_tensor_copy(input_lengths),
)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
input_ids=_async_h2d_tensor_copy(input_ids),
position_ids=_async_h2d_tensor_copy(position_ids),
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
kv_cache=kv_cache,
slots=slots,
slots=_async_h2d_tensor_copy(slots),
seqlen=trim_seqlen_metadata(seqlen),
hpu_attention_meta=batch.hpu_attn_meta,
lm_head_indices=lm_head_indices,
lm_head_indices=_async_h2d_tensor_copy(lm_head_indices),
pixel_values=batch.pixel_values,
pixel_attention_mask=batch.pixel_attention_mask,
image_sizes=batch.image_sizes,
@ -632,6 +563,4 @@ class FlashVlmCausalLM(FlashCausalLM):
batch.image_sizes = None
if batch.image_grid_thw is not None:
batch.image_grid_thw = None
return logits[:orig_bs], (
speculative_logits[:orig_bs] if speculative_logits is not None else None
)
return logits, speculative_logits

View File

@ -19,7 +19,11 @@ from text_generation_server.models.flash_vlm_causal_lm import (
FlashVlmCausalLM,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata
from text_generation_server.layers.attention import (
Seqlen,
trim_seqlen_metadata,
_async_h2d_tensor_copy,
)
import habana_frameworks.torch as htorch
from loguru import logger
from text_generation_server.models.globals import BLOCK_SIZE
@ -183,7 +187,7 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
input_ids = np.concatenate(batch.input_ids, dtype=np.int64)
else:
input_ids = batch.input_ids[0]
batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
batch.input_ids = torch.tensor(input_ids, dtype=torch.int64)
batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1)
@ -206,33 +210,26 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
def generate_cross_attention_states(
cross_attention_states, image_indices, seqlen, pad_seq_len, prefilling
cross_attention_states, image_indices, input_lengths, pad_seq_len, prefilling
):
if cross_attention_states is None:
return None, None, None
device = cross_attention_states.device
indices_list = []
if prefilling:
for i in image_indices:
indices_list.append(
torch.arange(pad_seq_len * i, pad_seq_len * (i + 1), device=device)
)
indices_list.append(torch.arange(pad_seq_len * i, pad_seq_len * (i + 1)))
indices = torch.cat(indices_list, dim=0)
else:
indices = image_indices[:]
return indices, seqlen.input_lengths.index_select(0, image_indices)
return indices, input_lengths.index_select(0, image_indices)
class FlashMllamaCausalLM(FlashVlmCausalLM):
def warmup_decode(
self, batch_size: int, block_num: int, batch: FlashMllamaCausalLMBatch
):
input_ids = torch.zeros(
batch_size, dtype=batch.input_ids.dtype, device=self.device
)
position_ids = torch.arange(
batch_size, dtype=batch.position_ids.dtype, device=self.device
)
input_ids = torch.zeros(batch_size, dtype=batch.input_ids.dtype)
position_ids = torch.arange(batch_size, dtype=batch.position_ids.dtype)
blocks = [block_num // batch_size for _ in range(batch_size)]
blocks[0] += block_num % batch_size
past_len = []
@ -247,19 +244,10 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
block_tables.append(block_array)
past_len.append(blocks[i] * BLOCK_SIZE - 1)
start_idx += blocks[i]
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
cache_lengths_tensor = torch.tensor(
past_len, dtype=torch.int32, device=self.device
)
cu_seqlen_prefill = torch.zeros(
batch_size + 1, device=self.device, dtype=torch.int32
)
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
input_lengths = torch.ones(batch_size, dtype=torch.int32)
seqlen = Seqlen(
input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
input_lengths=_async_h2d_tensor_copy(input_lengths),
)
hpu_attention_meta = prepare_for_decode(
@ -272,87 +260,86 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
bucketing_ctx=None,
)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
image_indices = torch.tensor(batch.image_indices, device=self.device)
image_indices = torch.tensor(batch.image_indices)
image_indices = image_indices.repeat(batch_size)
cross_attention_states = batch.cross_attention_states.repeat(batch_size, 1, 1)
indices, cross_attention_len = generate_cross_attention_states(
cross_attention_states, image_indices, seqlen, 1, False
cross_attention_states, image_indices, input_lengths, 1, False
)
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype)
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
input_ids=_async_h2d_tensor_copy(input_ids),
position_ids=_async_h2d_tensor_copy(position_ids),
cu_seqlen_prefill=None,
kv_cache=self.kv_cache,
slots=slots_tensor,
slots=_async_h2d_tensor_copy(slots_tensor),
seqlen=trim_seqlen_metadata(seqlen),
hpu_attention_meta=hpu_attention_meta,
lm_head_indices=None,
adapter_data=None,
cross_attention_states=cross_attention_states,
indices=indices,
cross_attention_len=cross_attention_len,
indices=_async_h2d_tensor_copy(indices),
cross_attention_len=_async_h2d_tensor_copy(cross_attention_len),
)
def warmup_prefill(
self, prompt_len: int, batch_size: int, batch: FlashMllamaCausalLMBatch
):
input_ids = torch.zeros(
prompt_len, dtype=batch.input_ids.dtype, device=self.device
).repeat(batch_size)
position_ids = torch.arange(
prompt_len, dtype=batch.position_ids.dtype, device=self.device
).repeat(batch_size)
input_ids = torch.zeros(prompt_len, dtype=batch.input_ids.dtype).repeat(
batch_size
)
position_ids = torch.arange(prompt_len, dtype=batch.position_ids.dtype).repeat(
batch_size
)
max_bt = (prompt_len // BLOCK_SIZE + 1) * batch_size
block_tables = torch.arange(
max_bt, dtype=torch.int32, device=self.device
).reshape(batch_size, -1)
block_tables = torch.arange(max_bt, dtype=torch.int32).reshape(batch_size, -1)
slot_acc = []
for i in range(batch_size):
slots = []
for b in block_tables[i]:
slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))
slot_acc.extend(slots[:prompt_len])
slots = torch.tensor(slot_acc, dtype=batch.slots.dtype, device=self.device)
slots = torch.tensor(slot_acc, dtype=batch.slots.dtype)
input_lengths = (
torch.ones(batch_size, dtype=torch.int32, device=self.device) * prompt_len
torch.ones(
batch_size,
dtype=torch.int32,
)
cache_lengths_tensor = torch.zeros(
batch_size, dtype=torch.int32, device=self.device
)
cu_seqlen_prefill = torch.zeros(
batch_size + 1, device=self.device, dtype=torch.int32
* prompt_len
)
cu_seqlen_prefill = torch.zeros(batch_size + 1, dtype=torch.int32)
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
seqlen = Seqlen(
input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
)
lm_head_indices = input_lengths - 1
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
image_indices = torch.tensor(batch.image_indices, device=self.device)
image_indices = torch.tensor(batch.image_indices)
image_indices = image_indices.repeat(batch_size)
cross_attention_states = batch.cross_attention_states.repeat(batch_size, 1, 1)
indices, cross_attention_len = generate_cross_attention_states(
cross_attention_states, image_indices, seqlen, prompt_len, True
cross_attention_states, image_indices, input_lengths, prompt_len, True
)
seqlen = Seqlen(
input_lengths=_async_h2d_tensor_copy(input_lengths),
)
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = self.limit_hpu_graphs
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
input_ids=_async_h2d_tensor_copy(input_ids),
position_ids=_async_h2d_tensor_copy(position_ids),
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
kv_cache=self.kv_cache,
slots=slots,
slots=_async_h2d_tensor_copy(slots),
seqlen=trim_seqlen_metadata(seqlen),
hpu_attention_meta=None,
lm_head_indices=lm_head_indices,
lm_head_indices=_async_h2d_tensor_copy(lm_head_indices),
adapter_data=None,
cross_attention_states=cross_attention_states,
indices=indices,
cross_attention_len=cross_attention_len,
indices=_async_h2d_tensor_copy(indices),
cross_attention_len=_async_h2d_tensor_copy(cross_attention_len),
**kwargs,
)
def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch):
@ -410,9 +397,6 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
cache_lengths_tensor = (
batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
).reshape(-1)
# Add Copy the block tables for all members
block_tables = (
@ -433,7 +417,6 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
cache_lengths_tensor = batch.cache_lengths_tensor
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
@ -455,59 +438,22 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = batch.prefilling
kwargs["bypass_hpu_graphs"] = (
batch.prefilling if self.limit_hpu_graphs else False
)
if batch.prefill_cache_indices is not None:
slots_pad = torch.zeros_like(input_ids)
slots_pad[batch.prefill_cache_indices] = slots
slots = slots_pad
if self.bucketing_ctx is not None:
if batch.prefilling:
padded_bs = self.bucketing_ctx.get_padded_prompt_batch_size(
input_lengths.shape[0]
)
else:
padded_bs = self.bucketing_ctx.get_padded_decode_batch_size(
input_lengths.shape[0]
)
else:
padded_bs = input_lengths.shape[0]
orig_bs = input_lengths.shape[0]
padded_input_len = input_ids.view(orig_bs, -1).shape[-1]
image_indices = torch.tensor(batch.image_indices, device=self.device)
if padded_bs != input_lengths.shape[0]:
padded_input_lengths = F.pad(
input_lengths,
(0, padded_bs - orig_bs),
value=0,
)
padded_cache_lengths_tensor = F.pad(
cache_lengths_tensor,
(0, padded_bs - orig_bs),
value=0,
)
if cu_seqlen_prefill is not None:
cu_seqlen_prefill = torch.zeros(
padded_bs + 1, device=self.device, dtype=torch.int32
)
torch.cumsum(padded_input_lengths, -1, out=cu_seqlen_prefill[1:])
seqlen = Seqlen(
input_lengths=padded_input_lengths,
cache_lengths=padded_cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
)
slots_pad = torch.zeros_like(input_ids)
slots_pad[: slots.shape[0]] = slots
slots = slots_pad
orig_bs = len(batch)
padded_bs = batch.input_lengths_tensor.shape[0]
padded_input_len = input_ids.view(padded_bs, -1).shape[-1]
image_indices = torch.tensor(batch.image_indices)
input_ids = F.pad(
input_ids, (0, (padded_bs - orig_bs) * padded_input_len), value=0
)
position_ids = F.pad(
position_ids, (0, (padded_bs - orig_bs) * padded_input_len), value=1
)
slots = F.pad(slots, (0, (padded_bs - orig_bs) * padded_input_len), value=0)
if lm_head_indices is not None:
lm_head_indices = F.pad(
lm_head_indices, (0, padded_bs - orig_bs), value=0
)
if cross_attention_states is not None:
cross_attention_states = F.pad(
cross_attention_states,
@ -515,40 +461,35 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
value=0,
)
if len(image_indices) != 0:
pad_indices = torch.arange(orig_bs, padded_bs, device=self.device)
pad_indices = torch.arange(orig_bs, padded_bs)
image_indices = torch.cat((image_indices, pad_indices), dim=0)
else:
seqlen = Seqlen(
input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
)
indices, cross_attention_len = generate_cross_attention_states(
cross_attention_states,
image_indices,
seqlen,
input_lengths,
padded_input_len,
batch.prefilling,
)
seqlen = Seqlen(
input_lengths=_async_h2d_tensor_copy(input_lengths),
)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
input_ids=_async_h2d_tensor_copy(input_ids),
position_ids=_async_h2d_tensor_copy(position_ids),
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
kv_cache=kv_cache,
slots=slots,
slots=_async_h2d_tensor_copy(slots),
seqlen=trim_seqlen_metadata(seqlen),
hpu_attention_meta=batch.hpu_attn_meta,
lm_head_indices=lm_head_indices,
lm_head_indices=_async_h2d_tensor_copy(lm_head_indices),
# TODO list
adapter_data=None,
cross_attention_states=cross_attention_states,
indices=indices,
cross_attention_len=cross_attention_len,
indices=_async_h2d_tensor_copy(indices),
cross_attention_len=_async_h2d_tensor_copy(cross_attention_len),
**kwargs,
)
if batch.pixel_values is not None:
batch.pixel_values = None
return logits[:orig_bs], (
speculative_logits[:orig_bs] if speculative_logits is not None else None
)
return logits, speculative_logits

View File

@ -552,6 +552,11 @@ def pad_next_token_chooser_parameters(
class Sampling:
def __init__(self, seed: int, device: str = "cpu"):
if device in ["hpu", torch.device("hpu")]:
import habana_frameworks.torch.hpu.random as htrandom
self.generator = htrandom.default_generators[0].manual_seed(seed)
else:
self.generator = torch.Generator("cpu")
self.generator.manual_seed(seed)
self.seed = seed