mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 00:12:08 +00:00
warmup prefill
remove model where pageattn is not used, set block table to None since it's not used Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
69773767c5
commit
fd70ad703e
@ -92,7 +92,6 @@ try:
|
||||
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
|
||||
FlashPhiForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM
|
||||
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch
|
||||
from text_generation_server.models.custom_modeling.flash_mllama import (
|
||||
FlashMllamaForConditionalGeneration,
|
||||
@ -144,7 +143,6 @@ except ImportError as e:
|
||||
|
||||
if FLASH_ATTENTION:
|
||||
__all__.append(FlashCausalLM)
|
||||
__all__.append(IdeficsCausalLM)
|
||||
|
||||
|
||||
class ModelType(enum.Enum):
|
||||
@ -301,12 +299,6 @@ class ModelType(enum.Enum):
|
||||
"name": "Gptj",
|
||||
"url": "https://huggingface.co/EleutherAI/gpt-j-6b",
|
||||
}
|
||||
IDEFICS = {
|
||||
"type": "idefics",
|
||||
"name": "Idefics",
|
||||
"url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
|
||||
"multimodal": True,
|
||||
}
|
||||
MLLAMA = {
|
||||
"type": "mllama",
|
||||
"name": "Mllama",
|
||||
@ -733,15 +725,6 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif model_type == IDEFICS:
|
||||
return IdeficsCausalLM(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif model_type == QWEN2_VL:
|
||||
return FlashVlmCausalLM(
|
||||
model_id=model_id,
|
||||
|
@ -69,6 +69,8 @@ from text_generation_server.utils.import_utils import (
|
||||
|
||||
import vllm_hpu_extension.environment as environment
|
||||
import habana_frameworks.torch as htorch
|
||||
import itertools
|
||||
from vllm_hpu_extension.ops import batch2block, block2batch
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -86,6 +88,78 @@ def get_sliding_windows() -> int:
|
||||
return SLIDING_WINDOW
|
||||
|
||||
|
||||
def prepare_for_decode(
|
||||
dtype, use_contiguous_pa, device, slot, block_tables, batch_size
|
||||
):
|
||||
# Prepare values if we need to continue decoding
|
||||
# need for HPUPagedAttentionMetadata preparation
|
||||
def flatten(in_list):
|
||||
return list(itertools.chain(*in_list))
|
||||
|
||||
def gather_list(input, indices, v):
|
||||
return [input[i] if i is not None else v for i in indices]
|
||||
|
||||
def pad_list(input, k, v):
|
||||
input_len = len(input)
|
||||
target_len = (input_len + k - 1) // k * k
|
||||
padding = target_len - input_len
|
||||
return input + [v] * padding
|
||||
|
||||
last_block_usage = slot % BLOCK_SIZE + 1
|
||||
block_groups = [[i] * len(bt) for i, bt in enumerate(block_tables)]
|
||||
block_usage = [
|
||||
[BLOCK_SIZE] * (len(bt) - 1) + [lbu]
|
||||
for bt, lbu in zip(block_tables, last_block_usage)
|
||||
if bt
|
||||
]
|
||||
|
||||
block_list = flatten(block_tables)
|
||||
block_groups = flatten(block_groups)
|
||||
block_usage = flatten(block_usage)
|
||||
|
||||
assert len(block_list) == len(block_groups)
|
||||
assert len(block_list) == len(block_usage)
|
||||
if use_contiguous_pa:
|
||||
block_bucket_size = max(max(block_list) + 1, len(block_list))
|
||||
# block_bucket_size = self.bucketing_ctx.get_padded_decode_num_blocks(
|
||||
# block_bucket_size)
|
||||
indices: List[Any]
|
||||
indices = [None] * block_bucket_size
|
||||
for i, bid in enumerate(block_list):
|
||||
indices[bid] = i
|
||||
block_list = gather_list(block_list, indices, 0)
|
||||
block_groups = gather_list(block_groups, indices, -1)
|
||||
block_usage = gather_list(block_usage, indices, 1)
|
||||
else:
|
||||
block_bucket_size = len(block_list)
|
||||
block_list = pad_list(block_list, block_bucket_size, 0)
|
||||
block_groups = pad_list(block_groups, block_bucket_size, -1)
|
||||
block_usage = pad_list(block_usage, block_bucket_size, 1)
|
||||
|
||||
block_list = torch.tensor(block_list, dtype=torch.int, device=device)
|
||||
block_groups = torch.tensor(block_groups, dtype=torch.int, device=device)
|
||||
block_usage = torch.tensor(block_usage, dtype=dtype, device=device)
|
||||
block_mapping = torch.nn.functional.one_hot(block_groups, num_classes=batch_size)
|
||||
mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0)
|
||||
mask = mask >= block_usage.unsqueeze(-1)
|
||||
attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)
|
||||
ones = torch.ones(
|
||||
(block_mapping.size(0),), device=device, dtype=block_mapping.dtype
|
||||
)
|
||||
sums = batch2block(block2batch(ones, block_mapping), block_mapping)
|
||||
block_scales = torch.reciprocal(torch.maximum(ones, sums))
|
||||
return trim_attn_metadata(
|
||||
HPUPagedAttentionMetadata(
|
||||
block_list=block_list,
|
||||
block_groups=block_groups,
|
||||
block_usage=block_usage,
|
||||
block_mapping=block_mapping.to(dtype),
|
||||
attn_bias=attn_bias,
|
||||
block_scales=block_scales,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashCausalLMBatch(Batch):
|
||||
batch_id: int
|
||||
@ -879,83 +953,18 @@ class FlashCausalLMBatch(Batch):
|
||||
)
|
||||
|
||||
def prepare_for_decode(self, dtype, use_contiguous_pa):
|
||||
# Prepare values if we need to continue decoding
|
||||
# need for HPUPagedAttentionMetadata preparation
|
||||
import itertools
|
||||
from vllm_hpu_extension.ops import batch2block, block2batch
|
||||
|
||||
def flatten(in_list):
|
||||
return list(itertools.chain(*in_list))
|
||||
|
||||
def gather_list(input, indices, v):
|
||||
return [input[i] if i is not None else v for i in indices]
|
||||
|
||||
def pad_list(input, k, v):
|
||||
input_len = len(input)
|
||||
target_len = (input_len + k - 1) // k * k
|
||||
padding = target_len - input_len
|
||||
return input + [v] * padding
|
||||
|
||||
device = self.block_tables_tensor.device
|
||||
last_block_usage = self.slots[self.slot_indices] % BLOCK_SIZE + 1
|
||||
block_num = self.cache_lengths_tensor // BLOCK_SIZE + 1
|
||||
block_tables = []
|
||||
for i, bt in enumerate(self.block_tables):
|
||||
block_tables.append(bt[0 : block_num[i]])
|
||||
block_groups = [[i] * len(bt) for i, bt in enumerate(block_tables)]
|
||||
block_usage = [
|
||||
[BLOCK_SIZE] * (len(bt) - 1) + [lbu]
|
||||
for bt, lbu in zip(block_tables, last_block_usage)
|
||||
if bt
|
||||
]
|
||||
|
||||
block_list = flatten(block_tables)
|
||||
block_groups = flatten(block_groups)
|
||||
block_usage = flatten(block_usage)
|
||||
batch = self.input_ids.size(0)
|
||||
|
||||
assert len(block_list) == len(block_groups)
|
||||
assert len(block_list) == len(block_usage)
|
||||
if use_contiguous_pa:
|
||||
block_bucket_size = max(max(block_list) + 1, len(block_list))
|
||||
# block_bucket_size = self.bucketing_ctx.get_padded_decode_num_blocks(
|
||||
# block_bucket_size)
|
||||
indices: List[Any]
|
||||
indices = [None] * block_bucket_size
|
||||
for i, bid in enumerate(block_list):
|
||||
indices[bid] = i
|
||||
block_list = gather_list(block_list, indices, 0)
|
||||
block_groups = gather_list(block_groups, indices, -1)
|
||||
block_usage = gather_list(block_usage, indices, 1)
|
||||
else:
|
||||
block_bucket_size = len(block_list)
|
||||
block_list = pad_list(block_list, block_bucket_size, 0)
|
||||
block_groups = pad_list(block_groups, block_bucket_size, -1)
|
||||
block_usage = pad_list(block_usage, block_bucket_size, 1)
|
||||
|
||||
block_list = torch.tensor(block_list, dtype=torch.int, device=device)
|
||||
block_groups = torch.tensor(block_groups, dtype=torch.int, device=device)
|
||||
block_usage = torch.tensor(block_usage, dtype=dtype, device=device)
|
||||
block_mapping = torch.nn.functional.one_hot(block_groups, num_classes=batch)
|
||||
mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(
|
||||
0
|
||||
)
|
||||
mask = mask >= block_usage.unsqueeze(-1)
|
||||
attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)
|
||||
ones = torch.ones(
|
||||
(block_mapping.size(0),), device=device, dtype=block_mapping.dtype
|
||||
)
|
||||
sums = batch2block(block2batch(ones, block_mapping), block_mapping)
|
||||
block_scales = torch.reciprocal(torch.maximum(ones, sums))
|
||||
self.hpu_attn_meta = trim_attn_metadata(
|
||||
HPUPagedAttentionMetadata(
|
||||
block_list=block_list,
|
||||
block_groups=block_groups,
|
||||
block_usage=block_usage,
|
||||
block_mapping=block_mapping.to(dtype),
|
||||
attn_bias=attn_bias,
|
||||
block_scales=block_scales,
|
||||
)
|
||||
self.hpu_attn_meta = prepare_for_decode(
|
||||
dtype,
|
||||
use_contiguous_pa,
|
||||
self.block_tables_tensor.device,
|
||||
self.slots[self.slot_indices],
|
||||
block_tables,
|
||||
self.input_ids.size(0),
|
||||
)
|
||||
|
||||
def prepare_for_prefill(self):
|
||||
@ -1481,32 +1490,44 @@ class FlashCausalLM(Model):
|
||||
self.kv_cache_dtype,
|
||||
self.device,
|
||||
)
|
||||
for bs in [1, 2, 4, 8]:
|
||||
for seqlen in [32, 64, 128, 256, 512, 1024]:
|
||||
self.warmup_prefill(seqlen, bs)
|
||||
|
||||
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
|
||||
|
||||
def tunableop_warmup(self, seqlen: int, max_bt: int):
|
||||
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
|
||||
position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
|
||||
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
|
||||
|
||||
# Dummy value, some models (starcoder2) don't accept `None`.
|
||||
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
|
||||
cache_lengths_tensor = torch.zeros(
|
||||
seqlen, dtype=torch.int32, device=self.device
|
||||
)
|
||||
cu_seqlen_prefill = torch.tensor(
|
||||
[0, seqlen], device=self.device, dtype=torch.int32
|
||||
)
|
||||
|
||||
def warmup_prefill(self, prompt_len: int, bs: int):
|
||||
input_ids = torch.zeros(
|
||||
prompt_len, dtype=torch.int64, device=self.device
|
||||
).repeat(bs)
|
||||
position_ids = torch.arange(
|
||||
prompt_len, dtype=torch.int32, device=self.device
|
||||
).repeat(bs)
|
||||
max_bt = (prompt_len // BLOCK_SIZE + 1) * bs
|
||||
block_tables = torch.arange(
|
||||
max_bt, dtype=torch.int32, device=self.device
|
||||
).repeat(seqlen)
|
||||
block_tables = block_tables.reshape((seqlen, max_bt))
|
||||
).reshape(bs, -1)
|
||||
slot_acc = []
|
||||
for i in range(bs):
|
||||
slots = []
|
||||
for b in block_tables[i]:
|
||||
slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))
|
||||
slot_acc.extend(slots[:prompt_len])
|
||||
slots = torch.tensor(slot_acc, dtype=torch.int64, device=self.device)
|
||||
|
||||
input_lengths = (
|
||||
torch.ones(bs, dtype=torch.int32, device=self.device) * prompt_len
|
||||
)
|
||||
cache_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||
cu_seqlen_prefill = torch.zeros(bs + 1, device=self.device, dtype=torch.int32)
|
||||
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
|
||||
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths,
|
||||
cache_lengths=cache_lengths_tensor,
|
||||
cu_seqlen_q=cu_seqlen_prefill,
|
||||
)
|
||||
lm_head_indices = input_lengths - 1
|
||||
|
||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||
self.model.forward(
|
||||
@ -1514,11 +1535,13 @@ class FlashCausalLM(Model):
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=self.kv_cache,
|
||||
block_tables=block_tables,
|
||||
seqlen=seqlen,
|
||||
block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph
|
||||
slots=slots,
|
||||
lm_head_indices=None,
|
||||
seqlen=trim_seqlen_metadata(seqlen),
|
||||
prefill_cache_indices=None,
|
||||
lm_head_indices=lm_head_indices,
|
||||
adapter_data=None,
|
||||
hpu_attention_meta=None,
|
||||
)
|
||||
|
||||
def forward(
|
||||
@ -1606,7 +1629,7 @@ class FlashCausalLM(Model):
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
block_tables=None,
|
||||
slots=slots,
|
||||
seqlen=trim_seqlen_metadata(seqlen),
|
||||
prefill_cache_indices=batch.prefill_cache_indices,
|
||||
@ -1637,9 +1660,7 @@ class FlashCausalLM(Model):
|
||||
batch.prepare_for_prefill()
|
||||
else:
|
||||
batch.prepare_for_decode(self.dtype, self.use_contiguous_pa)
|
||||
|
||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||
|
||||
# Update adapter indices for speculative tokens (if present)
|
||||
adapter_meta = batch.adapter_meta
|
||||
if batch.speculative_ids is not None:
|
||||
|
@ -462,7 +462,7 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph
|
||||
slots=slots,
|
||||
seqlen=trim_seqlen_metadata(seqlen),
|
||||
hpu_attention_meta=batch.hpu_attn_meta,
|
||||
|
@ -288,7 +288,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph
|
||||
slots=slots,
|
||||
seqlen=trim_seqlen_metadata(seqlen),
|
||||
hpu_attention_meta=batch.hpu_attn_meta,
|
||||
|
@ -33,13 +33,11 @@ try:
|
||||
from text_generation_server.models.flash_vlm_causal_lm import (
|
||||
FlashVlmCausalLMBatch,
|
||||
)
|
||||
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
|
||||
|
||||
VLM_BATCH_TYPES = {
|
||||
PaliGemmaBatch,
|
||||
VlmCausalLMBatch,
|
||||
FlashVlmCausalLMBatch,
|
||||
IdeficsCausalLMBatch,
|
||||
FlashMllamaCausalLMBatch,
|
||||
}
|
||||
except (ImportError, NotImplementedError):
|
||||
|
Loading…
Reference in New Issue
Block a user