Refile the hpu warmup

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
yuanwu 2024-08-02 04:36:59 +00:00
parent 05c13c89de
commit d34ffc4fe9
3 changed files with 269 additions and 68 deletions

View File

@ -26,7 +26,6 @@ from transformers.models.llava_next.modeling_llava_next import (
) )
from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration
from transformers.image_processing_utils import select_best_resolution from transformers.image_processing_utils import select_best_resolution
from loguru import logger
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
""" """

View File

@ -10,6 +10,7 @@ import numpy
from opentelemetry import trace from opentelemetry import trace
from loguru import logger from loguru import logger
from typing import Optional, Tuple, List, Type, Dict from typing import Optional, Tuple, List, Type, Dict
import itertools
import tempfile import tempfile
import copy import copy
from text_generation_server.models import Model from text_generation_server.models import Model
@ -20,9 +21,10 @@ from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import ( from text_generation_server.models.causal_lm import (
CausalLMBatch, CausalLMBatch,
CausalLMRequest, CausalLMRequest,
round_up, remove_kv_cache_from_output,
remove_kv_cache_from_output biggest_single_chunk,
) )
from transformers.models.llava_next.modeling_llava_next import ( from transformers.models.llava_next.modeling_llava_next import (
get_anyres_image_grid_shape, get_anyres_image_grid_shape,
) )
@ -43,6 +45,7 @@ from text_generation_server.utils import (
import habana_frameworks.torch as htorch import habana_frameworks.torch as htorch
from optimum.habana.utils import HabanaProfile from optimum.habana.utils import HabanaProfile
from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES
from optimum.habana.utils import get_hpu_memory_stats
from transformers import ( from transformers import (
AutoTokenizer, AutoTokenizer,
@ -70,18 +73,20 @@ tracer = trace.get_tracer(__name__)
IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)") IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
BASE_IMAGE_TOKENS = int(os.environ.get('BASE_IMAGE_TOKENS', 2048)) BASE_IMAGE_TOKENS = int(os.environ.get('BASE_IMAGE_TOKENS', 2048))
MAX_TOTAL_TOKENS = int(os.environ.get('MAX_TOTAL_TOKENS', 8192)) MAX_TOTAL_TOKENS = int(os.environ.get('MAX_TOTAL_TOKENS', 8192))
BATCH_BUCKET_SIZE = int(os.environ.get('BATCH_BUCKET_SIZE', 1)) MAX_BATCH_TOTAL_TOKENS = int(os.environ.get('MAX_BATCH_TOTAL_TOKENS', 131072))
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 128)) PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 256))
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 1))
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1)) LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1))
PREFILL_GRAPH_NUM = int(os.environ.get('PREFILL_GRAPH_NUM', 16))
os.environ['MAX_TOTAL_TOKENS'] = str(MAX_TOTAL_TOKENS) PREFILL_WARMUP_BATCH_SIZE_LIST = []
os.environ['BATCH_BUCKET_SIZE'] = str(BATCH_BUCKET_SIZE) PREFILL_WARMUP_SEQLEN_LIST = []
os.environ['PAD_SEQUENCE_TO_MULTIPLE_OF'] = str(PAD_SEQUENCE_TO_MULTIPLE_OF) DECODE_WARMUP_BATCH_SIZE_LIST = []
os.environ['PREFILL_BATCH_BUCKET_SIZE'] = str(PREFILL_BATCH_BUCKET_SIZE) def round_up(warmup_list:list, num) :
os.environ['LAZY_MODE'] = str(LAZY_MODE) i = 0
for i in warmup_list:
if num <= i :
break
return i
def split(string) -> List[Dict[str, str]]: def split(string) -> List[Dict[str, str]]:
parts = [] parts = []
@ -186,6 +191,7 @@ class VlmCausalLMBatch(CausalLMBatch):
batch_tokenized_inputs, batch_tokenized_inputs,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
is_warmup: bool = False,
) -> "VlmCausalLMBatch": ) -> "VlmCausalLMBatch":
dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}') dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}')
@ -200,7 +206,7 @@ class VlmCausalLMBatch(CausalLMBatch):
# TODO: by tokenizing all inputs at once we loose information on actual input lengths # TODO: by tokenizing all inputs at once we loose information on actual input lengths
# this means that we cannot shift inputs to the left after a long input sequence # this means that we cannot shift inputs to the left after a long input sequence
# was filtered out # was filtered out
new_bs = round_up(len(requests), PREFILL_BATCH_BUCKET_SIZE) new_bs = round_up(PREFILL_WARMUP_BATCH_SIZE_LIST, len(requests))
parameters = [r.parameters for r in pb.requests] parameters = [r.parameters for r in pb.requests]
# append the dummy parameters for dummy request # append the dummy parameters for dummy request
parameters = pad_next_token_chooser_parameters(parameters, new_bs) parameters = pad_next_token_chooser_parameters(parameters, new_bs)
@ -217,9 +223,9 @@ class VlmCausalLMBatch(CausalLMBatch):
bucket_size = max_input_length bucket_size = max_input_length
left_padding = max_input_length - input_len left_padding = max_input_length - input_len
if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0: if is_warmup is False:
assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length" if input_len < max_input_length :
rounded_seq_len = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF) rounded_seq_len = round_up(PREFILL_WARMUP_SEQLEN_LIST, input_len + 1)
if rounded_seq_len <= max_input_length: if rounded_seq_len <= max_input_length:
bucket_size = rounded_seq_len - 1 bucket_size = rounded_seq_len - 1
else: else:
@ -269,7 +275,7 @@ class VlmCausalLMBatch(CausalLMBatch):
) )
@classmethod @classmethod
def batch_tokenized_inputs(cls, requests, tokenizer, processor, config): def batch_tokenized_inputs(cls, requests, tokenizer, processor, config, is_warmup):
batch_inputs = [] batch_inputs = []
image_inputs = [] image_inputs = []
max_truncation = 0 max_truncation = 0
@ -301,17 +307,19 @@ class VlmCausalLMBatch(CausalLMBatch):
batch_inputs.append(full_text) batch_inputs.append(full_text)
max_truncation = max(max_truncation, r.truncate) max_truncation = max(max_truncation, r.truncate)
new_bs = round_up(len(requests), PREFILL_BATCH_BUCKET_SIZE) if is_warmup is False:
new_bs = round_up(PREFILL_WARMUP_BATCH_SIZE_LIST, len(requests))
missing_inputs = new_bs - len(requests) missing_inputs = new_bs - len(requests)
dummy_images = [] dummy_images = []
dummy_inputs = [] dummy_inputs = []
if len(batch_inputs) > 0 and len(image_inputs) > 0: if len(batch_inputs) > 0 and len(image_inputs) > 0:
dummy_inputs = [batch_inputs[0]] * missing_inputs dummy_inputs = [batch_inputs[0]] * missing_inputs
dummy_images = [image_inputs[0]] * missing_inputs dummy_images = [image_inputs[0]] * missing_inputs
image_inputs += dummy_images image_inputs += dummy_images
batch_inputs += dummy_inputs
batch_tokenized_inputs = tokenizer( batch_tokenized_inputs = tokenizer(
batch_inputs + dummy_inputs, batch_inputs,
truncation=True, truncation=True,
max_length=max_truncation, max_length=max_truncation,
return_tensors="pt", return_tensors="pt",
@ -347,9 +355,10 @@ class VlmCausalLMBatch(CausalLMBatch):
config, config,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
is_warmup: bool = False,
) -> "VlmCausalLMBatch": ) -> "VlmCausalLMBatch":
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs( batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
pb.requests, tokenizer, processor, config pb.requests, tokenizer, processor, config, is_warmup
) )
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
if image_inputs is not None: if image_inputs is not None:
@ -370,6 +379,114 @@ class VlmCausalLMBatch(CausalLMBatch):
batch.image_sizes = None batch.image_sizes = None
return batch return batch
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["CausalLMBatch"], pad_token_id: int = 0, is_warmup:bool = False) -> "CausalLMBatch":
return cls.recombine(batches, pad_token_id, is_warmup)
@classmethod
def recombine(cls, batches: List["VlmCausalLMBatch"], pad_token_id: int, is_warmup: bool =False) -> "VlmCausalLMBatch":
if not all(b.past_key_values is not None for b in batches):
raise ValueError("KV cache not allocated! Cannot recombine before prefill!")
total_requests = sum(len(b) for b in batches)
new_bs = total_requests
if is_warmup is False :
new_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, total_requests)
batch_id = batches[0].batch_id
device = batches[0].input_ids.device
input_lengths = [b.input_length for b in batches]
max_input_length = max(input_lengths)
offsets = [max_input_length - b.input_length for b in batches]
cur_padding = [b.right_padding for b in batches]
# For prefill there is a space allocated only for first token
# Need to add padding to the max total tokens before first decode
moves_needed = [total_requests - len(b) if b.batch_size == new_bs else total_requests for b in batches]
dst_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0]
reshape = (batches[dst_batch_idx].batch_size < new_bs)
# TODO: Add support for changing max seq len, i.e. due to output length bucketing
# FIXME: max_seq_len for non optimized code
if len(batches) > 1:
scenario = 'CONCAT'
elif reshape:
scenario = 'RESHAPE'
elif cur_padding[dst_batch_idx] <= 0:
scenario = 'SHIFT'
offsets = [biggest_single_chunk(b.max_input_length - max_input_length) for b in batches]
max_input_length = max_input_length + offsets[dst_batch_idx]
else:
# Nothing to do
return batches[0]
dbg_trace(
scenario, f'bs:{[b.batch_size for b in batches]}->{new_bs}'
f' reqs:{[len(b) for b in batches]}'
f' offsets:{offsets}'
f' input_lengths:{input_lengths}'
f' cur_padding:{cur_padding}'
f' dst_batch:{dst_batch_idx}')
grouped_requests = [[req for req in batch.requests] for batch in batches]
flat_requests = list(itertools.chain(*grouped_requests))
for i in range(len(batches)):
target_bs = new_bs if i == dst_batch_idx else batches[i].batch_size
batches[i].merge_kv_cache_if_needed(target_bs, offsets[i])
batches[i].realign(target_bs, offsets[i], pad_token_id)
batches[i].split_kv_cache_if_needed(i == dst_batch_idx)
batches[dst_batch_idx].expand_bs(new_bs)
batches[dst_batch_idx].move_data([batches[i] for i in range(len(batches)) if i != dst_batch_idx])
top_n_tokens = [r.data.top_n_tokens for r in flat_requests]
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
parameters = [r.data.parameters for r in flat_requests]
# append the dummy parameters for dummy requests
batch_size = batches[dst_batch_idx].batch_size
parameters = pad_next_token_chooser_parameters(parameters, batch_size)
# update past grammar states
fsm_grammar_states = [0] * batch_size
for batch in batches:
for i, req in enumerate(batch.requests):
fsm_grammar_states[req.idx] = batch.next_token_chooser.fsm_grammar_states[i]
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
parameters,
batches[dst_batch_idx].next_token_chooser.dtype,
batches[dst_batch_idx].next_token_chooser.device,
batches[dst_batch_idx].next_token_chooser.tokenizer,
fsm_grammar_states,
quantization_enabled=hq_env.is_quantization_enabled,
)
input_ids = batches[dst_batch_idx].input_ids
attention_mask = batches[dst_batch_idx].attention_mask
position_ids = batches[dst_batch_idx].position_ids
past_key_values = batches[dst_batch_idx].past_key_values
input_length = max_input_length
htorch.core.mark_step()
return cls(
batch_id=batch_id,
requests=flat_requests,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
merged_kv_cache=False,
next_token_chooser=next_token_chooser,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
input_length=input_length,
)
class VlmCausalLM(Model): class VlmCausalLM(Model):
def __init__( def __init__(
@ -672,7 +789,7 @@ class VlmCausalLM(Model):
@tracer.start_as_current_span("generate_token") @tracer.start_as_current_span("generate_token")
def generate_token( def generate_token(
self, batches: List[VlmCausalLMBatch] self, batches: List[VlmCausalLMBatch], is_warmup: bool = False
) -> Tuple[List[Generation], Optional[CausalLMBatch], Tuple[int, int]]: ) -> Tuple[List[Generation], Optional[CausalLMBatch], Tuple[int, int]]:
start = time.time_ns() start = time.time_ns()
# Results # Results
@ -755,7 +872,7 @@ class VlmCausalLM(Model):
# Stage 2. Prepare new batch for speculative scheduling # Stage 2. Prepare new batch for speculative scheduling
if len(batches) > 1: if len(batches) > 1:
batch = self.batch_type.concatenate(batches, self.tokenizer.pad_token_id) batch = self.batch_type.concatenate(batches, self.tokenizer.pad_token_id, is_warmup)
else: else:
batch = batches[0] batch = batches[0]
@ -763,12 +880,12 @@ class VlmCausalLM(Model):
# Check if we need to do any bookkeeping first # Check if we need to do any bookkeeping first
if not prefill: if not prefill:
batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id) batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id, is_warmup)
scenario = 'PREFILL' if prefill else 'GENERATE' scenario = 'PREFILL' if prefill else 'GENERATE'
if self.enable_hpu_graph and self.limit_hpu_graph and round_up(batch.batch_size, BATCH_BUCKET_SIZE) != self.prev_bs: if self.enable_hpu_graph and self.limit_hpu_graph and round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size) != self.prev_bs:
self.model.clear_cache() self.model.clear_cache()
self.prev_bs = round_up(batch.batch_size, BATCH_BUCKET_SIZE) self.prev_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size)
dbg_trace( dbg_trace(
scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}') scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}')
#assert batch.right_padding > 0, 'No more room for next token!' #assert batch.right_padding > 0, 'No more room for next token!'
@ -959,17 +1076,18 @@ class VlmCausalLM(Model):
decode_ns = time.time_ns() - start_decode decode_ns = time.time_ns() - start_decode
return generations, batch if not stopped else None, (forward_ns, decode_ns) return generations, batch if not stopped else None, (forward_ns, decode_ns)
def batch_from_pb(self, batch): def batch_from_pb(self, batch, is_warmup):
return VlmCausalLMBatch.from_pb_processor( return VlmCausalLMBatch.from_pb_processor(
batch, batch,
self.tokenizer, self.tokenizer,
self.processor, self.processor,
self.model.config, self.model.config,
self.dtype, self.dtype,
self.device self.device,
is_warmup
) )
def generate_warmup_batch(self, request, seq_len, batch_size): def generate_warmup_batch(self, request, seq_len, batch_size, is_warmup):
batch = copy.deepcopy(request.batches[0]) batch = copy.deepcopy(request.batches[0])
for req in batch.requests: for req in batch.requests:
req.truncate = seq_len req.truncate = seq_len
@ -977,39 +1095,122 @@ class VlmCausalLM(Model):
for i in range(len(batch.requests) - batch_size): for i in range(len(batch.requests) - batch_size):
batch.requests.pop() batch.requests.pop()
return self.batch_from_pb(batch) return self.batch_from_pb(batch, is_warmup)
def warmup(self, request) -> None: def warmup(self, request) -> None:
batches = [self.batch_from_pb(batch) for batch in request.batches] is_warmup = True
batches = [self.batch_from_pb(batch, is_warmup) for batch in request.batches]
try: try:
# prefill # max prefill batch size warmup
_, prefill_batch, _ = self.generate_token([batches[0]]) _, prefill_batch, _ = self.generate_token([batches[0]], is_warmup)
except torch.cuda.OutOfMemoryError as e: except:
raise RuntimeError( raise RuntimeError(
f"Not enough memory to handle {len(batches[0].input_ids)} prefill tokens. " f"Not enough memory to handle {len(batches[0].input_ids)} prefill tokens. "
f"You need to decrease `--max-batch-prefill-tokens`" f"You need to decrease `--max-batch-prefill-tokens`"
) from e )
global BASE_IMAGE_TOKENS, PAD_SEQUENCE_TO_MULTIPLE_OF, PREFILL_BATCH_BUCKET_SIZE, PREFILL_GRAPH_NUM self.model.clear_inputs()
global BASE_IMAGE_TOKENS, MAX_TOTAL_TOKENS, MAX_BATCH_TOTAL_TOKENS, PREFILL_WARMUP_BATCH_SIZE_LIST, PREFILL_WARMUP_SEQLEN_LIST, DECODE_WARMUP_BATCH_SIZE_LIST
max_input_length = batches[0].input_ids.shape[1] max_input_length = batches[0].input_ids.shape[1]
max_batch_size = batches[0].input_ids.shape[0] max_prefill_batch_size = batches[0].input_ids.shape[0]
seq_num = (max_input_length - BASE_IMAGE_TOKENS) / PAD_SEQUENCE_TO_MULTIPLE_OF PREFILL_WARMUP_BATCH_SIZE_LIST = []
batch_num = max_batch_size / PREFILL_BATCH_BUCKET_SIZE batch_size = 1
while batch_num > PREFILL_GRAPH_NUM : while batch_size <= max_prefill_batch_size:
PREFILL_BATCH_BUCKET_SIZE = PREFILL_BATCH_BUCKET_SIZE * 2 PREFILL_WARMUP_BATCH_SIZE_LIST.append(batch_size)
os.environ['PREFILL_BATCH_BUCKET_SIZE'] = str(PREFILL_BATCH_BUCKET_SIZE) batch_size = batch_size * 2
batch_num = max_batch_size / PREFILL_BATCH_BUCKET_SIZE if PREFILL_WARMUP_BATCH_SIZE_LIST[-1] < max_prefill_batch_size :
PREFILL_WARMUP_BATCH_SIZE_LIST.append(max_prefill_batch_size)
while seq_num * batch_num >= PREFILL_GRAPH_NUM : seq_len = BASE_IMAGE_TOKENS
PAD_SEQUENCE_TO_MULTIPLE_OF = PAD_SEQUENCE_TO_MULTIPLE_OF * 2 PREFILL_WARMUP_SEQLEN_LIST = []
os.environ['PAD_SEQUENCE_TO_MULTIPLE_OF'] = str(PAD_SEQUENCE_TO_MULTIPLE_OF) i = 0
seq_num = (max_input_length - BASE_IMAGE_TOKENS) / PAD_SEQUENCE_TO_MULTIPLE_OF while seq_len <= max_input_length:
PREFILL_WARMUP_SEQLEN_LIST.append(seq_len)
seq_len += PAD_SEQUENCE_TO_MULTIPLE_OF*(2**i)
i += 1
if PREFILL_WARMUP_SEQLEN_LIST[-1] < max_input_length:
PREFILL_WARMUP_SEQLEN_LIST.append(max_input_length)
seq_lens_list = numpy.arange(BASE_IMAGE_TOKENS + PAD_SEQUENCE_TO_MULTIPLE_OF, max_input_length + 1, PAD_SEQUENCE_TO_MULTIPLE_OF).tolist() #Prefill and decode warmup
batch_sizes_list = numpy.arange(PREFILL_BATCH_BUCKET_SIZE, max_batch_size + 1, PREFILL_BATCH_BUCKET_SIZE).tolist() DECODE_WARMUP_BATCH_SIZE_LIST = []
for seq_len in seq_lens_list : prefill_batch = None
for batch_size in batch_sizes_list : decode_batch = None
batch = self.generate_warmup_batch(request, seq_len, batch_size) try:
_, prefill_batch, _ = self.generate_token([batch]) for batch_size in PREFILL_WARMUP_BATCH_SIZE_LIST :
_, decode_batch, _ = self.generate_token([prefill_batch]) for seq_len in PREFILL_WARMUP_SEQLEN_LIST :
batch = self.generate_warmup_batch(request, seq_len, batch_size, is_warmup)
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
_, decode_batch, _ = self.generate_token([prefill_batch], is_warmup)
DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size)
except:
raise RuntimeError(
f"Not enough memory to handle following prefill and decode warmup."
f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}"
f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}"
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}"
f"You need to decrease `--max-batch-prefill-tokens`"
)
mem_stats = get_hpu_memory_stats(self.device)
logger.info(
f"\nFollowing prefill and decode warmup successfully.\n"
f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}\n"
f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}\n"
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n"
f"Memory stats: {mem_stats} "
)
self.model.clear_inputs()
max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS)
batch_size = max_prefill_batch_size * 2
# Decode warmup with bigger batch_size
try:
if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size and batch_size <= max_decode_batch_size:
batches = []
for i in range(int(batch_size/max_prefill_batch_size)) :
batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], DECODE_WARMUP_BATCH_SIZE_LIST[-1], is_warmup)
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
batches.append(prefill_batch)
while batch_size <= max_decode_batch_size:
_, decode_batch, _ = self.generate_token(batches, is_warmup)
DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size)
batch_size = batch_size * 2
batches.clear()
for i in range(int(batch_size/max_prefill_batch_size)) :
batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], DECODE_WARMUP_BATCH_SIZE_LIST[-1], is_warmup)
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
batches.append(prefill_batch)
batches.clear()
if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size:
max_decode_batch_size = math.floor( max_decode_batch_size / 2) * 2
batch_size = max_decode_batch_size
for i in range(int(max_decode_batch_size / 2)) :
batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], 2, is_warmup)
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
batches.append(prefill_batch)
_, decode_batch, _ = self.generate_token(batches, is_warmup)
DECODE_WARMUP_BATCH_SIZE_LIST.append(max_decode_batch_size)
max_batch_total_tokens = max_decode_batch_size * MAX_TOTAL_TOKENS
MAX_BATCH_TOTAL_TOKENS = max_batch_total_tokens
except :
raise RuntimeError(
f"Not enough memory to handle batch_size({batch_size}) decode warmup."
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}"
f"max_decode_batch_size is {max_decode_batch_size}"
f"You need to decrease env `MAX_BATCH_TOTAL_TOKENS` or '--max_batch_total_tokens'"
)
mem_stats = get_hpu_memory_stats(self.device)
logger.info(
f"\nFollowing decode warmup successfully.\n"
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n"
f"Memory stats: {mem_stats}"
)
self.model.clear_inputs()
return MAX_BATCH_TOTAL_TOKENS

View File

@ -97,13 +97,14 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
) )
if self.model.batch_type in VLM_BATCH_TYPES : if self.model.batch_type in VLM_BATCH_TYPES :
self.model.warmup(request) max_supported_total_tokens = self.model.warmup(request)
return generate_pb2.WarmupResponse(max_supported_total_tokens=max_supported_total_tokens)
else: else:
batches = [batch_from_pb(batch) for batch in request.batches] batches = [batch_from_pb(batch) for batch in request.batches]
self.model.warmup(batches) self.model.warmup(batches)
return generate_pb2.WarmupResponse() return generate_pb2.WarmupResponse()
async def Prefill(self, request, context): async def Prefill(self, request, context):
start = time.time_ns() start = time.time_ns()
if ( if (
@ -171,7 +172,7 @@ def serve(
uds_path: Path, uds_path: Path,
): ):
# Remove default handler # Remove default handler
logger.remove() #logger.remove()
logger.add( logger.add(
sys.stdout, sys.stdout,
format="{message}", format="{message}",