Grouped pad/shift/move operations (#57) (#82)

Co-authored-by: madamczykhabana <110973826+madamczykhabana@users.noreply.github.com>
This commit is contained in:
Karol Damaszke 2024-02-29 04:16:44 +01:00 committed by GitHub
parent 2122acc60f
commit c7ccfb87ff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,6 +3,8 @@ import tempfile
import itertools
import time
import glob
import bisect
import math
import torch
@ -51,6 +53,7 @@ PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF',
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 4))
DBG_TRACE_FILENAME = os.environ.get('DBG_TRACE_FILENAME')
START_TS = None
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
def count_hpu_graphs():
@ -74,61 +77,111 @@ def round_up(number, k):
return (number + k - 1) // k * k
def prepare_memory(new_bs, tensor, inplace):
if inplace:
return tensor
else:
return tensor.new_empty((new_bs,) + tensor.shape[1:])
def to_tensor_indices(indices, device):
return [torch.tensor(idx, dtype=torch.int32, device=device) for idx in indices]
def move_data(dst_tensor, chunk_size, indices, src_tensors):
batch_dim = 0
bs = dst_tensor.size(batch_dim)
assert bs % chunk_size == 0, 'Batch dim must be divisible by chunk size!'
result = dst_tensor
if chunk_size > 1:
dst_tensor = dst_tensor.view(bs // chunk_size, chunk_size, *dst_tensor.shape[1:])
htorch.core.mark_step()
for ind, src_t in zip(indices, src_tensors):
if chunk_size > 1:
src_t = src_t.view(bs // chunk_size, chunk_size, *src_t.shape[1:])
for dst_idx, src_idx in ind:
src_data = torch.index_select(src_t, batch_dim, src_idx)
dst_tensor.index_copy_(batch_dim, dst_idx, src_data)
htorch.core.mark_step()
return result
def generate_shift_chunks(offset):
chunk_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
def calculate_chunks(offset):
result = []
while offset != 0:
sign = 1 if offset > 0 else -1
best_chunk = min((abs(offset - sign * c), sign * c) for c in chunk_sizes)[1]
best_chunk = min((abs(offset - sign * c), sign * c) for c in CHUNK_SIZES)[1]
result.append(best_chunk)
offset = offset - best_chunk
return result
def roll(tensor, dim, chunks):
dbg_trace('ROLL', f'shape:{list(tensor.shape)} dim:{dim} chunks:{chunks}')
for c in chunks:
tensor = torch.roll(tensor, c, dim)
def biggest_single_chunk(offset):
if offset != 0:
idx = bisect.bisect(CHUNK_SIZES, abs(offset))
return int(math.copysign(CHUNK_SIZES[idx-1], offset))
else:
return 0
def grouped_pad(tensor_groups, dims, values):
grouped_result = []
for tensors, dim, value in zip(tensor_groups, dims, values):
padding = MAX_TOTAL_TOKENS - tensors[0].size(dim) if dim is not None else 0
if padding > 0:
assert dim in [-1, -2], f'Only dims -1 and -2 are supported! {dim}'
pad_shape = (0, 0, 0, padding) if dim == -2 else (0, padding)
result = [torch.nn.functional.pad(t, pad_shape, value=value) for t in tensors]
else:
result = [t for t in tensors]
grouped_result.append(result)
htorch.core.mark_step()
return grouped_result
def roll(tensor, chunk, dim, merge_graphs):
if dim is None:
return tensor
tensor = torch.roll(tensor, chunk, dim)
if not merge_graphs:
htorch.core.mark_step()
return tensor
def shift(tensor, dim, offset):
assert dim < 0, 'Only negative dims are supported'
if offset == 0:
return tensor
chunks = generate_shift_chunks(offset)
tensor = roll(tensor, dim, chunks)
return tensor
def grouped_roll(tensor_groups, chunk, dims, merge_graphs):
tensor_groups = [[roll(t, chunk, dim, merge_graphs) for t in tensors] for tensors, dim in zip(tensor_groups, dims)]
if merge_graphs:
htorch.core.mark_step()
return tensor_groups
def shift_all(srcs, dim, offsets):
return [shift(src, dim, offset) for src, offset in zip(srcs, offsets)]
def grouped_shift(tensor_groups, dims, offset, merge_graphs):
chunks = calculate_chunks(offset)
for c in chunks:
tensor_groups = grouped_roll(tensor_groups, c, dims, merge_graphs)
return tensor_groups
def move(dst_tensors, dst_dim, dst_indices, src_tensors, src_dim, src_indices):
if dst_dim == 1 and src_dim == 0:
# Case 1: Only destination is merged
dst_tensors = dst_tensors[0]
dst_dim = 0
elif dst_dim == 0 and src_dim == 1:
# Case 2: Only source is merged
src_tensors = src_tensors[0]
src_dim = 0
else:
# Other cases don't need special support
pass
for dst_t, src_t in zip(dst_tensors, src_tensors):
for dst_idx, src_idx in zip(dst_indices, src_indices):
dst_t.index_copy_(dst_dim, dst_idx, torch.index_select(src_t, src_dim, src_idx))
def grouped_move(dst_tensor_groups, dst_dims, dst_indices, src_tensor_groups, src_dims, src_indices):
for dst_tensors, dst_dim, src_tensors, src_dim in zip(dst_tensor_groups, dst_dims, src_tensor_groups, src_dims):
move(dst_tensors, dst_dim, dst_indices, src_tensors, src_dim, src_indices)
htorch.core.mark_step()
return dst_tensor_groups
def extend_tensor(tensor, padding, dim):
result = torch.cat([tensor, padding], dim=dim)
htorch.core.mark_step()
return result
def extend_batch(tensors, target_bs, dim):
diff = target_bs - tensors[0].size(dim)
#TODO: add support for shrinking bs
if diff <= 0:
return tensors
shape = list(tensors[0].shape)
shape[dim] = diff
padding = torch.empty(shape, device=tensors[0].device, dtype=tensors[0].dtype)
tensors = [extend_tensor(t, padding, dim) for t in tensors]
return tensors
def grouped_extend_batch(tensor_groups, target_bs, bs_dims):
tensor_groups = [extend_batch(tensors, target_bs, dim) for tensors, dim in zip(tensor_groups, bs_dims)]
return tensor_groups
def remove_kv_cache_from_output(module):
@ -152,13 +205,11 @@ def remove_kv_cache_from_output(module):
return module
def pad_tensors(tensors, paddings, dim, value):
for i, (tensor, padding) in enumerate(zip(tensors, paddings)):
if padding > 0:
pad_shape = (0, 0, 0, padding) if dim == -2 else (0, padding)
tensors[i] = torch.nn.functional.pad(tensor, pad_shape, value=value)
htorch.core.mark_step()
return tensors
def hpu_graph_fn(fn):
class FnModule(torch.nn.Module):
def forward(self, *args, **kwargs):
return fn(*args, **kwargs)
return wrap_in_hpu_graph(FnModule(), disable_tensor_cache=False)
@dataclass
@ -199,6 +250,7 @@ class CausalLMBatch(Batch):
attention_mask: torch.Tensor
position_ids: torch.Tensor
past_key_values: Optional[List[Tuple]]
merged_kv_cache: bool
# Generation helpers
next_token_chooser: HeterogeneousNextTokenChooser
@ -218,6 +270,103 @@ class CausalLMBatch(Batch):
max_tokens=self.max_tokens,
)
def detach_kv_cache(self):
past_keys = [past[0] for past in self.past_key_values]
past_values = [past[1] for past in self.past_key_values]
del self.past_key_values
return past_keys, past_values
def attach_kv_cache(self, past_keys, past_values):
# TODO: Add support for models that don't store kv_cache in a list
self.past_key_values = list(zip(past_keys, past_values))
def merge_kv_cache_if_needed(self, target_bs, offset):
pad_needed = self.seq_length < MAX_TOTAL_TOKENS
shift_needed = offset != 0
expand_needed = target_bs > self.batch_size
# Very simple heuristic to determine whether we should merge tensors
# this needs tuning for other models/scenarios
small_bs = len(self.past_key_values) > self.batch_size
if not self.merged_kv_cache and small_bs and (pad_needed or shift_needed or expand_needed):
past_keys, past_values = self.detach_kv_cache()
past_keys = [torch.stack(past_keys)]
past_values = [torch.stack(past_values)]
self.attach_kv_cache(past_keys, past_values)
self.merged_kv_cache = True
htorch.core.mark_step()
def split_kv_cache_if_needed(self):
if self.merged_kv_cache:
past_keys, past_values = self.detach_kv_cache()
past_keys = [t.clone() for t in past_keys[0]]
past_values = [t.clone() for t in past_values[0]]
self.attach_kv_cache(past_keys, past_values)
self.merged_kv_cache = False
htorch.core.mark_step()
def get_tensor_groups(self):
past_keys, past_values = self.detach_kv_cache()
seq_dim = -1
key_dim = -2 # TODO: Add case for Bloom and other models
value_dim = -2
tensors = [[self.input_ids], [self.attention_mask], [self.position_ids], past_keys, past_values]
# We don't need to align position_ids
seq_dims = [seq_dim, seq_dim, None, key_dim, value_dim]
bs_dims = [0, 0, 0] + ([1, 1] if self.merged_kv_cache else [0, 0])
return tensors, seq_dims, bs_dims
def set_tensor_groups(self, tensors):
self.input_ids = tensors.pop(0)[0]
self.attention_mask = tensors.pop(0)[0]
self.position_ids = tensors.pop(0)[0]
past_keys = tensors.pop(0)
past_values = tensors.pop(0)
self.attach_kv_cache(past_keys, past_values)
def realign(self, target_bs, offset, pad_token_id):
tensors, seq_dims, _ = self.get_tensor_groups()
tensors = grouped_pad(tensors, seq_dims, [pad_token_id, 0, 0, 0, 0])
tensors = grouped_shift(tensors, seq_dims, offset, self.merged_kv_cache)
self.set_tensor_groups(tensors)
def expand_bs(self, target_bs):
tensors, _, bs_dims = self.get_tensor_groups()
tensors = grouped_extend_batch(tensors, target_bs, bs_dims)
self.set_tensor_groups(tensors)
def used_indices(self):
return [req.idx for req in self.requests]
def update_indices(self, new_indices):
for req, new_idx in zip(self.requests, new_indices):
req.idx = new_idx
return self.used_indices()
def free_indices_generator(self):
used = set(req.idx for req in self.requests)
return (i for i in range(self.batch_size) if i not in used)
def move_data(self, src_batches):
dst_tensors, _, dst_dims = self.get_tensor_groups()
free_indices_gen = self.free_indices_generator()
for src_b in src_batches:
src_indices = to_tensor_indices(src_b.used_indices(), self.input_ids.device)
dst_indices = to_tensor_indices(src_b.update_indices(free_indices_gen), self.input_ids.device)
src_tensors, _, src_dims = src_b.get_tensor_groups()
# Instead of doing one huge grouped_move we're splitting it into 3 to improve perf and mem usage
# Both dst_tensors and src_tensors are lists of lists which follow this pattern:
# [[position_ids], [attention_mask], [position_ids], past_keys, past_values]
# move only past_keys
dst_tensors[3:4] = grouped_move(dst_tensors[3:4], dst_dims[3:4], dst_indices, src_tensors[3:4], src_dims[3:4], src_indices)
# move only past_values
dst_tensors[4:5] = grouped_move(dst_tensors[4:5], dst_dims[4:5], dst_indices, src_tensors[4:5], src_dims[4:5], src_indices)
# move only input_ids, attention_mask and position_ids
dst_tensors[:3] = grouped_move(dst_tensors[:3], dst_dims[:3], dst_indices, src_tensors[:3], src_dims[:3], src_indices)
self.set_tensor_groups(dst_tensors)
@classmethod
def recombine(cls, batches: List["CausalLMBatch"], pad_token_id: int) -> "CausalLMBatch":
total_requests = sum(len(b) for b in batches)
@ -228,127 +377,59 @@ class CausalLMBatch(Batch):
input_lengths = [b.input_length for b in batches]
max_input_length = max(input_lengths)
offsets = [max_input_length - b.input_length for b in batches]
padding = [b.right_padding for b in batches]
cur_padding = [b.right_padding for b in batches]
# For prefill there is a space allocated only for first token
# Need to add padding to the max total tokens before first decode
extra_padding = [MAX_TOTAL_TOKENS - b.seq_length for b in batches]
moves_needed = [total_requests - len(b) if b.batch_size == new_bs else total_requests for b in batches]
target_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0]
dst_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0]
reshape = (batches[dst_batch_idx].batch_size != new_bs)
# TODO: Add support for changing max seq len, i.e. due to output length bucketing
# FIXME: max_seq_len for non optimized code
if len(batches) > 1:
scenario = 'CONCAT'
elif batches[target_batch_idx].batch_size != new_bs:
elif reshape:
scenario = 'RESHAPE'
elif padding[target_batch_idx] <= 0:
elif cur_padding[dst_batch_idx] <= 0:
scenario = 'SHIFT'
offsets = [b.max_input_length - max_input_length for b in batches]
max_input_length = max(b.max_input_length for b in batches)
offsets = [biggest_single_chunk(b.max_input_length - max_input_length) for b in batches]
max_input_length = max_input_length + offsets[dst_batch_idx]
else:
# Nothing to do
return batches[0]
inplace = (batches[target_batch_idx].batch_size == new_bs)
dbg_trace(
scenario, f'bs:{[b.batch_size for b in batches]}->{new_bs}'
f' reqs:{[len(b) for b in batches]}'
f' offsets:{offsets}'
f' input_lengths:{input_lengths}'
f' cur_padding:{padding}'
f' inplace:{inplace}')
f' cur_padding:{cur_padding}'
f' dst_batch:{dst_batch_idx}')
grouped_requests = [[req for req in batch.requests] for batch in batches]
flat_requests = list(itertools.chain(*grouped_requests))
if inplace:
# The data is already present in the batch. No need to move it
grouped_requests[target_batch_idx] = []
free_indices = batches[target_batch_idx].free_indices()
else:
free_indices = itertools.count(0)
def to_tensors(ind): return (torch.tensor(ind[0], device=device), torch.tensor(ind[1], device=device))
indices = [[to_tensors(req.update_idx(next(free_indices))) for req in batch_reqs]
for batch_reqs in grouped_requests]
chunk_size = batches[0].past_key_values[0][0].size(0) // batches[0].batch_size
num_layers = len(batches[0].past_key_values)
past_key_values_type = type(batches[0].past_key_values)
seq_dim = -1
if batches[0].past_key_values[0][0].size(-1) != batches[0].past_key_values[0][1].size(-1):
# Case for Bloom
key_dim = -1
else:
key_dim = -2
value_dim = -2
for b in batches:
b.past_key_values = list(b.past_key_values)
src = [b.input_ids for b in batches]
for b in batches:
del b.input_ids
src = pad_tensors(src, extra_padding, seq_dim, pad_token_id)
src = shift_all(src, seq_dim, offsets)
input_ids = prepare_memory(new_bs, src[target_batch_idx], inplace)
input_ids = move_data(input_ids, 1, indices, src)
src = [b.attention_mask for b in batches]
for b in batches:
del b.attention_mask
src = pad_tensors(src, extra_padding, seq_dim, 0)
src = shift_all(src, seq_dim, offsets)
attention_mask = prepare_memory(new_bs, src[target_batch_idx], inplace)
attention_mask = move_data(attention_mask, 1, indices, src)
src = [b.position_ids for b in batches]
for b in batches:
del b.position_ids
position_ids = prepare_memory(new_bs, src[target_batch_idx], inplace)
position_ids = move_data(position_ids, 1, indices, src)
src = None
src_keys = [[b.past_key_values[layer_num][0] for layer_num in range(num_layers)] for b in batches]
src_values = [[b.past_key_values[layer_num][1] for layer_num in range(num_layers)] for b in batches]
for b in batches:
del b.past_key_values
src_keys = [torch.stack(src) for src in src_keys]
htorch.core.mark_step()
src_keys = pad_tensors(src_keys, extra_padding, key_dim, 0)
src_keys = shift_all(src_keys, key_dim, offsets)
src_keys = [[t.squeeze(0).clone() for t in torch.split(src, 1)] for src in src_keys]
htorch.core.mark_step()
dst_keys = [prepare_memory(new_bs * chunk_size, prev, inplace) for prev in src_keys[target_batch_idx]]
dst_keys = [move_data(dst_keys[layer_num], chunk_size, indices, [src[layer_num]
for src in src_keys]) for layer_num in range(num_layers)]
src_values = [torch.stack(src) for src in src_values]
htorch.core.mark_step()
src_values = pad_tensors(src_values, extra_padding, value_dim, 0)
src_values = shift_all(src_values, value_dim, offsets)
src_values = [[t.squeeze(0).clone() for t in torch.split(src, 1)] for src in src_values]
htorch.core.mark_step()
dst_values = [prepare_memory(new_bs * chunk_size, prev, inplace) for prev in src_values[target_batch_idx]]
dst_values = [move_data(dst_values[layer_num], chunk_size, indices, [src[layer_num]
for src in src_values]) for layer_num in range(num_layers)]
past_key_values = past_key_values_type(zip(dst_keys, dst_values))
for i in range(len(batches)):
target_bs = new_bs if i == dst_batch_idx else batches[i].batch_size
batches[i].merge_kv_cache_if_needed(target_bs, offsets[i])
batches[i].realign(target_bs, offsets[i], pad_token_id)
batches[dst_batch_idx].split_kv_cache_if_needed()
batches[dst_batch_idx].expand_bs(new_bs)
batches[dst_batch_idx].move_data([batches[i] for i in range(len(batches)) if i != dst_batch_idx])
top_n_tokens = [r.data.top_n_tokens for r in flat_requests]
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
[r.data.parameters for r in flat_requests],
batches[0].next_token_chooser.dtype,
batches[0].next_token_chooser.device
batches[dst_batch_idx].next_token_chooser.dtype,
batches[dst_batch_idx].next_token_chooser.device
)
max_seq_len = attention_mask.size(1)
input_ids = batches[dst_batch_idx].input_ids
attention_mask = batches[dst_batch_idx].attention_mask
position_ids = batches[dst_batch_idx].position_ids
past_key_values = batches[dst_batch_idx].past_key_values
input_length = max_input_length
htorch.core.mark_step()
@ -360,6 +441,7 @@ class CausalLMBatch(Batch):
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
merged_kv_cache=False,
next_token_chooser=next_token_chooser,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
@ -448,6 +530,7 @@ class CausalLMBatch(Batch):
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=None,
merged_kv_cache=False,
next_token_chooser=next_token_chooser,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
@ -491,13 +574,6 @@ class CausalLMBatch(Batch):
max_total_tokens = self.attention_mask.size(1)
return len(self.requests) * max_total_tokens
def free_indices(self):
used = set(req.idx for req in self.requests)
for i in range(self.batch_size):
if i in used:
continue
yield i
class CausalLM(Model):
def __init__(