mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-07-30 03:40:17 +00:00
Co-authored-by: madamczykhabana <110973826+madamczykhabana@users.noreply.github.com>
This commit is contained in:
parent
2122acc60f
commit
c7ccfb87ff
@ -3,6 +3,8 @@ import tempfile
|
||||
import itertools
|
||||
import time
|
||||
import glob
|
||||
import bisect
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
@ -51,6 +53,7 @@ PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF',
|
||||
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 4))
|
||||
DBG_TRACE_FILENAME = os.environ.get('DBG_TRACE_FILENAME')
|
||||
START_TS = None
|
||||
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
||||
|
||||
|
||||
def count_hpu_graphs():
|
||||
@ -74,61 +77,111 @@ def round_up(number, k):
|
||||
return (number + k - 1) // k * k
|
||||
|
||||
|
||||
def prepare_memory(new_bs, tensor, inplace):
|
||||
if inplace:
|
||||
return tensor
|
||||
else:
|
||||
return tensor.new_empty((new_bs,) + tensor.shape[1:])
|
||||
def to_tensor_indices(indices, device):
|
||||
return [torch.tensor(idx, dtype=torch.int32, device=device) for idx in indices]
|
||||
|
||||
|
||||
def move_data(dst_tensor, chunk_size, indices, src_tensors):
|
||||
batch_dim = 0
|
||||
bs = dst_tensor.size(batch_dim)
|
||||
assert bs % chunk_size == 0, 'Batch dim must be divisible by chunk size!'
|
||||
result = dst_tensor
|
||||
if chunk_size > 1:
|
||||
dst_tensor = dst_tensor.view(bs // chunk_size, chunk_size, *dst_tensor.shape[1:])
|
||||
htorch.core.mark_step()
|
||||
for ind, src_t in zip(indices, src_tensors):
|
||||
if chunk_size > 1:
|
||||
src_t = src_t.view(bs // chunk_size, chunk_size, *src_t.shape[1:])
|
||||
for dst_idx, src_idx in ind:
|
||||
src_data = torch.index_select(src_t, batch_dim, src_idx)
|
||||
dst_tensor.index_copy_(batch_dim, dst_idx, src_data)
|
||||
htorch.core.mark_step()
|
||||
return result
|
||||
|
||||
|
||||
def generate_shift_chunks(offset):
|
||||
chunk_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
||||
def calculate_chunks(offset):
|
||||
result = []
|
||||
while offset != 0:
|
||||
sign = 1 if offset > 0 else -1
|
||||
best_chunk = min((abs(offset - sign * c), sign * c) for c in chunk_sizes)[1]
|
||||
best_chunk = min((abs(offset - sign * c), sign * c) for c in CHUNK_SIZES)[1]
|
||||
result.append(best_chunk)
|
||||
offset = offset - best_chunk
|
||||
return result
|
||||
|
||||
|
||||
def roll(tensor, dim, chunks):
|
||||
dbg_trace('ROLL', f'shape:{list(tensor.shape)} dim:{dim} chunks:{chunks}')
|
||||
for c in chunks:
|
||||
tensor = torch.roll(tensor, c, dim)
|
||||
def biggest_single_chunk(offset):
|
||||
if offset != 0:
|
||||
idx = bisect.bisect(CHUNK_SIZES, abs(offset))
|
||||
return int(math.copysign(CHUNK_SIZES[idx-1], offset))
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def grouped_pad(tensor_groups, dims, values):
|
||||
grouped_result = []
|
||||
for tensors, dim, value in zip(tensor_groups, dims, values):
|
||||
padding = MAX_TOTAL_TOKENS - tensors[0].size(dim) if dim is not None else 0
|
||||
if padding > 0:
|
||||
assert dim in [-1, -2], f'Only dims -1 and -2 are supported! {dim}'
|
||||
pad_shape = (0, 0, 0, padding) if dim == -2 else (0, padding)
|
||||
result = [torch.nn.functional.pad(t, pad_shape, value=value) for t in tensors]
|
||||
else:
|
||||
result = [t for t in tensors]
|
||||
grouped_result.append(result)
|
||||
htorch.core.mark_step()
|
||||
return grouped_result
|
||||
|
||||
|
||||
def roll(tensor, chunk, dim, merge_graphs):
|
||||
if dim is None:
|
||||
return tensor
|
||||
tensor = torch.roll(tensor, chunk, dim)
|
||||
if not merge_graphs:
|
||||
htorch.core.mark_step()
|
||||
return tensor
|
||||
|
||||
|
||||
def shift(tensor, dim, offset):
|
||||
assert dim < 0, 'Only negative dims are supported'
|
||||
if offset == 0:
|
||||
return tensor
|
||||
chunks = generate_shift_chunks(offset)
|
||||
tensor = roll(tensor, dim, chunks)
|
||||
return tensor
|
||||
def grouped_roll(tensor_groups, chunk, dims, merge_graphs):
|
||||
tensor_groups = [[roll(t, chunk, dim, merge_graphs) for t in tensors] for tensors, dim in zip(tensor_groups, dims)]
|
||||
if merge_graphs:
|
||||
htorch.core.mark_step()
|
||||
return tensor_groups
|
||||
|
||||
|
||||
def shift_all(srcs, dim, offsets):
|
||||
return [shift(src, dim, offset) for src, offset in zip(srcs, offsets)]
|
||||
def grouped_shift(tensor_groups, dims, offset, merge_graphs):
|
||||
chunks = calculate_chunks(offset)
|
||||
for c in chunks:
|
||||
tensor_groups = grouped_roll(tensor_groups, c, dims, merge_graphs)
|
||||
return tensor_groups
|
||||
|
||||
|
||||
def move(dst_tensors, dst_dim, dst_indices, src_tensors, src_dim, src_indices):
|
||||
if dst_dim == 1 and src_dim == 0:
|
||||
# Case 1: Only destination is merged
|
||||
dst_tensors = dst_tensors[0]
|
||||
dst_dim = 0
|
||||
elif dst_dim == 0 and src_dim == 1:
|
||||
# Case 2: Only source is merged
|
||||
src_tensors = src_tensors[0]
|
||||
src_dim = 0
|
||||
else:
|
||||
# Other cases don't need special support
|
||||
pass
|
||||
for dst_t, src_t in zip(dst_tensors, src_tensors):
|
||||
for dst_idx, src_idx in zip(dst_indices, src_indices):
|
||||
dst_t.index_copy_(dst_dim, dst_idx, torch.index_select(src_t, src_dim, src_idx))
|
||||
|
||||
|
||||
def grouped_move(dst_tensor_groups, dst_dims, dst_indices, src_tensor_groups, src_dims, src_indices):
|
||||
for dst_tensors, dst_dim, src_tensors, src_dim in zip(dst_tensor_groups, dst_dims, src_tensor_groups, src_dims):
|
||||
move(dst_tensors, dst_dim, dst_indices, src_tensors, src_dim, src_indices)
|
||||
htorch.core.mark_step()
|
||||
return dst_tensor_groups
|
||||
|
||||
|
||||
def extend_tensor(tensor, padding, dim):
|
||||
result = torch.cat([tensor, padding], dim=dim)
|
||||
htorch.core.mark_step()
|
||||
return result
|
||||
|
||||
|
||||
def extend_batch(tensors, target_bs, dim):
|
||||
diff = target_bs - tensors[0].size(dim)
|
||||
#TODO: add support for shrinking bs
|
||||
if diff <= 0:
|
||||
return tensors
|
||||
shape = list(tensors[0].shape)
|
||||
shape[dim] = diff
|
||||
padding = torch.empty(shape, device=tensors[0].device, dtype=tensors[0].dtype)
|
||||
tensors = [extend_tensor(t, padding, dim) for t in tensors]
|
||||
return tensors
|
||||
|
||||
|
||||
def grouped_extend_batch(tensor_groups, target_bs, bs_dims):
|
||||
tensor_groups = [extend_batch(tensors, target_bs, dim) for tensors, dim in zip(tensor_groups, bs_dims)]
|
||||
return tensor_groups
|
||||
|
||||
|
||||
def remove_kv_cache_from_output(module):
|
||||
@ -152,13 +205,11 @@ def remove_kv_cache_from_output(module):
|
||||
return module
|
||||
|
||||
|
||||
def pad_tensors(tensors, paddings, dim, value):
|
||||
for i, (tensor, padding) in enumerate(zip(tensors, paddings)):
|
||||
if padding > 0:
|
||||
pad_shape = (0, 0, 0, padding) if dim == -2 else (0, padding)
|
||||
tensors[i] = torch.nn.functional.pad(tensor, pad_shape, value=value)
|
||||
htorch.core.mark_step()
|
||||
return tensors
|
||||
def hpu_graph_fn(fn):
|
||||
class FnModule(torch.nn.Module):
|
||||
def forward(self, *args, **kwargs):
|
||||
return fn(*args, **kwargs)
|
||||
return wrap_in_hpu_graph(FnModule(), disable_tensor_cache=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -199,6 +250,7 @@ class CausalLMBatch(Batch):
|
||||
attention_mask: torch.Tensor
|
||||
position_ids: torch.Tensor
|
||||
past_key_values: Optional[List[Tuple]]
|
||||
merged_kv_cache: bool
|
||||
|
||||
# Generation helpers
|
||||
next_token_chooser: HeterogeneousNextTokenChooser
|
||||
@ -218,6 +270,103 @@ class CausalLMBatch(Batch):
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
|
||||
def detach_kv_cache(self):
|
||||
past_keys = [past[0] for past in self.past_key_values]
|
||||
past_values = [past[1] for past in self.past_key_values]
|
||||
del self.past_key_values
|
||||
return past_keys, past_values
|
||||
|
||||
def attach_kv_cache(self, past_keys, past_values):
|
||||
# TODO: Add support for models that don't store kv_cache in a list
|
||||
self.past_key_values = list(zip(past_keys, past_values))
|
||||
|
||||
def merge_kv_cache_if_needed(self, target_bs, offset):
|
||||
pad_needed = self.seq_length < MAX_TOTAL_TOKENS
|
||||
shift_needed = offset != 0
|
||||
expand_needed = target_bs > self.batch_size
|
||||
# Very simple heuristic to determine whether we should merge tensors
|
||||
# this needs tuning for other models/scenarios
|
||||
small_bs = len(self.past_key_values) > self.batch_size
|
||||
if not self.merged_kv_cache and small_bs and (pad_needed or shift_needed or expand_needed):
|
||||
past_keys, past_values = self.detach_kv_cache()
|
||||
past_keys = [torch.stack(past_keys)]
|
||||
past_values = [torch.stack(past_values)]
|
||||
self.attach_kv_cache(past_keys, past_values)
|
||||
self.merged_kv_cache = True
|
||||
htorch.core.mark_step()
|
||||
|
||||
def split_kv_cache_if_needed(self):
|
||||
if self.merged_kv_cache:
|
||||
past_keys, past_values = self.detach_kv_cache()
|
||||
past_keys = [t.clone() for t in past_keys[0]]
|
||||
past_values = [t.clone() for t in past_values[0]]
|
||||
self.attach_kv_cache(past_keys, past_values)
|
||||
self.merged_kv_cache = False
|
||||
htorch.core.mark_step()
|
||||
|
||||
def get_tensor_groups(self):
|
||||
past_keys, past_values = self.detach_kv_cache()
|
||||
seq_dim = -1
|
||||
key_dim = -2 # TODO: Add case for Bloom and other models
|
||||
value_dim = -2
|
||||
tensors = [[self.input_ids], [self.attention_mask], [self.position_ids], past_keys, past_values]
|
||||
# We don't need to align position_ids
|
||||
seq_dims = [seq_dim, seq_dim, None, key_dim, value_dim]
|
||||
bs_dims = [0, 0, 0] + ([1, 1] if self.merged_kv_cache else [0, 0])
|
||||
return tensors, seq_dims, bs_dims
|
||||
|
||||
def set_tensor_groups(self, tensors):
|
||||
self.input_ids = tensors.pop(0)[0]
|
||||
self.attention_mask = tensors.pop(0)[0]
|
||||
self.position_ids = tensors.pop(0)[0]
|
||||
past_keys = tensors.pop(0)
|
||||
past_values = tensors.pop(0)
|
||||
self.attach_kv_cache(past_keys, past_values)
|
||||
|
||||
def realign(self, target_bs, offset, pad_token_id):
|
||||
tensors, seq_dims, _ = self.get_tensor_groups()
|
||||
tensors = grouped_pad(tensors, seq_dims, [pad_token_id, 0, 0, 0, 0])
|
||||
tensors = grouped_shift(tensors, seq_dims, offset, self.merged_kv_cache)
|
||||
self.set_tensor_groups(tensors)
|
||||
|
||||
def expand_bs(self, target_bs):
|
||||
tensors, _, bs_dims = self.get_tensor_groups()
|
||||
tensors = grouped_extend_batch(tensors, target_bs, bs_dims)
|
||||
self.set_tensor_groups(tensors)
|
||||
|
||||
def used_indices(self):
|
||||
return [req.idx for req in self.requests]
|
||||
|
||||
def update_indices(self, new_indices):
|
||||
for req, new_idx in zip(self.requests, new_indices):
|
||||
req.idx = new_idx
|
||||
return self.used_indices()
|
||||
|
||||
def free_indices_generator(self):
|
||||
used = set(req.idx for req in self.requests)
|
||||
return (i for i in range(self.batch_size) if i not in used)
|
||||
|
||||
def move_data(self, src_batches):
|
||||
dst_tensors, _, dst_dims = self.get_tensor_groups()
|
||||
free_indices_gen = self.free_indices_generator()
|
||||
for src_b in src_batches:
|
||||
src_indices = to_tensor_indices(src_b.used_indices(), self.input_ids.device)
|
||||
dst_indices = to_tensor_indices(src_b.update_indices(free_indices_gen), self.input_ids.device)
|
||||
src_tensors, _, src_dims = src_b.get_tensor_groups()
|
||||
|
||||
# Instead of doing one huge grouped_move we're splitting it into 3 to improve perf and mem usage
|
||||
# Both dst_tensors and src_tensors are lists of lists which follow this pattern:
|
||||
# [[position_ids], [attention_mask], [position_ids], past_keys, past_values]
|
||||
|
||||
# move only past_keys
|
||||
dst_tensors[3:4] = grouped_move(dst_tensors[3:4], dst_dims[3:4], dst_indices, src_tensors[3:4], src_dims[3:4], src_indices)
|
||||
# move only past_values
|
||||
dst_tensors[4:5] = grouped_move(dst_tensors[4:5], dst_dims[4:5], dst_indices, src_tensors[4:5], src_dims[4:5], src_indices)
|
||||
# move only input_ids, attention_mask and position_ids
|
||||
dst_tensors[:3] = grouped_move(dst_tensors[:3], dst_dims[:3], dst_indices, src_tensors[:3], src_dims[:3], src_indices)
|
||||
self.set_tensor_groups(dst_tensors)
|
||||
|
||||
|
||||
@classmethod
|
||||
def recombine(cls, batches: List["CausalLMBatch"], pad_token_id: int) -> "CausalLMBatch":
|
||||
total_requests = sum(len(b) for b in batches)
|
||||
@ -228,127 +377,59 @@ class CausalLMBatch(Batch):
|
||||
input_lengths = [b.input_length for b in batches]
|
||||
max_input_length = max(input_lengths)
|
||||
offsets = [max_input_length - b.input_length for b in batches]
|
||||
padding = [b.right_padding for b in batches]
|
||||
cur_padding = [b.right_padding for b in batches]
|
||||
# For prefill there is a space allocated only for first token
|
||||
# Need to add padding to the max total tokens before first decode
|
||||
extra_padding = [MAX_TOTAL_TOKENS - b.seq_length for b in batches]
|
||||
|
||||
moves_needed = [total_requests - len(b) if b.batch_size == new_bs else total_requests for b in batches]
|
||||
target_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0]
|
||||
dst_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0]
|
||||
reshape = (batches[dst_batch_idx].batch_size != new_bs)
|
||||
|
||||
# TODO: Add support for changing max seq len, i.e. due to output length bucketing
|
||||
# FIXME: max_seq_len for non optimized code
|
||||
if len(batches) > 1:
|
||||
scenario = 'CONCAT'
|
||||
elif batches[target_batch_idx].batch_size != new_bs:
|
||||
elif reshape:
|
||||
scenario = 'RESHAPE'
|
||||
elif padding[target_batch_idx] <= 0:
|
||||
elif cur_padding[dst_batch_idx] <= 0:
|
||||
scenario = 'SHIFT'
|
||||
offsets = [b.max_input_length - max_input_length for b in batches]
|
||||
max_input_length = max(b.max_input_length for b in batches)
|
||||
offsets = [biggest_single_chunk(b.max_input_length - max_input_length) for b in batches]
|
||||
max_input_length = max_input_length + offsets[dst_batch_idx]
|
||||
else:
|
||||
# Nothing to do
|
||||
return batches[0]
|
||||
|
||||
inplace = (batches[target_batch_idx].batch_size == new_bs)
|
||||
|
||||
dbg_trace(
|
||||
scenario, f'bs:{[b.batch_size for b in batches]}->{new_bs}'
|
||||
f' reqs:{[len(b) for b in batches]}'
|
||||
f' offsets:{offsets}'
|
||||
f' input_lengths:{input_lengths}'
|
||||
f' cur_padding:{padding}'
|
||||
f' inplace:{inplace}')
|
||||
f' cur_padding:{cur_padding}'
|
||||
f' dst_batch:{dst_batch_idx}')
|
||||
|
||||
grouped_requests = [[req for req in batch.requests] for batch in batches]
|
||||
flat_requests = list(itertools.chain(*grouped_requests))
|
||||
if inplace:
|
||||
# The data is already present in the batch. No need to move it
|
||||
grouped_requests[target_batch_idx] = []
|
||||
free_indices = batches[target_batch_idx].free_indices()
|
||||
else:
|
||||
free_indices = itertools.count(0)
|
||||
|
||||
def to_tensors(ind): return (torch.tensor(ind[0], device=device), torch.tensor(ind[1], device=device))
|
||||
indices = [[to_tensors(req.update_idx(next(free_indices))) for req in batch_reqs]
|
||||
for batch_reqs in grouped_requests]
|
||||
|
||||
chunk_size = batches[0].past_key_values[0][0].size(0) // batches[0].batch_size
|
||||
num_layers = len(batches[0].past_key_values)
|
||||
past_key_values_type = type(batches[0].past_key_values)
|
||||
|
||||
seq_dim = -1
|
||||
if batches[0].past_key_values[0][0].size(-1) != batches[0].past_key_values[0][1].size(-1):
|
||||
# Case for Bloom
|
||||
key_dim = -1
|
||||
else:
|
||||
key_dim = -2
|
||||
value_dim = -2
|
||||
|
||||
for b in batches:
|
||||
b.past_key_values = list(b.past_key_values)
|
||||
|
||||
src = [b.input_ids for b in batches]
|
||||
for b in batches:
|
||||
del b.input_ids
|
||||
src = pad_tensors(src, extra_padding, seq_dim, pad_token_id)
|
||||
src = shift_all(src, seq_dim, offsets)
|
||||
input_ids = prepare_memory(new_bs, src[target_batch_idx], inplace)
|
||||
input_ids = move_data(input_ids, 1, indices, src)
|
||||
|
||||
src = [b.attention_mask for b in batches]
|
||||
for b in batches:
|
||||
del b.attention_mask
|
||||
src = pad_tensors(src, extra_padding, seq_dim, 0)
|
||||
src = shift_all(src, seq_dim, offsets)
|
||||
attention_mask = prepare_memory(new_bs, src[target_batch_idx], inplace)
|
||||
attention_mask = move_data(attention_mask, 1, indices, src)
|
||||
|
||||
src = [b.position_ids for b in batches]
|
||||
for b in batches:
|
||||
del b.position_ids
|
||||
position_ids = prepare_memory(new_bs, src[target_batch_idx], inplace)
|
||||
position_ids = move_data(position_ids, 1, indices, src)
|
||||
|
||||
src = None
|
||||
src_keys = [[b.past_key_values[layer_num][0] for layer_num in range(num_layers)] for b in batches]
|
||||
src_values = [[b.past_key_values[layer_num][1] for layer_num in range(num_layers)] for b in batches]
|
||||
for b in batches:
|
||||
del b.past_key_values
|
||||
|
||||
src_keys = [torch.stack(src) for src in src_keys]
|
||||
htorch.core.mark_step()
|
||||
src_keys = pad_tensors(src_keys, extra_padding, key_dim, 0)
|
||||
src_keys = shift_all(src_keys, key_dim, offsets)
|
||||
src_keys = [[t.squeeze(0).clone() for t in torch.split(src, 1)] for src in src_keys]
|
||||
htorch.core.mark_step()
|
||||
|
||||
dst_keys = [prepare_memory(new_bs * chunk_size, prev, inplace) for prev in src_keys[target_batch_idx]]
|
||||
dst_keys = [move_data(dst_keys[layer_num], chunk_size, indices, [src[layer_num]
|
||||
for src in src_keys]) for layer_num in range(num_layers)]
|
||||
|
||||
src_values = [torch.stack(src) for src in src_values]
|
||||
htorch.core.mark_step()
|
||||
src_values = pad_tensors(src_values, extra_padding, value_dim, 0)
|
||||
src_values = shift_all(src_values, value_dim, offsets)
|
||||
src_values = [[t.squeeze(0).clone() for t in torch.split(src, 1)] for src in src_values]
|
||||
htorch.core.mark_step()
|
||||
|
||||
dst_values = [prepare_memory(new_bs * chunk_size, prev, inplace) for prev in src_values[target_batch_idx]]
|
||||
dst_values = [move_data(dst_values[layer_num], chunk_size, indices, [src[layer_num]
|
||||
for src in src_values]) for layer_num in range(num_layers)]
|
||||
|
||||
past_key_values = past_key_values_type(zip(dst_keys, dst_values))
|
||||
for i in range(len(batches)):
|
||||
target_bs = new_bs if i == dst_batch_idx else batches[i].batch_size
|
||||
batches[i].merge_kv_cache_if_needed(target_bs, offsets[i])
|
||||
batches[i].realign(target_bs, offsets[i], pad_token_id)
|
||||
batches[dst_batch_idx].split_kv_cache_if_needed()
|
||||
batches[dst_batch_idx].expand_bs(new_bs)
|
||||
batches[dst_batch_idx].move_data([batches[i] for i in range(len(batches)) if i != dst_batch_idx])
|
||||
|
||||
top_n_tokens = [r.data.top_n_tokens for r in flat_requests]
|
||||
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||
[r.data.parameters for r in flat_requests],
|
||||
batches[0].next_token_chooser.dtype,
|
||||
batches[0].next_token_chooser.device
|
||||
batches[dst_batch_idx].next_token_chooser.dtype,
|
||||
batches[dst_batch_idx].next_token_chooser.device
|
||||
)
|
||||
|
||||
max_seq_len = attention_mask.size(1)
|
||||
input_ids = batches[dst_batch_idx].input_ids
|
||||
attention_mask = batches[dst_batch_idx].attention_mask
|
||||
position_ids = batches[dst_batch_idx].position_ids
|
||||
past_key_values = batches[dst_batch_idx].past_key_values
|
||||
input_length = max_input_length
|
||||
|
||||
htorch.core.mark_step()
|
||||
@ -360,6 +441,7 @@ class CausalLMBatch(Batch):
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
merged_kv_cache=False,
|
||||
next_token_chooser=next_token_chooser,
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
@ -448,6 +530,7 @@ class CausalLMBatch(Batch):
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
merged_kv_cache=False,
|
||||
next_token_chooser=next_token_chooser,
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
@ -491,13 +574,6 @@ class CausalLMBatch(Batch):
|
||||
max_total_tokens = self.attention_mask.size(1)
|
||||
return len(self.requests) * max_total_tokens
|
||||
|
||||
def free_indices(self):
|
||||
used = set(req.idx for req in self.requests)
|
||||
for i in range(self.batch_size):
|
||||
if i in used:
|
||||
continue
|
||||
yield i
|
||||
|
||||
|
||||
class CausalLM(Model):
|
||||
def __init__(
|
||||
|
Loading…
Reference in New Issue
Block a user