mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 00:12:08 +00:00
warmup prefill
remove model where pageattn is not used, set block table to None since it's not used Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
69773767c5
commit
fd70ad703e
@ -92,7 +92,6 @@ try:
|
|||||||
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
|
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
|
||||||
FlashPhiForCausalLM,
|
FlashPhiForCausalLM,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM
|
|
||||||
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch
|
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch
|
||||||
from text_generation_server.models.custom_modeling.flash_mllama import (
|
from text_generation_server.models.custom_modeling.flash_mllama import (
|
||||||
FlashMllamaForConditionalGeneration,
|
FlashMllamaForConditionalGeneration,
|
||||||
@ -144,7 +143,6 @@ except ImportError as e:
|
|||||||
|
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
__all__.append(FlashCausalLM)
|
__all__.append(FlashCausalLM)
|
||||||
__all__.append(IdeficsCausalLM)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelType(enum.Enum):
|
class ModelType(enum.Enum):
|
||||||
@ -301,12 +299,6 @@ class ModelType(enum.Enum):
|
|||||||
"name": "Gptj",
|
"name": "Gptj",
|
||||||
"url": "https://huggingface.co/EleutherAI/gpt-j-6b",
|
"url": "https://huggingface.co/EleutherAI/gpt-j-6b",
|
||||||
}
|
}
|
||||||
IDEFICS = {
|
|
||||||
"type": "idefics",
|
|
||||||
"name": "Idefics",
|
|
||||||
"url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
|
|
||||||
"multimodal": True,
|
|
||||||
}
|
|
||||||
MLLAMA = {
|
MLLAMA = {
|
||||||
"type": "mllama",
|
"type": "mllama",
|
||||||
"name": "Mllama",
|
"name": "Mllama",
|
||||||
@ -733,15 +725,6 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
elif model_type == IDEFICS:
|
|
||||||
return IdeficsCausalLM(
|
|
||||||
model_id,
|
|
||||||
revision,
|
|
||||||
quantize=quantize,
|
|
||||||
speculator=speculator,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
elif model_type == QWEN2_VL:
|
elif model_type == QWEN2_VL:
|
||||||
return FlashVlmCausalLM(
|
return FlashVlmCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
@ -69,6 +69,8 @@ from text_generation_server.utils.import_utils import (
|
|||||||
|
|
||||||
import vllm_hpu_extension.environment as environment
|
import vllm_hpu_extension.environment as environment
|
||||||
import habana_frameworks.torch as htorch
|
import habana_frameworks.torch as htorch
|
||||||
|
import itertools
|
||||||
|
from vllm_hpu_extension.ops import batch2block, block2batch
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -86,6 +88,78 @@ def get_sliding_windows() -> int:
|
|||||||
return SLIDING_WINDOW
|
return SLIDING_WINDOW
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_for_decode(
|
||||||
|
dtype, use_contiguous_pa, device, slot, block_tables, batch_size
|
||||||
|
):
|
||||||
|
# Prepare values if we need to continue decoding
|
||||||
|
# need for HPUPagedAttentionMetadata preparation
|
||||||
|
def flatten(in_list):
|
||||||
|
return list(itertools.chain(*in_list))
|
||||||
|
|
||||||
|
def gather_list(input, indices, v):
|
||||||
|
return [input[i] if i is not None else v for i in indices]
|
||||||
|
|
||||||
|
def pad_list(input, k, v):
|
||||||
|
input_len = len(input)
|
||||||
|
target_len = (input_len + k - 1) // k * k
|
||||||
|
padding = target_len - input_len
|
||||||
|
return input + [v] * padding
|
||||||
|
|
||||||
|
last_block_usage = slot % BLOCK_SIZE + 1
|
||||||
|
block_groups = [[i] * len(bt) for i, bt in enumerate(block_tables)]
|
||||||
|
block_usage = [
|
||||||
|
[BLOCK_SIZE] * (len(bt) - 1) + [lbu]
|
||||||
|
for bt, lbu in zip(block_tables, last_block_usage)
|
||||||
|
if bt
|
||||||
|
]
|
||||||
|
|
||||||
|
block_list = flatten(block_tables)
|
||||||
|
block_groups = flatten(block_groups)
|
||||||
|
block_usage = flatten(block_usage)
|
||||||
|
|
||||||
|
assert len(block_list) == len(block_groups)
|
||||||
|
assert len(block_list) == len(block_usage)
|
||||||
|
if use_contiguous_pa:
|
||||||
|
block_bucket_size = max(max(block_list) + 1, len(block_list))
|
||||||
|
# block_bucket_size = self.bucketing_ctx.get_padded_decode_num_blocks(
|
||||||
|
# block_bucket_size)
|
||||||
|
indices: List[Any]
|
||||||
|
indices = [None] * block_bucket_size
|
||||||
|
for i, bid in enumerate(block_list):
|
||||||
|
indices[bid] = i
|
||||||
|
block_list = gather_list(block_list, indices, 0)
|
||||||
|
block_groups = gather_list(block_groups, indices, -1)
|
||||||
|
block_usage = gather_list(block_usage, indices, 1)
|
||||||
|
else:
|
||||||
|
block_bucket_size = len(block_list)
|
||||||
|
block_list = pad_list(block_list, block_bucket_size, 0)
|
||||||
|
block_groups = pad_list(block_groups, block_bucket_size, -1)
|
||||||
|
block_usage = pad_list(block_usage, block_bucket_size, 1)
|
||||||
|
|
||||||
|
block_list = torch.tensor(block_list, dtype=torch.int, device=device)
|
||||||
|
block_groups = torch.tensor(block_groups, dtype=torch.int, device=device)
|
||||||
|
block_usage = torch.tensor(block_usage, dtype=dtype, device=device)
|
||||||
|
block_mapping = torch.nn.functional.one_hot(block_groups, num_classes=batch_size)
|
||||||
|
mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0)
|
||||||
|
mask = mask >= block_usage.unsqueeze(-1)
|
||||||
|
attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)
|
||||||
|
ones = torch.ones(
|
||||||
|
(block_mapping.size(0),), device=device, dtype=block_mapping.dtype
|
||||||
|
)
|
||||||
|
sums = batch2block(block2batch(ones, block_mapping), block_mapping)
|
||||||
|
block_scales = torch.reciprocal(torch.maximum(ones, sums))
|
||||||
|
return trim_attn_metadata(
|
||||||
|
HPUPagedAttentionMetadata(
|
||||||
|
block_list=block_list,
|
||||||
|
block_groups=block_groups,
|
||||||
|
block_usage=block_usage,
|
||||||
|
block_mapping=block_mapping.to(dtype),
|
||||||
|
attn_bias=attn_bias,
|
||||||
|
block_scales=block_scales,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FlashCausalLMBatch(Batch):
|
class FlashCausalLMBatch(Batch):
|
||||||
batch_id: int
|
batch_id: int
|
||||||
@ -879,83 +953,18 @@ class FlashCausalLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def prepare_for_decode(self, dtype, use_contiguous_pa):
|
def prepare_for_decode(self, dtype, use_contiguous_pa):
|
||||||
# Prepare values if we need to continue decoding
|
|
||||||
# need for HPUPagedAttentionMetadata preparation
|
|
||||||
import itertools
|
|
||||||
from vllm_hpu_extension.ops import batch2block, block2batch
|
|
||||||
|
|
||||||
def flatten(in_list):
|
|
||||||
return list(itertools.chain(*in_list))
|
|
||||||
|
|
||||||
def gather_list(input, indices, v):
|
|
||||||
return [input[i] if i is not None else v for i in indices]
|
|
||||||
|
|
||||||
def pad_list(input, k, v):
|
|
||||||
input_len = len(input)
|
|
||||||
target_len = (input_len + k - 1) // k * k
|
|
||||||
padding = target_len - input_len
|
|
||||||
return input + [v] * padding
|
|
||||||
|
|
||||||
device = self.block_tables_tensor.device
|
|
||||||
last_block_usage = self.slots[self.slot_indices] % BLOCK_SIZE + 1
|
|
||||||
block_num = self.cache_lengths_tensor // BLOCK_SIZE + 1
|
block_num = self.cache_lengths_tensor // BLOCK_SIZE + 1
|
||||||
block_tables = []
|
block_tables = []
|
||||||
for i, bt in enumerate(self.block_tables):
|
for i, bt in enumerate(self.block_tables):
|
||||||
block_tables.append(bt[0 : block_num[i]])
|
block_tables.append(bt[0 : block_num[i]])
|
||||||
block_groups = [[i] * len(bt) for i, bt in enumerate(block_tables)]
|
|
||||||
block_usage = [
|
|
||||||
[BLOCK_SIZE] * (len(bt) - 1) + [lbu]
|
|
||||||
for bt, lbu in zip(block_tables, last_block_usage)
|
|
||||||
if bt
|
|
||||||
]
|
|
||||||
|
|
||||||
block_list = flatten(block_tables)
|
self.hpu_attn_meta = prepare_for_decode(
|
||||||
block_groups = flatten(block_groups)
|
dtype,
|
||||||
block_usage = flatten(block_usage)
|
use_contiguous_pa,
|
||||||
batch = self.input_ids.size(0)
|
self.block_tables_tensor.device,
|
||||||
|
self.slots[self.slot_indices],
|
||||||
assert len(block_list) == len(block_groups)
|
block_tables,
|
||||||
assert len(block_list) == len(block_usage)
|
self.input_ids.size(0),
|
||||||
if use_contiguous_pa:
|
|
||||||
block_bucket_size = max(max(block_list) + 1, len(block_list))
|
|
||||||
# block_bucket_size = self.bucketing_ctx.get_padded_decode_num_blocks(
|
|
||||||
# block_bucket_size)
|
|
||||||
indices: List[Any]
|
|
||||||
indices = [None] * block_bucket_size
|
|
||||||
for i, bid in enumerate(block_list):
|
|
||||||
indices[bid] = i
|
|
||||||
block_list = gather_list(block_list, indices, 0)
|
|
||||||
block_groups = gather_list(block_groups, indices, -1)
|
|
||||||
block_usage = gather_list(block_usage, indices, 1)
|
|
||||||
else:
|
|
||||||
block_bucket_size = len(block_list)
|
|
||||||
block_list = pad_list(block_list, block_bucket_size, 0)
|
|
||||||
block_groups = pad_list(block_groups, block_bucket_size, -1)
|
|
||||||
block_usage = pad_list(block_usage, block_bucket_size, 1)
|
|
||||||
|
|
||||||
block_list = torch.tensor(block_list, dtype=torch.int, device=device)
|
|
||||||
block_groups = torch.tensor(block_groups, dtype=torch.int, device=device)
|
|
||||||
block_usage = torch.tensor(block_usage, dtype=dtype, device=device)
|
|
||||||
block_mapping = torch.nn.functional.one_hot(block_groups, num_classes=batch)
|
|
||||||
mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(
|
|
||||||
0
|
|
||||||
)
|
|
||||||
mask = mask >= block_usage.unsqueeze(-1)
|
|
||||||
attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)
|
|
||||||
ones = torch.ones(
|
|
||||||
(block_mapping.size(0),), device=device, dtype=block_mapping.dtype
|
|
||||||
)
|
|
||||||
sums = batch2block(block2batch(ones, block_mapping), block_mapping)
|
|
||||||
block_scales = torch.reciprocal(torch.maximum(ones, sums))
|
|
||||||
self.hpu_attn_meta = trim_attn_metadata(
|
|
||||||
HPUPagedAttentionMetadata(
|
|
||||||
block_list=block_list,
|
|
||||||
block_groups=block_groups,
|
|
||||||
block_usage=block_usage,
|
|
||||||
block_mapping=block_mapping.to(dtype),
|
|
||||||
attn_bias=attn_bias,
|
|
||||||
block_scales=block_scales,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_for_prefill(self):
|
def prepare_for_prefill(self):
|
||||||
@ -1481,32 +1490,44 @@ class FlashCausalLM(Model):
|
|||||||
self.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
self.device,
|
self.device,
|
||||||
)
|
)
|
||||||
|
for bs in [1, 2, 4, 8]:
|
||||||
|
for seqlen in [32, 64, 128, 256, 512, 1024]:
|
||||||
|
self.warmup_prefill(seqlen, bs)
|
||||||
|
|
||||||
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
|
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
|
||||||
|
|
||||||
def tunableop_warmup(self, seqlen: int, max_bt: int):
|
def warmup_prefill(self, prompt_len: int, bs: int):
|
||||||
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
|
input_ids = torch.zeros(
|
||||||
position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
|
prompt_len, dtype=torch.int64, device=self.device
|
||||||
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
|
).repeat(bs)
|
||||||
|
position_ids = torch.arange(
|
||||||
# Dummy value, some models (starcoder2) don't accept `None`.
|
prompt_len, dtype=torch.int32, device=self.device
|
||||||
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
|
).repeat(bs)
|
||||||
cache_lengths_tensor = torch.zeros(
|
max_bt = (prompt_len // BLOCK_SIZE + 1) * bs
|
||||||
seqlen, dtype=torch.int32, device=self.device
|
|
||||||
)
|
|
||||||
cu_seqlen_prefill = torch.tensor(
|
|
||||||
[0, seqlen], device=self.device, dtype=torch.int32
|
|
||||||
)
|
|
||||||
|
|
||||||
block_tables = torch.arange(
|
block_tables = torch.arange(
|
||||||
max_bt, dtype=torch.int32, device=self.device
|
max_bt, dtype=torch.int32, device=self.device
|
||||||
).repeat(seqlen)
|
).reshape(bs, -1)
|
||||||
block_tables = block_tables.reshape((seqlen, max_bt))
|
slot_acc = []
|
||||||
|
for i in range(bs):
|
||||||
|
slots = []
|
||||||
|
for b in block_tables[i]:
|
||||||
|
slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))
|
||||||
|
slot_acc.extend(slots[:prompt_len])
|
||||||
|
slots = torch.tensor(slot_acc, dtype=torch.int64, device=self.device)
|
||||||
|
|
||||||
|
input_lengths = (
|
||||||
|
torch.ones(bs, dtype=torch.int32, device=self.device) * prompt_len
|
||||||
|
)
|
||||||
|
cache_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||||
|
cu_seqlen_prefill = torch.zeros(bs + 1, device=self.device, dtype=torch.int32)
|
||||||
|
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
|
||||||
|
|
||||||
seqlen = Seqlen(
|
seqlen = Seqlen(
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
cache_lengths=cache_lengths_tensor,
|
cache_lengths=cache_lengths_tensor,
|
||||||
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
)
|
)
|
||||||
|
lm_head_indices = input_lengths - 1
|
||||||
|
|
||||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||||
self.model.forward(
|
self.model.forward(
|
||||||
@ -1514,11 +1535,13 @@ class FlashCausalLM(Model):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
kv_cache=self.kv_cache,
|
kv_cache=self.kv_cache,
|
||||||
block_tables=block_tables,
|
block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph
|
||||||
seqlen=seqlen,
|
|
||||||
slots=slots,
|
slots=slots,
|
||||||
lm_head_indices=None,
|
seqlen=trim_seqlen_metadata(seqlen),
|
||||||
prefill_cache_indices=None,
|
prefill_cache_indices=None,
|
||||||
|
lm_head_indices=lm_head_indices,
|
||||||
|
adapter_data=None,
|
||||||
|
hpu_attention_meta=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -1606,7 +1629,7 @@ class FlashCausalLM(Model):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
block_tables=block_tables,
|
block_tables=None,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
seqlen=trim_seqlen_metadata(seqlen),
|
seqlen=trim_seqlen_metadata(seqlen),
|
||||||
prefill_cache_indices=batch.prefill_cache_indices,
|
prefill_cache_indices=batch.prefill_cache_indices,
|
||||||
@ -1637,9 +1660,7 @@ class FlashCausalLM(Model):
|
|||||||
batch.prepare_for_prefill()
|
batch.prepare_for_prefill()
|
||||||
else:
|
else:
|
||||||
batch.prepare_for_decode(self.dtype, self.use_contiguous_pa)
|
batch.prepare_for_decode(self.dtype, self.use_contiguous_pa)
|
||||||
|
|
||||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||||
|
|
||||||
# Update adapter indices for speculative tokens (if present)
|
# Update adapter indices for speculative tokens (if present)
|
||||||
adapter_meta = batch.adapter_meta
|
adapter_meta = batch.adapter_meta
|
||||||
if batch.speculative_ids is not None:
|
if batch.speculative_ids is not None:
|
||||||
|
@ -462,7 +462,7 @@ class FlashVlmCausalLM(FlashCausalLM):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
block_tables=block_tables,
|
block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph
|
||||||
slots=slots,
|
slots=slots,
|
||||||
seqlen=trim_seqlen_metadata(seqlen),
|
seqlen=trim_seqlen_metadata(seqlen),
|
||||||
hpu_attention_meta=batch.hpu_attn_meta,
|
hpu_attention_meta=batch.hpu_attn_meta,
|
||||||
|
@ -288,7 +288,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
block_tables=block_tables,
|
block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph
|
||||||
slots=slots,
|
slots=slots,
|
||||||
seqlen=trim_seqlen_metadata(seqlen),
|
seqlen=trim_seqlen_metadata(seqlen),
|
||||||
hpu_attention_meta=batch.hpu_attn_meta,
|
hpu_attention_meta=batch.hpu_attn_meta,
|
||||||
|
@ -33,13 +33,11 @@ try:
|
|||||||
from text_generation_server.models.flash_vlm_causal_lm import (
|
from text_generation_server.models.flash_vlm_causal_lm import (
|
||||||
FlashVlmCausalLMBatch,
|
FlashVlmCausalLMBatch,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
|
|
||||||
|
|
||||||
VLM_BATCH_TYPES = {
|
VLM_BATCH_TYPES = {
|
||||||
PaliGemmaBatch,
|
PaliGemmaBatch,
|
||||||
VlmCausalLMBatch,
|
VlmCausalLMBatch,
|
||||||
FlashVlmCausalLMBatch,
|
FlashVlmCausalLMBatch,
|
||||||
IdeficsCausalLMBatch,
|
|
||||||
FlashMllamaCausalLMBatch,
|
FlashMllamaCausalLMBatch,
|
||||||
}
|
}
|
||||||
except (ImportError, NotImplementedError):
|
except (ImportError, NotImplementedError):
|
||||||
|
Loading…
Reference in New Issue
Block a user