mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-21 17:52:09 +00:00
forward and tokenize chooser use the same shape (#3196)
* forward and tokenize chooser use the same shape concate or filter happened to cpu tensor to avoid dynamic shape in hpu Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * use hpu set seed Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> --------- Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
51a0b9d11c
commit
533eee50dc
@ -3,6 +3,7 @@ from .common import (
|
|||||||
HPUPagedAttentionMetadata,
|
HPUPagedAttentionMetadata,
|
||||||
trim_attn_metadata,
|
trim_attn_metadata,
|
||||||
trim_seqlen_metadata,
|
trim_seqlen_metadata,
|
||||||
|
_async_h2d_tensor_copy,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .hpu import (
|
from .hpu import (
|
||||||
@ -25,4 +26,5 @@ __all__ = [
|
|||||||
"HPUPagedAttentionMetadata",
|
"HPUPagedAttentionMetadata",
|
||||||
"trim_seqlen_metadata",
|
"trim_seqlen_metadata",
|
||||||
"trim_attn_metadata",
|
"trim_attn_metadata",
|
||||||
|
"_async_h2d_tensor_copy",
|
||||||
]
|
]
|
||||||
|
@ -75,42 +75,27 @@ def trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class Seqlen:
|
class Seqlen:
|
||||||
input_lengths: torch.Tensor
|
input_lengths: torch.Tensor
|
||||||
cache_lengths: torch.Tensor
|
|
||||||
cu_seqlen_q: Optional[torch.Tensor]
|
|
||||||
cu_seqlen_k: Optional[torch.Tensor]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
cache_lengths,
|
|
||||||
cu_seqlen_q=None,
|
|
||||||
):
|
):
|
||||||
self.input_lengths = input_lengths
|
self.input_lengths = input_lengths
|
||||||
self.cache_lengths = cache_lengths
|
|
||||||
device = self.input_lengths.device
|
|
||||||
shape = self.input_lengths.shape
|
|
||||||
if cu_seqlen_q is None:
|
|
||||||
cu_seqlen_q = torch.arange(
|
|
||||||
shape[0] + 1,
|
|
||||||
device=device,
|
|
||||||
dtype=torch.int32,
|
|
||||||
)
|
|
||||||
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
|
|
||||||
|
|
||||||
# cuda graphs don't like this and this is necessary to clamp within mistral
|
|
||||||
# Although FA2 might not want the clamping
|
|
||||||
# cu_seqlen_k[0] = 0
|
|
||||||
total = self.input_lengths + self.cache_lengths
|
|
||||||
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
|
|
||||||
|
|
||||||
self.cu_seqlen_q = cu_seqlen_q
|
|
||||||
self.cu_seqlen_k = cu_seqlen_k
|
|
||||||
|
|
||||||
def clamp(self, max):
|
def clamp(self, max):
|
||||||
# Flash decoding doesn't need to clamp
|
# Flash decoding doesn't need to clamp
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def _async_h2d_tensor_copy(source, device="hpu"):
|
||||||
|
if source is None:
|
||||||
|
return None
|
||||||
|
assert source.device.type == "cpu", "Source tensor is not present in host memory!"
|
||||||
|
target = torch.empty(source.shape, dtype=source.dtype, device=device)
|
||||||
|
target.copy_(source, non_blocking=True)
|
||||||
|
return target
|
||||||
|
|
||||||
|
|
||||||
def trim_seqlen_metadata(metadata: Seqlen) -> object:
|
def trim_seqlen_metadata(metadata: Seqlen) -> object:
|
||||||
# NOTE(kzawora): To anyone working on this in the future:
|
# NOTE(kzawora): To anyone working on this in the future:
|
||||||
# Trimming metadata is required when using HPUGraphs.
|
# Trimming metadata is required when using HPUGraphs.
|
||||||
@ -137,9 +122,6 @@ def trim_seqlen_metadata(metadata: Seqlen) -> object:
|
|||||||
"TrimmedSeqlen",
|
"TrimmedSeqlen",
|
||||||
[
|
[
|
||||||
"input_lengths",
|
"input_lengths",
|
||||||
"cache_lengths",
|
|
||||||
"cu_seqlen_q",
|
|
||||||
"cu_seqlen_k",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
return attention_metadata
|
return attention_metadata
|
||||||
|
@ -36,6 +36,7 @@ from text_generation_server.utils import (
|
|||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
|
pad_next_token_chooser_parameters,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.types import (
|
from text_generation_server.models.types import (
|
||||||
Batch,
|
Batch,
|
||||||
@ -56,6 +57,7 @@ from text_generation_server.layers.attention import (
|
|||||||
HPUPagedAttentionMetadata,
|
HPUPagedAttentionMetadata,
|
||||||
trim_attn_metadata,
|
trim_attn_metadata,
|
||||||
trim_seqlen_metadata,
|
trim_seqlen_metadata,
|
||||||
|
_async_h2d_tensor_copy,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||||
from text_generation_server.utils.dist import MEMORY_FRACTION
|
from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||||
@ -141,18 +143,23 @@ def prepare_for_decode(
|
|||||||
block_groups = pad_list(block_groups, block_bucket_size, -1)
|
block_groups = pad_list(block_groups, block_bucket_size, -1)
|
||||||
block_usage = pad_list(block_usage, block_bucket_size, 1)
|
block_usage = pad_list(block_usage, block_bucket_size, 1)
|
||||||
|
|
||||||
block_list = torch.tensor(block_list, dtype=torch.int, device=device)
|
block_list = torch.tensor(block_list, dtype=torch.int, device="cpu")
|
||||||
block_groups = torch.tensor(block_groups, dtype=torch.int, device=device)
|
block_groups = torch.tensor(block_groups, dtype=torch.int, device="cpu")
|
||||||
block_usage = torch.tensor(block_usage, dtype=dtype, device=device)
|
block_usage = torch.tensor(block_usage, dtype=dtype, device="cpu")
|
||||||
block_mapping = torch.nn.functional.one_hot(block_groups, num_classes=batch_size)
|
block_list_device = _async_h2d_tensor_copy(block_list)
|
||||||
|
block_groups_device = _async_h2d_tensor_copy(block_groups)
|
||||||
|
block_usage_device = _async_h2d_tensor_copy(block_usage)
|
||||||
|
block_mapping = torch.nn.functional.one_hot(
|
||||||
|
block_groups_device, num_classes=batch_size
|
||||||
|
)
|
||||||
mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0)
|
mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0)
|
||||||
mask = mask >= block_usage.unsqueeze(-1)
|
mask = mask >= block_usage.unsqueeze(-1)
|
||||||
attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)
|
attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)
|
||||||
return trim_attn_metadata(
|
return trim_attn_metadata(
|
||||||
HPUPagedAttentionMetadata(
|
HPUPagedAttentionMetadata(
|
||||||
block_list=block_list,
|
block_list=block_list_device,
|
||||||
block_groups=block_groups,
|
block_groups=block_groups_device,
|
||||||
block_usage=block_usage,
|
block_usage=block_usage_device,
|
||||||
block_mapping=block_mapping.to(dtype),
|
block_mapping=block_mapping.to(dtype),
|
||||||
attn_bias=attn_bias,
|
attn_bias=attn_bias,
|
||||||
)
|
)
|
||||||
@ -248,6 +255,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
next_token_logits: Optional[torch.Tensor]
|
next_token_logits: Optional[torch.Tensor]
|
||||||
speculative_logits: Optional[torch.Tensor]
|
speculative_logits: Optional[torch.Tensor]
|
||||||
|
valid_indices: Optional[List[int]]
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.CachedBatch:
|
def to_pb(self) -> generate_pb2.CachedBatch:
|
||||||
return generate_pb2.CachedBatch(
|
return generate_pb2.CachedBatch(
|
||||||
@ -417,32 +425,23 @@ class FlashCausalLMBatch(Batch):
|
|||||||
all_input_ids_tensor[i, : len(input_ids)] = input_ids
|
all_input_ids_tensor[i, : len(input_ids)] = input_ids
|
||||||
|
|
||||||
# Create tensors on device
|
# Create tensors on device
|
||||||
all_input_ids_tensor = torch.tensor(
|
all_input_ids_tensor = torch.tensor(all_input_ids_tensor, dtype=torch.int64)
|
||||||
all_input_ids_tensor, dtype=torch.int64, device=device
|
|
||||||
)
|
|
||||||
|
|
||||||
top_n_tokens_tensor = torch.tensor(
|
top_n_tokens_tensor = torch.tensor(top_n_tokens, dtype=torch.int64)
|
||||||
top_n_tokens, device=device, dtype=torch.int64
|
|
||||||
)
|
|
||||||
|
|
||||||
block_tables_ragged = torch.tensor(
|
block_tables_ragged = torch.tensor(block_tables_ragged, dtype=torch.int32)
|
||||||
block_tables_ragged, device=device, dtype=torch.int32
|
cu_blocks = torch.tensor(cu_blocks, dtype=torch.int64)
|
||||||
)
|
|
||||||
cu_blocks = torch.tensor(cu_blocks, device=device, dtype=torch.int64)
|
|
||||||
block_tables_tensor = torch.empty(
|
block_tables_tensor = torch.empty(
|
||||||
(len(block_tables), max_blocks),
|
(len(block_tables), max_blocks),
|
||||||
device=device,
|
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
)
|
)
|
||||||
|
|
||||||
for i, request_blocks in enumerate(block_tables):
|
for i, request_blocks in enumerate(block_tables):
|
||||||
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
|
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
|
||||||
|
|
||||||
prompt_lengths_tensor = torch.tensor(
|
prompt_lengths_tensor = torch.tensor(prompt_lengths, dtype=torch.int32)
|
||||||
prompt_lengths, dtype=torch.int32, device=device
|
|
||||||
)
|
|
||||||
|
|
||||||
slots = torch.tensor(slots, dtype=torch.int64, device=device)
|
slots = torch.tensor(slots, dtype=torch.int64)
|
||||||
cu_slots = torch.tensor(cu_slots, dtype=torch.int64)
|
cu_slots = torch.tensor(cu_slots, dtype=torch.int64)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
@ -488,6 +487,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
hpu_attn_meta=None,
|
hpu_attn_meta=None,
|
||||||
next_token_logits=None,
|
next_token_logits=None,
|
||||||
speculative_logits=None,
|
speculative_logits=None,
|
||||||
|
valid_indices=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -519,9 +519,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
indices = []
|
indices = []
|
||||||
|
|
||||||
# slots to keep after filtering
|
# slots to keep after filtering
|
||||||
slot_filtering_indices = torch.zeros(
|
slot_filtering_indices = torch.zeros(self.slots.shape[0], dtype=torch.bool)
|
||||||
self.slots.shape[0], dtype=torch.bool, device=device
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create on CPU to only move to GPU once instead of at every copy
|
# Create on CPU to only move to GPU once instead of at every copy
|
||||||
slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
|
slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
|
||||||
@ -544,7 +542,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefill_logprob_tokens = []
|
prefill_logprob_tokens = []
|
||||||
|
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
top_n_tokens = []
|
|
||||||
adapter_set = set()
|
adapter_set = set()
|
||||||
|
|
||||||
num_blocks = 0
|
num_blocks = 0
|
||||||
@ -582,7 +579,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
stopping_criteria = self.stopping_criterias[idx]
|
stopping_criteria = self.stopping_criterias[idx]
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
|
|
||||||
top_n_tokens.append(self.top_n_tokens[idx])
|
|
||||||
prefill_logprob_tokens.append(self.prefill_logprob_tokens[idx])
|
prefill_logprob_tokens.append(self.prefill_logprob_tokens[idx])
|
||||||
|
|
||||||
ADAPTER_TO_INDEX = get_adapter_to_index()
|
ADAPTER_TO_INDEX = get_adapter_to_index()
|
||||||
@ -614,19 +610,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
max_blocks = max(max_blocks, len(request_block_table))
|
max_blocks = max(max_blocks, len(request_block_table))
|
||||||
max_slots = max(max_slots, slot_length)
|
max_slots = max(max_slots, slot_length)
|
||||||
|
|
||||||
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
|
||||||
next_token_logits = self.next_token_logits[indices]
|
|
||||||
speculative_logits = (
|
|
||||||
self.speculative_logits[indices]
|
|
||||||
if self.speculative_logits is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
block_tables_tensor = self.block_tables_tensor[indices]
|
block_tables_tensor = self.block_tables_tensor[indices]
|
||||||
next_token_chooser = self.next_token_chooser.filter(indices)
|
|
||||||
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
|
|
||||||
speculative_ids = (
|
|
||||||
self.speculative_ids[indices] if self.speculative_ids is not None else None
|
|
||||||
)
|
|
||||||
prompt_lengths_tensor = self.prompt_lengths_tensor[indices]
|
prompt_lengths_tensor = self.prompt_lengths_tensor[indices]
|
||||||
|
|
||||||
cu_slots = torch.tensor(cu_slots, dtype=torch.int64)
|
cu_slots = torch.tensor(cu_slots, dtype=torch.int64)
|
||||||
@ -652,16 +636,14 @@ class FlashCausalLMBatch(Batch):
|
|||||||
slot_indices = slot_indices.to(device)
|
slot_indices = slot_indices.to(device)
|
||||||
|
|
||||||
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
|
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
|
||||||
adapter_segments = torch.tensor(
|
adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32)
|
||||||
adapter_segments, dtype=torch.int32, device=device
|
|
||||||
)
|
|
||||||
adapter_meta = AdapterBatchMetadata(
|
adapter_meta = AdapterBatchMetadata(
|
||||||
adapter_indices=adapter_indices,
|
adapter_indices=adapter_indices,
|
||||||
adapter_set=adapter_set,
|
adapter_set=adapter_set,
|
||||||
adapter_segments=adapter_segments,
|
adapter_segments=adapter_segments,
|
||||||
segment_indices=adapter_segment_indices,
|
segment_indices=adapter_segment_indices,
|
||||||
)
|
)
|
||||||
|
htorch.core.mark_step()
|
||||||
return type(self)(
|
return type(self)(
|
||||||
batch_id=self.batch_id,
|
batch_id=self.batch_id,
|
||||||
requests=requests,
|
requests=requests,
|
||||||
@ -692,18 +674,19 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefix_offsets=prefix_offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=self.all_input_ids_tensor,
|
||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=self.next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
top_n_tokens=top_n_tokens,
|
top_n_tokens=self.top_n_tokens,
|
||||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
top_n_tokens_tensor=self.top_n_tokens_tensor,
|
||||||
num_blocks=num_blocks,
|
num_blocks=num_blocks,
|
||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
speculative_ids=speculative_ids,
|
speculative_ids=self.speculative_ids,
|
||||||
adapter_meta=adapter_meta,
|
adapter_meta=adapter_meta,
|
||||||
hpu_attn_meta=None,
|
hpu_attn_meta=None,
|
||||||
next_token_logits=next_token_logits,
|
valid_indices=indices,
|
||||||
speculative_logits=speculative_logits,
|
next_token_logits=self.next_token_logits,
|
||||||
|
speculative_logits=self.speculative_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -820,6 +803,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
for i, batch in enumerate(batches):
|
for i, batch in enumerate(batches):
|
||||||
requests.extend(batch.requests)
|
requests.extend(batch.requests)
|
||||||
|
valid_bsize = len(batch)
|
||||||
|
|
||||||
if i == 0:
|
if i == 0:
|
||||||
requests_idx_mapping = batch.requests_idx_mapping
|
requests_idx_mapping = batch.requests_idx_mapping
|
||||||
@ -829,16 +813,15 @@ class FlashCausalLMBatch(Batch):
|
|||||||
requests_idx_mapping[k] = v + cumulative_batch_size
|
requests_idx_mapping[k] = v + cumulative_batch_size
|
||||||
|
|
||||||
start_index = cumulative_batch_size
|
start_index = cumulative_batch_size
|
||||||
end_index = cumulative_batch_size + len(batch)
|
end_index = cumulative_batch_size + valid_bsize
|
||||||
|
|
||||||
# Copy tensors (HPU)
|
|
||||||
index = torch.tensor(
|
index = torch.tensor(
|
||||||
list(range(start_index, end_index)), device=batch.input_ids.device
|
list(range(start_index, end_index)), device=batch.input_ids.device
|
||||||
)
|
)
|
||||||
top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor)
|
top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor)
|
||||||
all_input_ids_tensor[
|
all_input_ids_tensor[
|
||||||
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
|
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
|
||||||
] = batch.all_input_ids_tensor[:, :max_length]
|
] = batch.all_input_ids_tensor[:valid_bsize, :max_length]
|
||||||
|
|
||||||
block_tables_tensor[
|
block_tables_tensor[
|
||||||
start_index:end_index, : batch.block_tables_tensor.shape[1]
|
start_index:end_index, : batch.block_tables_tensor.shape[1]
|
||||||
@ -847,19 +830,28 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
slots_start_index = cumulative_slots
|
slots_start_index = cumulative_slots
|
||||||
slots_end_index = cumulative_slots + len(batch.slots)
|
slots_end_index = cumulative_slots + len(batch.slots)
|
||||||
slots[slots_start_index:slots_end_index] = batch.slots
|
slot_index = torch.tensor(
|
||||||
|
list(range(slots_start_index, slots_end_index)),
|
||||||
|
device=batch.slots.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
slots.index_copy_(0, slot_index, batch.slots)
|
||||||
cu_slots[start_index + 1 : end_index + 1] = (
|
cu_slots[start_index + 1 : end_index + 1] = (
|
||||||
batch.cu_slots[1:] + cumulative_slots
|
batch.cu_slots[1:] + cumulative_slots
|
||||||
)
|
)
|
||||||
|
|
||||||
if not prefilling:
|
if not prefilling:
|
||||||
input_ids.index_copy_(0, index, batch.input_ids)
|
input_ids.index_copy_(0, index, batch.input_ids[:valid_bsize])
|
||||||
position_ids.index_copy_(0, index, batch.position_ids)
|
position_ids.index_copy_(0, index, batch.position_ids[:valid_bsize])
|
||||||
slot_indices.index_copy_(
|
slot_indices.index_copy_(
|
||||||
0, index, batch.slot_indices + cumulative_slots
|
0, index, batch.slot_indices + cumulative_slots
|
||||||
)
|
)
|
||||||
input_lengths_tensor.index_copy_(0, index, batch.input_lengths_tensor)
|
input_lengths_tensor.index_copy_(
|
||||||
cache_lengths_tensor.index_copy_(0, index, batch.cache_lengths_tensor)
|
0, index, batch.input_lengths_tensor[:valid_bsize]
|
||||||
|
)
|
||||||
|
cache_lengths_tensor.index_copy_(
|
||||||
|
0, index, batch.cache_lengths_tensor[:valid_bsize]
|
||||||
|
)
|
||||||
adapter_start_index = cumulative_adapter_indices_size
|
adapter_start_index = cumulative_adapter_indices_size
|
||||||
adapter_end_index = (
|
adapter_end_index = (
|
||||||
cumulative_adapter_indices_size
|
cumulative_adapter_indices_size
|
||||||
@ -967,6 +959,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
hpu_attn_meta=None,
|
hpu_attn_meta=None,
|
||||||
next_token_logits=None,
|
next_token_logits=None,
|
||||||
speculative_logits=None,
|
speculative_logits=None,
|
||||||
|
valid_indices=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx):
|
def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx):
|
||||||
@ -982,27 +975,53 @@ class FlashCausalLMBatch(Batch):
|
|||||||
padded_bs = self.input_ids.shape[0]
|
padded_bs = self.input_ids.shape[0]
|
||||||
slots = self.slots[self.slot_indices]
|
slots = self.slots[self.slot_indices]
|
||||||
extra_pad = padded_bs - self.input_ids.shape[0]
|
extra_pad = padded_bs - self.input_ids.shape[0]
|
||||||
if extra_pad != 0:
|
|
||||||
slots = F.pad(slots, (0, extra_pad), value=0)
|
|
||||||
block_tables.extend([[0]] * extra_pad)
|
|
||||||
|
|
||||||
self.hpu_attn_meta = prepare_for_decode(
|
self.hpu_attn_meta = prepare_for_decode(
|
||||||
dtype,
|
dtype,
|
||||||
use_contiguous_pa,
|
use_contiguous_pa,
|
||||||
self.block_tables_tensor.device,
|
"hpu",
|
||||||
slots.cpu(),
|
slots,
|
||||||
block_tables,
|
block_tables,
|
||||||
padded_bs,
|
padded_bs,
|
||||||
bucketing_ctx,
|
bucketing_ctx,
|
||||||
)
|
)
|
||||||
|
self.input_ids = F.pad(self.input_ids, (0, extra_pad), value=0)
|
||||||
|
self.position_ids = F.pad(self.position_ids, (0, extra_pad), value=1)
|
||||||
|
self.input_lengths_tensor = F.pad(
|
||||||
|
self.input_lengths_tensor, (0, extra_pad), value=0
|
||||||
|
)
|
||||||
|
self.cache_lengths_tensor = F.pad(
|
||||||
|
self.cache_lengths_tensor, (0, extra_pad), value=0
|
||||||
|
)
|
||||||
|
self.all_input_ids_tensor = F.pad(
|
||||||
|
self.all_input_ids_tensor,
|
||||||
|
(0, 0, 0, extra_pad),
|
||||||
|
value=0,
|
||||||
|
)
|
||||||
|
next_token_chooser_parameters = []
|
||||||
|
next_token_chooser_parameters.extend([r.parameters for r in self.requests])
|
||||||
|
pad_next_token_chooser_parameters(next_token_chooser_parameters, padded_bs)
|
||||||
|
# update past grammar states
|
||||||
|
fsm_grammar_states = [0] * padded_bs
|
||||||
|
|
||||||
def prepare_for_prefill(self, max_padded_input_len):
|
for i, req in enumerate(self.requests):
|
||||||
|
fsm_grammar_states[i] = self.next_token_chooser.fsm_grammar_states[i]
|
||||||
|
|
||||||
|
self.next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
|
next_token_chooser_parameters,
|
||||||
|
self.next_token_chooser.dtype,
|
||||||
|
self.next_token_chooser.device,
|
||||||
|
self.next_token_chooser.tokenizer,
|
||||||
|
fsm_grammar_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_for_prefill(self, max_padded_input_len, max_padded_bs):
|
||||||
# Prepare values if we need to continue prefilling
|
# Prepare values if we need to continue prefilling
|
||||||
# Speculation must be ignored while we prefill even with chunking
|
# Speculation must be ignored while we prefill even with chunking
|
||||||
# it simplifies everything
|
# it simplifies everything
|
||||||
assert self.speculative_ids is None
|
assert self.speculative_ids is None
|
||||||
|
|
||||||
device = self.block_tables_tensor.device
|
# device = self.block_tables_tensor.device
|
||||||
|
|
||||||
# hpu does not support varlen for prefill, use sdpa instead. so need to pad input_tensor, position
|
# hpu does not support varlen for prefill, use sdpa instead. so need to pad input_tensor, position
|
||||||
# padding to left to work with sliding window
|
# padding to left to work with sliding window
|
||||||
@ -1011,6 +1030,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
input_ids_padded_length = []
|
input_ids_padded_length = []
|
||||||
# need extra pad to match warmup seq
|
# need extra pad to match warmup seq
|
||||||
extra_pad = max_padded_input_len - self.max_input_length
|
extra_pad = max_padded_input_len - self.max_input_length
|
||||||
|
extra_pad_bs = max_padded_bs - len(self)
|
||||||
if isinstance(self.input_ids, list) and len(self) > 1:
|
if isinstance(self.input_ids, list) and len(self) > 1:
|
||||||
input_ids_padded_length = []
|
input_ids_padded_length = []
|
||||||
input_ids = []
|
input_ids = []
|
||||||
@ -1021,24 +1041,32 @@ class FlashCausalLMBatch(Batch):
|
|||||||
input_ids.append(input_id)
|
input_ids.append(input_id)
|
||||||
input_ids_padded_length.append(padded)
|
input_ids_padded_length.append(padded)
|
||||||
input_ids = np.concatenate(input_ids, dtype=np.int64)
|
input_ids = np.concatenate(input_ids, dtype=np.int64)
|
||||||
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
self.input_ids = torch.tensor(input_ids, dtype=torch.int64)
|
||||||
elif isinstance(self.input_ids, list):
|
elif isinstance(self.input_ids, list):
|
||||||
input_ids = self.input_ids[0]
|
input_ids = self.input_ids[0]
|
||||||
input_ids_padded_length.append(extra_pad)
|
input_ids_padded_length.append(extra_pad)
|
||||||
input_ids = [0] * extra_pad + input_ids
|
input_ids = [0] * extra_pad + input_ids
|
||||||
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
self.input_ids = torch.tensor(input_ids, dtype=torch.int64)
|
||||||
else:
|
else:
|
||||||
self.input_ids = F.pad(self.input_ids, (extra_pad, 0), value=0)
|
self.input_ids = F.pad(self.input_ids, (extra_pad, 0), value=0)
|
||||||
input_ids_padded_length.append(extra_pad)
|
input_ids_padded_length.extend([extra_pad] * len(self))
|
||||||
|
|
||||||
self.input_lengths_tensor = torch.tensor(
|
self.input_ids = F.pad(
|
||||||
self.input_lengths, dtype=torch.int32, device=device
|
self.input_ids, (0, extra_pad_bs * max_padded_input_len), value=0
|
||||||
)
|
)
|
||||||
cu_seqlen_prefill = self.input_lengths_tensor.new_zeros(len(self) + 1)
|
|
||||||
|
self.input_lengths_tensor = torch.tensor(self.input_lengths, dtype=torch.int32)
|
||||||
|
|
||||||
|
self.input_lengths_tensor = F.pad(
|
||||||
|
self.input_lengths_tensor, (0, extra_pad_bs), value=0
|
||||||
|
)
|
||||||
|
|
||||||
|
cu_seqlen_prefill = self.input_lengths_tensor.new_zeros(max_padded_bs + 1)
|
||||||
torch.cumsum(self.input_lengths_tensor, out=cu_seqlen_prefill[1:], dim=0)
|
torch.cumsum(self.input_lengths_tensor, out=cu_seqlen_prefill[1:], dim=0)
|
||||||
self.cu_seqlen_prefill = cu_seqlen_prefill.to(torch.int32)
|
self.cu_seqlen_prefill = cu_seqlen_prefill.to(torch.int32)
|
||||||
self.cache_lengths_tensor = torch.tensor(
|
self.cache_lengths_tensor = torch.tensor(self.cache_lengths, dtype=torch.int32)
|
||||||
self.cache_lengths, dtype=torch.int32, device=device
|
self.cache_lengths_tensor = F.pad(
|
||||||
|
self.cache_lengths_tensor, (0, extra_pad_bs), value=0
|
||||||
)
|
)
|
||||||
|
|
||||||
sliding_window = get_sliding_windows()
|
sliding_window = get_sliding_windows()
|
||||||
@ -1171,7 +1199,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
torch.arange(
|
torch.arange(
|
||||||
cumulative_length,
|
cumulative_length,
|
||||||
cumulative_length + input_length,
|
cumulative_length + input_length,
|
||||||
dtype=torch.int64,
|
dtype=torch.int32,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
prefill_next_token_indices.append(
|
prefill_next_token_indices.append(
|
||||||
@ -1182,7 +1210,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefill_head_indices.append(
|
prefill_head_indices.append(
|
||||||
torch.tensor(
|
torch.tensor(
|
||||||
[cumulative_length + input_length - 1],
|
[cumulative_length + input_length - 1],
|
||||||
dtype=torch.int64,
|
dtype=torch.int32,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
prefill_next_token_indices.append(prefill_out_cumulative_length)
|
prefill_next_token_indices.append(prefill_out_cumulative_length)
|
||||||
@ -1204,12 +1232,15 @@ class FlashCausalLMBatch(Batch):
|
|||||||
slot_indices = slot_indices[0]
|
slot_indices = slot_indices[0]
|
||||||
prefill_cache_indices = prefill_cache_indices[0]
|
prefill_cache_indices = prefill_cache_indices[0]
|
||||||
|
|
||||||
self.position_ids = position_ids.to(device)
|
self.position_ids = position_ids
|
||||||
self.slot_indices = slot_indices.to(device)
|
self.position_ids = F.pad(
|
||||||
|
self.position_ids, (0, extra_pad_bs * max_padded_input_len), value=1
|
||||||
|
)
|
||||||
|
self.slot_indices = slot_indices
|
||||||
|
|
||||||
self.prefill_cu_outlens = prefill_cu_outlens
|
self.prefill_cu_outlens = prefill_cu_outlens
|
||||||
self.prefill_cache_indices = torch.zeros_like(self.input_ids, dtype=torch.bool)
|
self.prefill_cache_indices = torch.zeros_like(self.input_ids, dtype=torch.bool)
|
||||||
self.prefill_cache_indices[prefill_cache_indices.to(device)] = True
|
self.prefill_cache_indices[prefill_cache_indices] = True
|
||||||
|
|
||||||
if all_prefill_logprobs:
|
if all_prefill_logprobs:
|
||||||
prefill_head_indices = None
|
prefill_head_indices = None
|
||||||
@ -1218,16 +1249,19 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefill_head_indices = self.cu_seqlen_prefill[1:] - 1
|
prefill_head_indices = self.cu_seqlen_prefill[1:] - 1
|
||||||
prefill_next_token_indices = None
|
prefill_next_token_indices = None
|
||||||
else:
|
else:
|
||||||
prefill_head_indices = torch.cat(prefill_head_indices).to(device)
|
prefill_head_indices = torch.cat(prefill_head_indices)
|
||||||
prefill_next_token_indices = torch.tensor(
|
prefill_next_token_indices = torch.tensor(
|
||||||
prefill_next_token_indices, dtype=torch.int64, device=device
|
prefill_next_token_indices, dtype=torch.int64
|
||||||
)
|
)
|
||||||
|
|
||||||
self.prefill_head_indices = prefill_head_indices
|
self.prefill_head_indices = prefill_head_indices
|
||||||
self.prefill_next_token_indices = prefill_next_token_indices
|
self.prefill_next_token_indices = prefill_next_token_indices
|
||||||
input_ids_padded_length_tensor = torch.cumsum(
|
input_ids_padded_length_tensor = torch.cumsum(
|
||||||
torch.tensor(input_ids_padded_length, dtype=torch.int64, device=device),
|
torch.tensor(input_ids_padded_length, dtype=torch.int32),
|
||||||
dim=-1,
|
dim=-1,
|
||||||
|
).to(torch.int32)
|
||||||
|
input_ids_padded_length_tensor = F.pad(
|
||||||
|
input_ids_padded_length_tensor, (0, extra_pad_bs), value=0
|
||||||
)
|
)
|
||||||
if self.prefill_head_indices is not None:
|
if self.prefill_head_indices is not None:
|
||||||
self.prefill_head_indices = (
|
self.prefill_head_indices = (
|
||||||
@ -1239,19 +1273,37 @@ class FlashCausalLMBatch(Batch):
|
|||||||
self.prefill_next_token_indices + input_ids_padded_length_tensor
|
self.prefill_next_token_indices + input_ids_padded_length_tensor
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.all_input_ids_tensor = F.pad(
|
||||||
|
self.all_input_ids_tensor,
|
||||||
|
(0, 0, 0, extra_pad_bs),
|
||||||
|
value=0,
|
||||||
|
)
|
||||||
|
next_token_chooser_parameters = []
|
||||||
|
next_token_chooser_parameters.extend([r.parameters for r in self.requests])
|
||||||
|
pad_next_token_chooser_parameters(next_token_chooser_parameters, max_padded_bs)
|
||||||
|
# update past grammar states
|
||||||
|
fsm_grammar_states = [0] * max_padded_bs
|
||||||
|
|
||||||
|
for i, req in enumerate(self.requests):
|
||||||
|
fsm_grammar_states[i] = self.next_token_chooser.fsm_grammar_states[i]
|
||||||
|
|
||||||
|
self.next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
|
next_token_chooser_parameters,
|
||||||
|
self.next_token_chooser.dtype,
|
||||||
|
self.next_token_chooser.device,
|
||||||
|
self.next_token_chooser.tokenizer,
|
||||||
|
fsm_grammar_states,
|
||||||
|
)
|
||||||
|
|
||||||
if adapter_set:
|
if adapter_set:
|
||||||
adapter_indices = torch.cat(adapter_indices_list).to(
|
adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64)
|
||||||
dtype=torch.int64, device=device
|
|
||||||
)
|
|
||||||
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
|
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
|
||||||
else:
|
else:
|
||||||
adapter_indices = torch.zeros_like(self.input_ids)
|
adapter_indices = torch.zeros_like(self.input_ids)
|
||||||
adapter_segments = [0, len(adapter_indices)]
|
adapter_segments = [0, len(adapter_indices)]
|
||||||
adapter_segment_indices = [len(adapter_indices) - 1]
|
adapter_segment_indices = [len(adapter_indices) - 1]
|
||||||
|
|
||||||
adapter_segments = torch.tensor(
|
adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32)
|
||||||
adapter_segments, dtype=torch.int32, device=device
|
|
||||||
)
|
|
||||||
self.adapter_meta = AdapterBatchMetadata(
|
self.adapter_meta = AdapterBatchMetadata(
|
||||||
adapter_indices=adapter_indices,
|
adapter_indices=adapter_indices,
|
||||||
adapter_set=adapter_set,
|
adapter_set=adapter_set,
|
||||||
@ -1392,6 +1444,9 @@ class FlashCausalLM(Model):
|
|||||||
self.use_contiguous_pa = (
|
self.use_contiguous_pa = (
|
||||||
os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true"
|
os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true"
|
||||||
)
|
)
|
||||||
|
self.limit_hpu_graphs = (
|
||||||
|
os.environ.get("LIMIT_HPU_GRAPHS", "false").lower() == "true"
|
||||||
|
)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model=model,
|
model=model,
|
||||||
@ -1509,8 +1564,17 @@ class FlashCausalLM(Model):
|
|||||||
self.device,
|
self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
max_num_seqs = int(os.getenv("MAX_BATCH_SIZE", 128))
|
||||||
|
if os.getenv("VLLM_PROMPT_SEQ_BUCKET_MAX") is None:
|
||||||
|
os.environ["VLLM_PROMPT_SEQ_BUCKET_MAX"] = str(max_input_tokens)
|
||||||
|
if os.getenv("VLLM_DECODE_BLOCK_BUCKET_MAX") is None:
|
||||||
|
max_total_blocks = (
|
||||||
|
math.ceil(max_total_tokens / BLOCK_SIZE) * max_num_seqs + 1
|
||||||
|
)
|
||||||
|
os.environ["VLLM_DECODE_BLOCK_BUCKET_MAX"] = str(max_total_blocks)
|
||||||
|
|
||||||
self.bucketing_ctx = HPUBucketingContext(
|
self.bucketing_ctx = HPUBucketingContext(
|
||||||
os.getenv("DECODE_MAX_BS", 128), # self.max_num_seqs, #TODO
|
max_num_seqs,
|
||||||
os.getenv("PREFILL_MAX_BS", 64), # self.max_num_prefill_seqs, #TODO
|
os.getenv("PREFILL_MAX_BS", 64), # self.max_num_prefill_seqs, #TODO
|
||||||
BLOCK_SIZE,
|
BLOCK_SIZE,
|
||||||
num_blocks * BLOCK_SIZE,
|
num_blocks * BLOCK_SIZE,
|
||||||
@ -1536,6 +1600,7 @@ class FlashCausalLM(Model):
|
|||||||
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
|
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
|
||||||
for index in range(warmup_times):
|
for index in range(warmup_times):
|
||||||
self.warmup_prefill(seq_len, batch_size, batch)
|
self.warmup_prefill(seq_len, batch_size, batch)
|
||||||
|
|
||||||
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
|
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
|
||||||
for i, (batch_size, block_num) in enumerate(
|
for i, (batch_size, block_num) in enumerate(
|
||||||
reversed(self.bucketing_ctx.decode_buckets)
|
reversed(self.bucketing_ctx.decode_buckets)
|
||||||
@ -1552,62 +1617,51 @@ class FlashCausalLM(Model):
|
|||||||
def warmup_prefill(
|
def warmup_prefill(
|
||||||
self, prompt_len: int, batch_size: int, batch: FlashCausalLMBatch
|
self, prompt_len: int, batch_size: int, batch: FlashCausalLMBatch
|
||||||
):
|
):
|
||||||
input_ids = torch.zeros(
|
input_ids = torch.zeros(prompt_len, dtype=batch.input_ids.dtype).repeat(
|
||||||
prompt_len, dtype=batch.input_ids.dtype, device=self.device
|
batch_size
|
||||||
).repeat(batch_size)
|
)
|
||||||
position_ids = torch.arange(
|
position_ids = torch.arange(prompt_len, dtype=batch.position_ids.dtype).repeat(
|
||||||
prompt_len, dtype=batch.position_ids.dtype, device=self.device
|
batch_size
|
||||||
).repeat(batch_size)
|
)
|
||||||
max_bt = (prompt_len // BLOCK_SIZE + 1) * batch_size
|
max_bt = (prompt_len // BLOCK_SIZE + 1) * batch_size
|
||||||
block_tables = torch.arange(
|
block_tables = torch.arange(max_bt, dtype=torch.int32).reshape(batch_size, -1)
|
||||||
max_bt, dtype=torch.int32, device=self.device
|
|
||||||
).reshape(batch_size, -1)
|
|
||||||
slot_acc = []
|
slot_acc = []
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
slots = []
|
slots = []
|
||||||
for b in block_tables[i]:
|
for b in block_tables[i]:
|
||||||
slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))
|
slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))
|
||||||
slot_acc.extend(slots[:prompt_len])
|
slot_acc.extend(slots[:prompt_len])
|
||||||
slots = torch.tensor(slot_acc, dtype=batch.slots.dtype, device=self.device)
|
slots = torch.tensor(slot_acc, dtype=batch.slots.dtype)
|
||||||
|
|
||||||
input_lengths = (
|
input_lengths = torch.ones(batch_size, dtype=torch.int32) * prompt_len
|
||||||
torch.ones(batch_size, dtype=torch.int32, device=self.device) * prompt_len
|
cu_seqlen_prefill = torch.zeros(batch_size + 1, dtype=torch.int32)
|
||||||
)
|
|
||||||
cache_lengths_tensor = torch.zeros(
|
|
||||||
batch_size, dtype=torch.int32, device=self.device
|
|
||||||
)
|
|
||||||
cu_seqlen_prefill = torch.zeros(
|
|
||||||
batch_size + 1, device=self.device, dtype=torch.int32
|
|
||||||
)
|
|
||||||
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
|
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
|
||||||
|
|
||||||
seqlen = Seqlen(
|
seqlen = Seqlen(
|
||||||
input_lengths=input_lengths,
|
input_lengths=_async_h2d_tensor_copy(input_lengths),
|
||||||
cache_lengths=cache_lengths_tensor,
|
|
||||||
cu_seqlen_q=cu_seqlen_prefill,
|
|
||||||
)
|
)
|
||||||
lm_head_indices = input_lengths - 1
|
lm_head_indices = input_lengths - 1
|
||||||
|
kwargs = {}
|
||||||
|
if htorch.utils.internal.is_lazy():
|
||||||
|
kwargs["bypass_hpu_graphs"] = self.limit_hpu_graphs
|
||||||
|
|
||||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||||
self.model.forward(
|
self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=_async_h2d_tensor_copy(input_ids),
|
||||||
position_ids=position_ids,
|
position_ids=_async_h2d_tensor_copy(position_ids),
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
|
||||||
kv_cache=self.kv_cache,
|
kv_cache=self.kv_cache,
|
||||||
slots=slots,
|
slots=_async_h2d_tensor_copy(slots),
|
||||||
seqlen=trim_seqlen_metadata(seqlen),
|
seqlen=trim_seqlen_metadata(seqlen),
|
||||||
lm_head_indices=lm_head_indices,
|
lm_head_indices=_async_h2d_tensor_copy(lm_head_indices),
|
||||||
adapter_data=None,
|
adapter_data=None,
|
||||||
hpu_attention_meta=None,
|
hpu_attention_meta=None,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def warmup_decode(self, batch_size: int, block_num: int, batch: FlashCausalLMBatch):
|
def warmup_decode(self, batch_size: int, block_num: int, batch: FlashCausalLMBatch):
|
||||||
input_ids = torch.zeros(
|
input_ids = torch.zeros(batch_size, dtype=batch.input_ids.dtype)
|
||||||
batch_size, dtype=batch.input_ids.dtype, device=self.device
|
position_ids = torch.arange(batch_size, dtype=batch.position_ids.dtype)
|
||||||
)
|
|
||||||
position_ids = torch.arange(
|
|
||||||
batch_size, dtype=batch.position_ids.dtype, device=self.device
|
|
||||||
)
|
|
||||||
blocks = [block_num // batch_size for _ in range(batch_size)]
|
blocks = [block_num // batch_size for _ in range(batch_size)]
|
||||||
blocks[0] += block_num % batch_size
|
blocks[0] += block_num % batch_size
|
||||||
past_len = []
|
past_len = []
|
||||||
@ -1622,19 +1676,12 @@ class FlashCausalLM(Model):
|
|||||||
block_tables.append(block_array)
|
block_tables.append(block_array)
|
||||||
past_len.append(blocks[i] * BLOCK_SIZE - 1)
|
past_len.append(blocks[i] * BLOCK_SIZE - 1)
|
||||||
start_idx += blocks[i]
|
start_idx += blocks[i]
|
||||||
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
|
input_lengths = torch.ones(batch_size, dtype=torch.int32)
|
||||||
cache_lengths_tensor = torch.tensor(
|
cu_seqlen_prefill = torch.zeros(batch_size + 1, dtype=torch.int32)
|
||||||
past_len, dtype=torch.int32, device=self.device
|
|
||||||
)
|
|
||||||
cu_seqlen_prefill = torch.zeros(
|
|
||||||
batch_size + 1, device=self.device, dtype=torch.int32
|
|
||||||
)
|
|
||||||
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
|
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
|
||||||
|
|
||||||
seqlen = Seqlen(
|
seqlen = Seqlen(
|
||||||
input_lengths=input_lengths,
|
input_lengths=_async_h2d_tensor_copy(input_lengths),
|
||||||
cache_lengths=cache_lengths_tensor,
|
|
||||||
cu_seqlen_q=cu_seqlen_prefill,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
hpu_attention_meta = prepare_for_decode(
|
hpu_attention_meta = prepare_for_decode(
|
||||||
@ -1646,18 +1693,22 @@ class FlashCausalLM(Model):
|
|||||||
batch_size,
|
batch_size,
|
||||||
bucketing_ctx=None,
|
bucketing_ctx=None,
|
||||||
)
|
)
|
||||||
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
|
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype)
|
||||||
|
kwargs = {}
|
||||||
|
if htorch.utils.internal.is_lazy():
|
||||||
|
kwargs["bypass_hpu_graphs"] = False
|
||||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||||
self.model.forward(
|
self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=_async_h2d_tensor_copy(input_ids),
|
||||||
position_ids=position_ids,
|
position_ids=_async_h2d_tensor_copy(position_ids),
|
||||||
cu_seqlen_prefill=None,
|
cu_seqlen_prefill=None,
|
||||||
kv_cache=self.kv_cache,
|
kv_cache=self.kv_cache,
|
||||||
slots=slots_tensor,
|
slots=_async_h2d_tensor_copy(slots_tensor),
|
||||||
seqlen=trim_seqlen_metadata(seqlen),
|
seqlen=trim_seqlen_metadata(seqlen),
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
adapter_data=None,
|
adapter_data=None,
|
||||||
hpu_attention_meta=hpu_attention_meta,
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -1699,9 +1750,6 @@ class FlashCausalLM(Model):
|
|||||||
input_lengths = (
|
input_lengths = (
|
||||||
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||||
).view(-1)
|
).view(-1)
|
||||||
cache_lengths_tensor = (
|
|
||||||
batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
|
|
||||||
).reshape(-1)
|
|
||||||
|
|
||||||
# Add Copy the block tables for all members
|
# Add Copy the block tables for all members
|
||||||
block_tables = (
|
block_tables = (
|
||||||
@ -1722,7 +1770,6 @@ class FlashCausalLM(Model):
|
|||||||
block_tables = batch.block_tables_tensor
|
block_tables = batch.block_tables_tensor
|
||||||
slots = batch.slots[batch.slot_indices]
|
slots = batch.slots[batch.slot_indices]
|
||||||
input_lengths = batch.input_lengths_tensor
|
input_lengths = batch.input_lengths_tensor
|
||||||
cache_lengths_tensor = batch.cache_lengths_tensor
|
|
||||||
max_s = batch.max_current_length
|
max_s = batch.max_current_length
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
@ -1735,80 +1782,34 @@ class FlashCausalLM(Model):
|
|||||||
slots_pad = torch.zeros_like(input_ids)
|
slots_pad = torch.zeros_like(input_ids)
|
||||||
slots_pad[batch.prefill_cache_indices] = slots
|
slots_pad[batch.prefill_cache_indices] = slots
|
||||||
slots = slots_pad
|
slots = slots_pad
|
||||||
if self.bucketing_ctx is not None:
|
|
||||||
if batch.prefilling:
|
|
||||||
padded_bs = self.bucketing_ctx.get_padded_prompt_batch_size(
|
|
||||||
input_lengths.shape[0]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
padded_bs = self.bucketing_ctx.get_padded_decode_batch_size(
|
|
||||||
input_lengths.shape[0]
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
padded_bs = input_lengths.shape[0]
|
slots_pad = torch.zeros_like(input_ids)
|
||||||
orig_bs = input_lengths.shape[0]
|
slots_pad[: slots.shape[0]] = slots
|
||||||
if padded_bs != input_lengths.shape[0]:
|
slots = slots_pad
|
||||||
padded_input_lengths = F.pad(
|
seqlen = Seqlen(
|
||||||
input_lengths,
|
input_lengths=_async_h2d_tensor_copy(input_lengths),
|
||||||
(0, padded_bs - orig_bs),
|
)
|
||||||
value=0,
|
|
||||||
)
|
|
||||||
padded_cache_lengths_tensor = F.pad(
|
|
||||||
cache_lengths_tensor,
|
|
||||||
(0, padded_bs - orig_bs),
|
|
||||||
value=0,
|
|
||||||
)
|
|
||||||
if cu_seqlen_prefill is not None:
|
|
||||||
cu_seqlen_prefill = torch.zeros(
|
|
||||||
padded_bs + 1, device=self.device, dtype=torch.int32
|
|
||||||
)
|
|
||||||
torch.cumsum(padded_input_lengths, -1, out=cu_seqlen_prefill[1:])
|
|
||||||
seqlen = Seqlen(
|
|
||||||
input_lengths=padded_input_lengths,
|
|
||||||
cache_lengths=padded_cache_lengths_tensor,
|
|
||||||
cu_seqlen_q=cu_seqlen_prefill,
|
|
||||||
)
|
|
||||||
input_seq = input_ids.view(orig_bs, -1)
|
|
||||||
input_ids = F.pad(
|
|
||||||
input_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0
|
|
||||||
)
|
|
||||||
position_ids = F.pad(
|
|
||||||
position_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=1
|
|
||||||
)
|
|
||||||
slots = F.pad(
|
|
||||||
slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0
|
|
||||||
)
|
|
||||||
if lm_head_indices is not None:
|
|
||||||
lm_head_indices = F.pad(
|
|
||||||
lm_head_indices, (0, padded_bs - orig_bs), value=0
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
seqlen = Seqlen(
|
|
||||||
input_lengths=input_lengths,
|
|
||||||
cache_lengths=cache_lengths_tensor,
|
|
||||||
cu_seqlen_q=cu_seqlen_prefill,
|
|
||||||
)
|
|
||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if htorch.utils.internal.is_lazy():
|
if htorch.utils.internal.is_lazy():
|
||||||
kwargs["bypass_hpu_graphs"] = batch.prefilling
|
kwargs["bypass_hpu_graphs"] = (
|
||||||
|
batch.prefilling if self.limit_hpu_graphs else False
|
||||||
|
)
|
||||||
|
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=_async_h2d_tensor_copy(input_ids),
|
||||||
position_ids=position_ids,
|
position_ids=_async_h2d_tensor_copy(position_ids),
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
slots=slots,
|
slots=_async_h2d_tensor_copy(slots),
|
||||||
seqlen=trim_seqlen_metadata(seqlen),
|
seqlen=trim_seqlen_metadata(seqlen),
|
||||||
lm_head_indices=lm_head_indices,
|
lm_head_indices=_async_h2d_tensor_copy(lm_head_indices),
|
||||||
# TODO not support adapter now, need the add in the future
|
# TODO not support adapter now, need the add in the future
|
||||||
adapter_data=None,
|
adapter_data=None,
|
||||||
hpu_attention_meta=batch.hpu_attn_meta,
|
hpu_attention_meta=batch.hpu_attn_meta,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return logits[:orig_bs], (
|
return logits, speculative_logits
|
||||||
speculative_logits[:orig_bs] if speculative_logits is not None else None
|
|
||||||
)
|
|
||||||
|
|
||||||
@tracer.start_as_current_span("generate_token")
|
@tracer.start_as_current_span("generate_token")
|
||||||
def generate_token(
|
def generate_token(
|
||||||
@ -1817,6 +1818,7 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
# In order to pipeline any actions on CPU we perform the operation in 3 main stages:
|
# In order to pipeline any actions on CPU we perform the operation in 3 main stages:
|
||||||
# Stage 1. Collect next token ids of any previously started generations
|
# Stage 1. Collect next token ids of any previously started generations
|
||||||
|
start = time.time_ns()
|
||||||
prev_batches = []
|
prev_batches = []
|
||||||
requests_to_generate = []
|
requests_to_generate = []
|
||||||
for batch_id, batch in enumerate(batches):
|
for batch_id, batch in enumerate(batches):
|
||||||
@ -1834,7 +1836,9 @@ class FlashCausalLM(Model):
|
|||||||
accepted_ids,
|
accepted_ids,
|
||||||
speculative_ids,
|
speculative_ids,
|
||||||
) = batch.next_token_chooser(
|
) = batch.next_token_chooser(
|
||||||
batch.all_input_ids_tensor[:, : batch.max_current_length],
|
_async_h2d_tensor_copy(
|
||||||
|
batch.all_input_ids_tensor[:, : batch.max_current_length]
|
||||||
|
),
|
||||||
batch.next_token_logits,
|
batch.next_token_logits,
|
||||||
speculate,
|
speculate,
|
||||||
batch.speculative_ids,
|
batch.speculative_ids,
|
||||||
@ -1843,10 +1847,39 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||||
batch.top_n_tokens,
|
batch.top_n_tokens,
|
||||||
batch.top_n_tokens_tensor,
|
_async_h2d_tensor_copy(batch.top_n_tokens_tensor),
|
||||||
logprobs,
|
logprobs,
|
||||||
accepted_ids,
|
accepted_ids,
|
||||||
)
|
)
|
||||||
|
if batch.valid_indices is not None:
|
||||||
|
next_input_ids = next_input_ids.cpu()
|
||||||
|
next_token_logprobs = next_token_logprobs.cpu()
|
||||||
|
accepted_ids = accepted_ids.cpu()
|
||||||
|
batch.all_input_ids_tensor = batch.all_input_ids_tensor[
|
||||||
|
batch.valid_indices
|
||||||
|
]
|
||||||
|
next_input_ids = next_input_ids[batch.valid_indices]
|
||||||
|
next_token_logprobs = next_token_logprobs[batch.valid_indices]
|
||||||
|
accepted_ids = accepted_ids[batch.valid_indices]
|
||||||
|
if speculative_ids is not None:
|
||||||
|
speculative_ids = speculative_ids[batch.valid_indices]
|
||||||
|
batch.top_n_tokens_tensor = batch.top_n_tokens_tensor[
|
||||||
|
batch.valid_indices
|
||||||
|
]
|
||||||
|
top_n_tokens = []
|
||||||
|
batch_top_token_ids_v = []
|
||||||
|
batch_top_token_logprobs_v = []
|
||||||
|
for i in batch.valid_indices:
|
||||||
|
top_n_tokens.append(batch.top_n_tokens[i])
|
||||||
|
batch_top_token_ids_v.append(batch_top_token_ids[i])
|
||||||
|
batch_top_token_logprobs_v.append(batch_top_token_logprobs[i])
|
||||||
|
batch_top_token_ids = batch_top_token_ids_v
|
||||||
|
batch_top_token_logprobs = batch_top_token_logprobs_v
|
||||||
|
batch.top_n_tokens = top_n_tokens
|
||||||
|
batch.next_token_chooser = batch.next_token_chooser.filter(
|
||||||
|
batch.valid_indices
|
||||||
|
)
|
||||||
|
batch.valid_indices = None
|
||||||
|
|
||||||
# Since we are done prefilling, all the tensors that were concatenating values for all the requests
|
# Since we are done prefilling, all the tensors that were concatenating values for all the requests
|
||||||
# instantly become of shape [BATCH_SIZE]
|
# instantly become of shape [BATCH_SIZE]
|
||||||
@ -1860,14 +1893,16 @@ class FlashCausalLM(Model):
|
|||||||
else:
|
else:
|
||||||
batch.position_ids = batch.position_ids[indices]
|
batch.position_ids = batch.position_ids[indices]
|
||||||
|
|
||||||
batch.slot_indices = batch.slot_indices[indices]
|
batch.slot_indices = batch.slot_indices[indices[: len(batch)]]
|
||||||
batch.adapter_meta.adapter_indices = (
|
batch.adapter_meta.adapter_indices = (
|
||||||
batch.adapter_meta.adapter_indices[indices]
|
batch.adapter_meta.adapter_indices[indices]
|
||||||
)
|
)
|
||||||
# For each member of the batch
|
# For each member of the batch
|
||||||
# Cumulative length
|
# Cumulative length
|
||||||
|
accepted_ids = accepted_ids.cpu()
|
||||||
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
|
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
|
||||||
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
|
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
|
||||||
|
next_input_ids = next_input_ids.cpu()
|
||||||
if batch.speculative_logits is not None:
|
if batch.speculative_logits is not None:
|
||||||
for i in range(len(batch)):
|
for i in range(len(batch)):
|
||||||
batch.all_input_ids_tensor[
|
batch.all_input_ids_tensor[
|
||||||
@ -1879,16 +1914,16 @@ class FlashCausalLM(Model):
|
|||||||
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
|
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
|
||||||
else:
|
else:
|
||||||
index = batch.cache_lengths_tensor + batch.input_lengths_tensor
|
index = batch.cache_lengths_tensor + batch.input_lengths_tensor
|
||||||
|
index = index.to(batch.all_input_ids_tensor)
|
||||||
batch_idx = torch.arange(
|
batch_idx = torch.arange(
|
||||||
0,
|
0,
|
||||||
batch.all_input_ids_tensor.shape[0],
|
batch.all_input_ids_tensor.shape[0],
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=batch.input_lengths_tensor.device,
|
device=batch.all_input_ids_tensor.device,
|
||||||
)
|
)
|
||||||
batch.all_input_ids_tensor.index_put_(
|
batch.all_input_ids_tensor.index_put_(
|
||||||
(batch_idx, index.long()), next_input_ids
|
(batch_idx, index.long()), next_input_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
|
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
|
||||||
batch.speculative_ids = speculative_ids
|
batch.speculative_ids = speculative_ids
|
||||||
if batch.position_ids.dim() == 2:
|
if batch.position_ids.dim() == 2:
|
||||||
@ -1900,7 +1935,7 @@ class FlashCausalLM(Model):
|
|||||||
batch.input_lengths_tensor + accepted_ids - 1
|
batch.input_lengths_tensor + accepted_ids - 1
|
||||||
)
|
)
|
||||||
batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor)
|
batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor)
|
||||||
batch.slot_indices += accepted_ids
|
batch.slot_indices += accepted_ids[: len(batch)]
|
||||||
|
|
||||||
# Does a HPU <-> CPU sync internally
|
# Does a HPU <-> CPU sync internally
|
||||||
if prefill:
|
if prefill:
|
||||||
@ -1921,8 +1956,6 @@ class FlashCausalLM(Model):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
idx = len(prev_batches) - 1
|
idx = len(prev_batches) - 1
|
||||||
if batch.speculative_logits is not None:
|
|
||||||
accepted_ids_cpu = accepted_ids.cpu()
|
|
||||||
|
|
||||||
for req_idx, req in enumerate(batch.requests):
|
for req_idx, req in enumerate(batch.requests):
|
||||||
new_input_length = 1
|
new_input_length = 1
|
||||||
@ -1930,7 +1963,7 @@ class FlashCausalLM(Model):
|
|||||||
new_cache_length = (
|
new_cache_length = (
|
||||||
batch.cache_lengths[req_idx]
|
batch.cache_lengths[req_idx]
|
||||||
+ batch.input_lengths[req_idx]
|
+ batch.input_lengths[req_idx]
|
||||||
+ accepted_ids_cpu[req_idx]
|
+ accepted_ids[req_idx]
|
||||||
- 1
|
- 1
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -1978,15 +2011,17 @@ class FlashCausalLM(Model):
|
|||||||
batch = self.batch_type.concatenate(batches)
|
batch = self.batch_type.concatenate(batches)
|
||||||
else:
|
else:
|
||||||
batch = batches[0]
|
batch = batches[0]
|
||||||
start = time.time_ns()
|
|
||||||
prefill = batch.prefilling
|
prefill = batch.prefilling
|
||||||
if prefill:
|
if prefill:
|
||||||
if self.bucketing_ctx is not None:
|
if self.bucketing_ctx is not None:
|
||||||
batch.prepare_for_prefill(
|
batch.prepare_for_prefill(
|
||||||
self.bucketing_ctx.get_padded_prompt_seq_len(batch.max_input_length)
|
self.bucketing_ctx.get_padded_prompt_seq_len(
|
||||||
|
batch.max_input_length
|
||||||
|
),
|
||||||
|
self.bucketing_ctx.get_padded_prompt_batch_size(len(batch)),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
batch.prepare_for_prefill(batch.max_input_length)
|
batch.prepare_for_prefill(batch.max_input_length, len(batch))
|
||||||
else:
|
else:
|
||||||
batch.prepare_for_decode(
|
batch.prepare_for_decode(
|
||||||
self.dtype, self.use_contiguous_pa, self.bucketing_ctx
|
self.dtype, self.use_contiguous_pa, self.bucketing_ctx
|
||||||
@ -2037,14 +2072,15 @@ class FlashCausalLM(Model):
|
|||||||
batch.speculative_logits = speculative_logits
|
batch.speculative_logits = speculative_logits
|
||||||
|
|
||||||
# HPU->CPU sync
|
# HPU->CPU sync
|
||||||
|
htorch.core.mark_step()
|
||||||
|
start_decode = time.time_ns()
|
||||||
for prev_batch in prev_batches:
|
for prev_batch in prev_batches:
|
||||||
prev_batch["next_token_logprobs"] = prev_batch[
|
prev_batch["next_token_logprobs"] = prev_batch[
|
||||||
"next_token_logprobs"
|
"next_token_logprobs"
|
||||||
].tolist()
|
].tolist()
|
||||||
prev_batch["next_token_ids"] = prev_batch["next_token_ids"].tolist()
|
prev_batch["next_token_ids"] = prev_batch["next_token_ids"].tolist()
|
||||||
prev_batch["accepted_ids"] = prev_batch["accepted_ids"].tolist()
|
prev_batch["accepted_ids"] = prev_batch["accepted_ids"].tolist()
|
||||||
|
htorch.core.mark_step()
|
||||||
start_decode = time.time_ns()
|
|
||||||
# Stage 3. Finish and return previous generations
|
# Stage 3. Finish and return previous generations
|
||||||
# Results
|
# Results
|
||||||
generations: List[Generation] = []
|
generations: List[Generation] = []
|
||||||
@ -2186,7 +2222,7 @@ class FlashCausalLM(Model):
|
|||||||
batch.prefix_offsets[i] = prefix_offset
|
batch.prefix_offsets[i] = prefix_offset
|
||||||
batch.read_offsets[i] = read_offset
|
batch.read_offsets[i] = read_offset
|
||||||
batch.all_input_ids[i] = all_input_ids
|
batch.all_input_ids[i] = all_input_ids
|
||||||
|
htorch.core.mark_step()
|
||||||
if stopped:
|
if stopped:
|
||||||
# No need to return a batch if we know that all requests stopped
|
# No need to return a batch if we know that all requests stopped
|
||||||
forward_ns = start_decode - start
|
forward_ns = start_decode - start
|
||||||
|
@ -17,12 +17,15 @@ from text_generation_server.models.globals import PREFIX_CACHING, BLOCK_SIZE
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
from transformers import AutoProcessor
|
from transformers import AutoProcessor
|
||||||
from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata
|
from text_generation_server.layers.attention import (
|
||||||
|
Seqlen,
|
||||||
|
trim_seqlen_metadata,
|
||||||
|
_async_h2d_tensor_copy,
|
||||||
|
)
|
||||||
import habana_frameworks.torch as htorch
|
import habana_frameworks.torch as htorch
|
||||||
from text_generation_server.utils.import_utils import (
|
from text_generation_server.utils.import_utils import (
|
||||||
synchronize,
|
synchronize,
|
||||||
)
|
)
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -383,12 +386,8 @@ class FlashVlmCausalLM(FlashCausalLM):
|
|||||||
def warmup_decode(
|
def warmup_decode(
|
||||||
self, batch_size: int, block_num: int, batch: FlashVlmCausalLMBatch
|
self, batch_size: int, block_num: int, batch: FlashVlmCausalLMBatch
|
||||||
):
|
):
|
||||||
input_ids = torch.zeros(
|
input_ids = torch.zeros(batch_size, dtype=batch.input_ids.dtype)
|
||||||
batch_size, dtype=batch.input_ids.dtype, device=self.device
|
position_ids = torch.arange(batch_size, dtype=batch.position_ids.dtype)
|
||||||
)
|
|
||||||
position_ids = torch.arange(
|
|
||||||
batch_size, dtype=batch.position_ids.dtype, device=self.device
|
|
||||||
)
|
|
||||||
if batch.position_ids is not None and batch.position_ids.dim() == 2:
|
if batch.position_ids is not None and batch.position_ids.dim() == 2:
|
||||||
# qwen2_vl and qwen2_5_vl case
|
# qwen2_vl and qwen2_5_vl case
|
||||||
position_ids = position_ids.unsqueeze(-1).repeat(
|
position_ids = position_ids.unsqueeze(-1).repeat(
|
||||||
@ -408,19 +407,10 @@ class FlashVlmCausalLM(FlashCausalLM):
|
|||||||
block_tables.append(block_array)
|
block_tables.append(block_array)
|
||||||
past_len.append(blocks[i] * BLOCK_SIZE - 1)
|
past_len.append(blocks[i] * BLOCK_SIZE - 1)
|
||||||
start_idx += blocks[i]
|
start_idx += blocks[i]
|
||||||
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
|
input_lengths = torch.ones(batch_size, dtype=torch.int32)
|
||||||
cache_lengths_tensor = torch.tensor(
|
|
||||||
past_len, dtype=torch.int32, device=self.device
|
|
||||||
)
|
|
||||||
cu_seqlen_prefill = torch.zeros(
|
|
||||||
batch_size + 1, device=self.device, dtype=torch.int32
|
|
||||||
)
|
|
||||||
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
|
|
||||||
|
|
||||||
seqlen = Seqlen(
|
seqlen = Seqlen(
|
||||||
input_lengths=input_lengths,
|
input_lengths=_async_h2d_tensor_copy(input_lengths),
|
||||||
cache_lengths=cache_lengths_tensor,
|
|
||||||
cu_seqlen_q=cu_seqlen_prefill,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
hpu_attention_meta = prepare_for_decode(
|
hpu_attention_meta = prepare_for_decode(
|
||||||
@ -432,14 +422,14 @@ class FlashVlmCausalLM(FlashCausalLM):
|
|||||||
batch_size,
|
batch_size,
|
||||||
bucketing_ctx=None,
|
bucketing_ctx=None,
|
||||||
)
|
)
|
||||||
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
|
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype)
|
||||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||||
self.model.forward(
|
self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=_async_h2d_tensor_copy(input_ids),
|
||||||
position_ids=position_ids,
|
position_ids=_async_h2d_tensor_copy(position_ids),
|
||||||
cu_seqlen_prefill=None,
|
cu_seqlen_prefill=None,
|
||||||
kv_cache=self.kv_cache,
|
kv_cache=self.kv_cache,
|
||||||
slots=slots_tensor,
|
slots=_async_h2d_tensor_copy(slots_tensor),
|
||||||
seqlen=trim_seqlen_metadata(seqlen),
|
seqlen=trim_seqlen_metadata(seqlen),
|
||||||
hpu_attention_meta=hpu_attention_meta,
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
@ -498,9 +488,6 @@ class FlashVlmCausalLM(FlashCausalLM):
|
|||||||
input_lengths = (
|
input_lengths = (
|
||||||
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||||
).view(-1)
|
).view(-1)
|
||||||
cache_lengths_tensor = (
|
|
||||||
batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
|
|
||||||
).reshape(-1)
|
|
||||||
|
|
||||||
# Add Copy the block tables for all members
|
# Add Copy the block tables for all members
|
||||||
block_tables = (
|
block_tables = (
|
||||||
@ -521,7 +508,6 @@ class FlashVlmCausalLM(FlashCausalLM):
|
|||||||
block_tables = batch.block_tables_tensor
|
block_tables = batch.block_tables_tensor
|
||||||
slots = batch.slots[batch.slot_indices]
|
slots = batch.slots[batch.slot_indices]
|
||||||
input_lengths = batch.input_lengths_tensor
|
input_lengths = batch.input_lengths_tensor
|
||||||
cache_lengths_tensor = batch.cache_lengths_tensor
|
|
||||||
max_s = batch.max_current_length
|
max_s = batch.max_current_length
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
@ -546,78 +532,23 @@ class FlashVlmCausalLM(FlashCausalLM):
|
|||||||
slots_pad = torch.zeros_like(input_ids)
|
slots_pad = torch.zeros_like(input_ids)
|
||||||
slots_pad[batch.prefill_cache_indices] = slots
|
slots_pad[batch.prefill_cache_indices] = slots
|
||||||
slots = slots_pad
|
slots = slots_pad
|
||||||
if self.bucketing_ctx is not None:
|
|
||||||
if batch.prefilling:
|
|
||||||
padded_bs = self.bucketing_ctx.get_padded_prompt_batch_size(
|
|
||||||
input_lengths.shape[0]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
padded_bs = self.bucketing_ctx.get_padded_decode_batch_size(
|
|
||||||
input_lengths.shape[0]
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
padded_bs = input_lengths.shape[0]
|
slots_pad = torch.zeros_like(input_ids)
|
||||||
orig_bs = input_lengths.shape[0]
|
slots_pad[: slots.shape[0]] = slots
|
||||||
if padded_bs != input_lengths.shape[0]:
|
slots = slots_pad
|
||||||
padded_input_lengths = F.pad(
|
|
||||||
input_lengths,
|
seqlen = Seqlen(
|
||||||
(0, padded_bs - orig_bs),
|
input_lengths=_async_h2d_tensor_copy(input_lengths),
|
||||||
value=0,
|
)
|
||||||
)
|
|
||||||
padded_cache_lengths_tensor = F.pad(
|
|
||||||
cache_lengths_tensor,
|
|
||||||
(0, padded_bs - orig_bs),
|
|
||||||
value=0,
|
|
||||||
)
|
|
||||||
if cu_seqlen_prefill is not None:
|
|
||||||
cu_seqlen_prefill = torch.zeros(
|
|
||||||
padded_bs + 1, device=self.device, dtype=torch.int32
|
|
||||||
)
|
|
||||||
torch.cumsum(padded_input_lengths, -1, out=cu_seqlen_prefill[1:])
|
|
||||||
seqlen = Seqlen(
|
|
||||||
input_lengths=padded_input_lengths,
|
|
||||||
cache_lengths=padded_cache_lengths_tensor,
|
|
||||||
cu_seqlen_q=cu_seqlen_prefill,
|
|
||||||
)
|
|
||||||
input_seq = input_ids.view(orig_bs, -1)
|
|
||||||
input_ids = F.pad(
|
|
||||||
input_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0
|
|
||||||
)
|
|
||||||
if position_ids.dim() == 2:
|
|
||||||
# qwen2_vl and qwen2_5_vl case
|
|
||||||
position_ids = F.pad(
|
|
||||||
position_ids,
|
|
||||||
(0, 0, 0, (padded_bs - orig_bs) * input_seq.shape[-1]),
|
|
||||||
value=1,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
position_ids = F.pad(
|
|
||||||
position_ids,
|
|
||||||
(0, (padded_bs - orig_bs) * input_seq.shape[-1]),
|
|
||||||
value=1,
|
|
||||||
)
|
|
||||||
slots = F.pad(
|
|
||||||
slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0
|
|
||||||
)
|
|
||||||
if lm_head_indices is not None:
|
|
||||||
lm_head_indices = F.pad(
|
|
||||||
lm_head_indices, (0, padded_bs - orig_bs), value=0
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
seqlen = Seqlen(
|
|
||||||
input_lengths=input_lengths,
|
|
||||||
cache_lengths=cache_lengths_tensor,
|
|
||||||
cu_seqlen_q=cu_seqlen_prefill,
|
|
||||||
)
|
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=_async_h2d_tensor_copy(input_ids),
|
||||||
position_ids=position_ids,
|
position_ids=_async_h2d_tensor_copy(position_ids),
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
slots=slots,
|
slots=_async_h2d_tensor_copy(slots),
|
||||||
seqlen=trim_seqlen_metadata(seqlen),
|
seqlen=trim_seqlen_metadata(seqlen),
|
||||||
hpu_attention_meta=batch.hpu_attn_meta,
|
hpu_attention_meta=batch.hpu_attn_meta,
|
||||||
lm_head_indices=lm_head_indices,
|
lm_head_indices=_async_h2d_tensor_copy(lm_head_indices),
|
||||||
pixel_values=batch.pixel_values,
|
pixel_values=batch.pixel_values,
|
||||||
pixel_attention_mask=batch.pixel_attention_mask,
|
pixel_attention_mask=batch.pixel_attention_mask,
|
||||||
image_sizes=batch.image_sizes,
|
image_sizes=batch.image_sizes,
|
||||||
@ -632,6 +563,4 @@ class FlashVlmCausalLM(FlashCausalLM):
|
|||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
if batch.image_grid_thw is not None:
|
if batch.image_grid_thw is not None:
|
||||||
batch.image_grid_thw = None
|
batch.image_grid_thw = None
|
||||||
return logits[:orig_bs], (
|
return logits, speculative_logits
|
||||||
speculative_logits[:orig_bs] if speculative_logits is not None else None
|
|
||||||
)
|
|
||||||
|
@ -19,7 +19,11 @@ from text_generation_server.models.flash_vlm_causal_lm import (
|
|||||||
FlashVlmCausalLM,
|
FlashVlmCausalLM,
|
||||||
)
|
)
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata
|
from text_generation_server.layers.attention import (
|
||||||
|
Seqlen,
|
||||||
|
trim_seqlen_metadata,
|
||||||
|
_async_h2d_tensor_copy,
|
||||||
|
)
|
||||||
import habana_frameworks.torch as htorch
|
import habana_frameworks.torch as htorch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from text_generation_server.models.globals import BLOCK_SIZE
|
from text_generation_server.models.globals import BLOCK_SIZE
|
||||||
@ -183,7 +187,7 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
|
|||||||
input_ids = np.concatenate(batch.input_ids, dtype=np.int64)
|
input_ids = np.concatenate(batch.input_ids, dtype=np.int64)
|
||||||
else:
|
else:
|
||||||
input_ids = batch.input_ids[0]
|
input_ids = batch.input_ids[0]
|
||||||
batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
batch.input_ids = torch.tensor(input_ids, dtype=torch.int64)
|
||||||
|
|
||||||
batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1)
|
batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1)
|
||||||
|
|
||||||
@ -206,33 +210,26 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
|
|||||||
|
|
||||||
|
|
||||||
def generate_cross_attention_states(
|
def generate_cross_attention_states(
|
||||||
cross_attention_states, image_indices, seqlen, pad_seq_len, prefilling
|
cross_attention_states, image_indices, input_lengths, pad_seq_len, prefilling
|
||||||
):
|
):
|
||||||
if cross_attention_states is None:
|
if cross_attention_states is None:
|
||||||
return None, None, None
|
return None, None, None
|
||||||
device = cross_attention_states.device
|
|
||||||
indices_list = []
|
indices_list = []
|
||||||
if prefilling:
|
if prefilling:
|
||||||
for i in image_indices:
|
for i in image_indices:
|
||||||
indices_list.append(
|
indices_list.append(torch.arange(pad_seq_len * i, pad_seq_len * (i + 1)))
|
||||||
torch.arange(pad_seq_len * i, pad_seq_len * (i + 1), device=device)
|
|
||||||
)
|
|
||||||
indices = torch.cat(indices_list, dim=0)
|
indices = torch.cat(indices_list, dim=0)
|
||||||
else:
|
else:
|
||||||
indices = image_indices[:]
|
indices = image_indices[:]
|
||||||
return indices, seqlen.input_lengths.index_select(0, image_indices)
|
return indices, input_lengths.index_select(0, image_indices)
|
||||||
|
|
||||||
|
|
||||||
class FlashMllamaCausalLM(FlashVlmCausalLM):
|
class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||||
def warmup_decode(
|
def warmup_decode(
|
||||||
self, batch_size: int, block_num: int, batch: FlashMllamaCausalLMBatch
|
self, batch_size: int, block_num: int, batch: FlashMllamaCausalLMBatch
|
||||||
):
|
):
|
||||||
input_ids = torch.zeros(
|
input_ids = torch.zeros(batch_size, dtype=batch.input_ids.dtype)
|
||||||
batch_size, dtype=batch.input_ids.dtype, device=self.device
|
position_ids = torch.arange(batch_size, dtype=batch.position_ids.dtype)
|
||||||
)
|
|
||||||
position_ids = torch.arange(
|
|
||||||
batch_size, dtype=batch.position_ids.dtype, device=self.device
|
|
||||||
)
|
|
||||||
blocks = [block_num // batch_size for _ in range(batch_size)]
|
blocks = [block_num // batch_size for _ in range(batch_size)]
|
||||||
blocks[0] += block_num % batch_size
|
blocks[0] += block_num % batch_size
|
||||||
past_len = []
|
past_len = []
|
||||||
@ -247,19 +244,10 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
block_tables.append(block_array)
|
block_tables.append(block_array)
|
||||||
past_len.append(blocks[i] * BLOCK_SIZE - 1)
|
past_len.append(blocks[i] * BLOCK_SIZE - 1)
|
||||||
start_idx += blocks[i]
|
start_idx += blocks[i]
|
||||||
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
|
input_lengths = torch.ones(batch_size, dtype=torch.int32)
|
||||||
cache_lengths_tensor = torch.tensor(
|
|
||||||
past_len, dtype=torch.int32, device=self.device
|
|
||||||
)
|
|
||||||
cu_seqlen_prefill = torch.zeros(
|
|
||||||
batch_size + 1, device=self.device, dtype=torch.int32
|
|
||||||
)
|
|
||||||
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
|
|
||||||
|
|
||||||
seqlen = Seqlen(
|
seqlen = Seqlen(
|
||||||
input_lengths=input_lengths,
|
input_lengths=_async_h2d_tensor_copy(input_lengths),
|
||||||
cache_lengths=cache_lengths_tensor,
|
|
||||||
cu_seqlen_q=cu_seqlen_prefill,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
hpu_attention_meta = prepare_for_decode(
|
hpu_attention_meta = prepare_for_decode(
|
||||||
@ -272,87 +260,86 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
bucketing_ctx=None,
|
bucketing_ctx=None,
|
||||||
)
|
)
|
||||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||||
image_indices = torch.tensor(batch.image_indices, device=self.device)
|
image_indices = torch.tensor(batch.image_indices)
|
||||||
image_indices = image_indices.repeat(batch_size)
|
image_indices = image_indices.repeat(batch_size)
|
||||||
cross_attention_states = batch.cross_attention_states.repeat(batch_size, 1, 1)
|
cross_attention_states = batch.cross_attention_states.repeat(batch_size, 1, 1)
|
||||||
indices, cross_attention_len = generate_cross_attention_states(
|
indices, cross_attention_len = generate_cross_attention_states(
|
||||||
cross_attention_states, image_indices, seqlen, 1, False
|
cross_attention_states, image_indices, input_lengths, 1, False
|
||||||
)
|
)
|
||||||
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
|
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype)
|
||||||
self.model.forward(
|
self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=_async_h2d_tensor_copy(input_ids),
|
||||||
position_ids=position_ids,
|
position_ids=_async_h2d_tensor_copy(position_ids),
|
||||||
cu_seqlen_prefill=None,
|
cu_seqlen_prefill=None,
|
||||||
kv_cache=self.kv_cache,
|
kv_cache=self.kv_cache,
|
||||||
slots=slots_tensor,
|
slots=_async_h2d_tensor_copy(slots_tensor),
|
||||||
seqlen=trim_seqlen_metadata(seqlen),
|
seqlen=trim_seqlen_metadata(seqlen),
|
||||||
hpu_attention_meta=hpu_attention_meta,
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
adapter_data=None,
|
adapter_data=None,
|
||||||
cross_attention_states=cross_attention_states,
|
cross_attention_states=cross_attention_states,
|
||||||
indices=indices,
|
indices=_async_h2d_tensor_copy(indices),
|
||||||
cross_attention_len=cross_attention_len,
|
cross_attention_len=_async_h2d_tensor_copy(cross_attention_len),
|
||||||
)
|
)
|
||||||
|
|
||||||
def warmup_prefill(
|
def warmup_prefill(
|
||||||
self, prompt_len: int, batch_size: int, batch: FlashMllamaCausalLMBatch
|
self, prompt_len: int, batch_size: int, batch: FlashMllamaCausalLMBatch
|
||||||
):
|
):
|
||||||
input_ids = torch.zeros(
|
input_ids = torch.zeros(prompt_len, dtype=batch.input_ids.dtype).repeat(
|
||||||
prompt_len, dtype=batch.input_ids.dtype, device=self.device
|
batch_size
|
||||||
).repeat(batch_size)
|
)
|
||||||
position_ids = torch.arange(
|
position_ids = torch.arange(prompt_len, dtype=batch.position_ids.dtype).repeat(
|
||||||
prompt_len, dtype=batch.position_ids.dtype, device=self.device
|
batch_size
|
||||||
).repeat(batch_size)
|
)
|
||||||
max_bt = (prompt_len // BLOCK_SIZE + 1) * batch_size
|
max_bt = (prompt_len // BLOCK_SIZE + 1) * batch_size
|
||||||
block_tables = torch.arange(
|
block_tables = torch.arange(max_bt, dtype=torch.int32).reshape(batch_size, -1)
|
||||||
max_bt, dtype=torch.int32, device=self.device
|
|
||||||
).reshape(batch_size, -1)
|
|
||||||
slot_acc = []
|
slot_acc = []
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
slots = []
|
slots = []
|
||||||
for b in block_tables[i]:
|
for b in block_tables[i]:
|
||||||
slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))
|
slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))
|
||||||
slot_acc.extend(slots[:prompt_len])
|
slot_acc.extend(slots[:prompt_len])
|
||||||
slots = torch.tensor(slot_acc, dtype=batch.slots.dtype, device=self.device)
|
slots = torch.tensor(slot_acc, dtype=batch.slots.dtype)
|
||||||
|
|
||||||
input_lengths = (
|
input_lengths = (
|
||||||
torch.ones(batch_size, dtype=torch.int32, device=self.device) * prompt_len
|
torch.ones(
|
||||||
)
|
batch_size,
|
||||||
cache_lengths_tensor = torch.zeros(
|
dtype=torch.int32,
|
||||||
batch_size, dtype=torch.int32, device=self.device
|
)
|
||||||
)
|
* prompt_len
|
||||||
cu_seqlen_prefill = torch.zeros(
|
|
||||||
batch_size + 1, device=self.device, dtype=torch.int32
|
|
||||||
)
|
)
|
||||||
|
cu_seqlen_prefill = torch.zeros(batch_size + 1, dtype=torch.int32)
|
||||||
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
|
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
|
||||||
|
|
||||||
seqlen = Seqlen(
|
|
||||||
input_lengths=input_lengths,
|
|
||||||
cache_lengths=cache_lengths_tensor,
|
|
||||||
cu_seqlen_q=cu_seqlen_prefill,
|
|
||||||
)
|
|
||||||
lm_head_indices = input_lengths - 1
|
lm_head_indices = input_lengths - 1
|
||||||
|
|
||||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||||
image_indices = torch.tensor(batch.image_indices, device=self.device)
|
image_indices = torch.tensor(batch.image_indices)
|
||||||
image_indices = image_indices.repeat(batch_size)
|
image_indices = image_indices.repeat(batch_size)
|
||||||
cross_attention_states = batch.cross_attention_states.repeat(batch_size, 1, 1)
|
cross_attention_states = batch.cross_attention_states.repeat(batch_size, 1, 1)
|
||||||
indices, cross_attention_len = generate_cross_attention_states(
|
indices, cross_attention_len = generate_cross_attention_states(
|
||||||
cross_attention_states, image_indices, seqlen, prompt_len, True
|
cross_attention_states, image_indices, input_lengths, prompt_len, True
|
||||||
)
|
)
|
||||||
|
seqlen = Seqlen(
|
||||||
|
input_lengths=_async_h2d_tensor_copy(input_lengths),
|
||||||
|
)
|
||||||
|
kwargs = {}
|
||||||
|
if htorch.utils.internal.is_lazy():
|
||||||
|
kwargs["bypass_hpu_graphs"] = self.limit_hpu_graphs
|
||||||
self.model.forward(
|
self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=_async_h2d_tensor_copy(input_ids),
|
||||||
position_ids=position_ids,
|
position_ids=_async_h2d_tensor_copy(position_ids),
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
|
||||||
kv_cache=self.kv_cache,
|
kv_cache=self.kv_cache,
|
||||||
slots=slots,
|
slots=_async_h2d_tensor_copy(slots),
|
||||||
seqlen=trim_seqlen_metadata(seqlen),
|
seqlen=trim_seqlen_metadata(seqlen),
|
||||||
hpu_attention_meta=None,
|
hpu_attention_meta=None,
|
||||||
lm_head_indices=lm_head_indices,
|
lm_head_indices=_async_h2d_tensor_copy(lm_head_indices),
|
||||||
adapter_data=None,
|
adapter_data=None,
|
||||||
cross_attention_states=cross_attention_states,
|
cross_attention_states=cross_attention_states,
|
||||||
indices=indices,
|
indices=_async_h2d_tensor_copy(indices),
|
||||||
cross_attention_len=cross_attention_len,
|
cross_attention_len=_async_h2d_tensor_copy(cross_attention_len),
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch):
|
def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch):
|
||||||
@ -410,9 +397,6 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
input_lengths = (
|
input_lengths = (
|
||||||
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||||
).view(-1)
|
).view(-1)
|
||||||
cache_lengths_tensor = (
|
|
||||||
batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
|
|
||||||
).reshape(-1)
|
|
||||||
|
|
||||||
# Add Copy the block tables for all members
|
# Add Copy the block tables for all members
|
||||||
block_tables = (
|
block_tables = (
|
||||||
@ -433,7 +417,6 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
block_tables = batch.block_tables_tensor
|
block_tables = batch.block_tables_tensor
|
||||||
slots = batch.slots[batch.slot_indices]
|
slots = batch.slots[batch.slot_indices]
|
||||||
input_lengths = batch.input_lengths_tensor
|
input_lengths = batch.input_lengths_tensor
|
||||||
cache_lengths_tensor = batch.cache_lengths_tensor
|
|
||||||
max_s = batch.max_current_length
|
max_s = batch.max_current_length
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
@ -455,100 +438,58 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if htorch.utils.internal.is_lazy():
|
if htorch.utils.internal.is_lazy():
|
||||||
kwargs["bypass_hpu_graphs"] = batch.prefilling
|
kwargs["bypass_hpu_graphs"] = (
|
||||||
|
batch.prefilling if self.limit_hpu_graphs else False
|
||||||
|
)
|
||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
slots_pad = torch.zeros_like(input_ids)
|
slots_pad = torch.zeros_like(input_ids)
|
||||||
slots_pad[batch.prefill_cache_indices] = slots
|
slots_pad[batch.prefill_cache_indices] = slots
|
||||||
slots = slots_pad
|
slots = slots_pad
|
||||||
|
|
||||||
if self.bucketing_ctx is not None:
|
|
||||||
if batch.prefilling:
|
|
||||||
padded_bs = self.bucketing_ctx.get_padded_prompt_batch_size(
|
|
||||||
input_lengths.shape[0]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
padded_bs = self.bucketing_ctx.get_padded_decode_batch_size(
|
|
||||||
input_lengths.shape[0]
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
padded_bs = input_lengths.shape[0]
|
slots_pad = torch.zeros_like(input_ids)
|
||||||
orig_bs = input_lengths.shape[0]
|
slots_pad[: slots.shape[0]] = slots
|
||||||
padded_input_len = input_ids.view(orig_bs, -1).shape[-1]
|
slots = slots_pad
|
||||||
image_indices = torch.tensor(batch.image_indices, device=self.device)
|
orig_bs = len(batch)
|
||||||
if padded_bs != input_lengths.shape[0]:
|
padded_bs = batch.input_lengths_tensor.shape[0]
|
||||||
padded_input_lengths = F.pad(
|
padded_input_len = input_ids.view(padded_bs, -1).shape[-1]
|
||||||
input_lengths,
|
image_indices = torch.tensor(batch.image_indices)
|
||||||
(0, padded_bs - orig_bs),
|
|
||||||
|
if cross_attention_states is not None:
|
||||||
|
cross_attention_states = F.pad(
|
||||||
|
cross_attention_states,
|
||||||
|
(0, 0, 0, 0, 0, (padded_bs - orig_bs)),
|
||||||
value=0,
|
value=0,
|
||||||
)
|
)
|
||||||
padded_cache_lengths_tensor = F.pad(
|
if len(image_indices) != 0:
|
||||||
cache_lengths_tensor,
|
pad_indices = torch.arange(orig_bs, padded_bs)
|
||||||
(0, padded_bs - orig_bs),
|
image_indices = torch.cat((image_indices, pad_indices), dim=0)
|
||||||
value=0,
|
|
||||||
)
|
|
||||||
if cu_seqlen_prefill is not None:
|
|
||||||
cu_seqlen_prefill = torch.zeros(
|
|
||||||
padded_bs + 1, device=self.device, dtype=torch.int32
|
|
||||||
)
|
|
||||||
torch.cumsum(padded_input_lengths, -1, out=cu_seqlen_prefill[1:])
|
|
||||||
seqlen = Seqlen(
|
|
||||||
input_lengths=padded_input_lengths,
|
|
||||||
cache_lengths=padded_cache_lengths_tensor,
|
|
||||||
cu_seqlen_q=cu_seqlen_prefill,
|
|
||||||
)
|
|
||||||
|
|
||||||
input_ids = F.pad(
|
|
||||||
input_ids, (0, (padded_bs - orig_bs) * padded_input_len), value=0
|
|
||||||
)
|
|
||||||
position_ids = F.pad(
|
|
||||||
position_ids, (0, (padded_bs - orig_bs) * padded_input_len), value=1
|
|
||||||
)
|
|
||||||
slots = F.pad(slots, (0, (padded_bs - orig_bs) * padded_input_len), value=0)
|
|
||||||
if lm_head_indices is not None:
|
|
||||||
lm_head_indices = F.pad(
|
|
||||||
lm_head_indices, (0, padded_bs - orig_bs), value=0
|
|
||||||
)
|
|
||||||
if cross_attention_states is not None:
|
|
||||||
cross_attention_states = F.pad(
|
|
||||||
cross_attention_states,
|
|
||||||
(0, 0, 0, 0, 0, (padded_bs - orig_bs)),
|
|
||||||
value=0,
|
|
||||||
)
|
|
||||||
if len(image_indices) != 0:
|
|
||||||
pad_indices = torch.arange(orig_bs, padded_bs, device=self.device)
|
|
||||||
image_indices = torch.cat((image_indices, pad_indices), dim=0)
|
|
||||||
else:
|
|
||||||
seqlen = Seqlen(
|
|
||||||
input_lengths=input_lengths,
|
|
||||||
cache_lengths=cache_lengths_tensor,
|
|
||||||
cu_seqlen_q=cu_seqlen_prefill,
|
|
||||||
)
|
|
||||||
|
|
||||||
indices, cross_attention_len = generate_cross_attention_states(
|
indices, cross_attention_len = generate_cross_attention_states(
|
||||||
cross_attention_states,
|
cross_attention_states,
|
||||||
image_indices,
|
image_indices,
|
||||||
seqlen,
|
input_lengths,
|
||||||
padded_input_len,
|
padded_input_len,
|
||||||
batch.prefilling,
|
batch.prefilling,
|
||||||
)
|
)
|
||||||
|
seqlen = Seqlen(
|
||||||
|
input_lengths=_async_h2d_tensor_copy(input_lengths),
|
||||||
|
)
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=_async_h2d_tensor_copy(input_ids),
|
||||||
position_ids=position_ids,
|
position_ids=_async_h2d_tensor_copy(position_ids),
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
slots=slots,
|
slots=_async_h2d_tensor_copy(slots),
|
||||||
seqlen=trim_seqlen_metadata(seqlen),
|
seqlen=trim_seqlen_metadata(seqlen),
|
||||||
hpu_attention_meta=batch.hpu_attn_meta,
|
hpu_attention_meta=batch.hpu_attn_meta,
|
||||||
lm_head_indices=lm_head_indices,
|
lm_head_indices=_async_h2d_tensor_copy(lm_head_indices),
|
||||||
# TODO list
|
# TODO list
|
||||||
adapter_data=None,
|
adapter_data=None,
|
||||||
cross_attention_states=cross_attention_states,
|
cross_attention_states=cross_attention_states,
|
||||||
indices=indices,
|
indices=_async_h2d_tensor_copy(indices),
|
||||||
cross_attention_len=cross_attention_len,
|
cross_attention_len=_async_h2d_tensor_copy(cross_attention_len),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
if batch.pixel_values is not None:
|
if batch.pixel_values is not None:
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
return logits[:orig_bs], (
|
return logits, speculative_logits
|
||||||
speculative_logits[:orig_bs] if speculative_logits is not None else None
|
|
||||||
)
|
|
||||||
|
@ -552,8 +552,13 @@ def pad_next_token_chooser_parameters(
|
|||||||
|
|
||||||
class Sampling:
|
class Sampling:
|
||||||
def __init__(self, seed: int, device: str = "cpu"):
|
def __init__(self, seed: int, device: str = "cpu"):
|
||||||
self.generator = torch.Generator("cpu")
|
if device in ["hpu", torch.device("hpu")]:
|
||||||
self.generator.manual_seed(seed)
|
import habana_frameworks.torch.hpu.random as htrandom
|
||||||
|
|
||||||
|
self.generator = htrandom.default_generators[0].manual_seed(seed)
|
||||||
|
else:
|
||||||
|
self.generator = torch.Generator("cpu")
|
||||||
|
self.generator.manual_seed(seed)
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
|
|
||||||
def __call__(self, logits):
|
def __call__(self, logits):
|
||||||
|
Loading…
Reference in New Issue
Block a user