mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-27 04:52:07 +00:00
Complete padding of CausalLMBatch
when there exists batch bucketing (#261)
Signed-off-by: kaixuanliu <kaixuan.liu@intel.com>
This commit is contained in:
parent
fe7594e369
commit
b52164d38a
@ -53,21 +53,26 @@ from text_generation_server.utils.debug import dbg_trace
|
||||
from text_generation_server.utils.speculate import get_speculate
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
MAX_TOTAL_TOKENS = int(os.environ.get('MAX_TOTAL_TOKENS', 2048))
|
||||
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 256))
|
||||
MAX_TOTAL_TOKENS = int(os.environ.get("MAX_TOTAL_TOKENS", 2048))
|
||||
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 256))
|
||||
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
||||
LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1))
|
||||
BATCH_BUCKET_SIZE= int(os.environ.get('BATCH_BUCKET_SIZE', 8))
|
||||
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 2))
|
||||
LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1))
|
||||
BATCH_BUCKET_SIZE = int(os.environ.get("BATCH_BUCKET_SIZE", 8))
|
||||
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get("PREFILL_BATCH_BUCKET_SIZE", 2))
|
||||
|
||||
|
||||
def torch_compile_for_eager(func):
|
||||
if LAZY_MODE == 1:
|
||||
return func
|
||||
return torch.compile(func, backend="hpu_backend", options={"keep_input_mutations": True})
|
||||
return torch.compile(
|
||||
func, backend="hpu_backend", options={"keep_input_mutations": True}
|
||||
)
|
||||
|
||||
|
||||
def round_up(number, k):
|
||||
return (number + k - 1) // k * k
|
||||
|
||||
|
||||
def to_tensor_indices(indices, device):
|
||||
return torch.tensor(indices, dtype=torch.long, device=device)
|
||||
|
||||
@ -96,9 +101,11 @@ def grouped_pad(tensor_groups, dims, values):
|
||||
for tensors, dim, value in zip(tensor_groups, dims, values):
|
||||
padding = MAX_TOTAL_TOKENS - tensors[0].size(dim) if dim is not None else 0
|
||||
if padding > 0:
|
||||
assert dim in [-1, -2], f'Only dims -1 and -2 are supported! {dim}'
|
||||
assert dim in [-1, -2], f"Only dims -1 and -2 are supported! {dim}"
|
||||
pad_shape = (0, 0, 0, padding) if dim == -2 else (0, padding)
|
||||
result = [torch.nn.functional.pad(t, pad_shape, value=value) for t in tensors]
|
||||
result = [
|
||||
torch.nn.functional.pad(t, pad_shape, value=value) for t in tensors
|
||||
]
|
||||
else:
|
||||
result = [t for t in tensors]
|
||||
grouped_result.append(result)
|
||||
@ -117,7 +124,10 @@ def roll(tensor, chunk, dim, merge_graphs):
|
||||
|
||||
|
||||
def grouped_roll(tensor_groups, chunk, dims, merge_graphs):
|
||||
tensor_groups = [[roll(t, chunk, dim, merge_graphs) for t in tensors] for tensors, dim in zip(tensor_groups, dims)]
|
||||
tensor_groups = [
|
||||
[roll(t, chunk, dim, merge_graphs) for t in tensors]
|
||||
for tensors, dim in zip(tensor_groups, dims)
|
||||
]
|
||||
if merge_graphs:
|
||||
htorch.core.mark_step()
|
||||
return tensor_groups
|
||||
@ -167,7 +177,10 @@ def extend_batch(tensors, target_bs, dim):
|
||||
|
||||
|
||||
def grouped_extend_batch(tensor_groups, target_bs, bs_dims):
|
||||
tensor_groups = [extend_batch(tensors, target_bs, dim) for tensors, dim in zip(tensor_groups, bs_dims)]
|
||||
tensor_groups = [
|
||||
extend_batch(tensors, target_bs, dim)
|
||||
for tensors, dim in zip(tensor_groups, bs_dims)
|
||||
]
|
||||
return tensor_groups
|
||||
|
||||
|
||||
@ -220,15 +233,20 @@ class CausalLMRequest:
|
||||
all_input_ids: torch.Tensor
|
||||
|
||||
@classmethod
|
||||
def from_pb(cls, idx: int, data: generate_pb2.Request, tokenizer: PreTrainedTokenizerBase):
|
||||
def from_pb(
|
||||
cls, idx: int, data: generate_pb2.Request, tokenizer: PreTrainedTokenizerBase
|
||||
):
|
||||
return cls(
|
||||
idx=idx,
|
||||
data=data,
|
||||
input_length=None,
|
||||
prefix_offset=None,
|
||||
read_offset=None,
|
||||
stopping_criteria=StoppingCriteria.from_pb(data.stopping_parameters, tokenizer),
|
||||
all_input_ids=None,)
|
||||
stopping_criteria=StoppingCriteria.from_pb(
|
||||
data.stopping_parameters, tokenizer
|
||||
),
|
||||
all_input_ids=None,
|
||||
)
|
||||
|
||||
def update_idx(self, new_idx):
|
||||
prev = self.idx
|
||||
@ -289,7 +307,11 @@ class CausalLMBatch(Batch):
|
||||
# Very simple heuristic to determine whether we should merge tensors
|
||||
# this needs tuning for other models/scenarios
|
||||
small_bs = len(self.past_key_values) > self.batch_size
|
||||
if not self.merged_kv_cache and small_bs and (pad_needed or shift_needed or expand_needed):
|
||||
if (
|
||||
not self.merged_kv_cache
|
||||
and small_bs
|
||||
and (pad_needed or shift_needed or expand_needed)
|
||||
):
|
||||
past_keys, past_values = self.detach_kv_cache()
|
||||
past_keys = merge(past_keys)
|
||||
past_values = merge(past_values)
|
||||
@ -309,7 +331,13 @@ class CausalLMBatch(Batch):
|
||||
seq_dim = -1
|
||||
key_dim = -2 if self.keys_head_dim_last else -1
|
||||
value_dim = -2
|
||||
tensors = [[self.input_ids], [self.attention_mask], [self.position_ids], past_keys, past_values]
|
||||
tensors = [
|
||||
[self.input_ids],
|
||||
[self.attention_mask],
|
||||
[self.position_ids],
|
||||
past_keys,
|
||||
past_values,
|
||||
]
|
||||
# We don't need to align position_ids
|
||||
seq_dims = [seq_dim, seq_dim, None, key_dim, value_dim]
|
||||
bs_dims = [0, 0, 0] + ([1, 1] if self.merged_kv_cache else [0, 0])
|
||||
@ -350,13 +378,17 @@ class CausalLMBatch(Batch):
|
||||
dst_tensors, _, dst_dims = self.get_tensor_groups()
|
||||
free_indices_gen = self.free_indices_generator()
|
||||
for src_b in src_batches:
|
||||
dst_indices = to_tensor_indices(src_b.update_indices(free_indices_gen), self.input_ids.device)
|
||||
dst_indices = to_tensor_indices(
|
||||
src_b.update_indices(free_indices_gen), self.input_ids.device
|
||||
)
|
||||
src_tensors, _, src_dims = src_b.get_tensor_groups()
|
||||
grouped_move(dst_tensors, dst_indices, src_tensors)
|
||||
self.set_tensor_groups(dst_tensors)
|
||||
|
||||
@classmethod
|
||||
def recombine(cls, batches: List["CausalLMBatch"], pad_token_id: int) -> "CausalLMBatch":
|
||||
def recombine(
|
||||
cls, batches: List["CausalLMBatch"], pad_token_id: int
|
||||
) -> "CausalLMBatch":
|
||||
if not all(b.past_key_values is not None for b in batches):
|
||||
raise ValueError("KV cache not allocated! Cannot recombine before prefill!")
|
||||
|
||||
@ -375,31 +407,39 @@ class CausalLMBatch(Batch):
|
||||
# For prefill there is a space allocated only for first token
|
||||
# Need to add padding to the max total tokens before first decode
|
||||
|
||||
moves_needed = [total_requests - len(b) if b.batch_size == new_bs else total_requests for b in batches]
|
||||
moves_needed = [
|
||||
total_requests - len(b) if b.batch_size == new_bs else total_requests
|
||||
for b in batches
|
||||
]
|
||||
dst_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0]
|
||||
reshape = (batches[dst_batch_idx].batch_size < new_bs)
|
||||
reshape = batches[dst_batch_idx].batch_size < new_bs
|
||||
|
||||
# TODO: Add support for changing max seq len, i.e. due to output length bucketing
|
||||
# FIXME: max_seq_len for non optimized code
|
||||
if len(batches) > 1:
|
||||
scenario = 'CONCAT'
|
||||
scenario = "CONCAT"
|
||||
elif reshape:
|
||||
scenario = 'RESHAPE'
|
||||
scenario = "RESHAPE"
|
||||
elif cur_padding[dst_batch_idx] <= 0:
|
||||
scenario = 'SHIFT'
|
||||
offsets = [biggest_single_chunk(b.max_input_length - max_input_length) for b in batches]
|
||||
scenario = "SHIFT"
|
||||
offsets = [
|
||||
biggest_single_chunk(b.max_input_length - max_input_length)
|
||||
for b in batches
|
||||
]
|
||||
max_input_length = max_input_length + offsets[dst_batch_idx]
|
||||
else:
|
||||
# Nothing to do
|
||||
return batches[0]
|
||||
|
||||
dbg_trace(
|
||||
scenario, f'bs:{[b.batch_size for b in batches]}->{new_bs}'
|
||||
f' reqs:{[len(b) for b in batches]}'
|
||||
f' offsets:{offsets}'
|
||||
f' input_lengths:{input_lengths}'
|
||||
f' cur_padding:{cur_padding}'
|
||||
f' dst_batch:{dst_batch_idx}')
|
||||
scenario,
|
||||
f"bs:{[b.batch_size for b in batches]}->{new_bs}"
|
||||
f" reqs:{[len(b) for b in batches]}"
|
||||
f" offsets:{offsets}"
|
||||
f" input_lengths:{input_lengths}"
|
||||
f" cur_padding:{cur_padding}"
|
||||
f" dst_batch:{dst_batch_idx}",
|
||||
)
|
||||
|
||||
grouped_requests = [[req for req in batch.requests] for batch in batches]
|
||||
flat_requests = list(itertools.chain(*grouped_requests))
|
||||
@ -410,10 +450,15 @@ class CausalLMBatch(Batch):
|
||||
batches[i].realign(target_bs, offsets[i], pad_token_id)
|
||||
batches[i].split_kv_cache_if_needed(i == dst_batch_idx)
|
||||
batches[dst_batch_idx].expand_bs(new_bs)
|
||||
batches[dst_batch_idx].move_data([batches[i] for i in range(len(batches)) if i != dst_batch_idx])
|
||||
batches[dst_batch_idx].move_data(
|
||||
[batches[i] for i in range(len(batches)) if i != dst_batch_idx]
|
||||
)
|
||||
|
||||
top_n_tokens = [r.data.top_n_tokens for r in flat_requests]
|
||||
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
|
||||
top_n_tokens.extend([-1] * (new_bs - total_requests))
|
||||
top_n_tokens_tensor = torch.tensor(
|
||||
top_n_tokens, device=device, dtype=torch.int64
|
||||
)
|
||||
|
||||
parameters = [r.data.parameters for r in flat_requests]
|
||||
# append the dummy parameters for dummy requests
|
||||
@ -424,7 +469,9 @@ class CausalLMBatch(Batch):
|
||||
fsm_grammar_states = [0] * batch_size
|
||||
for batch in batches:
|
||||
for i, req in enumerate(batch.requests):
|
||||
fsm_grammar_states[req.idx] = batch.next_token_chooser.fsm_grammar_states[i]
|
||||
fsm_grammar_states[req.idx] = (
|
||||
batch.next_token_chooser.fsm_grammar_states[i]
|
||||
)
|
||||
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||
parameters,
|
||||
@ -465,8 +512,11 @@ class CausalLMBatch(Batch):
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "CausalLMBatch":
|
||||
dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}')
|
||||
requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)]
|
||||
dbg_trace("FROM_PB", f"num_reqs:{len(pb.requests)}")
|
||||
requests = [
|
||||
CausalLMRequest.from_pb(idx, req, tokenizer)
|
||||
for idx, req in enumerate(pb.requests)
|
||||
]
|
||||
inputs = []
|
||||
top_n_tokens = []
|
||||
|
||||
@ -476,10 +526,10 @@ class CausalLMBatch(Batch):
|
||||
inputs.append(concat_text_chunks(r.input_chunks.chunks))
|
||||
top_n_tokens.append(r.top_n_tokens)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
|
||||
|
||||
max_input_length = max_truncation
|
||||
if max_input_length < PAD_SEQUENCE_TO_MULTIPLE_OF:
|
||||
max_input_length = PAD_SEQUENCE_TO_MULTIPLE_OF
|
||||
max_input_length = PAD_SEQUENCE_TO_MULTIPLE_OF
|
||||
max_new_tokens = max(r.stopping_criteria.max_new_tokens for r in requests)
|
||||
|
||||
# TODO: by tokenizing all inputs at once we loose information on actual input lengths
|
||||
@ -501,7 +551,7 @@ class CausalLMBatch(Batch):
|
||||
)
|
||||
|
||||
tokenized_inputs = tokenizer(
|
||||
inputs+dummy_inputs,
|
||||
inputs + dummy_inputs,
|
||||
return_tensors="pt",
|
||||
padding="longest",
|
||||
return_token_type_ids=False,
|
||||
@ -514,7 +564,9 @@ class CausalLMBatch(Batch):
|
||||
bucket_size = max_input_length
|
||||
left_padding = max_input_length - input_len
|
||||
if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0:
|
||||
assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
|
||||
assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, (
|
||||
"PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
|
||||
)
|
||||
rounded_seq_len = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF)
|
||||
if rounded_seq_len <= max_input_length:
|
||||
bucket_size = rounded_seq_len - 1
|
||||
@ -547,7 +599,8 @@ class CausalLMBatch(Batch):
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
|
||||
|
||||
old_bs = len(requests)
|
||||
top_n_tokens.extend([-1] * (new_bs - old_bs))
|
||||
top_n_tokens_tensor = torch.tensor(
|
||||
top_n_tokens, device=device, dtype=torch.int64
|
||||
)
|
||||
@ -568,14 +621,16 @@ class CausalLMBatch(Batch):
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]:
|
||||
dbg_trace('FILTER', f'num_reqs:{len(self.requests)} -> {len(request_ids)}')
|
||||
dbg_trace("FILTER", f"num_reqs:{len(self.requests)} -> {len(request_ids)}")
|
||||
request_ids = set(request_ids)
|
||||
self.requests = [req for req in self.requests if req.data.id in request_ids]
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(cls, batches: List["CausalLMBatch"], pad_token_id: int = 0) -> "CausalLMBatch":
|
||||
def concatenate(
|
||||
cls, batches: List["CausalLMBatch"], pad_token_id: int = 0
|
||||
) -> "CausalLMBatch":
|
||||
return cls.recombine(batches, pad_token_id)
|
||||
|
||||
def __len__(self):
|
||||
@ -618,9 +673,7 @@ class CausalLM(Model):
|
||||
tokenizer_class=AutoTokenizer,
|
||||
config_class=AutoConfig,
|
||||
batch_class=CausalLMBatch,
|
||||
|
||||
):
|
||||
|
||||
if speculator:
|
||||
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
|
||||
|
||||
@ -646,18 +699,14 @@ class CausalLM(Model):
|
||||
htorch.core.hpu_set_env()
|
||||
|
||||
if world_size > 1:
|
||||
model = self.get_deepspeed_model(
|
||||
model_id, dtype, revision
|
||||
)
|
||||
model = self.get_deepspeed_model(model_id, dtype, revision)
|
||||
model = hq_env.prepare_model_for_quantization(model)
|
||||
else:
|
||||
get_repo_root(model_id)
|
||||
|
||||
# Check support for rope scaling
|
||||
model_kwargs = {}
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id
|
||||
)
|
||||
config = AutoConfig.from_pretrained(model_id)
|
||||
if hasattr(config, "rope_scaling"):
|
||||
model_kwargs["rope_scaling"] = self.get_rope_scaling()
|
||||
|
||||
@ -666,26 +715,34 @@ class CausalLM(Model):
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**model_kwargs
|
||||
**model_kwargs,
|
||||
)
|
||||
model = hq_env.prepare_model_for_quantization(model)
|
||||
model = model.eval().to(device)
|
||||
|
||||
self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1
|
||||
self.enable_hpu_graph = (
|
||||
os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1
|
||||
)
|
||||
self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
|
||||
|
||||
if model.config.model_type not in ["gpt_bigcode"]: # gpt_bigcode/starcoderbase-3b skips remove_kv_cache_from_output()
|
||||
if model.config.model_type not in [
|
||||
"gpt_bigcode"
|
||||
]: # gpt_bigcode/starcoderbase-3b skips remove_kv_cache_from_output()
|
||||
model = remove_kv_cache_from_output(model)
|
||||
|
||||
if self.enable_hpu_graph:
|
||||
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
||||
|
||||
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
||||
else:
|
||||
if LAZY_MODE == 0:
|
||||
# It is said that "keep_input_mutations" is safe for inference to be done
|
||||
dbg_trace(
|
||||
"TORCH COMPILE", f'Torch compiling of model')
|
||||
model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True})
|
||||
dbg_trace("TORCH COMPILE", f"Torch compiling of model")
|
||||
model.model = torch.compile(
|
||||
model.model,
|
||||
backend="hpu_backend",
|
||||
options={"keep_input_mutations": True},
|
||||
)
|
||||
|
||||
model = hq_env.setup_quantization(model)
|
||||
|
||||
@ -714,8 +771,14 @@ class CausalLM(Model):
|
||||
"return_dict": True,
|
||||
}
|
||||
|
||||
|
||||
if model.config.model_type in ["llama", "mistral", "starcoder2", "qwen2", "falcon", "gpt_bigcode"]:
|
||||
if model.config.model_type in [
|
||||
"llama",
|
||||
"mistral",
|
||||
"starcoder2",
|
||||
"qwen2",
|
||||
"falcon",
|
||||
"gpt_bigcode",
|
||||
]:
|
||||
if model.config.model_type not in ["falcon", "gpt_bigcode"]:
|
||||
self.kwargs["attn_softmax_bf16"] = True
|
||||
|
||||
@ -740,11 +803,15 @@ class CausalLM(Model):
|
||||
)
|
||||
|
||||
# Create profiler
|
||||
ranks_to_profile = [int(val) for val in os.getenv("PROF_RANKS", "0").split(',')]
|
||||
ranks_to_profile = [int(val) for val in os.getenv("PROF_RANKS", "0").split(",")]
|
||||
record_shapes = os.getenv("PROF_RECORD_SHAPES", "false").lower() == "true"
|
||||
output_dir = os.getenv("PROF_PATH", "/tmp/hpu_profile")
|
||||
self.profiling_warmup_steps = int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in ranks_to_profile else 0
|
||||
self.profiling_steps = int(os.getenv("PROF_STEP", "0")) if rank in ranks_to_profile else 0
|
||||
self.profiling_warmup_steps = (
|
||||
int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in ranks_to_profile else 0
|
||||
)
|
||||
self.profiling_steps = (
|
||||
int(os.getenv("PROF_STEP", "0")) if rank in ranks_to_profile else 0
|
||||
)
|
||||
self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0"))
|
||||
if self.profiling_steps > 0:
|
||||
self.hb_profiler = HabanaProfile(
|
||||
@ -752,7 +819,7 @@ class CausalLM(Model):
|
||||
warmup=self.profiling_warmup_steps,
|
||||
active=self.profiling_steps,
|
||||
output_dir=output_dir,
|
||||
record_shapes=record_shapes
|
||||
record_shapes=record_shapes,
|
||||
)
|
||||
self.hb_profiler.start()
|
||||
else:
|
||||
@ -760,23 +827,20 @@ class CausalLM(Model):
|
||||
self.step = 0
|
||||
|
||||
def get_deepspeed_model(
|
||||
self,
|
||||
model_id: str,
|
||||
dtype: torch.dtype,
|
||||
revision: Optional[str] = None
|
||||
self, model_id: str, dtype: torch.dtype, revision: Optional[str] = None
|
||||
) -> torch.nn.Module:
|
||||
import deepspeed
|
||||
from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu
|
||||
|
||||
world_size, rank, local_rank = initialize_distributed_hpu()
|
||||
model_kwargs = {
|
||||
"revision": revision
|
||||
}
|
||||
model_kwargs = {"revision": revision}
|
||||
|
||||
# Initialize process(es) for DeepSpeed
|
||||
deepspeed.init_distributed(dist_backend="hccl")
|
||||
logger.info(
|
||||
"DeepSpeed is enabled. world_size {} rank {} local_rank {}".format(world_size, rank, local_rank)
|
||||
"DeepSpeed is enabled. world_size {} rank {} local_rank {}".format(
|
||||
world_size, rank, local_rank
|
||||
)
|
||||
)
|
||||
config = AutoConfig.from_pretrained(model_id, **model_kwargs)
|
||||
load_to_meta = model_on_meta(config)
|
||||
@ -794,7 +858,9 @@ class CausalLM(Model):
|
||||
get_repo_root(model_id, local_rank=os.getenv("LOCAL_RANK"))
|
||||
# TODO: revisit placement on CPU when auto-injection is possible
|
||||
with deepspeed.OnDevice(dtype=dtype, device="cpu"):
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, **model_kwargs)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, torch_dtype=dtype, **model_kwargs
|
||||
)
|
||||
model = model.eval()
|
||||
|
||||
# Initialize the model
|
||||
@ -817,16 +883,16 @@ class CausalLM(Model):
|
||||
return None
|
||||
|
||||
rope_factor = float(os.getenv("ROPE_FACTOR", 1.0))
|
||||
return {
|
||||
'type': rope_scaling, 'factor': float(rope_factor)
|
||||
}
|
||||
return {"type": rope_scaling, "factor": float(rope_factor)}
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[CausalLMBatch]:
|
||||
return CausalLMBatch
|
||||
|
||||
def decode(self, generated_ids: List[int]) -> str:
|
||||
return self.tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
return self.tokenizer.decode(
|
||||
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
|
||||
def decode_token(
|
||||
self,
|
||||
@ -835,7 +901,9 @@ class CausalLM(Model):
|
||||
read_offset: int = 0,
|
||||
) -> Tuple[str, int, int]:
|
||||
if is_tokenizer_transparent(self.tokenizer):
|
||||
new_text = self.tokenizer.decode(all_input_ids[read_offset:], skip_special_tokens=False)
|
||||
new_text = self.tokenizer.decode(
|
||||
all_input_ids[read_offset:], skip_special_tokens=False
|
||||
)
|
||||
return new_text, read_offset, len(all_input_ids)
|
||||
else:
|
||||
return super().decode_token(all_input_ids, prefix_offset, read_offset)
|
||||
@ -858,7 +926,7 @@ class CausalLM(Model):
|
||||
}
|
||||
|
||||
# Optimum Habana got "lazy_mode" key-val only supported for llama type of models
|
||||
if self.model.config.model_type == "llama" :
|
||||
if self.model.config.model_type == "llama":
|
||||
kwargs["lazy_mode"] = LAZY_MODE == 1
|
||||
|
||||
if self.has_position_ids:
|
||||
@ -869,7 +937,9 @@ class CausalLM(Model):
|
||||
|
||||
kwargs.update(self.kwargs)
|
||||
|
||||
if past_key_values is not None and self.model.config.model_type not in ["gpt_bigcode"]:
|
||||
if past_key_values is not None and self.model.config.model_type not in [
|
||||
"gpt_bigcode"
|
||||
]:
|
||||
return self.model.forward(**kwargs)
|
||||
else:
|
||||
outputs = self.model.forward(**kwargs)
|
||||
@ -896,18 +966,26 @@ class CausalLM(Model):
|
||||
token_idx_scalar = batch.attention_mask.shape[-1] - 1
|
||||
token_idx = torch.tensor(token_idx_scalar).to(self.device)
|
||||
else:
|
||||
token_idx_scalar = batch.attention_mask.shape[-1] - batch.right_padding
|
||||
token_idx_scalar = (
|
||||
batch.attention_mask.shape[-1] - batch.right_padding
|
||||
)
|
||||
token_idx = torch.tensor(token_idx_scalar).to(self.device)
|
||||
|
||||
# Select next token
|
||||
input_length = batch.input_length
|
||||
if logits.shape[-2] > 1:
|
||||
next_token_ids, next_token_logprobs, logprobs, _, _ = batch.next_token_chooser(
|
||||
batch.input_ids, logits[:, input_length - 1: input_length, :].squeeze(-2), self.speculate
|
||||
next_token_ids, next_token_logprobs, logprobs, _, _ = (
|
||||
batch.next_token_chooser(
|
||||
batch.input_ids,
|
||||
logits[:, input_length - 1 : input_length, :].squeeze(-2),
|
||||
self.speculate,
|
||||
)
|
||||
)
|
||||
else:
|
||||
next_token_ids, next_token_logprobs, logprobs, _, _ = batch.next_token_chooser(
|
||||
batch.input_ids, logits.squeeze(-2), self.speculate
|
||||
next_token_ids, next_token_logprobs, logprobs, _, _ = (
|
||||
batch.next_token_chooser(
|
||||
batch.input_ids, logits.squeeze(-2), self.speculate
|
||||
)
|
||||
)
|
||||
# Speculation is not active for causal
|
||||
accepted_ids = torch.ones_like(batch.input_ids)[:, 0]
|
||||
@ -918,24 +996,29 @@ class CausalLM(Model):
|
||||
accepted_ids,
|
||||
)
|
||||
|
||||
prev_batches.append({
|
||||
'next_token_ids': next_token_ids,
|
||||
'next_token_logprobs': next_token_logprobs,
|
||||
})
|
||||
prev_batches.append(
|
||||
{
|
||||
"next_token_ids": next_token_ids,
|
||||
"next_token_logprobs": next_token_logprobs,
|
||||
}
|
||||
)
|
||||
|
||||
for req_idx, req in enumerate(batch.requests):
|
||||
requests_to_generate.append({
|
||||
'req': req,
|
||||
'prev_req_idx': req.idx,
|
||||
'batch_id': batch_id,
|
||||
'seed': batch.next_token_chooser.seeds[req_idx],
|
||||
'do_sample': batch.next_token_chooser.do_sample[req_idx],
|
||||
'top_n_tokens': batch.top_n_tokens[req_idx],
|
||||
'top_token_ids': batch_top_token_ids[req_idx],
|
||||
'top_token_logprobs': batch_top_token_logprobs[req_idx],
|
||||
'grammar_state': batch.next_token_chooser.fsm_grammar_states[req.idx],
|
||||
|
||||
})
|
||||
requests_to_generate.append(
|
||||
{
|
||||
"req": req,
|
||||
"prev_req_idx": req.idx,
|
||||
"batch_id": batch_id,
|
||||
"seed": batch.next_token_chooser.seeds[req_idx],
|
||||
"do_sample": batch.next_token_chooser.do_sample[req_idx],
|
||||
"top_n_tokens": batch.top_n_tokens[req_idx],
|
||||
"top_token_ids": batch_top_token_ids[req_idx],
|
||||
"top_token_logprobs": batch_top_token_logprobs[req_idx],
|
||||
"grammar_state": batch.next_token_chooser.fsm_grammar_states[
|
||||
req.idx
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
htorch.core.mark_step()
|
||||
|
||||
@ -950,7 +1033,9 @@ class CausalLM(Model):
|
||||
|
||||
# Update position_ids
|
||||
if prefill:
|
||||
batch.position_ids = torch.index_select(batch.position_ids, 1, token_idx - 1) + 1
|
||||
batch.position_ids = (
|
||||
torch.index_select(batch.position_ids, 1, token_idx - 1) + 1
|
||||
)
|
||||
else:
|
||||
batch.position_ids += 1
|
||||
# Update past key values
|
||||
@ -971,13 +1056,19 @@ class CausalLM(Model):
|
||||
if not prefill:
|
||||
batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id)
|
||||
|
||||
scenario = 'PREFILL' if prefill else 'GENERATE'
|
||||
if self.enable_hpu_graph and self.limit_hpu_graph and round_up(batch.batch_size, BATCH_BUCKET_SIZE) != self.prev_bs:
|
||||
scenario = "PREFILL" if prefill else "GENERATE"
|
||||
if (
|
||||
self.enable_hpu_graph
|
||||
and self.limit_hpu_graph
|
||||
and round_up(batch.batch_size, BATCH_BUCKET_SIZE) != self.prev_bs
|
||||
):
|
||||
self.model.clear_cache()
|
||||
self.prev_bs = round_up(batch.batch_size, BATCH_BUCKET_SIZE)
|
||||
dbg_trace(
|
||||
scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}')
|
||||
assert batch.right_padding > 0, 'No more room for next token!'
|
||||
scenario,
|
||||
f"bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}",
|
||||
)
|
||||
assert batch.right_padding > 0, "No more room for next token!"
|
||||
|
||||
# Execute batch
|
||||
if prefill:
|
||||
@ -989,14 +1080,18 @@ class CausalLM(Model):
|
||||
batch.position_ids,
|
||||
token_idx,
|
||||
batch.past_key_values,
|
||||
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None,
|
||||
bypass_hpu_graph=prefill and self.limit_hpu_graph
|
||||
if self.enable_hpu_graph
|
||||
else None,
|
||||
)
|
||||
elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]):
|
||||
# Don't schedule next forward if max_new_tokens for all requests equals 1
|
||||
# - we've already generated the first and only needed token in the prefill phase
|
||||
pass
|
||||
else:
|
||||
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
|
||||
token_idx = torch.tensor(
|
||||
batch.attention_mask.shape[-1] - batch.right_padding
|
||||
).to(self.device)
|
||||
input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1)
|
||||
logits = self.forward(
|
||||
input_ids,
|
||||
@ -1004,7 +1099,9 @@ class CausalLM(Model):
|
||||
batch.position_ids,
|
||||
token_idx,
|
||||
batch.past_key_values,
|
||||
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None,
|
||||
bypass_hpu_graph=prefill and self.limit_hpu_graph
|
||||
if self.enable_hpu_graph
|
||||
else None,
|
||||
)
|
||||
if self.model.config.model_type in ["gpt_bigcode"]:
|
||||
batch.logits, batch.past = logits
|
||||
@ -1018,40 +1115,45 @@ class CausalLM(Model):
|
||||
# Stage 3. Finish and return previous generations
|
||||
stopped = len(requests_to_generate) > 0
|
||||
for prev_batch in prev_batches:
|
||||
prev_batch['next_token_logprobs'] = prev_batch['next_token_logprobs'].tolist()
|
||||
prev_batch['next_token_ids_cpu'] = prev_batch['next_token_ids'].cpu()
|
||||
prev_batch["next_token_logprobs"] = prev_batch[
|
||||
"next_token_logprobs"
|
||||
].tolist()
|
||||
prev_batch["next_token_ids_cpu"] = prev_batch["next_token_ids"].cpu()
|
||||
htorch.core.mark_step()
|
||||
|
||||
for req_data in requests_to_generate:
|
||||
req = req_data['req']
|
||||
i = req_data['prev_req_idx']
|
||||
prev_batch_id = req_data['batch_id']
|
||||
req = req_data["req"]
|
||||
i = req_data["prev_req_idx"]
|
||||
prev_batch_id = req_data["batch_id"]
|
||||
assert len(prev_batches) > prev_batch_id
|
||||
next_token_ids_cpu = prev_batches[prev_batch_id]['next_token_ids_cpu']
|
||||
next_token_logprobs = prev_batches[prev_batch_id]['next_token_logprobs']
|
||||
next_token_ids_cpu = prev_batches[prev_batch_id]["next_token_ids_cpu"]
|
||||
next_token_logprobs = prev_batches[prev_batch_id]["next_token_logprobs"]
|
||||
|
||||
request = req.data
|
||||
input_length = req.input_length
|
||||
prefix_offset = req.prefix_offset
|
||||
read_offset = req.read_offset
|
||||
do_sample = req_data['do_sample']
|
||||
seed = req_data['seed']
|
||||
do_sample = req_data["do_sample"]
|
||||
seed = req_data["seed"]
|
||||
stopping_criteria = req.stopping_criteria
|
||||
all_input_ids = req.all_input_ids
|
||||
next_token_id = next_token_ids_cpu[i]
|
||||
next_token_logprob = next_token_logprobs[i]
|
||||
top_n_tokens = req_data['top_n_tokens']
|
||||
top_token_ids = req_data['top_token_ids']
|
||||
top_token_logprobs = req_data['top_token_logprobs']
|
||||
grammar_state = req_data['grammar_state']
|
||||
top_n_tokens = req_data["top_n_tokens"]
|
||||
top_token_ids = req_data["top_token_ids"]
|
||||
top_token_logprobs = req_data["top_token_logprobs"]
|
||||
grammar_state = req_data["grammar_state"]
|
||||
|
||||
# Append next token to all tokens
|
||||
all_input_ids[input_length] = next_token_id
|
||||
new_input_length = input_length + 1
|
||||
|
||||
# Generated token
|
||||
if is_tokenizer_transparent(self.tokenizer) and len(stopping_criteria.stop_sequence_criterias) == 0:
|
||||
next_token_text = ''
|
||||
if (
|
||||
is_tokenizer_transparent(self.tokenizer)
|
||||
and len(stopping_criteria.stop_sequence_criterias) == 0
|
||||
):
|
||||
next_token_text = ""
|
||||
else:
|
||||
next_token_text, prefix_offset, read_offset = self.decode_token(
|
||||
all_input_ids[0:new_input_length, 0], prefix_offset, read_offset
|
||||
@ -1075,7 +1177,11 @@ class CausalLM(Model):
|
||||
output_text = None
|
||||
else:
|
||||
output_text = self.decode(
|
||||
all_input_ids[new_input_length - stopping_criteria.current_tokens: new_input_length, 0]
|
||||
all_input_ids[
|
||||
new_input_length
|
||||
- stopping_criteria.current_tokens : new_input_length,
|
||||
0,
|
||||
]
|
||||
)
|
||||
generated_text = GeneratedText(
|
||||
output_text,
|
||||
@ -1090,7 +1196,7 @@ class CausalLM(Model):
|
||||
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
|
||||
# Remove generated token to only have prefill and add nan for first prompt token
|
||||
prefill_logprobs = [float("nan")] + next_token_logprobs
|
||||
prefill_token_ids = all_input_ids[0: new_input_length - 1]
|
||||
prefill_token_ids = all_input_ids[0 : new_input_length - 1]
|
||||
prefill_texts = self.tokenizer.batch_decode(
|
||||
prefill_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
@ -1159,7 +1265,12 @@ class CausalLM(Model):
|
||||
htorch.core.mark_step()
|
||||
self.step = self.step + 1
|
||||
if self.hb_profiler is not None:
|
||||
if self.step > self.profiling_wait_steps + self.profiling_warmup_steps + self.profiling_steps:
|
||||
if (
|
||||
self.step
|
||||
> self.profiling_wait_steps
|
||||
+ self.profiling_warmup_steps
|
||||
+ self.profiling_steps
|
||||
):
|
||||
self.hb_profiler.stop()
|
||||
else:
|
||||
self.hb_profiler.step()
|
||||
@ -1178,11 +1289,12 @@ class CausalLM(Model):
|
||||
|
||||
return self.batch_type.from_pb(batch, self.tokenizer, self.dtype, self.device)
|
||||
|
||||
|
||||
def warmup(self, request) -> None:
|
||||
MAX_TOTAL_TOKENS = request.max_total_tokens
|
||||
MAX_BATCH_TOTAL_TOKENS = request.max_batch_total_tokens
|
||||
batch = self.batch_type.from_pb(request.batch, self.tokenizer, self.dtype, self.device)
|
||||
batch = self.batch_type.from_pb(
|
||||
request.batch, self.tokenizer, self.dtype, self.device
|
||||
)
|
||||
max_prefill_batch_size = batch.input_ids.shape[0]
|
||||
try:
|
||||
# max prefill batch size warmup
|
||||
@ -1199,14 +1311,21 @@ class CausalLM(Model):
|
||||
max_input_length = request.max_input_length
|
||||
prefill_batch_size_list = [batch for batch in range(PREFILL_BATCH_BUCKET_SIZE, max_prefill_batch_size, PREFILL_BATCH_BUCKET_SIZE)]
|
||||
prefill_batch_size_list.append(max_prefill_batch_size)
|
||||
prefill_seqlen_list = [seq for seq in range(PAD_SEQUENCE_TO_MULTIPLE_OF, max_input_length, PAD_SEQUENCE_TO_MULTIPLE_OF)]
|
||||
prefill_seqlen_list = [
|
||||
seq
|
||||
for seq in range(
|
||||
PAD_SEQUENCE_TO_MULTIPLE_OF,
|
||||
max_input_length,
|
||||
PAD_SEQUENCE_TO_MULTIPLE_OF,
|
||||
)
|
||||
]
|
||||
prefill_seqlen_list.append(max_input_length)
|
||||
prefill_batch_size_list.sort(reverse=True)
|
||||
prefill_seqlen_list.sort(reverse=True)
|
||||
try:
|
||||
for batch_size in prefill_batch_size_list:
|
||||
for seq_len in prefill_seqlen_list:
|
||||
batch = self.generate_warmup_batch(request, seq_len-1, batch_size)
|
||||
batch = self.generate_warmup_batch(request, seq_len - 1, batch_size)
|
||||
_, prefill_batch, _ = self.generate_token([batch])
|
||||
except:
|
||||
prefill_batch_size_list.sort()
|
||||
@ -1227,24 +1346,33 @@ class CausalLM(Model):
|
||||
f"Memory stats: {mem_stats} "
|
||||
)
|
||||
|
||||
#warmup decode batch size
|
||||
# warmup decode batch size
|
||||
max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS)
|
||||
max_decode_batch_size = round_up(max_decode_batch_size, BATCH_BUCKET_SIZE)
|
||||
decode_batch_size_list = [i for i in range(BATCH_BUCKET_SIZE, max_decode_batch_size, BATCH_BUCKET_SIZE)]
|
||||
decode_batch_size_list = [
|
||||
i
|
||||
for i in range(BATCH_BUCKET_SIZE, max_decode_batch_size, BATCH_BUCKET_SIZE)
|
||||
]
|
||||
decode_batch_size_list.append(max_decode_batch_size)
|
||||
decode_batch_size_list.sort(reverse=True)
|
||||
|
||||
try:
|
||||
for batch_size in decode_batch_size_list:
|
||||
batches= []
|
||||
iters = math.floor(batch_size/max_prefill_batch_size)
|
||||
batches = []
|
||||
iters = math.floor(batch_size / max_prefill_batch_size)
|
||||
for i in range(iters):
|
||||
batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, max_prefill_batch_size)
|
||||
batch = self.generate_warmup_batch(
|
||||
request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, max_prefill_batch_size
|
||||
)
|
||||
_, prefill_batch, _ = self.generate_token([batch])
|
||||
batches.append(prefill_batch)
|
||||
|
||||
if batch_size % max_prefill_batch_size != 0:
|
||||
batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, batch_size % max_prefill_batch_size)
|
||||
batch = self.generate_warmup_batch(
|
||||
request,
|
||||
PAD_SEQUENCE_TO_MULTIPLE_OF - 1,
|
||||
batch_size % max_prefill_batch_size,
|
||||
)
|
||||
_, prefill_batch, _ = self.generate_token([batch])
|
||||
batches.append(prefill_batch)
|
||||
|
||||
@ -1254,10 +1382,10 @@ class CausalLM(Model):
|
||||
batches.clear()
|
||||
|
||||
except:
|
||||
raise RuntimeError(
|
||||
f"Not enough memory to warmup decode batch_sizes({decode_batch_size_list})."
|
||||
f"You need to decrease `--max-batch-total-tokens`"
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Not enough memory to warmup decode batch_sizes({decode_batch_size_list})."
|
||||
f"You need to decrease `--max-batch-total-tokens`"
|
||||
)
|
||||
|
||||
decode_batch_size_list.sort()
|
||||
MAX_BATCH_TOTAL_TOKENS = MAX_TOTAL_TOKENS * decode_batch_size_list[-1]
|
||||
@ -1268,4 +1396,4 @@ class CausalLM(Model):
|
||||
f"Memory stats: {mem_stats} "
|
||||
)
|
||||
|
||||
return MAX_BATCH_TOTAL_TOKENS
|
||||
return MAX_BATCH_TOTAL_TOKENS
|
||||
|
Loading…
Reference in New Issue
Block a user