mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
Refile the hpu warmup
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
05c13c89de
commit
d34ffc4fe9
@ -26,7 +26,6 @@ from transformers.models.llava_next.modeling_llava_next import (
|
|||||||
)
|
)
|
||||||
from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration
|
from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration
|
||||||
from transformers.image_processing_utils import select_best_resolution
|
from transformers.image_processing_utils import select_best_resolution
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||||
"""
|
"""
|
||||||
|
@ -10,6 +10,7 @@ import numpy
|
|||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typing import Optional, Tuple, List, Type, Dict
|
from typing import Optional, Tuple, List, Type, Dict
|
||||||
|
import itertools
|
||||||
import tempfile
|
import tempfile
|
||||||
import copy
|
import copy
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
@ -20,9 +21,10 @@ from text_generation_server.pb import generate_pb2
|
|||||||
from text_generation_server.models.causal_lm import (
|
from text_generation_server.models.causal_lm import (
|
||||||
CausalLMBatch,
|
CausalLMBatch,
|
||||||
CausalLMRequest,
|
CausalLMRequest,
|
||||||
round_up,
|
remove_kv_cache_from_output,
|
||||||
remove_kv_cache_from_output
|
biggest_single_chunk,
|
||||||
)
|
)
|
||||||
|
|
||||||
from transformers.models.llava_next.modeling_llava_next import (
|
from transformers.models.llava_next.modeling_llava_next import (
|
||||||
get_anyres_image_grid_shape,
|
get_anyres_image_grid_shape,
|
||||||
)
|
)
|
||||||
@ -43,6 +45,7 @@ from text_generation_server.utils import (
|
|||||||
import habana_frameworks.torch as htorch
|
import habana_frameworks.torch as htorch
|
||||||
from optimum.habana.utils import HabanaProfile
|
from optimum.habana.utils import HabanaProfile
|
||||||
from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES
|
from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES
|
||||||
|
from optimum.habana.utils import get_hpu_memory_stats
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
@ -70,18 +73,20 @@ tracer = trace.get_tracer(__name__)
|
|||||||
IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
|
IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
|
||||||
BASE_IMAGE_TOKENS = int(os.environ.get('BASE_IMAGE_TOKENS', 2048))
|
BASE_IMAGE_TOKENS = int(os.environ.get('BASE_IMAGE_TOKENS', 2048))
|
||||||
MAX_TOTAL_TOKENS = int(os.environ.get('MAX_TOTAL_TOKENS', 8192))
|
MAX_TOTAL_TOKENS = int(os.environ.get('MAX_TOTAL_TOKENS', 8192))
|
||||||
BATCH_BUCKET_SIZE = int(os.environ.get('BATCH_BUCKET_SIZE', 1))
|
MAX_BATCH_TOTAL_TOKENS = int(os.environ.get('MAX_BATCH_TOTAL_TOKENS', 131072))
|
||||||
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 128))
|
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 256))
|
||||||
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 1))
|
|
||||||
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
||||||
LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1))
|
LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1))
|
||||||
PREFILL_GRAPH_NUM = int(os.environ.get('PREFILL_GRAPH_NUM', 16))
|
|
||||||
|
|
||||||
os.environ['MAX_TOTAL_TOKENS'] = str(MAX_TOTAL_TOKENS)
|
PREFILL_WARMUP_BATCH_SIZE_LIST = []
|
||||||
os.environ['BATCH_BUCKET_SIZE'] = str(BATCH_BUCKET_SIZE)
|
PREFILL_WARMUP_SEQLEN_LIST = []
|
||||||
os.environ['PAD_SEQUENCE_TO_MULTIPLE_OF'] = str(PAD_SEQUENCE_TO_MULTIPLE_OF)
|
DECODE_WARMUP_BATCH_SIZE_LIST = []
|
||||||
os.environ['PREFILL_BATCH_BUCKET_SIZE'] = str(PREFILL_BATCH_BUCKET_SIZE)
|
def round_up(warmup_list:list, num) :
|
||||||
os.environ['LAZY_MODE'] = str(LAZY_MODE)
|
i = 0
|
||||||
|
for i in warmup_list:
|
||||||
|
if num <= i :
|
||||||
|
break
|
||||||
|
return i
|
||||||
|
|
||||||
def split(string) -> List[Dict[str, str]]:
|
def split(string) -> List[Dict[str, str]]:
|
||||||
parts = []
|
parts = []
|
||||||
@ -186,6 +191,7 @@ class VlmCausalLMBatch(CausalLMBatch):
|
|||||||
batch_tokenized_inputs,
|
batch_tokenized_inputs,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
|
is_warmup: bool = False,
|
||||||
) -> "VlmCausalLMBatch":
|
) -> "VlmCausalLMBatch":
|
||||||
|
|
||||||
dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}')
|
dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}')
|
||||||
@ -200,7 +206,7 @@ class VlmCausalLMBatch(CausalLMBatch):
|
|||||||
# TODO: by tokenizing all inputs at once we loose information on actual input lengths
|
# TODO: by tokenizing all inputs at once we loose information on actual input lengths
|
||||||
# this means that we cannot shift inputs to the left after a long input sequence
|
# this means that we cannot shift inputs to the left after a long input sequence
|
||||||
# was filtered out
|
# was filtered out
|
||||||
new_bs = round_up(len(requests), PREFILL_BATCH_BUCKET_SIZE)
|
new_bs = round_up(PREFILL_WARMUP_BATCH_SIZE_LIST, len(requests))
|
||||||
parameters = [r.parameters for r in pb.requests]
|
parameters = [r.parameters for r in pb.requests]
|
||||||
# append the dummy parameters for dummy request
|
# append the dummy parameters for dummy request
|
||||||
parameters = pad_next_token_chooser_parameters(parameters, new_bs)
|
parameters = pad_next_token_chooser_parameters(parameters, new_bs)
|
||||||
@ -217,9 +223,9 @@ class VlmCausalLMBatch(CausalLMBatch):
|
|||||||
|
|
||||||
bucket_size = max_input_length
|
bucket_size = max_input_length
|
||||||
left_padding = max_input_length - input_len
|
left_padding = max_input_length - input_len
|
||||||
if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0:
|
if is_warmup is False:
|
||||||
assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
|
if input_len < max_input_length :
|
||||||
rounded_seq_len = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF)
|
rounded_seq_len = round_up(PREFILL_WARMUP_SEQLEN_LIST, input_len + 1)
|
||||||
if rounded_seq_len <= max_input_length:
|
if rounded_seq_len <= max_input_length:
|
||||||
bucket_size = rounded_seq_len - 1
|
bucket_size = rounded_seq_len - 1
|
||||||
else:
|
else:
|
||||||
@ -269,7 +275,7 @@ class VlmCausalLMBatch(CausalLMBatch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def batch_tokenized_inputs(cls, requests, tokenizer, processor, config):
|
def batch_tokenized_inputs(cls, requests, tokenizer, processor, config, is_warmup):
|
||||||
batch_inputs = []
|
batch_inputs = []
|
||||||
image_inputs = []
|
image_inputs = []
|
||||||
max_truncation = 0
|
max_truncation = 0
|
||||||
@ -301,17 +307,19 @@ class VlmCausalLMBatch(CausalLMBatch):
|
|||||||
batch_inputs.append(full_text)
|
batch_inputs.append(full_text)
|
||||||
max_truncation = max(max_truncation, r.truncate)
|
max_truncation = max(max_truncation, r.truncate)
|
||||||
|
|
||||||
new_bs = round_up(len(requests), PREFILL_BATCH_BUCKET_SIZE)
|
if is_warmup is False:
|
||||||
|
new_bs = round_up(PREFILL_WARMUP_BATCH_SIZE_LIST, len(requests))
|
||||||
missing_inputs = new_bs - len(requests)
|
missing_inputs = new_bs - len(requests)
|
||||||
dummy_images = []
|
dummy_images = []
|
||||||
dummy_inputs = []
|
dummy_inputs = []
|
||||||
if len(batch_inputs) > 0 and len(image_inputs) > 0:
|
if len(batch_inputs) > 0 and len(image_inputs) > 0:
|
||||||
dummy_inputs = [batch_inputs[0]] * missing_inputs
|
dummy_inputs = [batch_inputs[0]] * missing_inputs
|
||||||
dummy_images = [image_inputs[0]] * missing_inputs
|
dummy_images = [image_inputs[0]] * missing_inputs
|
||||||
|
|
||||||
image_inputs += dummy_images
|
image_inputs += dummy_images
|
||||||
|
batch_inputs += dummy_inputs
|
||||||
|
|
||||||
batch_tokenized_inputs = tokenizer(
|
batch_tokenized_inputs = tokenizer(
|
||||||
batch_inputs + dummy_inputs,
|
batch_inputs,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=max_truncation,
|
max_length=max_truncation,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
@ -347,9 +355,10 @@ class VlmCausalLMBatch(CausalLMBatch):
|
|||||||
config,
|
config,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
|
is_warmup: bool = False,
|
||||||
) -> "VlmCausalLMBatch":
|
) -> "VlmCausalLMBatch":
|
||||||
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
|
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
|
||||||
pb.requests, tokenizer, processor, config
|
pb.requests, tokenizer, processor, config, is_warmup
|
||||||
)
|
)
|
||||||
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||||||
if image_inputs is not None:
|
if image_inputs is not None:
|
||||||
@ -370,6 +379,114 @@ class VlmCausalLMBatch(CausalLMBatch):
|
|||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@tracer.start_as_current_span("concatenate")
|
||||||
|
def concatenate(cls, batches: List["CausalLMBatch"], pad_token_id: int = 0, is_warmup:bool = False) -> "CausalLMBatch":
|
||||||
|
return cls.recombine(batches, pad_token_id, is_warmup)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def recombine(cls, batches: List["VlmCausalLMBatch"], pad_token_id: int, is_warmup: bool =False) -> "VlmCausalLMBatch":
|
||||||
|
if not all(b.past_key_values is not None for b in batches):
|
||||||
|
raise ValueError("KV cache not allocated! Cannot recombine before prefill!")
|
||||||
|
|
||||||
|
total_requests = sum(len(b) for b in batches)
|
||||||
|
new_bs = total_requests
|
||||||
|
if is_warmup is False :
|
||||||
|
new_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, total_requests)
|
||||||
|
batch_id = batches[0].batch_id
|
||||||
|
device = batches[0].input_ids.device
|
||||||
|
|
||||||
|
input_lengths = [b.input_length for b in batches]
|
||||||
|
max_input_length = max(input_lengths)
|
||||||
|
offsets = [max_input_length - b.input_length for b in batches]
|
||||||
|
|
||||||
|
cur_padding = [b.right_padding for b in batches]
|
||||||
|
# For prefill there is a space allocated only for first token
|
||||||
|
# Need to add padding to the max total tokens before first decode
|
||||||
|
|
||||||
|
moves_needed = [total_requests - len(b) if b.batch_size == new_bs else total_requests for b in batches]
|
||||||
|
dst_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0]
|
||||||
|
reshape = (batches[dst_batch_idx].batch_size < new_bs)
|
||||||
|
|
||||||
|
# TODO: Add support for changing max seq len, i.e. due to output length bucketing
|
||||||
|
# FIXME: max_seq_len for non optimized code
|
||||||
|
if len(batches) > 1:
|
||||||
|
scenario = 'CONCAT'
|
||||||
|
elif reshape:
|
||||||
|
scenario = 'RESHAPE'
|
||||||
|
elif cur_padding[dst_batch_idx] <= 0:
|
||||||
|
scenario = 'SHIFT'
|
||||||
|
offsets = [biggest_single_chunk(b.max_input_length - max_input_length) for b in batches]
|
||||||
|
max_input_length = max_input_length + offsets[dst_batch_idx]
|
||||||
|
else:
|
||||||
|
# Nothing to do
|
||||||
|
return batches[0]
|
||||||
|
|
||||||
|
dbg_trace(
|
||||||
|
scenario, f'bs:{[b.batch_size for b in batches]}->{new_bs}'
|
||||||
|
f' reqs:{[len(b) for b in batches]}'
|
||||||
|
f' offsets:{offsets}'
|
||||||
|
f' input_lengths:{input_lengths}'
|
||||||
|
f' cur_padding:{cur_padding}'
|
||||||
|
f' dst_batch:{dst_batch_idx}')
|
||||||
|
|
||||||
|
grouped_requests = [[req for req in batch.requests] for batch in batches]
|
||||||
|
flat_requests = list(itertools.chain(*grouped_requests))
|
||||||
|
|
||||||
|
for i in range(len(batches)):
|
||||||
|
target_bs = new_bs if i == dst_batch_idx else batches[i].batch_size
|
||||||
|
batches[i].merge_kv_cache_if_needed(target_bs, offsets[i])
|
||||||
|
batches[i].realign(target_bs, offsets[i], pad_token_id)
|
||||||
|
batches[i].split_kv_cache_if_needed(i == dst_batch_idx)
|
||||||
|
batches[dst_batch_idx].expand_bs(new_bs)
|
||||||
|
batches[dst_batch_idx].move_data([batches[i] for i in range(len(batches)) if i != dst_batch_idx])
|
||||||
|
|
||||||
|
top_n_tokens = [r.data.top_n_tokens for r in flat_requests]
|
||||||
|
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
|
||||||
|
|
||||||
|
parameters = [r.data.parameters for r in flat_requests]
|
||||||
|
# append the dummy parameters for dummy requests
|
||||||
|
batch_size = batches[dst_batch_idx].batch_size
|
||||||
|
parameters = pad_next_token_chooser_parameters(parameters, batch_size)
|
||||||
|
|
||||||
|
# update past grammar states
|
||||||
|
fsm_grammar_states = [0] * batch_size
|
||||||
|
for batch in batches:
|
||||||
|
for i, req in enumerate(batch.requests):
|
||||||
|
fsm_grammar_states[req.idx] = batch.next_token_chooser.fsm_grammar_states[i]
|
||||||
|
|
||||||
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
|
parameters,
|
||||||
|
batches[dst_batch_idx].next_token_chooser.dtype,
|
||||||
|
batches[dst_batch_idx].next_token_chooser.device,
|
||||||
|
batches[dst_batch_idx].next_token_chooser.tokenizer,
|
||||||
|
fsm_grammar_states,
|
||||||
|
quantization_enabled=hq_env.is_quantization_enabled,
|
||||||
|
)
|
||||||
|
|
||||||
|
input_ids = batches[dst_batch_idx].input_ids
|
||||||
|
attention_mask = batches[dst_batch_idx].attention_mask
|
||||||
|
position_ids = batches[dst_batch_idx].position_ids
|
||||||
|
past_key_values = batches[dst_batch_idx].past_key_values
|
||||||
|
input_length = max_input_length
|
||||||
|
|
||||||
|
htorch.core.mark_step()
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
batch_id=batch_id,
|
||||||
|
requests=flat_requests,
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
merged_kv_cache=False,
|
||||||
|
next_token_chooser=next_token_chooser,
|
||||||
|
top_n_tokens=top_n_tokens,
|
||||||
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
|
input_length=input_length,
|
||||||
|
)
|
||||||
|
|
||||||
class VlmCausalLM(Model):
|
class VlmCausalLM(Model):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -672,7 +789,7 @@ class VlmCausalLM(Model):
|
|||||||
|
|
||||||
@tracer.start_as_current_span("generate_token")
|
@tracer.start_as_current_span("generate_token")
|
||||||
def generate_token(
|
def generate_token(
|
||||||
self, batches: List[VlmCausalLMBatch]
|
self, batches: List[VlmCausalLMBatch], is_warmup: bool = False
|
||||||
) -> Tuple[List[Generation], Optional[CausalLMBatch], Tuple[int, int]]:
|
) -> Tuple[List[Generation], Optional[CausalLMBatch], Tuple[int, int]]:
|
||||||
start = time.time_ns()
|
start = time.time_ns()
|
||||||
# Results
|
# Results
|
||||||
@ -755,7 +872,7 @@ class VlmCausalLM(Model):
|
|||||||
|
|
||||||
# Stage 2. Prepare new batch for speculative scheduling
|
# Stage 2. Prepare new batch for speculative scheduling
|
||||||
if len(batches) > 1:
|
if len(batches) > 1:
|
||||||
batch = self.batch_type.concatenate(batches, self.tokenizer.pad_token_id)
|
batch = self.batch_type.concatenate(batches, self.tokenizer.pad_token_id, is_warmup)
|
||||||
else:
|
else:
|
||||||
batch = batches[0]
|
batch = batches[0]
|
||||||
|
|
||||||
@ -763,12 +880,12 @@ class VlmCausalLM(Model):
|
|||||||
|
|
||||||
# Check if we need to do any bookkeeping first
|
# Check if we need to do any bookkeeping first
|
||||||
if not prefill:
|
if not prefill:
|
||||||
batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id)
|
batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id, is_warmup)
|
||||||
|
|
||||||
scenario = 'PREFILL' if prefill else 'GENERATE'
|
scenario = 'PREFILL' if prefill else 'GENERATE'
|
||||||
if self.enable_hpu_graph and self.limit_hpu_graph and round_up(batch.batch_size, BATCH_BUCKET_SIZE) != self.prev_bs:
|
if self.enable_hpu_graph and self.limit_hpu_graph and round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size) != self.prev_bs:
|
||||||
self.model.clear_cache()
|
self.model.clear_cache()
|
||||||
self.prev_bs = round_up(batch.batch_size, BATCH_BUCKET_SIZE)
|
self.prev_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size)
|
||||||
dbg_trace(
|
dbg_trace(
|
||||||
scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}')
|
scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}')
|
||||||
#assert batch.right_padding > 0, 'No more room for next token!'
|
#assert batch.right_padding > 0, 'No more room for next token!'
|
||||||
@ -959,17 +1076,18 @@ class VlmCausalLM(Model):
|
|||||||
decode_ns = time.time_ns() - start_decode
|
decode_ns = time.time_ns() - start_decode
|
||||||
return generations, batch if not stopped else None, (forward_ns, decode_ns)
|
return generations, batch if not stopped else None, (forward_ns, decode_ns)
|
||||||
|
|
||||||
def batch_from_pb(self, batch):
|
def batch_from_pb(self, batch, is_warmup):
|
||||||
return VlmCausalLMBatch.from_pb_processor(
|
return VlmCausalLMBatch.from_pb_processor(
|
||||||
batch,
|
batch,
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
self.processor,
|
self.processor,
|
||||||
self.model.config,
|
self.model.config,
|
||||||
self.dtype,
|
self.dtype,
|
||||||
self.device
|
self.device,
|
||||||
|
is_warmup
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_warmup_batch(self, request, seq_len, batch_size):
|
def generate_warmup_batch(self, request, seq_len, batch_size, is_warmup):
|
||||||
batch = copy.deepcopy(request.batches[0])
|
batch = copy.deepcopy(request.batches[0])
|
||||||
for req in batch.requests:
|
for req in batch.requests:
|
||||||
req.truncate = seq_len
|
req.truncate = seq_len
|
||||||
@ -977,39 +1095,122 @@ class VlmCausalLM(Model):
|
|||||||
for i in range(len(batch.requests) - batch_size):
|
for i in range(len(batch.requests) - batch_size):
|
||||||
batch.requests.pop()
|
batch.requests.pop()
|
||||||
|
|
||||||
return self.batch_from_pb(batch)
|
return self.batch_from_pb(batch, is_warmup)
|
||||||
|
|
||||||
def warmup(self, request) -> None:
|
def warmup(self, request) -> None:
|
||||||
batches = [self.batch_from_pb(batch) for batch in request.batches]
|
is_warmup = True
|
||||||
|
batches = [self.batch_from_pb(batch, is_warmup) for batch in request.batches]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# prefill
|
# max prefill batch size warmup
|
||||||
_, prefill_batch, _ = self.generate_token([batches[0]])
|
_, prefill_batch, _ = self.generate_token([batches[0]], is_warmup)
|
||||||
except torch.cuda.OutOfMemoryError as e:
|
except:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Not enough memory to handle {len(batches[0].input_ids)} prefill tokens. "
|
f"Not enough memory to handle {len(batches[0].input_ids)} prefill tokens. "
|
||||||
f"You need to decrease `--max-batch-prefill-tokens`"
|
f"You need to decrease `--max-batch-prefill-tokens`"
|
||||||
) from e
|
)
|
||||||
|
|
||||||
global BASE_IMAGE_TOKENS, PAD_SEQUENCE_TO_MULTIPLE_OF, PREFILL_BATCH_BUCKET_SIZE, PREFILL_GRAPH_NUM
|
self.model.clear_inputs()
|
||||||
|
global BASE_IMAGE_TOKENS, MAX_TOTAL_TOKENS, MAX_BATCH_TOTAL_TOKENS, PREFILL_WARMUP_BATCH_SIZE_LIST, PREFILL_WARMUP_SEQLEN_LIST, DECODE_WARMUP_BATCH_SIZE_LIST
|
||||||
max_input_length = batches[0].input_ids.shape[1]
|
max_input_length = batches[0].input_ids.shape[1]
|
||||||
max_batch_size = batches[0].input_ids.shape[0]
|
max_prefill_batch_size = batches[0].input_ids.shape[0]
|
||||||
seq_num = (max_input_length - BASE_IMAGE_TOKENS) / PAD_SEQUENCE_TO_MULTIPLE_OF
|
PREFILL_WARMUP_BATCH_SIZE_LIST = []
|
||||||
batch_num = max_batch_size / PREFILL_BATCH_BUCKET_SIZE
|
batch_size = 1
|
||||||
while batch_num > PREFILL_GRAPH_NUM :
|
while batch_size <= max_prefill_batch_size:
|
||||||
PREFILL_BATCH_BUCKET_SIZE = PREFILL_BATCH_BUCKET_SIZE * 2
|
PREFILL_WARMUP_BATCH_SIZE_LIST.append(batch_size)
|
||||||
os.environ['PREFILL_BATCH_BUCKET_SIZE'] = str(PREFILL_BATCH_BUCKET_SIZE)
|
batch_size = batch_size * 2
|
||||||
batch_num = max_batch_size / PREFILL_BATCH_BUCKET_SIZE
|
if PREFILL_WARMUP_BATCH_SIZE_LIST[-1] < max_prefill_batch_size :
|
||||||
|
PREFILL_WARMUP_BATCH_SIZE_LIST.append(max_prefill_batch_size)
|
||||||
|
|
||||||
while seq_num * batch_num >= PREFILL_GRAPH_NUM :
|
seq_len = BASE_IMAGE_TOKENS
|
||||||
PAD_SEQUENCE_TO_MULTIPLE_OF = PAD_SEQUENCE_TO_MULTIPLE_OF * 2
|
PREFILL_WARMUP_SEQLEN_LIST = []
|
||||||
os.environ['PAD_SEQUENCE_TO_MULTIPLE_OF'] = str(PAD_SEQUENCE_TO_MULTIPLE_OF)
|
i = 0
|
||||||
seq_num = (max_input_length - BASE_IMAGE_TOKENS) / PAD_SEQUENCE_TO_MULTIPLE_OF
|
while seq_len <= max_input_length:
|
||||||
|
PREFILL_WARMUP_SEQLEN_LIST.append(seq_len)
|
||||||
|
seq_len += PAD_SEQUENCE_TO_MULTIPLE_OF*(2**i)
|
||||||
|
i += 1
|
||||||
|
if PREFILL_WARMUP_SEQLEN_LIST[-1] < max_input_length:
|
||||||
|
PREFILL_WARMUP_SEQLEN_LIST.append(max_input_length)
|
||||||
|
|
||||||
seq_lens_list = numpy.arange(BASE_IMAGE_TOKENS + PAD_SEQUENCE_TO_MULTIPLE_OF, max_input_length + 1, PAD_SEQUENCE_TO_MULTIPLE_OF).tolist()
|
#Prefill and decode warmup
|
||||||
batch_sizes_list = numpy.arange(PREFILL_BATCH_BUCKET_SIZE, max_batch_size + 1, PREFILL_BATCH_BUCKET_SIZE).tolist()
|
DECODE_WARMUP_BATCH_SIZE_LIST = []
|
||||||
for seq_len in seq_lens_list :
|
prefill_batch = None
|
||||||
for batch_size in batch_sizes_list :
|
decode_batch = None
|
||||||
batch = self.generate_warmup_batch(request, seq_len, batch_size)
|
try:
|
||||||
_, prefill_batch, _ = self.generate_token([batch])
|
for batch_size in PREFILL_WARMUP_BATCH_SIZE_LIST :
|
||||||
_, decode_batch, _ = self.generate_token([prefill_batch])
|
for seq_len in PREFILL_WARMUP_SEQLEN_LIST :
|
||||||
|
batch = self.generate_warmup_batch(request, seq_len, batch_size, is_warmup)
|
||||||
|
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
|
||||||
|
_, decode_batch, _ = self.generate_token([prefill_batch], is_warmup)
|
||||||
|
|
||||||
|
DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size)
|
||||||
|
|
||||||
|
except:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Not enough memory to handle following prefill and decode warmup."
|
||||||
|
f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}"
|
||||||
|
f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}"
|
||||||
|
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}"
|
||||||
|
f"You need to decrease `--max-batch-prefill-tokens`"
|
||||||
|
)
|
||||||
|
|
||||||
|
mem_stats = get_hpu_memory_stats(self.device)
|
||||||
|
logger.info(
|
||||||
|
f"\nFollowing prefill and decode warmup successfully.\n"
|
||||||
|
f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}\n"
|
||||||
|
f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}\n"
|
||||||
|
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n"
|
||||||
|
f"Memory stats: {mem_stats} "
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model.clear_inputs()
|
||||||
|
max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS)
|
||||||
|
batch_size = max_prefill_batch_size * 2
|
||||||
|
# Decode warmup with bigger batch_size
|
||||||
|
try:
|
||||||
|
if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size and batch_size <= max_decode_batch_size:
|
||||||
|
batches = []
|
||||||
|
for i in range(int(batch_size/max_prefill_batch_size)) :
|
||||||
|
batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], DECODE_WARMUP_BATCH_SIZE_LIST[-1], is_warmup)
|
||||||
|
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
|
||||||
|
batches.append(prefill_batch)
|
||||||
|
while batch_size <= max_decode_batch_size:
|
||||||
|
_, decode_batch, _ = self.generate_token(batches, is_warmup)
|
||||||
|
DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size)
|
||||||
|
batch_size = batch_size * 2
|
||||||
|
batches.clear()
|
||||||
|
|
||||||
|
for i in range(int(batch_size/max_prefill_batch_size)) :
|
||||||
|
batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], DECODE_WARMUP_BATCH_SIZE_LIST[-1], is_warmup)
|
||||||
|
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
|
||||||
|
batches.append(prefill_batch)
|
||||||
|
|
||||||
|
batches.clear()
|
||||||
|
if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size:
|
||||||
|
max_decode_batch_size = math.floor( max_decode_batch_size / 2) * 2
|
||||||
|
batch_size = max_decode_batch_size
|
||||||
|
for i in range(int(max_decode_batch_size / 2)) :
|
||||||
|
batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], 2, is_warmup)
|
||||||
|
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
|
||||||
|
batches.append(prefill_batch)
|
||||||
|
_, decode_batch, _ = self.generate_token(batches, is_warmup)
|
||||||
|
DECODE_WARMUP_BATCH_SIZE_LIST.append(max_decode_batch_size)
|
||||||
|
max_batch_total_tokens = max_decode_batch_size * MAX_TOTAL_TOKENS
|
||||||
|
MAX_BATCH_TOTAL_TOKENS = max_batch_total_tokens
|
||||||
|
except :
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Not enough memory to handle batch_size({batch_size}) decode warmup."
|
||||||
|
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}"
|
||||||
|
f"max_decode_batch_size is {max_decode_batch_size}"
|
||||||
|
f"You need to decrease env `MAX_BATCH_TOTAL_TOKENS` or '--max_batch_total_tokens'"
|
||||||
|
)
|
||||||
|
|
||||||
|
mem_stats = get_hpu_memory_stats(self.device)
|
||||||
|
logger.info(
|
||||||
|
f"\nFollowing decode warmup successfully.\n"
|
||||||
|
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n"
|
||||||
|
f"Memory stats: {mem_stats}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model.clear_inputs()
|
||||||
|
return MAX_BATCH_TOTAL_TOKENS
|
@ -97,13 +97,14 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.model.batch_type in VLM_BATCH_TYPES :
|
if self.model.batch_type in VLM_BATCH_TYPES :
|
||||||
self.model.warmup(request)
|
max_supported_total_tokens = self.model.warmup(request)
|
||||||
|
return generate_pb2.WarmupResponse(max_supported_total_tokens=max_supported_total_tokens)
|
||||||
else:
|
else:
|
||||||
batches = [batch_from_pb(batch) for batch in request.batches]
|
batches = [batch_from_pb(batch) for batch in request.batches]
|
||||||
self.model.warmup(batches)
|
self.model.warmup(batches)
|
||||||
|
|
||||||
return generate_pb2.WarmupResponse()
|
return generate_pb2.WarmupResponse()
|
||||||
|
|
||||||
|
|
||||||
async def Prefill(self, request, context):
|
async def Prefill(self, request, context):
|
||||||
start = time.time_ns()
|
start = time.time_ns()
|
||||||
if (
|
if (
|
||||||
@ -171,7 +172,7 @@ def serve(
|
|||||||
uds_path: Path,
|
uds_path: Path,
|
||||||
):
|
):
|
||||||
# Remove default handler
|
# Remove default handler
|
||||||
logger.remove()
|
#logger.remove()
|
||||||
logger.add(
|
logger.add(
|
||||||
sys.stdout,
|
sys.stdout,
|
||||||
format="{message}",
|
format="{message}",
|
||||||
|
Loading…
Reference in New Issue
Block a user