Complete padding of CausalLMBatch when there exists batch bucketing (#261)

Signed-off-by: kaixuanliu <kaixuan.liu@intel.com>
This commit is contained in:
kaixuanliu 2025-01-30 17:19:13 +08:00 committed by GitHub
parent fe7594e369
commit b52164d38a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -53,21 +53,26 @@ from text_generation_server.utils.debug import dbg_trace
from text_generation_server.utils.speculate import get_speculate
tracer = trace.get_tracer(__name__)
MAX_TOTAL_TOKENS = int(os.environ.get('MAX_TOTAL_TOKENS', 2048))
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 256))
MAX_TOTAL_TOKENS = int(os.environ.get("MAX_TOTAL_TOKENS", 2048))
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 256))
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1))
BATCH_BUCKET_SIZE= int(os.environ.get('BATCH_BUCKET_SIZE', 8))
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 2))
LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1))
BATCH_BUCKET_SIZE = int(os.environ.get("BATCH_BUCKET_SIZE", 8))
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get("PREFILL_BATCH_BUCKET_SIZE", 2))
def torch_compile_for_eager(func):
if LAZY_MODE == 1:
return func
return torch.compile(func, backend="hpu_backend", options={"keep_input_mutations": True})
return torch.compile(
func, backend="hpu_backend", options={"keep_input_mutations": True}
)
def round_up(number, k):
return (number + k - 1) // k * k
def to_tensor_indices(indices, device):
return torch.tensor(indices, dtype=torch.long, device=device)
@ -96,9 +101,11 @@ def grouped_pad(tensor_groups, dims, values):
for tensors, dim, value in zip(tensor_groups, dims, values):
padding = MAX_TOTAL_TOKENS - tensors[0].size(dim) if dim is not None else 0
if padding > 0:
assert dim in [-1, -2], f'Only dims -1 and -2 are supported! {dim}'
assert dim in [-1, -2], f"Only dims -1 and -2 are supported! {dim}"
pad_shape = (0, 0, 0, padding) if dim == -2 else (0, padding)
result = [torch.nn.functional.pad(t, pad_shape, value=value) for t in tensors]
result = [
torch.nn.functional.pad(t, pad_shape, value=value) for t in tensors
]
else:
result = [t for t in tensors]
grouped_result.append(result)
@ -117,7 +124,10 @@ def roll(tensor, chunk, dim, merge_graphs):
def grouped_roll(tensor_groups, chunk, dims, merge_graphs):
tensor_groups = [[roll(t, chunk, dim, merge_graphs) for t in tensors] for tensors, dim in zip(tensor_groups, dims)]
tensor_groups = [
[roll(t, chunk, dim, merge_graphs) for t in tensors]
for tensors, dim in zip(tensor_groups, dims)
]
if merge_graphs:
htorch.core.mark_step()
return tensor_groups
@ -167,7 +177,10 @@ def extend_batch(tensors, target_bs, dim):
def grouped_extend_batch(tensor_groups, target_bs, bs_dims):
tensor_groups = [extend_batch(tensors, target_bs, dim) for tensors, dim in zip(tensor_groups, bs_dims)]
tensor_groups = [
extend_batch(tensors, target_bs, dim)
for tensors, dim in zip(tensor_groups, bs_dims)
]
return tensor_groups
@ -220,15 +233,20 @@ class CausalLMRequest:
all_input_ids: torch.Tensor
@classmethod
def from_pb(cls, idx: int, data: generate_pb2.Request, tokenizer: PreTrainedTokenizerBase):
def from_pb(
cls, idx: int, data: generate_pb2.Request, tokenizer: PreTrainedTokenizerBase
):
return cls(
idx=idx,
data=data,
input_length=None,
prefix_offset=None,
read_offset=None,
stopping_criteria=StoppingCriteria.from_pb(data.stopping_parameters, tokenizer),
all_input_ids=None,)
stopping_criteria=StoppingCriteria.from_pb(
data.stopping_parameters, tokenizer
),
all_input_ids=None,
)
def update_idx(self, new_idx):
prev = self.idx
@ -289,7 +307,11 @@ class CausalLMBatch(Batch):
# Very simple heuristic to determine whether we should merge tensors
# this needs tuning for other models/scenarios
small_bs = len(self.past_key_values) > self.batch_size
if not self.merged_kv_cache and small_bs and (pad_needed or shift_needed or expand_needed):
if (
not self.merged_kv_cache
and small_bs
and (pad_needed or shift_needed or expand_needed)
):
past_keys, past_values = self.detach_kv_cache()
past_keys = merge(past_keys)
past_values = merge(past_values)
@ -309,7 +331,13 @@ class CausalLMBatch(Batch):
seq_dim = -1
key_dim = -2 if self.keys_head_dim_last else -1
value_dim = -2
tensors = [[self.input_ids], [self.attention_mask], [self.position_ids], past_keys, past_values]
tensors = [
[self.input_ids],
[self.attention_mask],
[self.position_ids],
past_keys,
past_values,
]
# We don't need to align position_ids
seq_dims = [seq_dim, seq_dim, None, key_dim, value_dim]
bs_dims = [0, 0, 0] + ([1, 1] if self.merged_kv_cache else [0, 0])
@ -350,13 +378,17 @@ class CausalLMBatch(Batch):
dst_tensors, _, dst_dims = self.get_tensor_groups()
free_indices_gen = self.free_indices_generator()
for src_b in src_batches:
dst_indices = to_tensor_indices(src_b.update_indices(free_indices_gen), self.input_ids.device)
dst_indices = to_tensor_indices(
src_b.update_indices(free_indices_gen), self.input_ids.device
)
src_tensors, _, src_dims = src_b.get_tensor_groups()
grouped_move(dst_tensors, dst_indices, src_tensors)
self.set_tensor_groups(dst_tensors)
@classmethod
def recombine(cls, batches: List["CausalLMBatch"], pad_token_id: int) -> "CausalLMBatch":
def recombine(
cls, batches: List["CausalLMBatch"], pad_token_id: int
) -> "CausalLMBatch":
if not all(b.past_key_values is not None for b in batches):
raise ValueError("KV cache not allocated! Cannot recombine before prefill!")
@ -375,31 +407,39 @@ class CausalLMBatch(Batch):
# For prefill there is a space allocated only for first token
# Need to add padding to the max total tokens before first decode
moves_needed = [total_requests - len(b) if b.batch_size == new_bs else total_requests for b in batches]
moves_needed = [
total_requests - len(b) if b.batch_size == new_bs else total_requests
for b in batches
]
dst_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0]
reshape = (batches[dst_batch_idx].batch_size < new_bs)
reshape = batches[dst_batch_idx].batch_size < new_bs
# TODO: Add support for changing max seq len, i.e. due to output length bucketing
# FIXME: max_seq_len for non optimized code
if len(batches) > 1:
scenario = 'CONCAT'
scenario = "CONCAT"
elif reshape:
scenario = 'RESHAPE'
scenario = "RESHAPE"
elif cur_padding[dst_batch_idx] <= 0:
scenario = 'SHIFT'
offsets = [biggest_single_chunk(b.max_input_length - max_input_length) for b in batches]
scenario = "SHIFT"
offsets = [
biggest_single_chunk(b.max_input_length - max_input_length)
for b in batches
]
max_input_length = max_input_length + offsets[dst_batch_idx]
else:
# Nothing to do
return batches[0]
dbg_trace(
scenario, f'bs:{[b.batch_size for b in batches]}->{new_bs}'
f' reqs:{[len(b) for b in batches]}'
f' offsets:{offsets}'
f' input_lengths:{input_lengths}'
f' cur_padding:{cur_padding}'
f' dst_batch:{dst_batch_idx}')
scenario,
f"bs:{[b.batch_size for b in batches]}->{new_bs}"
f" reqs:{[len(b) for b in batches]}"
f" offsets:{offsets}"
f" input_lengths:{input_lengths}"
f" cur_padding:{cur_padding}"
f" dst_batch:{dst_batch_idx}",
)
grouped_requests = [[req for req in batch.requests] for batch in batches]
flat_requests = list(itertools.chain(*grouped_requests))
@ -410,10 +450,15 @@ class CausalLMBatch(Batch):
batches[i].realign(target_bs, offsets[i], pad_token_id)
batches[i].split_kv_cache_if_needed(i == dst_batch_idx)
batches[dst_batch_idx].expand_bs(new_bs)
batches[dst_batch_idx].move_data([batches[i] for i in range(len(batches)) if i != dst_batch_idx])
batches[dst_batch_idx].move_data(
[batches[i] for i in range(len(batches)) if i != dst_batch_idx]
)
top_n_tokens = [r.data.top_n_tokens for r in flat_requests]
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
top_n_tokens.extend([-1] * (new_bs - total_requests))
top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
parameters = [r.data.parameters for r in flat_requests]
# append the dummy parameters for dummy requests
@ -424,7 +469,9 @@ class CausalLMBatch(Batch):
fsm_grammar_states = [0] * batch_size
for batch in batches:
for i, req in enumerate(batch.requests):
fsm_grammar_states[req.idx] = batch.next_token_chooser.fsm_grammar_states[i]
fsm_grammar_states[req.idx] = (
batch.next_token_chooser.fsm_grammar_states[i]
)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
parameters,
@ -465,8 +512,11 @@ class CausalLMBatch(Batch):
dtype: torch.dtype,
device: torch.device,
) -> "CausalLMBatch":
dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}')
requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)]
dbg_trace("FROM_PB", f"num_reqs:{len(pb.requests)}")
requests = [
CausalLMRequest.from_pb(idx, req, tokenizer)
for idx, req in enumerate(pb.requests)
]
inputs = []
top_n_tokens = []
@ -476,10 +526,10 @@ class CausalLMBatch(Batch):
inputs.append(concat_text_chunks(r.input_chunks.chunks))
top_n_tokens.append(r.top_n_tokens)
max_truncation = max(max_truncation, r.truncate)
max_input_length = max_truncation
if max_input_length < PAD_SEQUENCE_TO_MULTIPLE_OF:
max_input_length = PAD_SEQUENCE_TO_MULTIPLE_OF
max_input_length = PAD_SEQUENCE_TO_MULTIPLE_OF
max_new_tokens = max(r.stopping_criteria.max_new_tokens for r in requests)
# TODO: by tokenizing all inputs at once we loose information on actual input lengths
@ -501,7 +551,7 @@ class CausalLMBatch(Batch):
)
tokenized_inputs = tokenizer(
inputs+dummy_inputs,
inputs + dummy_inputs,
return_tensors="pt",
padding="longest",
return_token_type_ids=False,
@ -514,7 +564,9 @@ class CausalLMBatch(Batch):
bucket_size = max_input_length
left_padding = max_input_length - input_len
if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0:
assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, (
"PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
)
rounded_seq_len = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF)
if rounded_seq_len <= max_input_length:
bucket_size = rounded_seq_len - 1
@ -547,7 +599,8 @@ class CausalLMBatch(Batch):
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
old_bs = len(requests)
top_n_tokens.extend([-1] * (new_bs - old_bs))
top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
@ -568,14 +621,16 @@ class CausalLMBatch(Batch):
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]:
dbg_trace('FILTER', f'num_reqs:{len(self.requests)} -> {len(request_ids)}')
dbg_trace("FILTER", f"num_reqs:{len(self.requests)} -> {len(request_ids)}")
request_ids = set(request_ids)
self.requests = [req for req in self.requests if req.data.id in request_ids]
return self
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["CausalLMBatch"], pad_token_id: int = 0) -> "CausalLMBatch":
def concatenate(
cls, batches: List["CausalLMBatch"], pad_token_id: int = 0
) -> "CausalLMBatch":
return cls.recombine(batches, pad_token_id)
def __len__(self):
@ -618,9 +673,7 @@ class CausalLM(Model):
tokenizer_class=AutoTokenizer,
config_class=AutoConfig,
batch_class=CausalLMBatch,
):
if speculator:
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
@ -646,18 +699,14 @@ class CausalLM(Model):
htorch.core.hpu_set_env()
if world_size > 1:
model = self.get_deepspeed_model(
model_id, dtype, revision
)
model = self.get_deepspeed_model(model_id, dtype, revision)
model = hq_env.prepare_model_for_quantization(model)
else:
get_repo_root(model_id)
# Check support for rope scaling
model_kwargs = {}
config = AutoConfig.from_pretrained(
model_id
)
config = AutoConfig.from_pretrained(model_id)
if hasattr(config, "rope_scaling"):
model_kwargs["rope_scaling"] = self.get_rope_scaling()
@ -666,26 +715,34 @@ class CausalLM(Model):
revision=revision,
torch_dtype=dtype,
trust_remote_code=trust_remote_code,
**model_kwargs
**model_kwargs,
)
model = hq_env.prepare_model_for_quantization(model)
model = model.eval().to(device)
self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1
self.enable_hpu_graph = (
os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1
)
self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
if model.config.model_type not in ["gpt_bigcode"]: # gpt_bigcode/starcoderbase-3b skips remove_kv_cache_from_output()
if model.config.model_type not in [
"gpt_bigcode"
]: # gpt_bigcode/starcoderbase-3b skips remove_kv_cache_from_output()
model = remove_kv_cache_from_output(model)
if self.enable_hpu_graph:
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
else:
if LAZY_MODE == 0:
# It is said that "keep_input_mutations" is safe for inference to be done
dbg_trace(
"TORCH COMPILE", f'Torch compiling of model')
model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True})
dbg_trace("TORCH COMPILE", f"Torch compiling of model")
model.model = torch.compile(
model.model,
backend="hpu_backend",
options={"keep_input_mutations": True},
)
model = hq_env.setup_quantization(model)
@ -714,8 +771,14 @@ class CausalLM(Model):
"return_dict": True,
}
if model.config.model_type in ["llama", "mistral", "starcoder2", "qwen2", "falcon", "gpt_bigcode"]:
if model.config.model_type in [
"llama",
"mistral",
"starcoder2",
"qwen2",
"falcon",
"gpt_bigcode",
]:
if model.config.model_type not in ["falcon", "gpt_bigcode"]:
self.kwargs["attn_softmax_bf16"] = True
@ -740,11 +803,15 @@ class CausalLM(Model):
)
# Create profiler
ranks_to_profile = [int(val) for val in os.getenv("PROF_RANKS", "0").split(',')]
ranks_to_profile = [int(val) for val in os.getenv("PROF_RANKS", "0").split(",")]
record_shapes = os.getenv("PROF_RECORD_SHAPES", "false").lower() == "true"
output_dir = os.getenv("PROF_PATH", "/tmp/hpu_profile")
self.profiling_warmup_steps = int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in ranks_to_profile else 0
self.profiling_steps = int(os.getenv("PROF_STEP", "0")) if rank in ranks_to_profile else 0
self.profiling_warmup_steps = (
int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in ranks_to_profile else 0
)
self.profiling_steps = (
int(os.getenv("PROF_STEP", "0")) if rank in ranks_to_profile else 0
)
self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0"))
if self.profiling_steps > 0:
self.hb_profiler = HabanaProfile(
@ -752,7 +819,7 @@ class CausalLM(Model):
warmup=self.profiling_warmup_steps,
active=self.profiling_steps,
output_dir=output_dir,
record_shapes=record_shapes
record_shapes=record_shapes,
)
self.hb_profiler.start()
else:
@ -760,23 +827,20 @@ class CausalLM(Model):
self.step = 0
def get_deepspeed_model(
self,
model_id: str,
dtype: torch.dtype,
revision: Optional[str] = None
self, model_id: str, dtype: torch.dtype, revision: Optional[str] = None
) -> torch.nn.Module:
import deepspeed
from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu
world_size, rank, local_rank = initialize_distributed_hpu()
model_kwargs = {
"revision": revision
}
model_kwargs = {"revision": revision}
# Initialize process(es) for DeepSpeed
deepspeed.init_distributed(dist_backend="hccl")
logger.info(
"DeepSpeed is enabled. world_size {} rank {} local_rank {}".format(world_size, rank, local_rank)
"DeepSpeed is enabled. world_size {} rank {} local_rank {}".format(
world_size, rank, local_rank
)
)
config = AutoConfig.from_pretrained(model_id, **model_kwargs)
load_to_meta = model_on_meta(config)
@ -794,7 +858,9 @@ class CausalLM(Model):
get_repo_root(model_id, local_rank=os.getenv("LOCAL_RANK"))
# TODO: revisit placement on CPU when auto-injection is possible
with deepspeed.OnDevice(dtype=dtype, device="cpu"):
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, **model_kwargs)
model = AutoModelForCausalLM.from_pretrained(
model_id, torch_dtype=dtype, **model_kwargs
)
model = model.eval()
# Initialize the model
@ -817,16 +883,16 @@ class CausalLM(Model):
return None
rope_factor = float(os.getenv("ROPE_FACTOR", 1.0))
return {
'type': rope_scaling, 'factor': float(rope_factor)
}
return {"type": rope_scaling, "factor": float(rope_factor)}
@property
def batch_type(self) -> Type[CausalLMBatch]:
return CausalLMBatch
def decode(self, generated_ids: List[int]) -> str:
return self.tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
return self.tokenizer.decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
def decode_token(
self,
@ -835,7 +901,9 @@ class CausalLM(Model):
read_offset: int = 0,
) -> Tuple[str, int, int]:
if is_tokenizer_transparent(self.tokenizer):
new_text = self.tokenizer.decode(all_input_ids[read_offset:], skip_special_tokens=False)
new_text = self.tokenizer.decode(
all_input_ids[read_offset:], skip_special_tokens=False
)
return new_text, read_offset, len(all_input_ids)
else:
return super().decode_token(all_input_ids, prefix_offset, read_offset)
@ -858,7 +926,7 @@ class CausalLM(Model):
}
# Optimum Habana got "lazy_mode" key-val only supported for llama type of models
if self.model.config.model_type == "llama" :
if self.model.config.model_type == "llama":
kwargs["lazy_mode"] = LAZY_MODE == 1
if self.has_position_ids:
@ -869,7 +937,9 @@ class CausalLM(Model):
kwargs.update(self.kwargs)
if past_key_values is not None and self.model.config.model_type not in ["gpt_bigcode"]:
if past_key_values is not None and self.model.config.model_type not in [
"gpt_bigcode"
]:
return self.model.forward(**kwargs)
else:
outputs = self.model.forward(**kwargs)
@ -896,18 +966,26 @@ class CausalLM(Model):
token_idx_scalar = batch.attention_mask.shape[-1] - 1
token_idx = torch.tensor(token_idx_scalar).to(self.device)
else:
token_idx_scalar = batch.attention_mask.shape[-1] - batch.right_padding
token_idx_scalar = (
batch.attention_mask.shape[-1] - batch.right_padding
)
token_idx = torch.tensor(token_idx_scalar).to(self.device)
# Select next token
input_length = batch.input_length
if logits.shape[-2] > 1:
next_token_ids, next_token_logprobs, logprobs, _, _ = batch.next_token_chooser(
batch.input_ids, logits[:, input_length - 1: input_length, :].squeeze(-2), self.speculate
next_token_ids, next_token_logprobs, logprobs, _, _ = (
batch.next_token_chooser(
batch.input_ids,
logits[:, input_length - 1 : input_length, :].squeeze(-2),
self.speculate,
)
)
else:
next_token_ids, next_token_logprobs, logprobs, _, _ = batch.next_token_chooser(
batch.input_ids, logits.squeeze(-2), self.speculate
next_token_ids, next_token_logprobs, logprobs, _, _ = (
batch.next_token_chooser(
batch.input_ids, logits.squeeze(-2), self.speculate
)
)
# Speculation is not active for causal
accepted_ids = torch.ones_like(batch.input_ids)[:, 0]
@ -918,24 +996,29 @@ class CausalLM(Model):
accepted_ids,
)
prev_batches.append({
'next_token_ids': next_token_ids,
'next_token_logprobs': next_token_logprobs,
})
prev_batches.append(
{
"next_token_ids": next_token_ids,
"next_token_logprobs": next_token_logprobs,
}
)
for req_idx, req in enumerate(batch.requests):
requests_to_generate.append({
'req': req,
'prev_req_idx': req.idx,
'batch_id': batch_id,
'seed': batch.next_token_chooser.seeds[req_idx],
'do_sample': batch.next_token_chooser.do_sample[req_idx],
'top_n_tokens': batch.top_n_tokens[req_idx],
'top_token_ids': batch_top_token_ids[req_idx],
'top_token_logprobs': batch_top_token_logprobs[req_idx],
'grammar_state': batch.next_token_chooser.fsm_grammar_states[req.idx],
})
requests_to_generate.append(
{
"req": req,
"prev_req_idx": req.idx,
"batch_id": batch_id,
"seed": batch.next_token_chooser.seeds[req_idx],
"do_sample": batch.next_token_chooser.do_sample[req_idx],
"top_n_tokens": batch.top_n_tokens[req_idx],
"top_token_ids": batch_top_token_ids[req_idx],
"top_token_logprobs": batch_top_token_logprobs[req_idx],
"grammar_state": batch.next_token_chooser.fsm_grammar_states[
req.idx
],
}
)
htorch.core.mark_step()
@ -950,7 +1033,9 @@ class CausalLM(Model):
# Update position_ids
if prefill:
batch.position_ids = torch.index_select(batch.position_ids, 1, token_idx - 1) + 1
batch.position_ids = (
torch.index_select(batch.position_ids, 1, token_idx - 1) + 1
)
else:
batch.position_ids += 1
# Update past key values
@ -971,13 +1056,19 @@ class CausalLM(Model):
if not prefill:
batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id)
scenario = 'PREFILL' if prefill else 'GENERATE'
if self.enable_hpu_graph and self.limit_hpu_graph and round_up(batch.batch_size, BATCH_BUCKET_SIZE) != self.prev_bs:
scenario = "PREFILL" if prefill else "GENERATE"
if (
self.enable_hpu_graph
and self.limit_hpu_graph
and round_up(batch.batch_size, BATCH_BUCKET_SIZE) != self.prev_bs
):
self.model.clear_cache()
self.prev_bs = round_up(batch.batch_size, BATCH_BUCKET_SIZE)
dbg_trace(
scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}')
assert batch.right_padding > 0, 'No more room for next token!'
scenario,
f"bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}",
)
assert batch.right_padding > 0, "No more room for next token!"
# Execute batch
if prefill:
@ -989,14 +1080,18 @@ class CausalLM(Model):
batch.position_ids,
token_idx,
batch.past_key_values,
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None,
bypass_hpu_graph=prefill and self.limit_hpu_graph
if self.enable_hpu_graph
else None,
)
elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]):
# Don't schedule next forward if max_new_tokens for all requests equals 1
# - we've already generated the first and only needed token in the prefill phase
pass
else:
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
token_idx = torch.tensor(
batch.attention_mask.shape[-1] - batch.right_padding
).to(self.device)
input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1)
logits = self.forward(
input_ids,
@ -1004,7 +1099,9 @@ class CausalLM(Model):
batch.position_ids,
token_idx,
batch.past_key_values,
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None,
bypass_hpu_graph=prefill and self.limit_hpu_graph
if self.enable_hpu_graph
else None,
)
if self.model.config.model_type in ["gpt_bigcode"]:
batch.logits, batch.past = logits
@ -1018,40 +1115,45 @@ class CausalLM(Model):
# Stage 3. Finish and return previous generations
stopped = len(requests_to_generate) > 0
for prev_batch in prev_batches:
prev_batch['next_token_logprobs'] = prev_batch['next_token_logprobs'].tolist()
prev_batch['next_token_ids_cpu'] = prev_batch['next_token_ids'].cpu()
prev_batch["next_token_logprobs"] = prev_batch[
"next_token_logprobs"
].tolist()
prev_batch["next_token_ids_cpu"] = prev_batch["next_token_ids"].cpu()
htorch.core.mark_step()
for req_data in requests_to_generate:
req = req_data['req']
i = req_data['prev_req_idx']
prev_batch_id = req_data['batch_id']
req = req_data["req"]
i = req_data["prev_req_idx"]
prev_batch_id = req_data["batch_id"]
assert len(prev_batches) > prev_batch_id
next_token_ids_cpu = prev_batches[prev_batch_id]['next_token_ids_cpu']
next_token_logprobs = prev_batches[prev_batch_id]['next_token_logprobs']
next_token_ids_cpu = prev_batches[prev_batch_id]["next_token_ids_cpu"]
next_token_logprobs = prev_batches[prev_batch_id]["next_token_logprobs"]
request = req.data
input_length = req.input_length
prefix_offset = req.prefix_offset
read_offset = req.read_offset
do_sample = req_data['do_sample']
seed = req_data['seed']
do_sample = req_data["do_sample"]
seed = req_data["seed"]
stopping_criteria = req.stopping_criteria
all_input_ids = req.all_input_ids
next_token_id = next_token_ids_cpu[i]
next_token_logprob = next_token_logprobs[i]
top_n_tokens = req_data['top_n_tokens']
top_token_ids = req_data['top_token_ids']
top_token_logprobs = req_data['top_token_logprobs']
grammar_state = req_data['grammar_state']
top_n_tokens = req_data["top_n_tokens"]
top_token_ids = req_data["top_token_ids"]
top_token_logprobs = req_data["top_token_logprobs"]
grammar_state = req_data["grammar_state"]
# Append next token to all tokens
all_input_ids[input_length] = next_token_id
new_input_length = input_length + 1
# Generated token
if is_tokenizer_transparent(self.tokenizer) and len(stopping_criteria.stop_sequence_criterias) == 0:
next_token_text = ''
if (
is_tokenizer_transparent(self.tokenizer)
and len(stopping_criteria.stop_sequence_criterias) == 0
):
next_token_text = ""
else:
next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids[0:new_input_length, 0], prefix_offset, read_offset
@ -1075,7 +1177,11 @@ class CausalLM(Model):
output_text = None
else:
output_text = self.decode(
all_input_ids[new_input_length - stopping_criteria.current_tokens: new_input_length, 0]
all_input_ids[
new_input_length
- stopping_criteria.current_tokens : new_input_length,
0,
]
)
generated_text = GeneratedText(
output_text,
@ -1090,7 +1196,7 @@ class CausalLM(Model):
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
# Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + next_token_logprobs
prefill_token_ids = all_input_ids[0: new_input_length - 1]
prefill_token_ids = all_input_ids[0 : new_input_length - 1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
@ -1159,7 +1265,12 @@ class CausalLM(Model):
htorch.core.mark_step()
self.step = self.step + 1
if self.hb_profiler is not None:
if self.step > self.profiling_wait_steps + self.profiling_warmup_steps + self.profiling_steps:
if (
self.step
> self.profiling_wait_steps
+ self.profiling_warmup_steps
+ self.profiling_steps
):
self.hb_profiler.stop()
else:
self.hb_profiler.step()
@ -1178,11 +1289,12 @@ class CausalLM(Model):
return self.batch_type.from_pb(batch, self.tokenizer, self.dtype, self.device)
def warmup(self, request) -> None:
MAX_TOTAL_TOKENS = request.max_total_tokens
MAX_BATCH_TOTAL_TOKENS = request.max_batch_total_tokens
batch = self.batch_type.from_pb(request.batch, self.tokenizer, self.dtype, self.device)
batch = self.batch_type.from_pb(
request.batch, self.tokenizer, self.dtype, self.device
)
max_prefill_batch_size = batch.input_ids.shape[0]
try:
# max prefill batch size warmup
@ -1199,14 +1311,21 @@ class CausalLM(Model):
max_input_length = request.max_input_length
prefill_batch_size_list = [batch for batch in range(PREFILL_BATCH_BUCKET_SIZE, max_prefill_batch_size, PREFILL_BATCH_BUCKET_SIZE)]
prefill_batch_size_list.append(max_prefill_batch_size)
prefill_seqlen_list = [seq for seq in range(PAD_SEQUENCE_TO_MULTIPLE_OF, max_input_length, PAD_SEQUENCE_TO_MULTIPLE_OF)]
prefill_seqlen_list = [
seq
for seq in range(
PAD_SEQUENCE_TO_MULTIPLE_OF,
max_input_length,
PAD_SEQUENCE_TO_MULTIPLE_OF,
)
]
prefill_seqlen_list.append(max_input_length)
prefill_batch_size_list.sort(reverse=True)
prefill_seqlen_list.sort(reverse=True)
try:
for batch_size in prefill_batch_size_list:
for seq_len in prefill_seqlen_list:
batch = self.generate_warmup_batch(request, seq_len-1, batch_size)
batch = self.generate_warmup_batch(request, seq_len - 1, batch_size)
_, prefill_batch, _ = self.generate_token([batch])
except:
prefill_batch_size_list.sort()
@ -1227,24 +1346,33 @@ class CausalLM(Model):
f"Memory stats: {mem_stats} "
)
#warmup decode batch size
# warmup decode batch size
max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS)
max_decode_batch_size = round_up(max_decode_batch_size, BATCH_BUCKET_SIZE)
decode_batch_size_list = [i for i in range(BATCH_BUCKET_SIZE, max_decode_batch_size, BATCH_BUCKET_SIZE)]
decode_batch_size_list = [
i
for i in range(BATCH_BUCKET_SIZE, max_decode_batch_size, BATCH_BUCKET_SIZE)
]
decode_batch_size_list.append(max_decode_batch_size)
decode_batch_size_list.sort(reverse=True)
try:
for batch_size in decode_batch_size_list:
batches= []
iters = math.floor(batch_size/max_prefill_batch_size)
batches = []
iters = math.floor(batch_size / max_prefill_batch_size)
for i in range(iters):
batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, max_prefill_batch_size)
batch = self.generate_warmup_batch(
request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, max_prefill_batch_size
)
_, prefill_batch, _ = self.generate_token([batch])
batches.append(prefill_batch)
if batch_size % max_prefill_batch_size != 0:
batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, batch_size % max_prefill_batch_size)
batch = self.generate_warmup_batch(
request,
PAD_SEQUENCE_TO_MULTIPLE_OF - 1,
batch_size % max_prefill_batch_size,
)
_, prefill_batch, _ = self.generate_token([batch])
batches.append(prefill_batch)
@ -1254,10 +1382,10 @@ class CausalLM(Model):
batches.clear()
except:
raise RuntimeError(
f"Not enough memory to warmup decode batch_sizes({decode_batch_size_list})."
f"You need to decrease `--max-batch-total-tokens`"
)
raise RuntimeError(
f"Not enough memory to warmup decode batch_sizes({decode_batch_size_list})."
f"You need to decrease `--max-batch-total-tokens`"
)
decode_batch_size_list.sort()
MAX_BATCH_TOTAL_TOKENS = MAX_TOTAL_TOKENS * decode_batch_size_list[-1]
@ -1268,4 +1396,4 @@ class CausalLM(Model):
f"Memory stats: {mem_stats} "
)
return MAX_BATCH_TOTAL_TOKENS
return MAX_BATCH_TOTAL_TOKENS