Batch bucketing improvements (#15)

This commit is contained in:
madamczykhabana 2024-01-17 10:09:27 +01:00 committed by GitHub
parent 8523f7ef64
commit 381ec38cad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -68,7 +68,10 @@ def round_up(number, k):
return (number + k - 1) // k * k
def batch_alloc(new_bs, tensor):
def prepare_memory(new_bs, tensor, inplace):
if inplace:
return tensor
else:
return tensor.new_empty((new_bs,) + tensor.shape[1:])
@ -154,9 +157,6 @@ class CausalLMBatch(Batch):
top_n_tokens: List[int]
top_n_tokens_tensor: torch.Tensor
# Maximum number of tokens this batch will grow to
max_tokens: int
input_length: int
right_padding: int
@ -169,32 +169,53 @@ class CausalLMBatch(Batch):
)
@classmethod
def recombine(cls, batches: List["CausalLMBatch"], req_ids: List[List[int]], is_optimized_for_gaudi: bool = False) -> "CausalLMBatch":
new_bs = round_up(sum([len(reqs) for reqs in req_ids]), BATCH_BUCKET_SIZE)
def recombine(cls, batches: List["CausalLMBatch"], is_optimized_for_gaudi: bool = False) -> "CausalLMBatch":
total_requests = sum(len(b) for b in batches)
new_bs = round_up(total_requests, BATCH_BUCKET_SIZE)
batch_id = batches[0].batch_id
device = batches[0].input_ids.device
# TODO: for now use consecutive indices. This could be optimized to reuse existing batch memory and only overwrite
# indices that are no longer used instead of allocating new memory
free_indices = itertools.count(0)
to_tensors = lambda ind: (torch.tensor(ind[0], device=device), torch.tensor(ind[1], device=device))
requests = [[req for req in batch.requests if req.data.id in ids] for batch, ids in zip(batches, req_ids)]
indices = [[to_tensors(req.update_idx(next(free_indices))) for req in batch_reqs] for batch_reqs in requests]
requests = list(itertools.chain(*requests))
max_input_length = max(b.input_length for b in batches)
offsets = [max_input_length - b.input_length for b in batches]
padding = [b.right_padding for b in batches]
moves_needed = [total_requests - len(b) if b.batch_size == new_bs else total_requests for b in batches]
target_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0]
# TODO: Add support for changing max seq len, i.e. due to output length bucketing
# FIXME: max_seq_len for non optimized code
max_input_length = max(req.input_length for req in requests)
offsets = [(max_input_length - b.input_length) for b in batches]
scenario = 'CONCAT' if len(batches) > 1 else 'FILTER'
dbg_trace(scenario, f'bs:{[b.input_ids.size(0) for b in batches]}->{new_bs} num_reqs:{[len(b.requests) for b in batches]}->{len(requests)} offsets:{offsets}')
if len(batches) > 1:
scenario = 'CONCAT'
elif batches[0].batch_size != new_bs:
scenario = 'RESHAPE'
elif padding[0] <= 1:
scenario = 'SHIFT'
offsets = [b.max_input_length - max_input_length for b in batches]
max_input_length = max(b.max_input_length for b in batches)
else:
# Nothing to do
return batches[0]
inplace = batches[target_batch_idx].batch_size == new_bs
dbg_trace(scenario, f'bs:{[b.batch_size for b in batches]}->{new_bs} reqs:{[len(b) for b in batches]} offsets:{offsets} padding:{padding} moves_needed:{moves_needed} inplace:{inplace}')
grouped_requests = [[req for req in batch.requests] for batch in batches]
flat_requests = list(itertools.chain(*grouped_requests))
if inplace and scenario != 'SHIFT':
# The data is already present in the batch. No need to move it
grouped_requests[target_batch_idx] = []
free_indices = batches[target_batch_idx].free_indices()
else:
free_indices = itertools.count(0)
to_tensors = lambda ind: (torch.tensor(ind[0], device=device), torch.tensor(ind[1], device=device))
indices = [[to_tensors(req.update_idx(next(free_indices))) for req in batch_reqs] for batch_reqs in grouped_requests]
max_seq_len = batches[0].attention_mask.size(1)
input_length = max(r.input_length for r in requests)
input_length = max_input_length
right_padding = max_seq_len - input_length
max_tokens = len(requests) * max_seq_len
chunk_size = batches[0].past_key_values[0][0].size(0) // batches[0].input_ids.size(0)
chunk_size = batches[0].past_key_values[0][0].size(0) // batches[0].batch_size
num_layers = len(batches[0].past_key_values)
past_key_values_type = type(batches[0].past_key_values)
@ -213,33 +234,33 @@ class CausalLMBatch(Batch):
for b in batches:
del b.input_ids
src = shift_all(src, seq_dim, offsets)
input_ids = batch_alloc(new_bs, src[0])
input_ids = prepare_memory(new_bs, src[target_batch_idx], inplace)
input_ids = move_data(input_ids, 1, indices, src)
src = [b.attention_mask for b in batches]
for b in batches:
del b.attention_mask
src = shift_all(src, seq_dim, offsets)
attention_mask = batch_alloc(new_bs, src[0])
attention_mask = prepare_memory(new_bs, src[target_batch_idx], inplace)
attention_mask = move_data(attention_mask, 1, indices, src)
src = [b.position_ids for b in batches]
for b in batches:
del b.position_ids
src = shift_all(src, seq_dim, offsets)
position_ids = batch_alloc(new_bs, src[0])
position_ids = prepare_memory(new_bs, src[target_batch_idx], inplace)
position_ids = move_data(position_ids, 1, indices, src)
past_key_values = []
for layer_num in range(num_layers):
src = [b.past_key_values[layer_num][0] for b in batches]
src = shift_all(src, key_dim, offsets)
updated_key = batch_alloc(new_bs * chunk_size, src[0])
updated_key = prepare_memory(new_bs * chunk_size, src[target_batch_idx], inplace)
updated_key = move_data(updated_key, chunk_size, indices, src)
src = [b.past_key_values[layer_num][1] for b in batches]
src = shift_all(src, value_dim, offsets)
updated_value = batch_alloc(new_bs * chunk_size, src[0])
updated_value = prepare_memory(new_bs * chunk_size, src[target_batch_idx], inplace)
updated_value = move_data(updated_value, chunk_size, indices, src)
past_key_values.append((updated_key, updated_value))
@ -248,10 +269,10 @@ class CausalLMBatch(Batch):
past_key_values = past_key_values_type(past_key_values)
top_n_tokens = [r.data.top_n_tokens for r in requests]
top_n_tokens = [r.data.top_n_tokens for r in flat_requests]
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
[r.data.parameters for r in requests],
[r.data.parameters for r in flat_requests],
batches[0].next_token_chooser.device,
batches[0].next_token_chooser.dtype
)
@ -260,7 +281,7 @@ class CausalLMBatch(Batch):
return cls(
batch_id=batch_id,
requests=requests,
requests=flat_requests,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
@ -268,7 +289,6 @@ class CausalLMBatch(Batch):
next_token_chooser=next_token_chooser,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_tokens=max_tokens,
input_length=input_length,
right_padding=right_padding
)
@ -327,9 +347,6 @@ class CausalLMBatch(Batch):
r.prefix_offset = input_len - 5
r.read_offset = input_len
#max_tokens = new_bs * max_total_tokens
max_tokens = len(requests) * max_total_tokens
input_ids = tokenized_inputs["input_ids"]
attention_mask = tokenized_inputs["attention_mask"]
@ -363,23 +380,50 @@ class CausalLMBatch(Batch):
next_token_chooser=next_token_chooser,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_tokens=max_tokens,
input_length=max_input_length,
right_padding=max_new_tokens + extra_padding if is_optimized_for_gaudi else 0
)
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int], is_optimized_for_gaudi: bool = False) -> Optional["CausalLMBatch"]:
return self.__class__.recombine([self], [request_ids], is_optimized_for_gaudi)
dbg_trace('FILTER', f'num_reqs:{len(self.requests)} -> {len(request_ids)}')
request_ids = set(request_ids)
self.requests = [req for req in self.requests if req.data.id in request_ids]
return self.__class__.recombine([self], is_optimized_for_gaudi)
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["CausalLMBatch"], is_optimized_for_gaudi: bool = False) -> "CausalLMBatch":
return cls.recombine(batches, [[req.data.id for req in b.requests] for b in batches], is_optimized_for_gaudi)
return cls.recombine(batches, is_optimized_for_gaudi)
def __len__(self):
return len(self.requests)
@property
def max_input_length(self):
return max(req.input_length for req in self.requests)
@property
def batch_size(self):
return self.attention_mask.size(0)
@property
def seq_length(self):
return self.attention_mask.size(1)
# Maximum number of tokens this batch will grow to
@property
def max_tokens(self):
max_total_tokens = self.attention_mask.size(1)
return len(self.requests) * max_total_tokens
def free_indices(self):
used = set(req.idx for req in self.requests)
for i in range(self.batch_size):
if i in used:
continue
yield i
class CausalLM(Model):
def __init__(
@ -550,8 +594,12 @@ class CausalLM(Model):
@tracer.start_as_current_span("generate_token")
def generate_token(self, batch: CausalLMBatch) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
prefill = batch.past_key_values is None
# Check if we need to do any bookkeeping first
if not prefill:
batch = batch.__class__.recombine([batch], self.is_optimized_for_gaudi)
scenario = 'PREFILL' if prefill else 'GENERATE'
dbg_trace(scenario, f'bs:{batch.input_ids.size(0)} num_reqs:{len(batch.requests)} seq_len:{batch.input_ids.shape[1]}')
dbg_trace(scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length}')
self.step = self.step + 1
if self.hb_profer_started == True and self.step > self.profiling_warmup_steps + self.profiling_steps:
self.hb_profer.stop()
@ -605,21 +653,21 @@ class CausalLM(Model):
next_token_ids_cpu = next_token_ids.cpu()
htorch.core.mark_step()
for req in batch.requests:
for req_idx, req in enumerate(batch.requests):
i = req.idx
request = req.data
input_length = req.input_length
prefix_offset = req.prefix_offset
read_offset = req.read_offset
do_sample = batch.next_token_chooser.do_sample[i]
seed = batch.next_token_chooser.seeds[i]
do_sample = batch.next_token_chooser.do_sample[req_idx]
seed = batch.next_token_chooser.seeds[req_idx]
stopping_criteria = req.stopping_criteria
all_input_ids = req.all_input_ids
top_n_tokens = batch.top_n_tokens[i]
top_n_tokens = batch.top_n_tokens[req_idx]
next_token_id = next_token_ids_cpu[i]
next_token_logprob = next_token_logprobs[i]
top_token_ids = batch_top_token_ids[i]
top_token_logprobs = batch_top_token_logprobs[i]
top_token_ids = batch_top_token_ids[req_idx]
top_token_logprobs = batch_top_token_logprobs[req_idx]
# Append next token to all tokens
if self.is_optimized_for_gaudi:
@ -717,6 +765,7 @@ class CausalLM(Model):
if stopped:
if self.hb_profer_started == True:
self.hb_profer.step()
htorch.core.mark_step()
return generations, None
# Slice unused values from prefill, use it to store next token