mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
Batch bucketing improvements (#15)
This commit is contained in:
parent
8523f7ef64
commit
381ec38cad
@ -68,7 +68,10 @@ def round_up(number, k):
|
||||
return (number + k - 1) // k * k
|
||||
|
||||
|
||||
def batch_alloc(new_bs, tensor):
|
||||
def prepare_memory(new_bs, tensor, inplace):
|
||||
if inplace:
|
||||
return tensor
|
||||
else:
|
||||
return tensor.new_empty((new_bs,) + tensor.shape[1:])
|
||||
|
||||
|
||||
@ -154,9 +157,6 @@ class CausalLMBatch(Batch):
|
||||
top_n_tokens: List[int]
|
||||
top_n_tokens_tensor: torch.Tensor
|
||||
|
||||
# Maximum number of tokens this batch will grow to
|
||||
max_tokens: int
|
||||
|
||||
input_length: int
|
||||
right_padding: int
|
||||
|
||||
@ -169,32 +169,53 @@ class CausalLMBatch(Batch):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def recombine(cls, batches: List["CausalLMBatch"], req_ids: List[List[int]], is_optimized_for_gaudi: bool = False) -> "CausalLMBatch":
|
||||
new_bs = round_up(sum([len(reqs) for reqs in req_ids]), BATCH_BUCKET_SIZE)
|
||||
def recombine(cls, batches: List["CausalLMBatch"], is_optimized_for_gaudi: bool = False) -> "CausalLMBatch":
|
||||
total_requests = sum(len(b) for b in batches)
|
||||
new_bs = round_up(total_requests, BATCH_BUCKET_SIZE)
|
||||
batch_id = batches[0].batch_id
|
||||
device = batches[0].input_ids.device
|
||||
|
||||
# TODO: for now use consecutive indices. This could be optimized to reuse existing batch memory and only overwrite
|
||||
# indices that are no longer used instead of allocating new memory
|
||||
free_indices = itertools.count(0)
|
||||
to_tensors = lambda ind: (torch.tensor(ind[0], device=device), torch.tensor(ind[1], device=device))
|
||||
requests = [[req for req in batch.requests if req.data.id in ids] for batch, ids in zip(batches, req_ids)]
|
||||
indices = [[to_tensors(req.update_idx(next(free_indices))) for req in batch_reqs] for batch_reqs in requests]
|
||||
requests = list(itertools.chain(*requests))
|
||||
max_input_length = max(b.input_length for b in batches)
|
||||
offsets = [max_input_length - b.input_length for b in batches]
|
||||
padding = [b.right_padding for b in batches]
|
||||
|
||||
moves_needed = [total_requests - len(b) if b.batch_size == new_bs else total_requests for b in batches]
|
||||
target_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0]
|
||||
|
||||
# TODO: Add support for changing max seq len, i.e. due to output length bucketing
|
||||
# FIXME: max_seq_len for non optimized code
|
||||
max_input_length = max(req.input_length for req in requests)
|
||||
offsets = [(max_input_length - b.input_length) for b in batches]
|
||||
scenario = 'CONCAT' if len(batches) > 1 else 'FILTER'
|
||||
dbg_trace(scenario, f'bs:{[b.input_ids.size(0) for b in batches]}->{new_bs} num_reqs:{[len(b.requests) for b in batches]}->{len(requests)} offsets:{offsets}')
|
||||
if len(batches) > 1:
|
||||
scenario = 'CONCAT'
|
||||
elif batches[0].batch_size != new_bs:
|
||||
scenario = 'RESHAPE'
|
||||
elif padding[0] <= 1:
|
||||
scenario = 'SHIFT'
|
||||
offsets = [b.max_input_length - max_input_length for b in batches]
|
||||
max_input_length = max(b.max_input_length for b in batches)
|
||||
else:
|
||||
# Nothing to do
|
||||
return batches[0]
|
||||
|
||||
inplace = batches[target_batch_idx].batch_size == new_bs
|
||||
dbg_trace(scenario, f'bs:{[b.batch_size for b in batches]}->{new_bs} reqs:{[len(b) for b in batches]} offsets:{offsets} padding:{padding} moves_needed:{moves_needed} inplace:{inplace}')
|
||||
|
||||
grouped_requests = [[req for req in batch.requests] for batch in batches]
|
||||
flat_requests = list(itertools.chain(*grouped_requests))
|
||||
if inplace and scenario != 'SHIFT':
|
||||
# The data is already present in the batch. No need to move it
|
||||
grouped_requests[target_batch_idx] = []
|
||||
free_indices = batches[target_batch_idx].free_indices()
|
||||
else:
|
||||
free_indices = itertools.count(0)
|
||||
|
||||
to_tensors = lambda ind: (torch.tensor(ind[0], device=device), torch.tensor(ind[1], device=device))
|
||||
indices = [[to_tensors(req.update_idx(next(free_indices))) for req in batch_reqs] for batch_reqs in grouped_requests]
|
||||
|
||||
max_seq_len = batches[0].attention_mask.size(1)
|
||||
input_length = max(r.input_length for r in requests)
|
||||
input_length = max_input_length
|
||||
right_padding = max_seq_len - input_length
|
||||
max_tokens = len(requests) * max_seq_len
|
||||
|
||||
chunk_size = batches[0].past_key_values[0][0].size(0) // batches[0].input_ids.size(0)
|
||||
chunk_size = batches[0].past_key_values[0][0].size(0) // batches[0].batch_size
|
||||
num_layers = len(batches[0].past_key_values)
|
||||
past_key_values_type = type(batches[0].past_key_values)
|
||||
|
||||
@ -213,33 +234,33 @@ class CausalLMBatch(Batch):
|
||||
for b in batches:
|
||||
del b.input_ids
|
||||
src = shift_all(src, seq_dim, offsets)
|
||||
input_ids = batch_alloc(new_bs, src[0])
|
||||
input_ids = prepare_memory(new_bs, src[target_batch_idx], inplace)
|
||||
input_ids = move_data(input_ids, 1, indices, src)
|
||||
|
||||
src = [b.attention_mask for b in batches]
|
||||
for b in batches:
|
||||
del b.attention_mask
|
||||
src = shift_all(src, seq_dim, offsets)
|
||||
attention_mask = batch_alloc(new_bs, src[0])
|
||||
attention_mask = prepare_memory(new_bs, src[target_batch_idx], inplace)
|
||||
attention_mask = move_data(attention_mask, 1, indices, src)
|
||||
|
||||
src = [b.position_ids for b in batches]
|
||||
for b in batches:
|
||||
del b.position_ids
|
||||
src = shift_all(src, seq_dim, offsets)
|
||||
position_ids = batch_alloc(new_bs, src[0])
|
||||
position_ids = prepare_memory(new_bs, src[target_batch_idx], inplace)
|
||||
position_ids = move_data(position_ids, 1, indices, src)
|
||||
|
||||
past_key_values = []
|
||||
for layer_num in range(num_layers):
|
||||
src = [b.past_key_values[layer_num][0] for b in batches]
|
||||
src = shift_all(src, key_dim, offsets)
|
||||
updated_key = batch_alloc(new_bs * chunk_size, src[0])
|
||||
updated_key = prepare_memory(new_bs * chunk_size, src[target_batch_idx], inplace)
|
||||
updated_key = move_data(updated_key, chunk_size, indices, src)
|
||||
|
||||
src = [b.past_key_values[layer_num][1] for b in batches]
|
||||
src = shift_all(src, value_dim, offsets)
|
||||
updated_value = batch_alloc(new_bs * chunk_size, src[0])
|
||||
updated_value = prepare_memory(new_bs * chunk_size, src[target_batch_idx], inplace)
|
||||
updated_value = move_data(updated_value, chunk_size, indices, src)
|
||||
|
||||
past_key_values.append((updated_key, updated_value))
|
||||
@ -248,10 +269,10 @@ class CausalLMBatch(Batch):
|
||||
|
||||
past_key_values = past_key_values_type(past_key_values)
|
||||
|
||||
top_n_tokens = [r.data.top_n_tokens for r in requests]
|
||||
top_n_tokens = [r.data.top_n_tokens for r in flat_requests]
|
||||
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||
[r.data.parameters for r in requests],
|
||||
[r.data.parameters for r in flat_requests],
|
||||
batches[0].next_token_chooser.device,
|
||||
batches[0].next_token_chooser.dtype
|
||||
)
|
||||
@ -260,7 +281,7 @@ class CausalLMBatch(Batch):
|
||||
|
||||
return cls(
|
||||
batch_id=batch_id,
|
||||
requests=requests,
|
||||
requests=flat_requests,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
@ -268,7 +289,6 @@ class CausalLMBatch(Batch):
|
||||
next_token_chooser=next_token_chooser,
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
max_tokens=max_tokens,
|
||||
input_length=input_length,
|
||||
right_padding=right_padding
|
||||
)
|
||||
@ -327,9 +347,6 @@ class CausalLMBatch(Batch):
|
||||
r.prefix_offset = input_len - 5
|
||||
r.read_offset = input_len
|
||||
|
||||
#max_tokens = new_bs * max_total_tokens
|
||||
max_tokens = len(requests) * max_total_tokens
|
||||
|
||||
input_ids = tokenized_inputs["input_ids"]
|
||||
attention_mask = tokenized_inputs["attention_mask"]
|
||||
|
||||
@ -363,23 +380,50 @@ class CausalLMBatch(Batch):
|
||||
next_token_chooser=next_token_chooser,
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
max_tokens=max_tokens,
|
||||
input_length=max_input_length,
|
||||
right_padding=max_new_tokens + extra_padding if is_optimized_for_gaudi else 0
|
||||
)
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
def filter(self, request_ids: List[int], is_optimized_for_gaudi: bool = False) -> Optional["CausalLMBatch"]:
|
||||
return self.__class__.recombine([self], [request_ids], is_optimized_for_gaudi)
|
||||
dbg_trace('FILTER', f'num_reqs:{len(self.requests)} -> {len(request_ids)}')
|
||||
request_ids = set(request_ids)
|
||||
self.requests = [req for req in self.requests if req.data.id in request_ids]
|
||||
return self.__class__.recombine([self], is_optimized_for_gaudi)
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(cls, batches: List["CausalLMBatch"], is_optimized_for_gaudi: bool = False) -> "CausalLMBatch":
|
||||
return cls.recombine(batches, [[req.data.id for req in b.requests] for b in batches], is_optimized_for_gaudi)
|
||||
return cls.recombine(batches, is_optimized_for_gaudi)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.requests)
|
||||
|
||||
@property
|
||||
def max_input_length(self):
|
||||
return max(req.input_length for req in self.requests)
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return self.attention_mask.size(0)
|
||||
|
||||
@property
|
||||
def seq_length(self):
|
||||
return self.attention_mask.size(1)
|
||||
|
||||
# Maximum number of tokens this batch will grow to
|
||||
@property
|
||||
def max_tokens(self):
|
||||
max_total_tokens = self.attention_mask.size(1)
|
||||
return len(self.requests) * max_total_tokens
|
||||
|
||||
def free_indices(self):
|
||||
used = set(req.idx for req in self.requests)
|
||||
for i in range(self.batch_size):
|
||||
if i in used:
|
||||
continue
|
||||
yield i
|
||||
|
||||
|
||||
class CausalLM(Model):
|
||||
def __init__(
|
||||
@ -550,8 +594,12 @@ class CausalLM(Model):
|
||||
@tracer.start_as_current_span("generate_token")
|
||||
def generate_token(self, batch: CausalLMBatch) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
|
||||
prefill = batch.past_key_values is None
|
||||
# Check if we need to do any bookkeeping first
|
||||
if not prefill:
|
||||
batch = batch.__class__.recombine([batch], self.is_optimized_for_gaudi)
|
||||
|
||||
scenario = 'PREFILL' if prefill else 'GENERATE'
|
||||
dbg_trace(scenario, f'bs:{batch.input_ids.size(0)} num_reqs:{len(batch.requests)} seq_len:{batch.input_ids.shape[1]}')
|
||||
dbg_trace(scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length}')
|
||||
self.step = self.step + 1
|
||||
if self.hb_profer_started == True and self.step > self.profiling_warmup_steps + self.profiling_steps:
|
||||
self.hb_profer.stop()
|
||||
@ -605,21 +653,21 @@ class CausalLM(Model):
|
||||
next_token_ids_cpu = next_token_ids.cpu()
|
||||
htorch.core.mark_step()
|
||||
|
||||
for req in batch.requests:
|
||||
for req_idx, req in enumerate(batch.requests):
|
||||
i = req.idx
|
||||
request = req.data
|
||||
input_length = req.input_length
|
||||
prefix_offset = req.prefix_offset
|
||||
read_offset = req.read_offset
|
||||
do_sample = batch.next_token_chooser.do_sample[i]
|
||||
seed = batch.next_token_chooser.seeds[i]
|
||||
do_sample = batch.next_token_chooser.do_sample[req_idx]
|
||||
seed = batch.next_token_chooser.seeds[req_idx]
|
||||
stopping_criteria = req.stopping_criteria
|
||||
all_input_ids = req.all_input_ids
|
||||
top_n_tokens = batch.top_n_tokens[i]
|
||||
top_n_tokens = batch.top_n_tokens[req_idx]
|
||||
next_token_id = next_token_ids_cpu[i]
|
||||
next_token_logprob = next_token_logprobs[i]
|
||||
top_token_ids = batch_top_token_ids[i]
|
||||
top_token_logprobs = batch_top_token_logprobs[i]
|
||||
top_token_ids = batch_top_token_ids[req_idx]
|
||||
top_token_logprobs = batch_top_token_logprobs[req_idx]
|
||||
|
||||
# Append next token to all tokens
|
||||
if self.is_optimized_for_gaudi:
|
||||
@ -717,6 +765,7 @@ class CausalLM(Model):
|
||||
if stopped:
|
||||
if self.hb_profer_started == True:
|
||||
self.hb_profer.step()
|
||||
htorch.core.mark_step()
|
||||
return generations, None
|
||||
|
||||
# Slice unused values from prefill, use it to store next token
|
||||
|
Loading…
Reference in New Issue
Block a user