mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-27 13:02:12 +00:00
Complete padding of CausalLMBatch
when there exists batch bucketing (#261)
Signed-off-by: kaixuanliu <kaixuan.liu@intel.com>
This commit is contained in:
parent
fe7594e369
commit
b52164d38a
@ -53,21 +53,26 @@ from text_generation_server.utils.debug import dbg_trace
|
|||||||
from text_generation_server.utils.speculate import get_speculate
|
from text_generation_server.utils.speculate import get_speculate
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
MAX_TOTAL_TOKENS = int(os.environ.get('MAX_TOTAL_TOKENS', 2048))
|
MAX_TOTAL_TOKENS = int(os.environ.get("MAX_TOTAL_TOKENS", 2048))
|
||||||
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 256))
|
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 256))
|
||||||
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
||||||
LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1))
|
LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1))
|
||||||
BATCH_BUCKET_SIZE= int(os.environ.get('BATCH_BUCKET_SIZE', 8))
|
BATCH_BUCKET_SIZE = int(os.environ.get("BATCH_BUCKET_SIZE", 8))
|
||||||
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 2))
|
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get("PREFILL_BATCH_BUCKET_SIZE", 2))
|
||||||
|
|
||||||
|
|
||||||
def torch_compile_for_eager(func):
|
def torch_compile_for_eager(func):
|
||||||
if LAZY_MODE == 1:
|
if LAZY_MODE == 1:
|
||||||
return func
|
return func
|
||||||
return torch.compile(func, backend="hpu_backend", options={"keep_input_mutations": True})
|
return torch.compile(
|
||||||
|
func, backend="hpu_backend", options={"keep_input_mutations": True}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def round_up(number, k):
|
def round_up(number, k):
|
||||||
return (number + k - 1) // k * k
|
return (number + k - 1) // k * k
|
||||||
|
|
||||||
|
|
||||||
def to_tensor_indices(indices, device):
|
def to_tensor_indices(indices, device):
|
||||||
return torch.tensor(indices, dtype=torch.long, device=device)
|
return torch.tensor(indices, dtype=torch.long, device=device)
|
||||||
|
|
||||||
@ -96,9 +101,11 @@ def grouped_pad(tensor_groups, dims, values):
|
|||||||
for tensors, dim, value in zip(tensor_groups, dims, values):
|
for tensors, dim, value in zip(tensor_groups, dims, values):
|
||||||
padding = MAX_TOTAL_TOKENS - tensors[0].size(dim) if dim is not None else 0
|
padding = MAX_TOTAL_TOKENS - tensors[0].size(dim) if dim is not None else 0
|
||||||
if padding > 0:
|
if padding > 0:
|
||||||
assert dim in [-1, -2], f'Only dims -1 and -2 are supported! {dim}'
|
assert dim in [-1, -2], f"Only dims -1 and -2 are supported! {dim}"
|
||||||
pad_shape = (0, 0, 0, padding) if dim == -2 else (0, padding)
|
pad_shape = (0, 0, 0, padding) if dim == -2 else (0, padding)
|
||||||
result = [torch.nn.functional.pad(t, pad_shape, value=value) for t in tensors]
|
result = [
|
||||||
|
torch.nn.functional.pad(t, pad_shape, value=value) for t in tensors
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
result = [t for t in tensors]
|
result = [t for t in tensors]
|
||||||
grouped_result.append(result)
|
grouped_result.append(result)
|
||||||
@ -117,7 +124,10 @@ def roll(tensor, chunk, dim, merge_graphs):
|
|||||||
|
|
||||||
|
|
||||||
def grouped_roll(tensor_groups, chunk, dims, merge_graphs):
|
def grouped_roll(tensor_groups, chunk, dims, merge_graphs):
|
||||||
tensor_groups = [[roll(t, chunk, dim, merge_graphs) for t in tensors] for tensors, dim in zip(tensor_groups, dims)]
|
tensor_groups = [
|
||||||
|
[roll(t, chunk, dim, merge_graphs) for t in tensors]
|
||||||
|
for tensors, dim in zip(tensor_groups, dims)
|
||||||
|
]
|
||||||
if merge_graphs:
|
if merge_graphs:
|
||||||
htorch.core.mark_step()
|
htorch.core.mark_step()
|
||||||
return tensor_groups
|
return tensor_groups
|
||||||
@ -167,7 +177,10 @@ def extend_batch(tensors, target_bs, dim):
|
|||||||
|
|
||||||
|
|
||||||
def grouped_extend_batch(tensor_groups, target_bs, bs_dims):
|
def grouped_extend_batch(tensor_groups, target_bs, bs_dims):
|
||||||
tensor_groups = [extend_batch(tensors, target_bs, dim) for tensors, dim in zip(tensor_groups, bs_dims)]
|
tensor_groups = [
|
||||||
|
extend_batch(tensors, target_bs, dim)
|
||||||
|
for tensors, dim in zip(tensor_groups, bs_dims)
|
||||||
|
]
|
||||||
return tensor_groups
|
return tensor_groups
|
||||||
|
|
||||||
|
|
||||||
@ -220,15 +233,20 @@ class CausalLMRequest:
|
|||||||
all_input_ids: torch.Tensor
|
all_input_ids: torch.Tensor
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(cls, idx: int, data: generate_pb2.Request, tokenizer: PreTrainedTokenizerBase):
|
def from_pb(
|
||||||
|
cls, idx: int, data: generate_pb2.Request, tokenizer: PreTrainedTokenizerBase
|
||||||
|
):
|
||||||
return cls(
|
return cls(
|
||||||
idx=idx,
|
idx=idx,
|
||||||
data=data,
|
data=data,
|
||||||
input_length=None,
|
input_length=None,
|
||||||
prefix_offset=None,
|
prefix_offset=None,
|
||||||
read_offset=None,
|
read_offset=None,
|
||||||
stopping_criteria=StoppingCriteria.from_pb(data.stopping_parameters, tokenizer),
|
stopping_criteria=StoppingCriteria.from_pb(
|
||||||
all_input_ids=None,)
|
data.stopping_parameters, tokenizer
|
||||||
|
),
|
||||||
|
all_input_ids=None,
|
||||||
|
)
|
||||||
|
|
||||||
def update_idx(self, new_idx):
|
def update_idx(self, new_idx):
|
||||||
prev = self.idx
|
prev = self.idx
|
||||||
@ -289,7 +307,11 @@ class CausalLMBatch(Batch):
|
|||||||
# Very simple heuristic to determine whether we should merge tensors
|
# Very simple heuristic to determine whether we should merge tensors
|
||||||
# this needs tuning for other models/scenarios
|
# this needs tuning for other models/scenarios
|
||||||
small_bs = len(self.past_key_values) > self.batch_size
|
small_bs = len(self.past_key_values) > self.batch_size
|
||||||
if not self.merged_kv_cache and small_bs and (pad_needed or shift_needed or expand_needed):
|
if (
|
||||||
|
not self.merged_kv_cache
|
||||||
|
and small_bs
|
||||||
|
and (pad_needed or shift_needed or expand_needed)
|
||||||
|
):
|
||||||
past_keys, past_values = self.detach_kv_cache()
|
past_keys, past_values = self.detach_kv_cache()
|
||||||
past_keys = merge(past_keys)
|
past_keys = merge(past_keys)
|
||||||
past_values = merge(past_values)
|
past_values = merge(past_values)
|
||||||
@ -309,7 +331,13 @@ class CausalLMBatch(Batch):
|
|||||||
seq_dim = -1
|
seq_dim = -1
|
||||||
key_dim = -2 if self.keys_head_dim_last else -1
|
key_dim = -2 if self.keys_head_dim_last else -1
|
||||||
value_dim = -2
|
value_dim = -2
|
||||||
tensors = [[self.input_ids], [self.attention_mask], [self.position_ids], past_keys, past_values]
|
tensors = [
|
||||||
|
[self.input_ids],
|
||||||
|
[self.attention_mask],
|
||||||
|
[self.position_ids],
|
||||||
|
past_keys,
|
||||||
|
past_values,
|
||||||
|
]
|
||||||
# We don't need to align position_ids
|
# We don't need to align position_ids
|
||||||
seq_dims = [seq_dim, seq_dim, None, key_dim, value_dim]
|
seq_dims = [seq_dim, seq_dim, None, key_dim, value_dim]
|
||||||
bs_dims = [0, 0, 0] + ([1, 1] if self.merged_kv_cache else [0, 0])
|
bs_dims = [0, 0, 0] + ([1, 1] if self.merged_kv_cache else [0, 0])
|
||||||
@ -350,13 +378,17 @@ class CausalLMBatch(Batch):
|
|||||||
dst_tensors, _, dst_dims = self.get_tensor_groups()
|
dst_tensors, _, dst_dims = self.get_tensor_groups()
|
||||||
free_indices_gen = self.free_indices_generator()
|
free_indices_gen = self.free_indices_generator()
|
||||||
for src_b in src_batches:
|
for src_b in src_batches:
|
||||||
dst_indices = to_tensor_indices(src_b.update_indices(free_indices_gen), self.input_ids.device)
|
dst_indices = to_tensor_indices(
|
||||||
|
src_b.update_indices(free_indices_gen), self.input_ids.device
|
||||||
|
)
|
||||||
src_tensors, _, src_dims = src_b.get_tensor_groups()
|
src_tensors, _, src_dims = src_b.get_tensor_groups()
|
||||||
grouped_move(dst_tensors, dst_indices, src_tensors)
|
grouped_move(dst_tensors, dst_indices, src_tensors)
|
||||||
self.set_tensor_groups(dst_tensors)
|
self.set_tensor_groups(dst_tensors)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def recombine(cls, batches: List["CausalLMBatch"], pad_token_id: int) -> "CausalLMBatch":
|
def recombine(
|
||||||
|
cls, batches: List["CausalLMBatch"], pad_token_id: int
|
||||||
|
) -> "CausalLMBatch":
|
||||||
if not all(b.past_key_values is not None for b in batches):
|
if not all(b.past_key_values is not None for b in batches):
|
||||||
raise ValueError("KV cache not allocated! Cannot recombine before prefill!")
|
raise ValueError("KV cache not allocated! Cannot recombine before prefill!")
|
||||||
|
|
||||||
@ -375,31 +407,39 @@ class CausalLMBatch(Batch):
|
|||||||
# For prefill there is a space allocated only for first token
|
# For prefill there is a space allocated only for first token
|
||||||
# Need to add padding to the max total tokens before first decode
|
# Need to add padding to the max total tokens before first decode
|
||||||
|
|
||||||
moves_needed = [total_requests - len(b) if b.batch_size == new_bs else total_requests for b in batches]
|
moves_needed = [
|
||||||
|
total_requests - len(b) if b.batch_size == new_bs else total_requests
|
||||||
|
for b in batches
|
||||||
|
]
|
||||||
dst_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0]
|
dst_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0]
|
||||||
reshape = (batches[dst_batch_idx].batch_size < new_bs)
|
reshape = batches[dst_batch_idx].batch_size < new_bs
|
||||||
|
|
||||||
# TODO: Add support for changing max seq len, i.e. due to output length bucketing
|
# TODO: Add support for changing max seq len, i.e. due to output length bucketing
|
||||||
# FIXME: max_seq_len for non optimized code
|
# FIXME: max_seq_len for non optimized code
|
||||||
if len(batches) > 1:
|
if len(batches) > 1:
|
||||||
scenario = 'CONCAT'
|
scenario = "CONCAT"
|
||||||
elif reshape:
|
elif reshape:
|
||||||
scenario = 'RESHAPE'
|
scenario = "RESHAPE"
|
||||||
elif cur_padding[dst_batch_idx] <= 0:
|
elif cur_padding[dst_batch_idx] <= 0:
|
||||||
scenario = 'SHIFT'
|
scenario = "SHIFT"
|
||||||
offsets = [biggest_single_chunk(b.max_input_length - max_input_length) for b in batches]
|
offsets = [
|
||||||
|
biggest_single_chunk(b.max_input_length - max_input_length)
|
||||||
|
for b in batches
|
||||||
|
]
|
||||||
max_input_length = max_input_length + offsets[dst_batch_idx]
|
max_input_length = max_input_length + offsets[dst_batch_idx]
|
||||||
else:
|
else:
|
||||||
# Nothing to do
|
# Nothing to do
|
||||||
return batches[0]
|
return batches[0]
|
||||||
|
|
||||||
dbg_trace(
|
dbg_trace(
|
||||||
scenario, f'bs:{[b.batch_size for b in batches]}->{new_bs}'
|
scenario,
|
||||||
f' reqs:{[len(b) for b in batches]}'
|
f"bs:{[b.batch_size for b in batches]}->{new_bs}"
|
||||||
f' offsets:{offsets}'
|
f" reqs:{[len(b) for b in batches]}"
|
||||||
f' input_lengths:{input_lengths}'
|
f" offsets:{offsets}"
|
||||||
f' cur_padding:{cur_padding}'
|
f" input_lengths:{input_lengths}"
|
||||||
f' dst_batch:{dst_batch_idx}')
|
f" cur_padding:{cur_padding}"
|
||||||
|
f" dst_batch:{dst_batch_idx}",
|
||||||
|
)
|
||||||
|
|
||||||
grouped_requests = [[req for req in batch.requests] for batch in batches]
|
grouped_requests = [[req for req in batch.requests] for batch in batches]
|
||||||
flat_requests = list(itertools.chain(*grouped_requests))
|
flat_requests = list(itertools.chain(*grouped_requests))
|
||||||
@ -410,10 +450,15 @@ class CausalLMBatch(Batch):
|
|||||||
batches[i].realign(target_bs, offsets[i], pad_token_id)
|
batches[i].realign(target_bs, offsets[i], pad_token_id)
|
||||||
batches[i].split_kv_cache_if_needed(i == dst_batch_idx)
|
batches[i].split_kv_cache_if_needed(i == dst_batch_idx)
|
||||||
batches[dst_batch_idx].expand_bs(new_bs)
|
batches[dst_batch_idx].expand_bs(new_bs)
|
||||||
batches[dst_batch_idx].move_data([batches[i] for i in range(len(batches)) if i != dst_batch_idx])
|
batches[dst_batch_idx].move_data(
|
||||||
|
[batches[i] for i in range(len(batches)) if i != dst_batch_idx]
|
||||||
|
)
|
||||||
|
|
||||||
top_n_tokens = [r.data.top_n_tokens for r in flat_requests]
|
top_n_tokens = [r.data.top_n_tokens for r in flat_requests]
|
||||||
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
|
top_n_tokens.extend([-1] * (new_bs - total_requests))
|
||||||
|
top_n_tokens_tensor = torch.tensor(
|
||||||
|
top_n_tokens, device=device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
|
||||||
parameters = [r.data.parameters for r in flat_requests]
|
parameters = [r.data.parameters for r in flat_requests]
|
||||||
# append the dummy parameters for dummy requests
|
# append the dummy parameters for dummy requests
|
||||||
@ -424,7 +469,9 @@ class CausalLMBatch(Batch):
|
|||||||
fsm_grammar_states = [0] * batch_size
|
fsm_grammar_states = [0] * batch_size
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
for i, req in enumerate(batch.requests):
|
for i, req in enumerate(batch.requests):
|
||||||
fsm_grammar_states[req.idx] = batch.next_token_chooser.fsm_grammar_states[i]
|
fsm_grammar_states[req.idx] = (
|
||||||
|
batch.next_token_chooser.fsm_grammar_states[i]
|
||||||
|
)
|
||||||
|
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
parameters,
|
parameters,
|
||||||
@ -465,8 +512,11 @@ class CausalLMBatch(Batch):
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "CausalLMBatch":
|
) -> "CausalLMBatch":
|
||||||
dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}')
|
dbg_trace("FROM_PB", f"num_reqs:{len(pb.requests)}")
|
||||||
requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)]
|
requests = [
|
||||||
|
CausalLMRequest.from_pb(idx, req, tokenizer)
|
||||||
|
for idx, req in enumerate(pb.requests)
|
||||||
|
]
|
||||||
inputs = []
|
inputs = []
|
||||||
top_n_tokens = []
|
top_n_tokens = []
|
||||||
|
|
||||||
@ -476,10 +526,10 @@ class CausalLMBatch(Batch):
|
|||||||
inputs.append(concat_text_chunks(r.input_chunks.chunks))
|
inputs.append(concat_text_chunks(r.input_chunks.chunks))
|
||||||
top_n_tokens.append(r.top_n_tokens)
|
top_n_tokens.append(r.top_n_tokens)
|
||||||
max_truncation = max(max_truncation, r.truncate)
|
max_truncation = max(max_truncation, r.truncate)
|
||||||
|
|
||||||
max_input_length = max_truncation
|
max_input_length = max_truncation
|
||||||
if max_input_length < PAD_SEQUENCE_TO_MULTIPLE_OF:
|
if max_input_length < PAD_SEQUENCE_TO_MULTIPLE_OF:
|
||||||
max_input_length = PAD_SEQUENCE_TO_MULTIPLE_OF
|
max_input_length = PAD_SEQUENCE_TO_MULTIPLE_OF
|
||||||
max_new_tokens = max(r.stopping_criteria.max_new_tokens for r in requests)
|
max_new_tokens = max(r.stopping_criteria.max_new_tokens for r in requests)
|
||||||
|
|
||||||
# TODO: by tokenizing all inputs at once we loose information on actual input lengths
|
# TODO: by tokenizing all inputs at once we loose information on actual input lengths
|
||||||
@ -501,7 +551,7 @@ class CausalLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
tokenized_inputs = tokenizer(
|
tokenized_inputs = tokenizer(
|
||||||
inputs+dummy_inputs,
|
inputs + dummy_inputs,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
padding="longest",
|
padding="longest",
|
||||||
return_token_type_ids=False,
|
return_token_type_ids=False,
|
||||||
@ -514,7 +564,9 @@ class CausalLMBatch(Batch):
|
|||||||
bucket_size = max_input_length
|
bucket_size = max_input_length
|
||||||
left_padding = max_input_length - input_len
|
left_padding = max_input_length - input_len
|
||||||
if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0:
|
if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0:
|
||||||
assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
|
assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, (
|
||||||
|
"PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
|
||||||
|
)
|
||||||
rounded_seq_len = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF)
|
rounded_seq_len = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF)
|
||||||
if rounded_seq_len <= max_input_length:
|
if rounded_seq_len <= max_input_length:
|
||||||
bucket_size = rounded_seq_len - 1
|
bucket_size = rounded_seq_len - 1
|
||||||
@ -547,7 +599,8 @@ class CausalLMBatch(Batch):
|
|||||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||||
|
|
||||||
|
old_bs = len(requests)
|
||||||
|
top_n_tokens.extend([-1] * (new_bs - old_bs))
|
||||||
top_n_tokens_tensor = torch.tensor(
|
top_n_tokens_tensor = torch.tensor(
|
||||||
top_n_tokens, device=device, dtype=torch.int64
|
top_n_tokens, device=device, dtype=torch.int64
|
||||||
)
|
)
|
||||||
@ -568,14 +621,16 @@ class CausalLMBatch(Batch):
|
|||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]:
|
def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]:
|
||||||
dbg_trace('FILTER', f'num_reqs:{len(self.requests)} -> {len(request_ids)}')
|
dbg_trace("FILTER", f"num_reqs:{len(self.requests)} -> {len(request_ids)}")
|
||||||
request_ids = set(request_ids)
|
request_ids = set(request_ids)
|
||||||
self.requests = [req for req in self.requests if req.data.id in request_ids]
|
self.requests = [req for req in self.requests if req.data.id in request_ids]
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
def concatenate(cls, batches: List["CausalLMBatch"], pad_token_id: int = 0) -> "CausalLMBatch":
|
def concatenate(
|
||||||
|
cls, batches: List["CausalLMBatch"], pad_token_id: int = 0
|
||||||
|
) -> "CausalLMBatch":
|
||||||
return cls.recombine(batches, pad_token_id)
|
return cls.recombine(batches, pad_token_id)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
@ -618,9 +673,7 @@ class CausalLM(Model):
|
|||||||
tokenizer_class=AutoTokenizer,
|
tokenizer_class=AutoTokenizer,
|
||||||
config_class=AutoConfig,
|
config_class=AutoConfig,
|
||||||
batch_class=CausalLMBatch,
|
batch_class=CausalLMBatch,
|
||||||
|
|
||||||
):
|
):
|
||||||
|
|
||||||
if speculator:
|
if speculator:
|
||||||
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
|
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
|
||||||
|
|
||||||
@ -646,18 +699,14 @@ class CausalLM(Model):
|
|||||||
htorch.core.hpu_set_env()
|
htorch.core.hpu_set_env()
|
||||||
|
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
model = self.get_deepspeed_model(
|
model = self.get_deepspeed_model(model_id, dtype, revision)
|
||||||
model_id, dtype, revision
|
|
||||||
)
|
|
||||||
model = hq_env.prepare_model_for_quantization(model)
|
model = hq_env.prepare_model_for_quantization(model)
|
||||||
else:
|
else:
|
||||||
get_repo_root(model_id)
|
get_repo_root(model_id)
|
||||||
|
|
||||||
# Check support for rope scaling
|
# Check support for rope scaling
|
||||||
model_kwargs = {}
|
model_kwargs = {}
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(model_id)
|
||||||
model_id
|
|
||||||
)
|
|
||||||
if hasattr(config, "rope_scaling"):
|
if hasattr(config, "rope_scaling"):
|
||||||
model_kwargs["rope_scaling"] = self.get_rope_scaling()
|
model_kwargs["rope_scaling"] = self.get_rope_scaling()
|
||||||
|
|
||||||
@ -666,26 +715,34 @@ class CausalLM(Model):
|
|||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
**model_kwargs
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
model = hq_env.prepare_model_for_quantization(model)
|
model = hq_env.prepare_model_for_quantization(model)
|
||||||
model = model.eval().to(device)
|
model = model.eval().to(device)
|
||||||
|
|
||||||
self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1
|
self.enable_hpu_graph = (
|
||||||
|
os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1
|
||||||
|
)
|
||||||
self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
|
self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
|
||||||
|
|
||||||
if model.config.model_type not in ["gpt_bigcode"]: # gpt_bigcode/starcoderbase-3b skips remove_kv_cache_from_output()
|
if model.config.model_type not in [
|
||||||
|
"gpt_bigcode"
|
||||||
|
]: # gpt_bigcode/starcoderbase-3b skips remove_kv_cache_from_output()
|
||||||
model = remove_kv_cache_from_output(model)
|
model = remove_kv_cache_from_output(model)
|
||||||
|
|
||||||
if self.enable_hpu_graph:
|
if self.enable_hpu_graph:
|
||||||
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
||||||
|
|
||||||
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
||||||
else:
|
else:
|
||||||
if LAZY_MODE == 0:
|
if LAZY_MODE == 0:
|
||||||
# It is said that "keep_input_mutations" is safe for inference to be done
|
# It is said that "keep_input_mutations" is safe for inference to be done
|
||||||
dbg_trace(
|
dbg_trace("TORCH COMPILE", f"Torch compiling of model")
|
||||||
"TORCH COMPILE", f'Torch compiling of model')
|
model.model = torch.compile(
|
||||||
model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True})
|
model.model,
|
||||||
|
backend="hpu_backend",
|
||||||
|
options={"keep_input_mutations": True},
|
||||||
|
)
|
||||||
|
|
||||||
model = hq_env.setup_quantization(model)
|
model = hq_env.setup_quantization(model)
|
||||||
|
|
||||||
@ -714,8 +771,14 @@ class CausalLM(Model):
|
|||||||
"return_dict": True,
|
"return_dict": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if model.config.model_type in [
|
||||||
if model.config.model_type in ["llama", "mistral", "starcoder2", "qwen2", "falcon", "gpt_bigcode"]:
|
"llama",
|
||||||
|
"mistral",
|
||||||
|
"starcoder2",
|
||||||
|
"qwen2",
|
||||||
|
"falcon",
|
||||||
|
"gpt_bigcode",
|
||||||
|
]:
|
||||||
if model.config.model_type not in ["falcon", "gpt_bigcode"]:
|
if model.config.model_type not in ["falcon", "gpt_bigcode"]:
|
||||||
self.kwargs["attn_softmax_bf16"] = True
|
self.kwargs["attn_softmax_bf16"] = True
|
||||||
|
|
||||||
@ -740,11 +803,15 @@ class CausalLM(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create profiler
|
# Create profiler
|
||||||
ranks_to_profile = [int(val) for val in os.getenv("PROF_RANKS", "0").split(',')]
|
ranks_to_profile = [int(val) for val in os.getenv("PROF_RANKS", "0").split(",")]
|
||||||
record_shapes = os.getenv("PROF_RECORD_SHAPES", "false").lower() == "true"
|
record_shapes = os.getenv("PROF_RECORD_SHAPES", "false").lower() == "true"
|
||||||
output_dir = os.getenv("PROF_PATH", "/tmp/hpu_profile")
|
output_dir = os.getenv("PROF_PATH", "/tmp/hpu_profile")
|
||||||
self.profiling_warmup_steps = int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in ranks_to_profile else 0
|
self.profiling_warmup_steps = (
|
||||||
self.profiling_steps = int(os.getenv("PROF_STEP", "0")) if rank in ranks_to_profile else 0
|
int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in ranks_to_profile else 0
|
||||||
|
)
|
||||||
|
self.profiling_steps = (
|
||||||
|
int(os.getenv("PROF_STEP", "0")) if rank in ranks_to_profile else 0
|
||||||
|
)
|
||||||
self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0"))
|
self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0"))
|
||||||
if self.profiling_steps > 0:
|
if self.profiling_steps > 0:
|
||||||
self.hb_profiler = HabanaProfile(
|
self.hb_profiler = HabanaProfile(
|
||||||
@ -752,7 +819,7 @@ class CausalLM(Model):
|
|||||||
warmup=self.profiling_warmup_steps,
|
warmup=self.profiling_warmup_steps,
|
||||||
active=self.profiling_steps,
|
active=self.profiling_steps,
|
||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
record_shapes=record_shapes
|
record_shapes=record_shapes,
|
||||||
)
|
)
|
||||||
self.hb_profiler.start()
|
self.hb_profiler.start()
|
||||||
else:
|
else:
|
||||||
@ -760,23 +827,20 @@ class CausalLM(Model):
|
|||||||
self.step = 0
|
self.step = 0
|
||||||
|
|
||||||
def get_deepspeed_model(
|
def get_deepspeed_model(
|
||||||
self,
|
self, model_id: str, dtype: torch.dtype, revision: Optional[str] = None
|
||||||
model_id: str,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
revision: Optional[str] = None
|
|
||||||
) -> torch.nn.Module:
|
) -> torch.nn.Module:
|
||||||
import deepspeed
|
import deepspeed
|
||||||
from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu
|
from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu
|
||||||
|
|
||||||
world_size, rank, local_rank = initialize_distributed_hpu()
|
world_size, rank, local_rank = initialize_distributed_hpu()
|
||||||
model_kwargs = {
|
model_kwargs = {"revision": revision}
|
||||||
"revision": revision
|
|
||||||
}
|
|
||||||
|
|
||||||
# Initialize process(es) for DeepSpeed
|
# Initialize process(es) for DeepSpeed
|
||||||
deepspeed.init_distributed(dist_backend="hccl")
|
deepspeed.init_distributed(dist_backend="hccl")
|
||||||
logger.info(
|
logger.info(
|
||||||
"DeepSpeed is enabled. world_size {} rank {} local_rank {}".format(world_size, rank, local_rank)
|
"DeepSpeed is enabled. world_size {} rank {} local_rank {}".format(
|
||||||
|
world_size, rank, local_rank
|
||||||
|
)
|
||||||
)
|
)
|
||||||
config = AutoConfig.from_pretrained(model_id, **model_kwargs)
|
config = AutoConfig.from_pretrained(model_id, **model_kwargs)
|
||||||
load_to_meta = model_on_meta(config)
|
load_to_meta = model_on_meta(config)
|
||||||
@ -794,7 +858,9 @@ class CausalLM(Model):
|
|||||||
get_repo_root(model_id, local_rank=os.getenv("LOCAL_RANK"))
|
get_repo_root(model_id, local_rank=os.getenv("LOCAL_RANK"))
|
||||||
# TODO: revisit placement on CPU when auto-injection is possible
|
# TODO: revisit placement on CPU when auto-injection is possible
|
||||||
with deepspeed.OnDevice(dtype=dtype, device="cpu"):
|
with deepspeed.OnDevice(dtype=dtype, device="cpu"):
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, **model_kwargs)
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_id, torch_dtype=dtype, **model_kwargs
|
||||||
|
)
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
@ -817,16 +883,16 @@ class CausalLM(Model):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
rope_factor = float(os.getenv("ROPE_FACTOR", 1.0))
|
rope_factor = float(os.getenv("ROPE_FACTOR", 1.0))
|
||||||
return {
|
return {"type": rope_scaling, "factor": float(rope_factor)}
|
||||||
'type': rope_scaling, 'factor': float(rope_factor)
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def batch_type(self) -> Type[CausalLMBatch]:
|
def batch_type(self) -> Type[CausalLMBatch]:
|
||||||
return CausalLMBatch
|
return CausalLMBatch
|
||||||
|
|
||||||
def decode(self, generated_ids: List[int]) -> str:
|
def decode(self, generated_ids: List[int]) -> str:
|
||||||
return self.tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
return self.tokenizer.decode(
|
||||||
|
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||||
|
)
|
||||||
|
|
||||||
def decode_token(
|
def decode_token(
|
||||||
self,
|
self,
|
||||||
@ -835,7 +901,9 @@ class CausalLM(Model):
|
|||||||
read_offset: int = 0,
|
read_offset: int = 0,
|
||||||
) -> Tuple[str, int, int]:
|
) -> Tuple[str, int, int]:
|
||||||
if is_tokenizer_transparent(self.tokenizer):
|
if is_tokenizer_transparent(self.tokenizer):
|
||||||
new_text = self.tokenizer.decode(all_input_ids[read_offset:], skip_special_tokens=False)
|
new_text = self.tokenizer.decode(
|
||||||
|
all_input_ids[read_offset:], skip_special_tokens=False
|
||||||
|
)
|
||||||
return new_text, read_offset, len(all_input_ids)
|
return new_text, read_offset, len(all_input_ids)
|
||||||
else:
|
else:
|
||||||
return super().decode_token(all_input_ids, prefix_offset, read_offset)
|
return super().decode_token(all_input_ids, prefix_offset, read_offset)
|
||||||
@ -858,7 +926,7 @@ class CausalLM(Model):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Optimum Habana got "lazy_mode" key-val only supported for llama type of models
|
# Optimum Habana got "lazy_mode" key-val only supported for llama type of models
|
||||||
if self.model.config.model_type == "llama" :
|
if self.model.config.model_type == "llama":
|
||||||
kwargs["lazy_mode"] = LAZY_MODE == 1
|
kwargs["lazy_mode"] = LAZY_MODE == 1
|
||||||
|
|
||||||
if self.has_position_ids:
|
if self.has_position_ids:
|
||||||
@ -869,7 +937,9 @@ class CausalLM(Model):
|
|||||||
|
|
||||||
kwargs.update(self.kwargs)
|
kwargs.update(self.kwargs)
|
||||||
|
|
||||||
if past_key_values is not None and self.model.config.model_type not in ["gpt_bigcode"]:
|
if past_key_values is not None and self.model.config.model_type not in [
|
||||||
|
"gpt_bigcode"
|
||||||
|
]:
|
||||||
return self.model.forward(**kwargs)
|
return self.model.forward(**kwargs)
|
||||||
else:
|
else:
|
||||||
outputs = self.model.forward(**kwargs)
|
outputs = self.model.forward(**kwargs)
|
||||||
@ -896,18 +966,26 @@ class CausalLM(Model):
|
|||||||
token_idx_scalar = batch.attention_mask.shape[-1] - 1
|
token_idx_scalar = batch.attention_mask.shape[-1] - 1
|
||||||
token_idx = torch.tensor(token_idx_scalar).to(self.device)
|
token_idx = torch.tensor(token_idx_scalar).to(self.device)
|
||||||
else:
|
else:
|
||||||
token_idx_scalar = batch.attention_mask.shape[-1] - batch.right_padding
|
token_idx_scalar = (
|
||||||
|
batch.attention_mask.shape[-1] - batch.right_padding
|
||||||
|
)
|
||||||
token_idx = torch.tensor(token_idx_scalar).to(self.device)
|
token_idx = torch.tensor(token_idx_scalar).to(self.device)
|
||||||
|
|
||||||
# Select next token
|
# Select next token
|
||||||
input_length = batch.input_length
|
input_length = batch.input_length
|
||||||
if logits.shape[-2] > 1:
|
if logits.shape[-2] > 1:
|
||||||
next_token_ids, next_token_logprobs, logprobs, _, _ = batch.next_token_chooser(
|
next_token_ids, next_token_logprobs, logprobs, _, _ = (
|
||||||
batch.input_ids, logits[:, input_length - 1: input_length, :].squeeze(-2), self.speculate
|
batch.next_token_chooser(
|
||||||
|
batch.input_ids,
|
||||||
|
logits[:, input_length - 1 : input_length, :].squeeze(-2),
|
||||||
|
self.speculate,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
next_token_ids, next_token_logprobs, logprobs, _, _ = batch.next_token_chooser(
|
next_token_ids, next_token_logprobs, logprobs, _, _ = (
|
||||||
batch.input_ids, logits.squeeze(-2), self.speculate
|
batch.next_token_chooser(
|
||||||
|
batch.input_ids, logits.squeeze(-2), self.speculate
|
||||||
|
)
|
||||||
)
|
)
|
||||||
# Speculation is not active for causal
|
# Speculation is not active for causal
|
||||||
accepted_ids = torch.ones_like(batch.input_ids)[:, 0]
|
accepted_ids = torch.ones_like(batch.input_ids)[:, 0]
|
||||||
@ -918,24 +996,29 @@ class CausalLM(Model):
|
|||||||
accepted_ids,
|
accepted_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
prev_batches.append({
|
prev_batches.append(
|
||||||
'next_token_ids': next_token_ids,
|
{
|
||||||
'next_token_logprobs': next_token_logprobs,
|
"next_token_ids": next_token_ids,
|
||||||
})
|
"next_token_logprobs": next_token_logprobs,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
for req_idx, req in enumerate(batch.requests):
|
for req_idx, req in enumerate(batch.requests):
|
||||||
requests_to_generate.append({
|
requests_to_generate.append(
|
||||||
'req': req,
|
{
|
||||||
'prev_req_idx': req.idx,
|
"req": req,
|
||||||
'batch_id': batch_id,
|
"prev_req_idx": req.idx,
|
||||||
'seed': batch.next_token_chooser.seeds[req_idx],
|
"batch_id": batch_id,
|
||||||
'do_sample': batch.next_token_chooser.do_sample[req_idx],
|
"seed": batch.next_token_chooser.seeds[req_idx],
|
||||||
'top_n_tokens': batch.top_n_tokens[req_idx],
|
"do_sample": batch.next_token_chooser.do_sample[req_idx],
|
||||||
'top_token_ids': batch_top_token_ids[req_idx],
|
"top_n_tokens": batch.top_n_tokens[req_idx],
|
||||||
'top_token_logprobs': batch_top_token_logprobs[req_idx],
|
"top_token_ids": batch_top_token_ids[req_idx],
|
||||||
'grammar_state': batch.next_token_chooser.fsm_grammar_states[req.idx],
|
"top_token_logprobs": batch_top_token_logprobs[req_idx],
|
||||||
|
"grammar_state": batch.next_token_chooser.fsm_grammar_states[
|
||||||
})
|
req.idx
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
htorch.core.mark_step()
|
htorch.core.mark_step()
|
||||||
|
|
||||||
@ -950,7 +1033,9 @@ class CausalLM(Model):
|
|||||||
|
|
||||||
# Update position_ids
|
# Update position_ids
|
||||||
if prefill:
|
if prefill:
|
||||||
batch.position_ids = torch.index_select(batch.position_ids, 1, token_idx - 1) + 1
|
batch.position_ids = (
|
||||||
|
torch.index_select(batch.position_ids, 1, token_idx - 1) + 1
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
batch.position_ids += 1
|
batch.position_ids += 1
|
||||||
# Update past key values
|
# Update past key values
|
||||||
@ -971,13 +1056,19 @@ class CausalLM(Model):
|
|||||||
if not prefill:
|
if not prefill:
|
||||||
batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id)
|
batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id)
|
||||||
|
|
||||||
scenario = 'PREFILL' if prefill else 'GENERATE'
|
scenario = "PREFILL" if prefill else "GENERATE"
|
||||||
if self.enable_hpu_graph and self.limit_hpu_graph and round_up(batch.batch_size, BATCH_BUCKET_SIZE) != self.prev_bs:
|
if (
|
||||||
|
self.enable_hpu_graph
|
||||||
|
and self.limit_hpu_graph
|
||||||
|
and round_up(batch.batch_size, BATCH_BUCKET_SIZE) != self.prev_bs
|
||||||
|
):
|
||||||
self.model.clear_cache()
|
self.model.clear_cache()
|
||||||
self.prev_bs = round_up(batch.batch_size, BATCH_BUCKET_SIZE)
|
self.prev_bs = round_up(batch.batch_size, BATCH_BUCKET_SIZE)
|
||||||
dbg_trace(
|
dbg_trace(
|
||||||
scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}')
|
scenario,
|
||||||
assert batch.right_padding > 0, 'No more room for next token!'
|
f"bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}",
|
||||||
|
)
|
||||||
|
assert batch.right_padding > 0, "No more room for next token!"
|
||||||
|
|
||||||
# Execute batch
|
# Execute batch
|
||||||
if prefill:
|
if prefill:
|
||||||
@ -989,14 +1080,18 @@ class CausalLM(Model):
|
|||||||
batch.position_ids,
|
batch.position_ids,
|
||||||
token_idx,
|
token_idx,
|
||||||
batch.past_key_values,
|
batch.past_key_values,
|
||||||
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None,
|
bypass_hpu_graph=prefill and self.limit_hpu_graph
|
||||||
|
if self.enable_hpu_graph
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]):
|
elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]):
|
||||||
# Don't schedule next forward if max_new_tokens for all requests equals 1
|
# Don't schedule next forward if max_new_tokens for all requests equals 1
|
||||||
# - we've already generated the first and only needed token in the prefill phase
|
# - we've already generated the first and only needed token in the prefill phase
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
|
token_idx = torch.tensor(
|
||||||
|
batch.attention_mask.shape[-1] - batch.right_padding
|
||||||
|
).to(self.device)
|
||||||
input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1)
|
input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1)
|
||||||
logits = self.forward(
|
logits = self.forward(
|
||||||
input_ids,
|
input_ids,
|
||||||
@ -1004,7 +1099,9 @@ class CausalLM(Model):
|
|||||||
batch.position_ids,
|
batch.position_ids,
|
||||||
token_idx,
|
token_idx,
|
||||||
batch.past_key_values,
|
batch.past_key_values,
|
||||||
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None,
|
bypass_hpu_graph=prefill and self.limit_hpu_graph
|
||||||
|
if self.enable_hpu_graph
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
if self.model.config.model_type in ["gpt_bigcode"]:
|
if self.model.config.model_type in ["gpt_bigcode"]:
|
||||||
batch.logits, batch.past = logits
|
batch.logits, batch.past = logits
|
||||||
@ -1018,40 +1115,45 @@ class CausalLM(Model):
|
|||||||
# Stage 3. Finish and return previous generations
|
# Stage 3. Finish and return previous generations
|
||||||
stopped = len(requests_to_generate) > 0
|
stopped = len(requests_to_generate) > 0
|
||||||
for prev_batch in prev_batches:
|
for prev_batch in prev_batches:
|
||||||
prev_batch['next_token_logprobs'] = prev_batch['next_token_logprobs'].tolist()
|
prev_batch["next_token_logprobs"] = prev_batch[
|
||||||
prev_batch['next_token_ids_cpu'] = prev_batch['next_token_ids'].cpu()
|
"next_token_logprobs"
|
||||||
|
].tolist()
|
||||||
|
prev_batch["next_token_ids_cpu"] = prev_batch["next_token_ids"].cpu()
|
||||||
htorch.core.mark_step()
|
htorch.core.mark_step()
|
||||||
|
|
||||||
for req_data in requests_to_generate:
|
for req_data in requests_to_generate:
|
||||||
req = req_data['req']
|
req = req_data["req"]
|
||||||
i = req_data['prev_req_idx']
|
i = req_data["prev_req_idx"]
|
||||||
prev_batch_id = req_data['batch_id']
|
prev_batch_id = req_data["batch_id"]
|
||||||
assert len(prev_batches) > prev_batch_id
|
assert len(prev_batches) > prev_batch_id
|
||||||
next_token_ids_cpu = prev_batches[prev_batch_id]['next_token_ids_cpu']
|
next_token_ids_cpu = prev_batches[prev_batch_id]["next_token_ids_cpu"]
|
||||||
next_token_logprobs = prev_batches[prev_batch_id]['next_token_logprobs']
|
next_token_logprobs = prev_batches[prev_batch_id]["next_token_logprobs"]
|
||||||
|
|
||||||
request = req.data
|
request = req.data
|
||||||
input_length = req.input_length
|
input_length = req.input_length
|
||||||
prefix_offset = req.prefix_offset
|
prefix_offset = req.prefix_offset
|
||||||
read_offset = req.read_offset
|
read_offset = req.read_offset
|
||||||
do_sample = req_data['do_sample']
|
do_sample = req_data["do_sample"]
|
||||||
seed = req_data['seed']
|
seed = req_data["seed"]
|
||||||
stopping_criteria = req.stopping_criteria
|
stopping_criteria = req.stopping_criteria
|
||||||
all_input_ids = req.all_input_ids
|
all_input_ids = req.all_input_ids
|
||||||
next_token_id = next_token_ids_cpu[i]
|
next_token_id = next_token_ids_cpu[i]
|
||||||
next_token_logprob = next_token_logprobs[i]
|
next_token_logprob = next_token_logprobs[i]
|
||||||
top_n_tokens = req_data['top_n_tokens']
|
top_n_tokens = req_data["top_n_tokens"]
|
||||||
top_token_ids = req_data['top_token_ids']
|
top_token_ids = req_data["top_token_ids"]
|
||||||
top_token_logprobs = req_data['top_token_logprobs']
|
top_token_logprobs = req_data["top_token_logprobs"]
|
||||||
grammar_state = req_data['grammar_state']
|
grammar_state = req_data["grammar_state"]
|
||||||
|
|
||||||
# Append next token to all tokens
|
# Append next token to all tokens
|
||||||
all_input_ids[input_length] = next_token_id
|
all_input_ids[input_length] = next_token_id
|
||||||
new_input_length = input_length + 1
|
new_input_length = input_length + 1
|
||||||
|
|
||||||
# Generated token
|
# Generated token
|
||||||
if is_tokenizer_transparent(self.tokenizer) and len(stopping_criteria.stop_sequence_criterias) == 0:
|
if (
|
||||||
next_token_text = ''
|
is_tokenizer_transparent(self.tokenizer)
|
||||||
|
and len(stopping_criteria.stop_sequence_criterias) == 0
|
||||||
|
):
|
||||||
|
next_token_text = ""
|
||||||
else:
|
else:
|
||||||
next_token_text, prefix_offset, read_offset = self.decode_token(
|
next_token_text, prefix_offset, read_offset = self.decode_token(
|
||||||
all_input_ids[0:new_input_length, 0], prefix_offset, read_offset
|
all_input_ids[0:new_input_length, 0], prefix_offset, read_offset
|
||||||
@ -1075,7 +1177,11 @@ class CausalLM(Model):
|
|||||||
output_text = None
|
output_text = None
|
||||||
else:
|
else:
|
||||||
output_text = self.decode(
|
output_text = self.decode(
|
||||||
all_input_ids[new_input_length - stopping_criteria.current_tokens: new_input_length, 0]
|
all_input_ids[
|
||||||
|
new_input_length
|
||||||
|
- stopping_criteria.current_tokens : new_input_length,
|
||||||
|
0,
|
||||||
|
]
|
||||||
)
|
)
|
||||||
generated_text = GeneratedText(
|
generated_text = GeneratedText(
|
||||||
output_text,
|
output_text,
|
||||||
@ -1090,7 +1196,7 @@ class CausalLM(Model):
|
|||||||
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
|
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
|
||||||
# Remove generated token to only have prefill and add nan for first prompt token
|
# Remove generated token to only have prefill and add nan for first prompt token
|
||||||
prefill_logprobs = [float("nan")] + next_token_logprobs
|
prefill_logprobs = [float("nan")] + next_token_logprobs
|
||||||
prefill_token_ids = all_input_ids[0: new_input_length - 1]
|
prefill_token_ids = all_input_ids[0 : new_input_length - 1]
|
||||||
prefill_texts = self.tokenizer.batch_decode(
|
prefill_texts = self.tokenizer.batch_decode(
|
||||||
prefill_token_ids,
|
prefill_token_ids,
|
||||||
clean_up_tokenization_spaces=False,
|
clean_up_tokenization_spaces=False,
|
||||||
@ -1159,7 +1265,12 @@ class CausalLM(Model):
|
|||||||
htorch.core.mark_step()
|
htorch.core.mark_step()
|
||||||
self.step = self.step + 1
|
self.step = self.step + 1
|
||||||
if self.hb_profiler is not None:
|
if self.hb_profiler is not None:
|
||||||
if self.step > self.profiling_wait_steps + self.profiling_warmup_steps + self.profiling_steps:
|
if (
|
||||||
|
self.step
|
||||||
|
> self.profiling_wait_steps
|
||||||
|
+ self.profiling_warmup_steps
|
||||||
|
+ self.profiling_steps
|
||||||
|
):
|
||||||
self.hb_profiler.stop()
|
self.hb_profiler.stop()
|
||||||
else:
|
else:
|
||||||
self.hb_profiler.step()
|
self.hb_profiler.step()
|
||||||
@ -1178,11 +1289,12 @@ class CausalLM(Model):
|
|||||||
|
|
||||||
return self.batch_type.from_pb(batch, self.tokenizer, self.dtype, self.device)
|
return self.batch_type.from_pb(batch, self.tokenizer, self.dtype, self.device)
|
||||||
|
|
||||||
|
|
||||||
def warmup(self, request) -> None:
|
def warmup(self, request) -> None:
|
||||||
MAX_TOTAL_TOKENS = request.max_total_tokens
|
MAX_TOTAL_TOKENS = request.max_total_tokens
|
||||||
MAX_BATCH_TOTAL_TOKENS = request.max_batch_total_tokens
|
MAX_BATCH_TOTAL_TOKENS = request.max_batch_total_tokens
|
||||||
batch = self.batch_type.from_pb(request.batch, self.tokenizer, self.dtype, self.device)
|
batch = self.batch_type.from_pb(
|
||||||
|
request.batch, self.tokenizer, self.dtype, self.device
|
||||||
|
)
|
||||||
max_prefill_batch_size = batch.input_ids.shape[0]
|
max_prefill_batch_size = batch.input_ids.shape[0]
|
||||||
try:
|
try:
|
||||||
# max prefill batch size warmup
|
# max prefill batch size warmup
|
||||||
@ -1199,14 +1311,21 @@ class CausalLM(Model):
|
|||||||
max_input_length = request.max_input_length
|
max_input_length = request.max_input_length
|
||||||
prefill_batch_size_list = [batch for batch in range(PREFILL_BATCH_BUCKET_SIZE, max_prefill_batch_size, PREFILL_BATCH_BUCKET_SIZE)]
|
prefill_batch_size_list = [batch for batch in range(PREFILL_BATCH_BUCKET_SIZE, max_prefill_batch_size, PREFILL_BATCH_BUCKET_SIZE)]
|
||||||
prefill_batch_size_list.append(max_prefill_batch_size)
|
prefill_batch_size_list.append(max_prefill_batch_size)
|
||||||
prefill_seqlen_list = [seq for seq in range(PAD_SEQUENCE_TO_MULTIPLE_OF, max_input_length, PAD_SEQUENCE_TO_MULTIPLE_OF)]
|
prefill_seqlen_list = [
|
||||||
|
seq
|
||||||
|
for seq in range(
|
||||||
|
PAD_SEQUENCE_TO_MULTIPLE_OF,
|
||||||
|
max_input_length,
|
||||||
|
PAD_SEQUENCE_TO_MULTIPLE_OF,
|
||||||
|
)
|
||||||
|
]
|
||||||
prefill_seqlen_list.append(max_input_length)
|
prefill_seqlen_list.append(max_input_length)
|
||||||
prefill_batch_size_list.sort(reverse=True)
|
prefill_batch_size_list.sort(reverse=True)
|
||||||
prefill_seqlen_list.sort(reverse=True)
|
prefill_seqlen_list.sort(reverse=True)
|
||||||
try:
|
try:
|
||||||
for batch_size in prefill_batch_size_list:
|
for batch_size in prefill_batch_size_list:
|
||||||
for seq_len in prefill_seqlen_list:
|
for seq_len in prefill_seqlen_list:
|
||||||
batch = self.generate_warmup_batch(request, seq_len-1, batch_size)
|
batch = self.generate_warmup_batch(request, seq_len - 1, batch_size)
|
||||||
_, prefill_batch, _ = self.generate_token([batch])
|
_, prefill_batch, _ = self.generate_token([batch])
|
||||||
except:
|
except:
|
||||||
prefill_batch_size_list.sort()
|
prefill_batch_size_list.sort()
|
||||||
@ -1227,24 +1346,33 @@ class CausalLM(Model):
|
|||||||
f"Memory stats: {mem_stats} "
|
f"Memory stats: {mem_stats} "
|
||||||
)
|
)
|
||||||
|
|
||||||
#warmup decode batch size
|
# warmup decode batch size
|
||||||
max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS)
|
max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS)
|
||||||
max_decode_batch_size = round_up(max_decode_batch_size, BATCH_BUCKET_SIZE)
|
max_decode_batch_size = round_up(max_decode_batch_size, BATCH_BUCKET_SIZE)
|
||||||
decode_batch_size_list = [i for i in range(BATCH_BUCKET_SIZE, max_decode_batch_size, BATCH_BUCKET_SIZE)]
|
decode_batch_size_list = [
|
||||||
|
i
|
||||||
|
for i in range(BATCH_BUCKET_SIZE, max_decode_batch_size, BATCH_BUCKET_SIZE)
|
||||||
|
]
|
||||||
decode_batch_size_list.append(max_decode_batch_size)
|
decode_batch_size_list.append(max_decode_batch_size)
|
||||||
decode_batch_size_list.sort(reverse=True)
|
decode_batch_size_list.sort(reverse=True)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for batch_size in decode_batch_size_list:
|
for batch_size in decode_batch_size_list:
|
||||||
batches= []
|
batches = []
|
||||||
iters = math.floor(batch_size/max_prefill_batch_size)
|
iters = math.floor(batch_size / max_prefill_batch_size)
|
||||||
for i in range(iters):
|
for i in range(iters):
|
||||||
batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, max_prefill_batch_size)
|
batch = self.generate_warmup_batch(
|
||||||
|
request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, max_prefill_batch_size
|
||||||
|
)
|
||||||
_, prefill_batch, _ = self.generate_token([batch])
|
_, prefill_batch, _ = self.generate_token([batch])
|
||||||
batches.append(prefill_batch)
|
batches.append(prefill_batch)
|
||||||
|
|
||||||
if batch_size % max_prefill_batch_size != 0:
|
if batch_size % max_prefill_batch_size != 0:
|
||||||
batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, batch_size % max_prefill_batch_size)
|
batch = self.generate_warmup_batch(
|
||||||
|
request,
|
||||||
|
PAD_SEQUENCE_TO_MULTIPLE_OF - 1,
|
||||||
|
batch_size % max_prefill_batch_size,
|
||||||
|
)
|
||||||
_, prefill_batch, _ = self.generate_token([batch])
|
_, prefill_batch, _ = self.generate_token([batch])
|
||||||
batches.append(prefill_batch)
|
batches.append(prefill_batch)
|
||||||
|
|
||||||
@ -1254,10 +1382,10 @@ class CausalLM(Model):
|
|||||||
batches.clear()
|
batches.clear()
|
||||||
|
|
||||||
except:
|
except:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Not enough memory to warmup decode batch_sizes({decode_batch_size_list})."
|
f"Not enough memory to warmup decode batch_sizes({decode_batch_size_list})."
|
||||||
f"You need to decrease `--max-batch-total-tokens`"
|
f"You need to decrease `--max-batch-total-tokens`"
|
||||||
)
|
)
|
||||||
|
|
||||||
decode_batch_size_list.sort()
|
decode_batch_size_list.sort()
|
||||||
MAX_BATCH_TOTAL_TOKENS = MAX_TOTAL_TOKENS * decode_batch_size_list[-1]
|
MAX_BATCH_TOTAL_TOKENS = MAX_TOTAL_TOKENS * decode_batch_size_list[-1]
|
||||||
@ -1268,4 +1396,4 @@ class CausalLM(Model):
|
|||||||
f"Memory stats: {mem_stats} "
|
f"Memory stats: {mem_stats} "
|
||||||
)
|
)
|
||||||
|
|
||||||
return MAX_BATCH_TOTAL_TOKENS
|
return MAX_BATCH_TOTAL_TOKENS
|
||||||
|
Loading…
Reference in New Issue
Block a user