Complete padding of CausalLMBatch when there exists batch bucketing (#261)

Signed-off-by: kaixuanliu <kaixuan.liu@intel.com>
This commit is contained in:
kaixuanliu 2025-01-30 17:19:13 +08:00 committed by GitHub
parent fe7594e369
commit b52164d38a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -53,21 +53,26 @@ from text_generation_server.utils.debug import dbg_trace
from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils.speculate import get_speculate
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
MAX_TOTAL_TOKENS = int(os.environ.get('MAX_TOTAL_TOKENS', 2048)) MAX_TOTAL_TOKENS = int(os.environ.get("MAX_TOTAL_TOKENS", 2048))
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 256)) PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 256))
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1)) LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1))
BATCH_BUCKET_SIZE= int(os.environ.get('BATCH_BUCKET_SIZE', 8)) BATCH_BUCKET_SIZE = int(os.environ.get("BATCH_BUCKET_SIZE", 8))
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 2)) PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get("PREFILL_BATCH_BUCKET_SIZE", 2))
def torch_compile_for_eager(func): def torch_compile_for_eager(func):
if LAZY_MODE == 1: if LAZY_MODE == 1:
return func return func
return torch.compile(func, backend="hpu_backend", options={"keep_input_mutations": True}) return torch.compile(
func, backend="hpu_backend", options={"keep_input_mutations": True}
)
def round_up(number, k): def round_up(number, k):
return (number + k - 1) // k * k return (number + k - 1) // k * k
def to_tensor_indices(indices, device): def to_tensor_indices(indices, device):
return torch.tensor(indices, dtype=torch.long, device=device) return torch.tensor(indices, dtype=torch.long, device=device)
@ -96,9 +101,11 @@ def grouped_pad(tensor_groups, dims, values):
for tensors, dim, value in zip(tensor_groups, dims, values): for tensors, dim, value in zip(tensor_groups, dims, values):
padding = MAX_TOTAL_TOKENS - tensors[0].size(dim) if dim is not None else 0 padding = MAX_TOTAL_TOKENS - tensors[0].size(dim) if dim is not None else 0
if padding > 0: if padding > 0:
assert dim in [-1, -2], f'Only dims -1 and -2 are supported! {dim}' assert dim in [-1, -2], f"Only dims -1 and -2 are supported! {dim}"
pad_shape = (0, 0, 0, padding) if dim == -2 else (0, padding) pad_shape = (0, 0, 0, padding) if dim == -2 else (0, padding)
result = [torch.nn.functional.pad(t, pad_shape, value=value) for t in tensors] result = [
torch.nn.functional.pad(t, pad_shape, value=value) for t in tensors
]
else: else:
result = [t for t in tensors] result = [t for t in tensors]
grouped_result.append(result) grouped_result.append(result)
@ -117,7 +124,10 @@ def roll(tensor, chunk, dim, merge_graphs):
def grouped_roll(tensor_groups, chunk, dims, merge_graphs): def grouped_roll(tensor_groups, chunk, dims, merge_graphs):
tensor_groups = [[roll(t, chunk, dim, merge_graphs) for t in tensors] for tensors, dim in zip(tensor_groups, dims)] tensor_groups = [
[roll(t, chunk, dim, merge_graphs) for t in tensors]
for tensors, dim in zip(tensor_groups, dims)
]
if merge_graphs: if merge_graphs:
htorch.core.mark_step() htorch.core.mark_step()
return tensor_groups return tensor_groups
@ -167,7 +177,10 @@ def extend_batch(tensors, target_bs, dim):
def grouped_extend_batch(tensor_groups, target_bs, bs_dims): def grouped_extend_batch(tensor_groups, target_bs, bs_dims):
tensor_groups = [extend_batch(tensors, target_bs, dim) for tensors, dim in zip(tensor_groups, bs_dims)] tensor_groups = [
extend_batch(tensors, target_bs, dim)
for tensors, dim in zip(tensor_groups, bs_dims)
]
return tensor_groups return tensor_groups
@ -220,15 +233,20 @@ class CausalLMRequest:
all_input_ids: torch.Tensor all_input_ids: torch.Tensor
@classmethod @classmethod
def from_pb(cls, idx: int, data: generate_pb2.Request, tokenizer: PreTrainedTokenizerBase): def from_pb(
cls, idx: int, data: generate_pb2.Request, tokenizer: PreTrainedTokenizerBase
):
return cls( return cls(
idx=idx, idx=idx,
data=data, data=data,
input_length=None, input_length=None,
prefix_offset=None, prefix_offset=None,
read_offset=None, read_offset=None,
stopping_criteria=StoppingCriteria.from_pb(data.stopping_parameters, tokenizer), stopping_criteria=StoppingCriteria.from_pb(
all_input_ids=None,) data.stopping_parameters, tokenizer
),
all_input_ids=None,
)
def update_idx(self, new_idx): def update_idx(self, new_idx):
prev = self.idx prev = self.idx
@ -289,7 +307,11 @@ class CausalLMBatch(Batch):
# Very simple heuristic to determine whether we should merge tensors # Very simple heuristic to determine whether we should merge tensors
# this needs tuning for other models/scenarios # this needs tuning for other models/scenarios
small_bs = len(self.past_key_values) > self.batch_size small_bs = len(self.past_key_values) > self.batch_size
if not self.merged_kv_cache and small_bs and (pad_needed or shift_needed or expand_needed): if (
not self.merged_kv_cache
and small_bs
and (pad_needed or shift_needed or expand_needed)
):
past_keys, past_values = self.detach_kv_cache() past_keys, past_values = self.detach_kv_cache()
past_keys = merge(past_keys) past_keys = merge(past_keys)
past_values = merge(past_values) past_values = merge(past_values)
@ -309,7 +331,13 @@ class CausalLMBatch(Batch):
seq_dim = -1 seq_dim = -1
key_dim = -2 if self.keys_head_dim_last else -1 key_dim = -2 if self.keys_head_dim_last else -1
value_dim = -2 value_dim = -2
tensors = [[self.input_ids], [self.attention_mask], [self.position_ids], past_keys, past_values] tensors = [
[self.input_ids],
[self.attention_mask],
[self.position_ids],
past_keys,
past_values,
]
# We don't need to align position_ids # We don't need to align position_ids
seq_dims = [seq_dim, seq_dim, None, key_dim, value_dim] seq_dims = [seq_dim, seq_dim, None, key_dim, value_dim]
bs_dims = [0, 0, 0] + ([1, 1] if self.merged_kv_cache else [0, 0]) bs_dims = [0, 0, 0] + ([1, 1] if self.merged_kv_cache else [0, 0])
@ -350,13 +378,17 @@ class CausalLMBatch(Batch):
dst_tensors, _, dst_dims = self.get_tensor_groups() dst_tensors, _, dst_dims = self.get_tensor_groups()
free_indices_gen = self.free_indices_generator() free_indices_gen = self.free_indices_generator()
for src_b in src_batches: for src_b in src_batches:
dst_indices = to_tensor_indices(src_b.update_indices(free_indices_gen), self.input_ids.device) dst_indices = to_tensor_indices(
src_b.update_indices(free_indices_gen), self.input_ids.device
)
src_tensors, _, src_dims = src_b.get_tensor_groups() src_tensors, _, src_dims = src_b.get_tensor_groups()
grouped_move(dst_tensors, dst_indices, src_tensors) grouped_move(dst_tensors, dst_indices, src_tensors)
self.set_tensor_groups(dst_tensors) self.set_tensor_groups(dst_tensors)
@classmethod @classmethod
def recombine(cls, batches: List["CausalLMBatch"], pad_token_id: int) -> "CausalLMBatch": def recombine(
cls, batches: List["CausalLMBatch"], pad_token_id: int
) -> "CausalLMBatch":
if not all(b.past_key_values is not None for b in batches): if not all(b.past_key_values is not None for b in batches):
raise ValueError("KV cache not allocated! Cannot recombine before prefill!") raise ValueError("KV cache not allocated! Cannot recombine before prefill!")
@ -375,31 +407,39 @@ class CausalLMBatch(Batch):
# For prefill there is a space allocated only for first token # For prefill there is a space allocated only for first token
# Need to add padding to the max total tokens before first decode # Need to add padding to the max total tokens before first decode
moves_needed = [total_requests - len(b) if b.batch_size == new_bs else total_requests for b in batches] moves_needed = [
total_requests - len(b) if b.batch_size == new_bs else total_requests
for b in batches
]
dst_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0] dst_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0]
reshape = (batches[dst_batch_idx].batch_size < new_bs) reshape = batches[dst_batch_idx].batch_size < new_bs
# TODO: Add support for changing max seq len, i.e. due to output length bucketing # TODO: Add support for changing max seq len, i.e. due to output length bucketing
# FIXME: max_seq_len for non optimized code # FIXME: max_seq_len for non optimized code
if len(batches) > 1: if len(batches) > 1:
scenario = 'CONCAT' scenario = "CONCAT"
elif reshape: elif reshape:
scenario = 'RESHAPE' scenario = "RESHAPE"
elif cur_padding[dst_batch_idx] <= 0: elif cur_padding[dst_batch_idx] <= 0:
scenario = 'SHIFT' scenario = "SHIFT"
offsets = [biggest_single_chunk(b.max_input_length - max_input_length) for b in batches] offsets = [
biggest_single_chunk(b.max_input_length - max_input_length)
for b in batches
]
max_input_length = max_input_length + offsets[dst_batch_idx] max_input_length = max_input_length + offsets[dst_batch_idx]
else: else:
# Nothing to do # Nothing to do
return batches[0] return batches[0]
dbg_trace( dbg_trace(
scenario, f'bs:{[b.batch_size for b in batches]}->{new_bs}' scenario,
f' reqs:{[len(b) for b in batches]}' f"bs:{[b.batch_size for b in batches]}->{new_bs}"
f' offsets:{offsets}' f" reqs:{[len(b) for b in batches]}"
f' input_lengths:{input_lengths}' f" offsets:{offsets}"
f' cur_padding:{cur_padding}' f" input_lengths:{input_lengths}"
f' dst_batch:{dst_batch_idx}') f" cur_padding:{cur_padding}"
f" dst_batch:{dst_batch_idx}",
)
grouped_requests = [[req for req in batch.requests] for batch in batches] grouped_requests = [[req for req in batch.requests] for batch in batches]
flat_requests = list(itertools.chain(*grouped_requests)) flat_requests = list(itertools.chain(*grouped_requests))
@ -410,10 +450,15 @@ class CausalLMBatch(Batch):
batches[i].realign(target_bs, offsets[i], pad_token_id) batches[i].realign(target_bs, offsets[i], pad_token_id)
batches[i].split_kv_cache_if_needed(i == dst_batch_idx) batches[i].split_kv_cache_if_needed(i == dst_batch_idx)
batches[dst_batch_idx].expand_bs(new_bs) batches[dst_batch_idx].expand_bs(new_bs)
batches[dst_batch_idx].move_data([batches[i] for i in range(len(batches)) if i != dst_batch_idx]) batches[dst_batch_idx].move_data(
[batches[i] for i in range(len(batches)) if i != dst_batch_idx]
)
top_n_tokens = [r.data.top_n_tokens for r in flat_requests] top_n_tokens = [r.data.top_n_tokens for r in flat_requests]
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) top_n_tokens.extend([-1] * (new_bs - total_requests))
top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
parameters = [r.data.parameters for r in flat_requests] parameters = [r.data.parameters for r in flat_requests]
# append the dummy parameters for dummy requests # append the dummy parameters for dummy requests
@ -424,7 +469,9 @@ class CausalLMBatch(Batch):
fsm_grammar_states = [0] * batch_size fsm_grammar_states = [0] * batch_size
for batch in batches: for batch in batches:
for i, req in enumerate(batch.requests): for i, req in enumerate(batch.requests):
fsm_grammar_states[req.idx] = batch.next_token_chooser.fsm_grammar_states[i] fsm_grammar_states[req.idx] = (
batch.next_token_chooser.fsm_grammar_states[i]
)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
parameters, parameters,
@ -465,8 +512,11 @@ class CausalLMBatch(Batch):
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "CausalLMBatch": ) -> "CausalLMBatch":
dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}') dbg_trace("FROM_PB", f"num_reqs:{len(pb.requests)}")
requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)] requests = [
CausalLMRequest.from_pb(idx, req, tokenizer)
for idx, req in enumerate(pb.requests)
]
inputs = [] inputs = []
top_n_tokens = [] top_n_tokens = []
@ -476,10 +526,10 @@ class CausalLMBatch(Batch):
inputs.append(concat_text_chunks(r.input_chunks.chunks)) inputs.append(concat_text_chunks(r.input_chunks.chunks))
top_n_tokens.append(r.top_n_tokens) top_n_tokens.append(r.top_n_tokens)
max_truncation = max(max_truncation, r.truncate) max_truncation = max(max_truncation, r.truncate)
max_input_length = max_truncation max_input_length = max_truncation
if max_input_length < PAD_SEQUENCE_TO_MULTIPLE_OF: if max_input_length < PAD_SEQUENCE_TO_MULTIPLE_OF:
max_input_length = PAD_SEQUENCE_TO_MULTIPLE_OF max_input_length = PAD_SEQUENCE_TO_MULTIPLE_OF
max_new_tokens = max(r.stopping_criteria.max_new_tokens for r in requests) max_new_tokens = max(r.stopping_criteria.max_new_tokens for r in requests)
# TODO: by tokenizing all inputs at once we loose information on actual input lengths # TODO: by tokenizing all inputs at once we loose information on actual input lengths
@ -501,7 +551,7 @@ class CausalLMBatch(Batch):
) )
tokenized_inputs = tokenizer( tokenized_inputs = tokenizer(
inputs+dummy_inputs, inputs + dummy_inputs,
return_tensors="pt", return_tensors="pt",
padding="longest", padding="longest",
return_token_type_ids=False, return_token_type_ids=False,
@ -514,7 +564,9 @@ class CausalLMBatch(Batch):
bucket_size = max_input_length bucket_size = max_input_length
left_padding = max_input_length - input_len left_padding = max_input_length - input_len
if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0: if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0:
assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length" assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, (
"PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
)
rounded_seq_len = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF) rounded_seq_len = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF)
if rounded_seq_len <= max_input_length: if rounded_seq_len <= max_input_length:
bucket_size = rounded_seq_len - 1 bucket_size = rounded_seq_len - 1
@ -547,7 +599,8 @@ class CausalLMBatch(Batch):
position_ids = attention_mask.long().cumsum(-1) - 1 position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1) position_ids.masked_fill_(attention_mask == 0, 1)
old_bs = len(requests)
top_n_tokens.extend([-1] * (new_bs - old_bs))
top_n_tokens_tensor = torch.tensor( top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64 top_n_tokens, device=device, dtype=torch.int64
) )
@ -568,14 +621,16 @@ class CausalLMBatch(Batch):
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]: def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]:
dbg_trace('FILTER', f'num_reqs:{len(self.requests)} -> {len(request_ids)}') dbg_trace("FILTER", f"num_reqs:{len(self.requests)} -> {len(request_ids)}")
request_ids = set(request_ids) request_ids = set(request_ids)
self.requests = [req for req in self.requests if req.data.id in request_ids] self.requests = [req for req in self.requests if req.data.id in request_ids]
return self return self
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["CausalLMBatch"], pad_token_id: int = 0) -> "CausalLMBatch": def concatenate(
cls, batches: List["CausalLMBatch"], pad_token_id: int = 0
) -> "CausalLMBatch":
return cls.recombine(batches, pad_token_id) return cls.recombine(batches, pad_token_id)
def __len__(self): def __len__(self):
@ -618,9 +673,7 @@ class CausalLM(Model):
tokenizer_class=AutoTokenizer, tokenizer_class=AutoTokenizer,
config_class=AutoConfig, config_class=AutoConfig,
batch_class=CausalLMBatch, batch_class=CausalLMBatch,
): ):
if speculator: if speculator:
raise RuntimeError("Speculator decoding is not enabled for AutoModel") raise RuntimeError("Speculator decoding is not enabled for AutoModel")
@ -646,18 +699,14 @@ class CausalLM(Model):
htorch.core.hpu_set_env() htorch.core.hpu_set_env()
if world_size > 1: if world_size > 1:
model = self.get_deepspeed_model( model = self.get_deepspeed_model(model_id, dtype, revision)
model_id, dtype, revision
)
model = hq_env.prepare_model_for_quantization(model) model = hq_env.prepare_model_for_quantization(model)
else: else:
get_repo_root(model_id) get_repo_root(model_id)
# Check support for rope scaling # Check support for rope scaling
model_kwargs = {} model_kwargs = {}
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(model_id)
model_id
)
if hasattr(config, "rope_scaling"): if hasattr(config, "rope_scaling"):
model_kwargs["rope_scaling"] = self.get_rope_scaling() model_kwargs["rope_scaling"] = self.get_rope_scaling()
@ -666,26 +715,34 @@ class CausalLM(Model):
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
**model_kwargs **model_kwargs,
) )
model = hq_env.prepare_model_for_quantization(model) model = hq_env.prepare_model_for_quantization(model)
model = model.eval().to(device) model = model.eval().to(device)
self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1 self.enable_hpu_graph = (
os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1
)
self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true" self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
if model.config.model_type not in ["gpt_bigcode"]: # gpt_bigcode/starcoderbase-3b skips remove_kv_cache_from_output() if model.config.model_type not in [
"gpt_bigcode"
]: # gpt_bigcode/starcoderbase-3b skips remove_kv_cache_from_output()
model = remove_kv_cache_from_output(model) model = remove_kv_cache_from_output(model)
if self.enable_hpu_graph: if self.enable_hpu_graph:
from habana_frameworks.torch.hpu import wrap_in_hpu_graph from habana_frameworks.torch.hpu import wrap_in_hpu_graph
model = wrap_in_hpu_graph(model, disable_tensor_cache=True) model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
else: else:
if LAZY_MODE == 0: if LAZY_MODE == 0:
# It is said that "keep_input_mutations" is safe for inference to be done # It is said that "keep_input_mutations" is safe for inference to be done
dbg_trace( dbg_trace("TORCH COMPILE", f"Torch compiling of model")
"TORCH COMPILE", f'Torch compiling of model') model.model = torch.compile(
model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True}) model.model,
backend="hpu_backend",
options={"keep_input_mutations": True},
)
model = hq_env.setup_quantization(model) model = hq_env.setup_quantization(model)
@ -714,8 +771,14 @@ class CausalLM(Model):
"return_dict": True, "return_dict": True,
} }
if model.config.model_type in [
if model.config.model_type in ["llama", "mistral", "starcoder2", "qwen2", "falcon", "gpt_bigcode"]: "llama",
"mistral",
"starcoder2",
"qwen2",
"falcon",
"gpt_bigcode",
]:
if model.config.model_type not in ["falcon", "gpt_bigcode"]: if model.config.model_type not in ["falcon", "gpt_bigcode"]:
self.kwargs["attn_softmax_bf16"] = True self.kwargs["attn_softmax_bf16"] = True
@ -740,11 +803,15 @@ class CausalLM(Model):
) )
# Create profiler # Create profiler
ranks_to_profile = [int(val) for val in os.getenv("PROF_RANKS", "0").split(',')] ranks_to_profile = [int(val) for val in os.getenv("PROF_RANKS", "0").split(",")]
record_shapes = os.getenv("PROF_RECORD_SHAPES", "false").lower() == "true" record_shapes = os.getenv("PROF_RECORD_SHAPES", "false").lower() == "true"
output_dir = os.getenv("PROF_PATH", "/tmp/hpu_profile") output_dir = os.getenv("PROF_PATH", "/tmp/hpu_profile")
self.profiling_warmup_steps = int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in ranks_to_profile else 0 self.profiling_warmup_steps = (
self.profiling_steps = int(os.getenv("PROF_STEP", "0")) if rank in ranks_to_profile else 0 int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in ranks_to_profile else 0
)
self.profiling_steps = (
int(os.getenv("PROF_STEP", "0")) if rank in ranks_to_profile else 0
)
self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0")) self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0"))
if self.profiling_steps > 0: if self.profiling_steps > 0:
self.hb_profiler = HabanaProfile( self.hb_profiler = HabanaProfile(
@ -752,7 +819,7 @@ class CausalLM(Model):
warmup=self.profiling_warmup_steps, warmup=self.profiling_warmup_steps,
active=self.profiling_steps, active=self.profiling_steps,
output_dir=output_dir, output_dir=output_dir,
record_shapes=record_shapes record_shapes=record_shapes,
) )
self.hb_profiler.start() self.hb_profiler.start()
else: else:
@ -760,23 +827,20 @@ class CausalLM(Model):
self.step = 0 self.step = 0
def get_deepspeed_model( def get_deepspeed_model(
self, self, model_id: str, dtype: torch.dtype, revision: Optional[str] = None
model_id: str,
dtype: torch.dtype,
revision: Optional[str] = None
) -> torch.nn.Module: ) -> torch.nn.Module:
import deepspeed import deepspeed
from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu
world_size, rank, local_rank = initialize_distributed_hpu() world_size, rank, local_rank = initialize_distributed_hpu()
model_kwargs = { model_kwargs = {"revision": revision}
"revision": revision
}
# Initialize process(es) for DeepSpeed # Initialize process(es) for DeepSpeed
deepspeed.init_distributed(dist_backend="hccl") deepspeed.init_distributed(dist_backend="hccl")
logger.info( logger.info(
"DeepSpeed is enabled. world_size {} rank {} local_rank {}".format(world_size, rank, local_rank) "DeepSpeed is enabled. world_size {} rank {} local_rank {}".format(
world_size, rank, local_rank
)
) )
config = AutoConfig.from_pretrained(model_id, **model_kwargs) config = AutoConfig.from_pretrained(model_id, **model_kwargs)
load_to_meta = model_on_meta(config) load_to_meta = model_on_meta(config)
@ -794,7 +858,9 @@ class CausalLM(Model):
get_repo_root(model_id, local_rank=os.getenv("LOCAL_RANK")) get_repo_root(model_id, local_rank=os.getenv("LOCAL_RANK"))
# TODO: revisit placement on CPU when auto-injection is possible # TODO: revisit placement on CPU when auto-injection is possible
with deepspeed.OnDevice(dtype=dtype, device="cpu"): with deepspeed.OnDevice(dtype=dtype, device="cpu"):
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, **model_kwargs) model = AutoModelForCausalLM.from_pretrained(
model_id, torch_dtype=dtype, **model_kwargs
)
model = model.eval() model = model.eval()
# Initialize the model # Initialize the model
@ -817,16 +883,16 @@ class CausalLM(Model):
return None return None
rope_factor = float(os.getenv("ROPE_FACTOR", 1.0)) rope_factor = float(os.getenv("ROPE_FACTOR", 1.0))
return { return {"type": rope_scaling, "factor": float(rope_factor)}
'type': rope_scaling, 'factor': float(rope_factor)
}
@property @property
def batch_type(self) -> Type[CausalLMBatch]: def batch_type(self) -> Type[CausalLMBatch]:
return CausalLMBatch return CausalLMBatch
def decode(self, generated_ids: List[int]) -> str: def decode(self, generated_ids: List[int]) -> str:
return self.tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) return self.tokenizer.decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
def decode_token( def decode_token(
self, self,
@ -835,7 +901,9 @@ class CausalLM(Model):
read_offset: int = 0, read_offset: int = 0,
) -> Tuple[str, int, int]: ) -> Tuple[str, int, int]:
if is_tokenizer_transparent(self.tokenizer): if is_tokenizer_transparent(self.tokenizer):
new_text = self.tokenizer.decode(all_input_ids[read_offset:], skip_special_tokens=False) new_text = self.tokenizer.decode(
all_input_ids[read_offset:], skip_special_tokens=False
)
return new_text, read_offset, len(all_input_ids) return new_text, read_offset, len(all_input_ids)
else: else:
return super().decode_token(all_input_ids, prefix_offset, read_offset) return super().decode_token(all_input_ids, prefix_offset, read_offset)
@ -858,7 +926,7 @@ class CausalLM(Model):
} }
# Optimum Habana got "lazy_mode" key-val only supported for llama type of models # Optimum Habana got "lazy_mode" key-val only supported for llama type of models
if self.model.config.model_type == "llama" : if self.model.config.model_type == "llama":
kwargs["lazy_mode"] = LAZY_MODE == 1 kwargs["lazy_mode"] = LAZY_MODE == 1
if self.has_position_ids: if self.has_position_ids:
@ -869,7 +937,9 @@ class CausalLM(Model):
kwargs.update(self.kwargs) kwargs.update(self.kwargs)
if past_key_values is not None and self.model.config.model_type not in ["gpt_bigcode"]: if past_key_values is not None and self.model.config.model_type not in [
"gpt_bigcode"
]:
return self.model.forward(**kwargs) return self.model.forward(**kwargs)
else: else:
outputs = self.model.forward(**kwargs) outputs = self.model.forward(**kwargs)
@ -896,18 +966,26 @@ class CausalLM(Model):
token_idx_scalar = batch.attention_mask.shape[-1] - 1 token_idx_scalar = batch.attention_mask.shape[-1] - 1
token_idx = torch.tensor(token_idx_scalar).to(self.device) token_idx = torch.tensor(token_idx_scalar).to(self.device)
else: else:
token_idx_scalar = batch.attention_mask.shape[-1] - batch.right_padding token_idx_scalar = (
batch.attention_mask.shape[-1] - batch.right_padding
)
token_idx = torch.tensor(token_idx_scalar).to(self.device) token_idx = torch.tensor(token_idx_scalar).to(self.device)
# Select next token # Select next token
input_length = batch.input_length input_length = batch.input_length
if logits.shape[-2] > 1: if logits.shape[-2] > 1:
next_token_ids, next_token_logprobs, logprobs, _, _ = batch.next_token_chooser( next_token_ids, next_token_logprobs, logprobs, _, _ = (
batch.input_ids, logits[:, input_length - 1: input_length, :].squeeze(-2), self.speculate batch.next_token_chooser(
batch.input_ids,
logits[:, input_length - 1 : input_length, :].squeeze(-2),
self.speculate,
)
) )
else: else:
next_token_ids, next_token_logprobs, logprobs, _, _ = batch.next_token_chooser( next_token_ids, next_token_logprobs, logprobs, _, _ = (
batch.input_ids, logits.squeeze(-2), self.speculate batch.next_token_chooser(
batch.input_ids, logits.squeeze(-2), self.speculate
)
) )
# Speculation is not active for causal # Speculation is not active for causal
accepted_ids = torch.ones_like(batch.input_ids)[:, 0] accepted_ids = torch.ones_like(batch.input_ids)[:, 0]
@ -918,24 +996,29 @@ class CausalLM(Model):
accepted_ids, accepted_ids,
) )
prev_batches.append({ prev_batches.append(
'next_token_ids': next_token_ids, {
'next_token_logprobs': next_token_logprobs, "next_token_ids": next_token_ids,
}) "next_token_logprobs": next_token_logprobs,
}
)
for req_idx, req in enumerate(batch.requests): for req_idx, req in enumerate(batch.requests):
requests_to_generate.append({ requests_to_generate.append(
'req': req, {
'prev_req_idx': req.idx, "req": req,
'batch_id': batch_id, "prev_req_idx": req.idx,
'seed': batch.next_token_chooser.seeds[req_idx], "batch_id": batch_id,
'do_sample': batch.next_token_chooser.do_sample[req_idx], "seed": batch.next_token_chooser.seeds[req_idx],
'top_n_tokens': batch.top_n_tokens[req_idx], "do_sample": batch.next_token_chooser.do_sample[req_idx],
'top_token_ids': batch_top_token_ids[req_idx], "top_n_tokens": batch.top_n_tokens[req_idx],
'top_token_logprobs': batch_top_token_logprobs[req_idx], "top_token_ids": batch_top_token_ids[req_idx],
'grammar_state': batch.next_token_chooser.fsm_grammar_states[req.idx], "top_token_logprobs": batch_top_token_logprobs[req_idx],
"grammar_state": batch.next_token_chooser.fsm_grammar_states[
}) req.idx
],
}
)
htorch.core.mark_step() htorch.core.mark_step()
@ -950,7 +1033,9 @@ class CausalLM(Model):
# Update position_ids # Update position_ids
if prefill: if prefill:
batch.position_ids = torch.index_select(batch.position_ids, 1, token_idx - 1) + 1 batch.position_ids = (
torch.index_select(batch.position_ids, 1, token_idx - 1) + 1
)
else: else:
batch.position_ids += 1 batch.position_ids += 1
# Update past key values # Update past key values
@ -971,13 +1056,19 @@ class CausalLM(Model):
if not prefill: if not prefill:
batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id) batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id)
scenario = 'PREFILL' if prefill else 'GENERATE' scenario = "PREFILL" if prefill else "GENERATE"
if self.enable_hpu_graph and self.limit_hpu_graph and round_up(batch.batch_size, BATCH_BUCKET_SIZE) != self.prev_bs: if (
self.enable_hpu_graph
and self.limit_hpu_graph
and round_up(batch.batch_size, BATCH_BUCKET_SIZE) != self.prev_bs
):
self.model.clear_cache() self.model.clear_cache()
self.prev_bs = round_up(batch.batch_size, BATCH_BUCKET_SIZE) self.prev_bs = round_up(batch.batch_size, BATCH_BUCKET_SIZE)
dbg_trace( dbg_trace(
scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}') scenario,
assert batch.right_padding > 0, 'No more room for next token!' f"bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}",
)
assert batch.right_padding > 0, "No more room for next token!"
# Execute batch # Execute batch
if prefill: if prefill:
@ -989,14 +1080,18 @@ class CausalLM(Model):
batch.position_ids, batch.position_ids,
token_idx, token_idx,
batch.past_key_values, batch.past_key_values,
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None, bypass_hpu_graph=prefill and self.limit_hpu_graph
if self.enable_hpu_graph
else None,
) )
elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]): elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]):
# Don't schedule next forward if max_new_tokens for all requests equals 1 # Don't schedule next forward if max_new_tokens for all requests equals 1
# - we've already generated the first and only needed token in the prefill phase # - we've already generated the first and only needed token in the prefill phase
pass pass
else: else:
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device) token_idx = torch.tensor(
batch.attention_mask.shape[-1] - batch.right_padding
).to(self.device)
input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1) input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1)
logits = self.forward( logits = self.forward(
input_ids, input_ids,
@ -1004,7 +1099,9 @@ class CausalLM(Model):
batch.position_ids, batch.position_ids,
token_idx, token_idx,
batch.past_key_values, batch.past_key_values,
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None, bypass_hpu_graph=prefill and self.limit_hpu_graph
if self.enable_hpu_graph
else None,
) )
if self.model.config.model_type in ["gpt_bigcode"]: if self.model.config.model_type in ["gpt_bigcode"]:
batch.logits, batch.past = logits batch.logits, batch.past = logits
@ -1018,40 +1115,45 @@ class CausalLM(Model):
# Stage 3. Finish and return previous generations # Stage 3. Finish and return previous generations
stopped = len(requests_to_generate) > 0 stopped = len(requests_to_generate) > 0
for prev_batch in prev_batches: for prev_batch in prev_batches:
prev_batch['next_token_logprobs'] = prev_batch['next_token_logprobs'].tolist() prev_batch["next_token_logprobs"] = prev_batch[
prev_batch['next_token_ids_cpu'] = prev_batch['next_token_ids'].cpu() "next_token_logprobs"
].tolist()
prev_batch["next_token_ids_cpu"] = prev_batch["next_token_ids"].cpu()
htorch.core.mark_step() htorch.core.mark_step()
for req_data in requests_to_generate: for req_data in requests_to_generate:
req = req_data['req'] req = req_data["req"]
i = req_data['prev_req_idx'] i = req_data["prev_req_idx"]
prev_batch_id = req_data['batch_id'] prev_batch_id = req_data["batch_id"]
assert len(prev_batches) > prev_batch_id assert len(prev_batches) > prev_batch_id
next_token_ids_cpu = prev_batches[prev_batch_id]['next_token_ids_cpu'] next_token_ids_cpu = prev_batches[prev_batch_id]["next_token_ids_cpu"]
next_token_logprobs = prev_batches[prev_batch_id]['next_token_logprobs'] next_token_logprobs = prev_batches[prev_batch_id]["next_token_logprobs"]
request = req.data request = req.data
input_length = req.input_length input_length = req.input_length
prefix_offset = req.prefix_offset prefix_offset = req.prefix_offset
read_offset = req.read_offset read_offset = req.read_offset
do_sample = req_data['do_sample'] do_sample = req_data["do_sample"]
seed = req_data['seed'] seed = req_data["seed"]
stopping_criteria = req.stopping_criteria stopping_criteria = req.stopping_criteria
all_input_ids = req.all_input_ids all_input_ids = req.all_input_ids
next_token_id = next_token_ids_cpu[i] next_token_id = next_token_ids_cpu[i]
next_token_logprob = next_token_logprobs[i] next_token_logprob = next_token_logprobs[i]
top_n_tokens = req_data['top_n_tokens'] top_n_tokens = req_data["top_n_tokens"]
top_token_ids = req_data['top_token_ids'] top_token_ids = req_data["top_token_ids"]
top_token_logprobs = req_data['top_token_logprobs'] top_token_logprobs = req_data["top_token_logprobs"]
grammar_state = req_data['grammar_state'] grammar_state = req_data["grammar_state"]
# Append next token to all tokens # Append next token to all tokens
all_input_ids[input_length] = next_token_id all_input_ids[input_length] = next_token_id
new_input_length = input_length + 1 new_input_length = input_length + 1
# Generated token # Generated token
if is_tokenizer_transparent(self.tokenizer) and len(stopping_criteria.stop_sequence_criterias) == 0: if (
next_token_text = '' is_tokenizer_transparent(self.tokenizer)
and len(stopping_criteria.stop_sequence_criterias) == 0
):
next_token_text = ""
else: else:
next_token_text, prefix_offset, read_offset = self.decode_token( next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids[0:new_input_length, 0], prefix_offset, read_offset all_input_ids[0:new_input_length, 0], prefix_offset, read_offset
@ -1075,7 +1177,11 @@ class CausalLM(Model):
output_text = None output_text = None
else: else:
output_text = self.decode( output_text = self.decode(
all_input_ids[new_input_length - stopping_criteria.current_tokens: new_input_length, 0] all_input_ids[
new_input_length
- stopping_criteria.current_tokens : new_input_length,
0,
]
) )
generated_text = GeneratedText( generated_text = GeneratedText(
output_text, output_text,
@ -1090,7 +1196,7 @@ class CausalLM(Model):
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
# Remove generated token to only have prefill and add nan for first prompt token # Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + next_token_logprobs prefill_logprobs = [float("nan")] + next_token_logprobs
prefill_token_ids = all_input_ids[0: new_input_length - 1] prefill_token_ids = all_input_ids[0 : new_input_length - 1]
prefill_texts = self.tokenizer.batch_decode( prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids, prefill_token_ids,
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
@ -1159,7 +1265,12 @@ class CausalLM(Model):
htorch.core.mark_step() htorch.core.mark_step()
self.step = self.step + 1 self.step = self.step + 1
if self.hb_profiler is not None: if self.hb_profiler is not None:
if self.step > self.profiling_wait_steps + self.profiling_warmup_steps + self.profiling_steps: if (
self.step
> self.profiling_wait_steps
+ self.profiling_warmup_steps
+ self.profiling_steps
):
self.hb_profiler.stop() self.hb_profiler.stop()
else: else:
self.hb_profiler.step() self.hb_profiler.step()
@ -1178,11 +1289,12 @@ class CausalLM(Model):
return self.batch_type.from_pb(batch, self.tokenizer, self.dtype, self.device) return self.batch_type.from_pb(batch, self.tokenizer, self.dtype, self.device)
def warmup(self, request) -> None: def warmup(self, request) -> None:
MAX_TOTAL_TOKENS = request.max_total_tokens MAX_TOTAL_TOKENS = request.max_total_tokens
MAX_BATCH_TOTAL_TOKENS = request.max_batch_total_tokens MAX_BATCH_TOTAL_TOKENS = request.max_batch_total_tokens
batch = self.batch_type.from_pb(request.batch, self.tokenizer, self.dtype, self.device) batch = self.batch_type.from_pb(
request.batch, self.tokenizer, self.dtype, self.device
)
max_prefill_batch_size = batch.input_ids.shape[0] max_prefill_batch_size = batch.input_ids.shape[0]
try: try:
# max prefill batch size warmup # max prefill batch size warmup
@ -1199,14 +1311,21 @@ class CausalLM(Model):
max_input_length = request.max_input_length max_input_length = request.max_input_length
prefill_batch_size_list = [batch for batch in range(PREFILL_BATCH_BUCKET_SIZE, max_prefill_batch_size, PREFILL_BATCH_BUCKET_SIZE)] prefill_batch_size_list = [batch for batch in range(PREFILL_BATCH_BUCKET_SIZE, max_prefill_batch_size, PREFILL_BATCH_BUCKET_SIZE)]
prefill_batch_size_list.append(max_prefill_batch_size) prefill_batch_size_list.append(max_prefill_batch_size)
prefill_seqlen_list = [seq for seq in range(PAD_SEQUENCE_TO_MULTIPLE_OF, max_input_length, PAD_SEQUENCE_TO_MULTIPLE_OF)] prefill_seqlen_list = [
seq
for seq in range(
PAD_SEQUENCE_TO_MULTIPLE_OF,
max_input_length,
PAD_SEQUENCE_TO_MULTIPLE_OF,
)
]
prefill_seqlen_list.append(max_input_length) prefill_seqlen_list.append(max_input_length)
prefill_batch_size_list.sort(reverse=True) prefill_batch_size_list.sort(reverse=True)
prefill_seqlen_list.sort(reverse=True) prefill_seqlen_list.sort(reverse=True)
try: try:
for batch_size in prefill_batch_size_list: for batch_size in prefill_batch_size_list:
for seq_len in prefill_seqlen_list: for seq_len in prefill_seqlen_list:
batch = self.generate_warmup_batch(request, seq_len-1, batch_size) batch = self.generate_warmup_batch(request, seq_len - 1, batch_size)
_, prefill_batch, _ = self.generate_token([batch]) _, prefill_batch, _ = self.generate_token([batch])
except: except:
prefill_batch_size_list.sort() prefill_batch_size_list.sort()
@ -1227,24 +1346,33 @@ class CausalLM(Model):
f"Memory stats: {mem_stats} " f"Memory stats: {mem_stats} "
) )
#warmup decode batch size # warmup decode batch size
max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS) max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS)
max_decode_batch_size = round_up(max_decode_batch_size, BATCH_BUCKET_SIZE) max_decode_batch_size = round_up(max_decode_batch_size, BATCH_BUCKET_SIZE)
decode_batch_size_list = [i for i in range(BATCH_BUCKET_SIZE, max_decode_batch_size, BATCH_BUCKET_SIZE)] decode_batch_size_list = [
i
for i in range(BATCH_BUCKET_SIZE, max_decode_batch_size, BATCH_BUCKET_SIZE)
]
decode_batch_size_list.append(max_decode_batch_size) decode_batch_size_list.append(max_decode_batch_size)
decode_batch_size_list.sort(reverse=True) decode_batch_size_list.sort(reverse=True)
try: try:
for batch_size in decode_batch_size_list: for batch_size in decode_batch_size_list:
batches= [] batches = []
iters = math.floor(batch_size/max_prefill_batch_size) iters = math.floor(batch_size / max_prefill_batch_size)
for i in range(iters): for i in range(iters):
batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, max_prefill_batch_size) batch = self.generate_warmup_batch(
request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, max_prefill_batch_size
)
_, prefill_batch, _ = self.generate_token([batch]) _, prefill_batch, _ = self.generate_token([batch])
batches.append(prefill_batch) batches.append(prefill_batch)
if batch_size % max_prefill_batch_size != 0: if batch_size % max_prefill_batch_size != 0:
batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, batch_size % max_prefill_batch_size) batch = self.generate_warmup_batch(
request,
PAD_SEQUENCE_TO_MULTIPLE_OF - 1,
batch_size % max_prefill_batch_size,
)
_, prefill_batch, _ = self.generate_token([batch]) _, prefill_batch, _ = self.generate_token([batch])
batches.append(prefill_batch) batches.append(prefill_batch)
@ -1254,10 +1382,10 @@ class CausalLM(Model):
batches.clear() batches.clear()
except: except:
raise RuntimeError( raise RuntimeError(
f"Not enough memory to warmup decode batch_sizes({decode_batch_size_list})." f"Not enough memory to warmup decode batch_sizes({decode_batch_size_list})."
f"You need to decrease `--max-batch-total-tokens`" f"You need to decrease `--max-batch-total-tokens`"
) )
decode_batch_size_list.sort() decode_batch_size_list.sort()
MAX_BATCH_TOTAL_TOKENS = MAX_TOTAL_TOKENS * decode_batch_size_list[-1] MAX_BATCH_TOTAL_TOKENS = MAX_TOTAL_TOKENS * decode_batch_size_list[-1]
@ -1268,4 +1396,4 @@ class CausalLM(Model):
f"Memory stats: {mem_stats} " f"Memory stats: {mem_stats} "
) )
return MAX_BATCH_TOTAL_TOKENS return MAX_BATCH_TOTAL_TOKENS