mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 08:22:07 +00:00
multi-modality warmup
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
9d85ac9485
commit
705cc0b619
@ -1487,7 +1487,6 @@ class FlashCausalLM(Model):
|
|||||||
if max_input_tokens is None:
|
if max_input_tokens is None:
|
||||||
max_input_tokens = max_total_tokens - 1
|
max_input_tokens = max_total_tokens - 1
|
||||||
|
|
||||||
del _batch, batch
|
|
||||||
self.kv_cache = []
|
self.kv_cache = []
|
||||||
empty_cache()
|
empty_cache()
|
||||||
|
|
||||||
@ -1499,6 +1498,7 @@ class FlashCausalLM(Model):
|
|||||||
self.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
self.device,
|
self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.bucketing_ctx = HPUBucketingContext(
|
self.bucketing_ctx = HPUBucketingContext(
|
||||||
os.getenv("DECODE_MAX_BS", 128), # self.max_num_seqs, #TODO
|
os.getenv("DECODE_MAX_BS", 128), # self.max_num_seqs, #TODO
|
||||||
os.getenv("PREFILL_MAX_BS", 16), # self.max_num_prefill_seqs, #TODO
|
os.getenv("PREFILL_MAX_BS", 16), # self.max_num_prefill_seqs, #TODO
|
||||||
@ -1506,6 +1506,17 @@ class FlashCausalLM(Model):
|
|||||||
num_blocks * BLOCK_SIZE,
|
num_blocks * BLOCK_SIZE,
|
||||||
)
|
)
|
||||||
self.bucketing_ctx.num_hpu_blocks = num_blocks
|
self.bucketing_ctx.num_hpu_blocks = num_blocks
|
||||||
|
if os.getenv("SKIP_WARMUP_GRAPH", "false").lower() == "true":
|
||||||
|
logger.info("skip warmup hpu graph, not recommmended")
|
||||||
|
del _batch, batch
|
||||||
|
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
|
||||||
|
|
||||||
|
self.warmup_hpu_graph(batch)
|
||||||
|
del _batch, batch
|
||||||
|
|
||||||
|
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
|
||||||
|
|
||||||
|
def warmup_hpu_graph(self, batch):
|
||||||
warmup_times = 3
|
warmup_times = 3
|
||||||
self.bucketing_ctx.generate_prompt_buckets()
|
self.bucketing_ctx.generate_prompt_buckets()
|
||||||
for i, (batch_size, seq_len) in enumerate(
|
for i, (batch_size, seq_len) in enumerate(
|
||||||
@ -1513,14 +1524,13 @@ class FlashCausalLM(Model):
|
|||||||
):
|
):
|
||||||
for index in range(warmup_times):
|
for index in range(warmup_times):
|
||||||
self.warmup_prefill(seq_len, batch_size)
|
self.warmup_prefill(seq_len, batch_size)
|
||||||
self.bucketing_ctx.generate_decode_buckets(num_blocks)
|
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
|
||||||
for i, (batch_size, block_num) in enumerate(
|
for i, (batch_size, block_num) in enumerate(
|
||||||
reversed(self.bucketing_ctx.decode_buckets)
|
reversed(self.bucketing_ctx.decode_buckets)
|
||||||
):
|
):
|
||||||
for index in range(warmup_times):
|
for index in range(warmup_times):
|
||||||
self.warmup_decode(batch_size, block_num)
|
self.warmup_decode(batch_size, block_num)
|
||||||
synchronize(self.device)
|
synchronize(self.device)
|
||||||
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
|
|
||||||
|
|
||||||
def warmup_prefill(self, prompt_len: int, bs: int):
|
def warmup_prefill(self, prompt_len: int, bs: int):
|
||||||
logger.info(f"warmup prefill seq {prompt_len} bs {bs}")
|
logger.info(f"warmup prefill seq {prompt_len} bs {bs}")
|
||||||
|
@ -11,13 +11,18 @@ from text_generation_server.pb import generate_pb2
|
|||||||
from text_generation_server.models.flash_causal_lm import (
|
from text_generation_server.models.flash_causal_lm import (
|
||||||
FlashCausalLMBatch,
|
FlashCausalLMBatch,
|
||||||
FlashCausalLM,
|
FlashCausalLM,
|
||||||
|
prepare_for_decode,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.globals import PREFIX_CACHING
|
from text_generation_server.models.globals import PREFIX_CACHING, BLOCK_SIZE
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
from transformers import AutoProcessor
|
from transformers import AutoProcessor
|
||||||
from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata
|
from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata
|
||||||
import habana_frameworks.torch as htorch
|
import habana_frameworks.torch as htorch
|
||||||
|
from text_generation_server.utils.import_utils import (
|
||||||
|
synchronize,
|
||||||
|
)
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -375,6 +380,80 @@ class FlashVlmCausalLM(FlashCausalLM):
|
|||||||
def max_past(self) -> Optional[int]:
|
def max_past(self) -> Optional[int]:
|
||||||
return getattr(self.model.text_model, "max_past", None)
|
return getattr(self.model.text_model, "max_past", None)
|
||||||
|
|
||||||
|
def warmup_decode(
|
||||||
|
self, batch_size: int, block_num: int, batch: FlashVlmCausalLMBatch
|
||||||
|
):
|
||||||
|
logger.info(f"warmup decode bs {batch_size} block_num {block_num}")
|
||||||
|
input_ids = torch.zeros(batch_size, dtype=torch.int64, device=self.device)
|
||||||
|
position_ids = torch.arange(batch_size, dtype=torch.int32, device=self.device)
|
||||||
|
if batch.position_ids is not None and batch.position_ids.dim() == 2:
|
||||||
|
# qwen2_vl and qwen2_5_vl case
|
||||||
|
position_ids = position_ids.unsqueeze(-1).repeat(
|
||||||
|
(1, batch.position_ids.shape[-1])
|
||||||
|
)
|
||||||
|
blocks = [block_num // batch_size for _ in range(batch_size)]
|
||||||
|
blocks[0] += block_num % batch_size
|
||||||
|
past_len = []
|
||||||
|
block_tables = []
|
||||||
|
slots = []
|
||||||
|
start_idx = 0
|
||||||
|
|
||||||
|
# fetch the last blocked to warmup block num
|
||||||
|
for i in range(batch_size):
|
||||||
|
block_array = list(range(start_idx, start_idx + blocks[i]))
|
||||||
|
slots.append(BLOCK_SIZE * block_array[-1] + BLOCK_SIZE - 1)
|
||||||
|
block_tables.append(block_array)
|
||||||
|
past_len.append(blocks[i] * BLOCK_SIZE - 1)
|
||||||
|
start_idx += blocks[i]
|
||||||
|
slots = torch.tensor(slots, dtype=torch.int64, device=self.device)
|
||||||
|
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
|
||||||
|
cache_lengths_tensor = torch.tensor(
|
||||||
|
past_len, dtype=torch.int32, device=self.device
|
||||||
|
)
|
||||||
|
cu_seqlen_prefill = torch.zeros(
|
||||||
|
batch_size + 1, device=self.device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
|
||||||
|
|
||||||
|
seqlen = Seqlen(
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
cache_lengths=cache_lengths_tensor,
|
||||||
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
|
)
|
||||||
|
|
||||||
|
hpu_attention_meta = prepare_for_decode(
|
||||||
|
self.dtype,
|
||||||
|
self.use_contiguous_pa,
|
||||||
|
self.device,
|
||||||
|
slots,
|
||||||
|
block_tables,
|
||||||
|
batch_size,
|
||||||
|
bucketing_ctx=None,
|
||||||
|
)
|
||||||
|
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||||
|
self.model.forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=None,
|
||||||
|
kv_cache=self.kv_cache,
|
||||||
|
slots=slots,
|
||||||
|
seqlen=trim_seqlen_metadata(seqlen),
|
||||||
|
lm_head_indices=None,
|
||||||
|
adapter_data=None,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch):
|
||||||
|
warmup_times = 3
|
||||||
|
# only warmup decode, for prefill, image pixal size may change, make the warmup useless
|
||||||
|
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
|
||||||
|
for i, (batch_size, block_num) in enumerate(
|
||||||
|
reversed(self.bucketing_ctx.decode_buckets)
|
||||||
|
):
|
||||||
|
for index in range(warmup_times):
|
||||||
|
self.warmup_decode(batch_size, block_num, batch)
|
||||||
|
synchronize(self.device)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
batch: FlashVlmCausalLMBatch,
|
batch: FlashVlmCausalLMBatch,
|
||||||
@ -450,17 +529,75 @@ class FlashVlmCausalLM(FlashCausalLM):
|
|||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if htorch.utils.internal.is_lazy():
|
if htorch.utils.internal.is_lazy():
|
||||||
kwargs["bypass_hpu_graphs"] = False
|
kwargs["bypass_hpu_graphs"] = batch.prefilling
|
||||||
|
|
||||||
seqlen = Seqlen(
|
|
||||||
input_lengths=input_lengths,
|
|
||||||
cache_lengths=cache_lengths_tensor,
|
|
||||||
cu_seqlen_q=cu_seqlen_prefill,
|
|
||||||
)
|
|
||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
slots_pad = torch.zeros_like(input_ids)
|
slots_pad = torch.zeros_like(input_ids)
|
||||||
slots_pad[batch.prefill_cache_indices] = slots
|
slots_pad[batch.prefill_cache_indices] = slots
|
||||||
slots = slots_pad
|
slots = slots_pad
|
||||||
|
if self.bucketing_ctx is not None:
|
||||||
|
if batch.prefilling:
|
||||||
|
padded_bs = self.bucketing_ctx.get_padded_prompt_batch_size(
|
||||||
|
input_lengths.shape[0]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
padded_bs = self.bucketing_ctx.get_padded_decode_batch_size(
|
||||||
|
input_lengths.shape[0]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
padded_bs = input_lengths.shape[0]
|
||||||
|
if padded_bs != input_lengths.shape[0]:
|
||||||
|
orig_bs = input_lengths.shape[0]
|
||||||
|
padded_input_lengths = F.pad(
|
||||||
|
input_lengths,
|
||||||
|
(0, padded_bs - orig_bs),
|
||||||
|
value=0,
|
||||||
|
)
|
||||||
|
padded_cache_lengths_tensor = F.pad(
|
||||||
|
cache_lengths_tensor,
|
||||||
|
(0, padded_bs - orig_bs),
|
||||||
|
value=0,
|
||||||
|
)
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
cu_seqlen_prefill = torch.zeros(
|
||||||
|
padded_bs + 1, device=self.device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
torch.cumsum(padded_input_lengths, -1, out=cu_seqlen_prefill[1:])
|
||||||
|
seqlen = Seqlen(
|
||||||
|
input_lengths=padded_input_lengths,
|
||||||
|
cache_lengths=padded_cache_lengths_tensor,
|
||||||
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
|
)
|
||||||
|
input_seq = input_ids.view(orig_bs, -1)
|
||||||
|
input_ids = F.pad(
|
||||||
|
input_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0
|
||||||
|
)
|
||||||
|
if position_ids.dim() == 2:
|
||||||
|
# qwen2_vl and qwen2_5_vl case
|
||||||
|
position_ids = F.pad(
|
||||||
|
position_ids,
|
||||||
|
(0, 0, 0, (padded_bs - orig_bs) * input_seq.shape[-1]),
|
||||||
|
value=1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
position_ids = F.pad(
|
||||||
|
position_ids,
|
||||||
|
(0, (padded_bs - orig_bs) * input_seq.shape[-1]),
|
||||||
|
value=1,
|
||||||
|
)
|
||||||
|
slots = F.pad(
|
||||||
|
slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=-1
|
||||||
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
lm_head_indices = F.pad(
|
||||||
|
lm_head_indices, (0, padded_bs - orig_bs), value=0
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
seqlen = Seqlen(
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
cache_lengths=cache_lengths_tensor,
|
||||||
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
|
)
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@ -476,8 +613,6 @@ class FlashVlmCausalLM(FlashCausalLM):
|
|||||||
image_grid_thw=batch.image_grid_thw,
|
image_grid_thw=batch.image_grid_thw,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
if batch.prefill_cache_indices is not None:
|
|
||||||
batch.prefill_cache_indices = None
|
|
||||||
if batch.pixel_values is not None:
|
if batch.pixel_values is not None:
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
if batch.pixel_attention_mask is not None:
|
if batch.pixel_attention_mask is not None:
|
||||||
|
@ -11,7 +11,9 @@ from opentelemetry import trace
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.models.flash_causal_lm import (
|
||||||
|
prepare_for_decode,
|
||||||
|
)
|
||||||
from text_generation_server.models.flash_vlm_causal_lm import (
|
from text_generation_server.models.flash_vlm_causal_lm import (
|
||||||
FlashVlmCausalLMBatch,
|
FlashVlmCausalLMBatch,
|
||||||
FlashVlmCausalLM,
|
FlashVlmCausalLM,
|
||||||
@ -19,6 +21,12 @@ from text_generation_server.models.flash_vlm_causal_lm import (
|
|||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata
|
from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata
|
||||||
import habana_frameworks.torch as htorch
|
import habana_frameworks.torch as htorch
|
||||||
|
from loguru import logger
|
||||||
|
from text_generation_server.models.globals import BLOCK_SIZE
|
||||||
|
from text_generation_server.utils.import_utils import (
|
||||||
|
synchronize,
|
||||||
|
)
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -197,6 +205,131 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
|
|||||||
|
|
||||||
|
|
||||||
class FlashMllamaCausalLM(FlashVlmCausalLM):
|
class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||||
|
def warmup_decode(
|
||||||
|
self, batch_size: int, block_num: int, batch: FlashMllamaCausalLMBatch
|
||||||
|
):
|
||||||
|
logger.info(f"warmup decode bs {batch_size} block_num {block_num}")
|
||||||
|
input_ids = torch.zeros(batch_size, dtype=torch.int64, device=self.device)
|
||||||
|
position_ids = torch.arange(batch_size, dtype=torch.int32, device=self.device)
|
||||||
|
blocks = [block_num // batch_size for _ in range(batch_size)]
|
||||||
|
blocks[0] += block_num % batch_size
|
||||||
|
past_len = []
|
||||||
|
block_tables = []
|
||||||
|
slots = []
|
||||||
|
start_idx = 0
|
||||||
|
|
||||||
|
# fetch the last blocked to warmup block num
|
||||||
|
for i in range(batch_size):
|
||||||
|
block_array = list(range(start_idx, start_idx + blocks[i]))
|
||||||
|
slots.append(BLOCK_SIZE * block_array[-1] + BLOCK_SIZE - 1)
|
||||||
|
block_tables.append(block_array)
|
||||||
|
past_len.append(blocks[i] * BLOCK_SIZE - 1)
|
||||||
|
start_idx += blocks[i]
|
||||||
|
slots = torch.tensor(slots, dtype=torch.int64, device=self.device)
|
||||||
|
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
|
||||||
|
cache_lengths_tensor = torch.tensor(
|
||||||
|
past_len, dtype=torch.int32, device=self.device
|
||||||
|
)
|
||||||
|
cu_seqlen_prefill = torch.zeros(
|
||||||
|
batch_size + 1, device=self.device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
|
||||||
|
|
||||||
|
seqlen = Seqlen(
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
cache_lengths=cache_lengths_tensor,
|
||||||
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
|
)
|
||||||
|
|
||||||
|
hpu_attention_meta = prepare_for_decode(
|
||||||
|
self.dtype,
|
||||||
|
self.use_contiguous_pa,
|
||||||
|
self.device,
|
||||||
|
slots,
|
||||||
|
block_tables,
|
||||||
|
batch_size,
|
||||||
|
bucketing_ctx=None,
|
||||||
|
)
|
||||||
|
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||||
|
self.model.forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=None,
|
||||||
|
kv_cache=self.kv_cache,
|
||||||
|
slots=slots,
|
||||||
|
seqlen=trim_seqlen_metadata(seqlen),
|
||||||
|
lm_head_indices=None,
|
||||||
|
adapter_data=None,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
|
cross_attention_states=batch.cross_attention_states,
|
||||||
|
image_indices=batch.image_indices[:],
|
||||||
|
)
|
||||||
|
|
||||||
|
def warmup_prefill(self, prompt_len: int, bs: int, batch: FlashMllamaCausalLMBatch):
|
||||||
|
logger.info(f"warmup prefill seq {prompt_len} bs {bs}")
|
||||||
|
input_ids = torch.zeros(
|
||||||
|
prompt_len, dtype=torch.int64, device=self.device
|
||||||
|
).repeat(bs)
|
||||||
|
position_ids = torch.arange(
|
||||||
|
prompt_len, dtype=torch.int32, device=self.device
|
||||||
|
).repeat(bs)
|
||||||
|
max_bt = (prompt_len // BLOCK_SIZE + 1) * bs
|
||||||
|
block_tables = torch.arange(
|
||||||
|
max_bt, dtype=torch.int32, device=self.device
|
||||||
|
).reshape(bs, -1)
|
||||||
|
slot_acc = []
|
||||||
|
for i in range(bs):
|
||||||
|
slots = []
|
||||||
|
for b in block_tables[i]:
|
||||||
|
slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))
|
||||||
|
slot_acc.extend(slots[:prompt_len])
|
||||||
|
slots = torch.tensor(slot_acc, dtype=torch.int64, device=self.device)
|
||||||
|
|
||||||
|
input_lengths = (
|
||||||
|
torch.ones(bs, dtype=torch.int32, device=self.device) * prompt_len
|
||||||
|
)
|
||||||
|
cache_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||||
|
cu_seqlen_prefill = torch.zeros(bs + 1, device=self.device, dtype=torch.int32)
|
||||||
|
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
|
||||||
|
|
||||||
|
seqlen = Seqlen(
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
cache_lengths=cache_lengths_tensor,
|
||||||
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
|
)
|
||||||
|
lm_head_indices = input_lengths - 1
|
||||||
|
|
||||||
|
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||||
|
self.model.forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
kv_cache=self.kv_cache,
|
||||||
|
slots=slots,
|
||||||
|
seqlen=trim_seqlen_metadata(seqlen),
|
||||||
|
lm_head_indices=lm_head_indices,
|
||||||
|
cross_attention_states=batch.cross_attention_states,
|
||||||
|
adapter_data=None,
|
||||||
|
hpu_attention_meta=None,
|
||||||
|
image_indices=batch.image_indices[:],
|
||||||
|
)
|
||||||
|
|
||||||
|
def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch):
|
||||||
|
warmup_times = 3
|
||||||
|
self.bucketing_ctx.generate_prompt_buckets()
|
||||||
|
for i, (batch_size, seq_len) in enumerate(
|
||||||
|
reversed(self.bucketing_ctx.prompt_buckets)
|
||||||
|
):
|
||||||
|
for index in range(warmup_times):
|
||||||
|
self.warmup_prefill(seq_len, batch_size, batch)
|
||||||
|
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
|
||||||
|
for i, (batch_size, block_num) in enumerate(
|
||||||
|
reversed(self.bucketing_ctx.decode_buckets)
|
||||||
|
):
|
||||||
|
for index in range(warmup_times):
|
||||||
|
self.warmup_decode(batch_size, block_num, batch)
|
||||||
|
synchronize(self.device)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
batch: FlashMllamaCausalLMBatch,
|
batch: FlashMllamaCausalLMBatch,
|
||||||
@ -263,12 +396,6 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
# This makes sure the max_s for the decode pass is correct.
|
# This makes sure the max_s for the decode pass is correct.
|
||||||
max_s = min(self.max_past(), max_s)
|
max_s = min(self.max_past(), max_s)
|
||||||
|
|
||||||
seqlen = Seqlen(
|
|
||||||
input_lengths=input_lengths,
|
|
||||||
cache_lengths=cache_lengths_tensor,
|
|
||||||
cu_seqlen_q=cu_seqlen_prefill,
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch.pixel_values is not None:
|
if batch.pixel_values is not None:
|
||||||
cross_attention_states = self.model.vision_forward(
|
cross_attention_states = self.model.vision_forward(
|
||||||
pixel_values=batch.pixel_values,
|
pixel_values=batch.pixel_values,
|
||||||
@ -286,6 +413,60 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
slots_pad = torch.zeros_like(input_ids)
|
slots_pad = torch.zeros_like(input_ids)
|
||||||
slots_pad[batch.prefill_cache_indices] = slots
|
slots_pad[batch.prefill_cache_indices] = slots
|
||||||
slots = slots_pad
|
slots = slots_pad
|
||||||
|
|
||||||
|
if self.bucketing_ctx is not None:
|
||||||
|
if batch.prefilling:
|
||||||
|
padded_bs = self.bucketing_ctx.get_padded_prompt_batch_size(
|
||||||
|
input_lengths.shape[0]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
padded_bs = self.bucketing_ctx.get_padded_decode_batch_size(
|
||||||
|
input_lengths.shape[0]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
padded_bs = input_lengths.shape[0]
|
||||||
|
if padded_bs != input_lengths.shape[0]:
|
||||||
|
orig_bs = input_lengths.shape[0]
|
||||||
|
padded_input_lengths = F.pad(
|
||||||
|
input_lengths,
|
||||||
|
(0, padded_bs - orig_bs),
|
||||||
|
value=0,
|
||||||
|
)
|
||||||
|
padded_cache_lengths_tensor = F.pad(
|
||||||
|
cache_lengths_tensor,
|
||||||
|
(0, padded_bs - orig_bs),
|
||||||
|
value=0,
|
||||||
|
)
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
cu_seqlen_prefill = torch.zeros(
|
||||||
|
padded_bs + 1, device=self.device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
torch.cumsum(padded_input_lengths, -1, out=cu_seqlen_prefill[1:])
|
||||||
|
seqlen = Seqlen(
|
||||||
|
input_lengths=padded_input_lengths,
|
||||||
|
cache_lengths=padded_cache_lengths_tensor,
|
||||||
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
|
)
|
||||||
|
input_seq = input_ids.view(orig_bs, -1)
|
||||||
|
input_ids = F.pad(
|
||||||
|
input_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0
|
||||||
|
)
|
||||||
|
position_ids = F.pad(
|
||||||
|
position_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=1
|
||||||
|
)
|
||||||
|
slots = F.pad(
|
||||||
|
slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=-1
|
||||||
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
lm_head_indices = F.pad(
|
||||||
|
lm_head_indices, (0, padded_bs - orig_bs), value=0
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
seqlen = Seqlen(
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
cache_lengths=cache_lengths_tensor,
|
||||||
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
|
)
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@ -301,8 +482,6 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
image_indices=batch.image_indices[:],
|
image_indices=batch.image_indices[:],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
if batch.prefill_cache_indices is not None:
|
|
||||||
batch.prefill_cache_indices = None
|
|
||||||
if batch.pixel_values is not None:
|
if batch.pixel_values is not None:
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
Loading…
Reference in New Issue
Block a user