mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
Superseeds #1459 The fix works as follows. We updated next_token_chooser to return all logprbs, then batch_top_n_tokens, now also gets accepted_ids + speculated_length (so it knows how to interpret the flat logprobs). We then update the code to return lists ot `Tokens` that it expects. <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
1123 lines
42 KiB
Python
1123 lines
42 KiB
Python
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
|
|
|
|
import bisect
|
|
from dataclasses import dataclass
|
|
from functools import wraps
|
|
import itertools
|
|
import math
|
|
import os
|
|
import tempfile
|
|
import time
|
|
from typing import Dict, List, Optional, Tuple, Type
|
|
|
|
import torch
|
|
from loguru import logger
|
|
from opentelemetry import trace
|
|
|
|
import text_generation_server.habana_quantization_env as hq_env
|
|
import habana_frameworks.torch as htorch
|
|
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
|
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
|
|
from optimum.habana.utils import HabanaProfile
|
|
from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES
|
|
from optimum.habana.checkpoint_utils import (
|
|
get_repo_root,
|
|
model_on_meta,
|
|
write_checkpoints_json,
|
|
)
|
|
from transformers import (
|
|
AutoTokenizer,
|
|
AutoModelForCausalLM,
|
|
PreTrainedTokenizerBase,
|
|
AutoConfig,
|
|
)
|
|
|
|
from text_generation_server.utils.tokens import batch_top_tokens
|
|
from text_generation_server.models import Model
|
|
from text_generation_server.utils.tokens import batch_top_tokens
|
|
from text_generation_server.models.types import (
|
|
Batch,
|
|
Tokens,
|
|
Generation,
|
|
GeneratedText,
|
|
)
|
|
from text_generation_server.pb import generate_pb2
|
|
from text_generation_server.utils import (
|
|
HeterogeneousNextTokenChooser,
|
|
StoppingCriteria,
|
|
make_tokenizer_optional,
|
|
is_tokenizer_transparent,
|
|
)
|
|
from text_generation_server.utils.debug import dbg_trace
|
|
from text_generation_server.utils.speculate import get_speculate
|
|
|
|
tracer = trace.get_tracer(__name__)
|
|
|
|
MAX_TOTAL_TOKENS = int(os.environ.get('MAX_TOTAL_TOKENS', 2048))
|
|
BATCH_BUCKET_SIZE = int(os.environ.get('BATCH_BUCKET_SIZE', 8))
|
|
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 128))
|
|
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 4))
|
|
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
|
|
|
|
|
def round_up(number, k):
|
|
return (number + k - 1) // k * k
|
|
|
|
|
|
def to_tensor_indices(indices, device):
|
|
return torch.tensor(indices, dtype=torch.int32, device=device)
|
|
|
|
|
|
def calculate_chunks(offset):
|
|
result = []
|
|
while offset != 0:
|
|
sign = 1 if offset > 0 else -1
|
|
best_chunk = min((abs(offset - sign * c), sign * c) for c in CHUNK_SIZES)[1]
|
|
result.append(best_chunk)
|
|
offset = offset - best_chunk
|
|
return result
|
|
|
|
|
|
def biggest_single_chunk(offset):
|
|
if offset != 0:
|
|
idx = bisect.bisect(CHUNK_SIZES, abs(offset))
|
|
return int(math.copysign(CHUNK_SIZES[idx - 1], offset))
|
|
else:
|
|
return 0
|
|
|
|
|
|
def grouped_pad(tensor_groups, dims, values):
|
|
grouped_result = []
|
|
for tensors, dim, value in zip(tensor_groups, dims, values):
|
|
padding = MAX_TOTAL_TOKENS - tensors[0].size(dim) if dim is not None else 0
|
|
if padding > 0:
|
|
assert dim in [-1, -2], f'Only dims -1 and -2 are supported! {dim}'
|
|
pad_shape = (0, 0, 0, padding) if dim == -2 else (0, padding)
|
|
result = [torch.nn.functional.pad(t, pad_shape, value=value) for t in tensors]
|
|
else:
|
|
result = [t for t in tensors]
|
|
grouped_result.append(result)
|
|
htorch.core.mark_step()
|
|
return grouped_result
|
|
|
|
|
|
def roll(tensor, chunk, dim, merge_graphs):
|
|
if dim is None:
|
|
return tensor
|
|
tensor = torch.roll(tensor, chunk, dim)
|
|
if not merge_graphs:
|
|
htorch.core.mark_step()
|
|
return tensor
|
|
|
|
|
|
def grouped_roll(tensor_groups, chunk, dims, merge_graphs):
|
|
tensor_groups = [[roll(t, chunk, dim, merge_graphs) for t in tensors] for tensors, dim in zip(tensor_groups, dims)]
|
|
if merge_graphs:
|
|
htorch.core.mark_step()
|
|
return tensor_groups
|
|
|
|
|
|
def grouped_shift(tensor_groups, dims, offset, merge_graphs):
|
|
chunks = calculate_chunks(offset)
|
|
for c in chunks:
|
|
tensor_groups = grouped_roll(tensor_groups, c, dims, merge_graphs)
|
|
return tensor_groups
|
|
|
|
|
|
def move(dst_tensors, dst_indices, src_tensors):
|
|
bs_dim = 0
|
|
num_indices = dst_indices.size(0)
|
|
for i, (dst_t, src_t) in enumerate(zip(dst_tensors, src_tensors)):
|
|
if src_t.size(bs_dim) != num_indices:
|
|
src_t = torch.narrow(src_t, bs_dim, 0, num_indices)
|
|
dst_t.index_copy_(bs_dim, dst_indices, src_t)
|
|
htorch.core.mark_step()
|
|
|
|
|
|
def grouped_move(dst_tensor_groups, dst_indices, src_tensor_groups):
|
|
for dst_tensors, src_tensors in zip(dst_tensor_groups, src_tensor_groups):
|
|
move(dst_tensors, dst_indices, src_tensors)
|
|
|
|
|
|
def extend_tensor(tensor, padding, dim):
|
|
result = torch.cat([tensor, padding], dim=dim)
|
|
htorch.core.mark_step()
|
|
return result
|
|
|
|
|
|
def extend_batch(tensors, target_bs, dim):
|
|
diff = target_bs - tensors[0].size(dim)
|
|
# TODO: add support for shrinking bs
|
|
if diff <= 0:
|
|
return tensors
|
|
shape = list(tensors[0].shape)
|
|
shape[dim] = diff
|
|
padding = torch.empty(shape, device=tensors[0].device, dtype=tensors[0].dtype)
|
|
tensors = [extend_tensor(t, padding, dim) for t in tensors]
|
|
return tensors
|
|
|
|
|
|
def grouped_extend_batch(tensor_groups, target_bs, bs_dims):
|
|
tensor_groups = [extend_batch(tensors, target_bs, dim) for tensors, dim in zip(tensor_groups, bs_dims)]
|
|
return tensor_groups
|
|
|
|
|
|
def merge(tensor_group):
|
|
tensor_group = [torch.stack(tensor_group)]
|
|
htorch.core.mark_step()
|
|
return tensor_group
|
|
|
|
|
|
def split(tensor_group, clone_data):
|
|
tensor_group = [t.squeeze(0) for t in torch.split(tensor_group[0], 1)]
|
|
if clone_data:
|
|
tensor_group = [t.clone() for t in tensor_group]
|
|
htorch.core.mark_step()
|
|
return tensor_group
|
|
|
|
|
|
def remove_kv_cache_from_output(module):
|
|
orig_fwd = module.forward
|
|
|
|
@wraps(orig_fwd)
|
|
def forward(*args, **kwargs):
|
|
if kwargs["past_key_values"] is not None:
|
|
kwargs["return_dict"] = False
|
|
output = orig_fwd(*args, **kwargs)
|
|
first_value, second_value, *_ = output
|
|
if first_value.nelement() < 2:
|
|
return second_value
|
|
else:
|
|
return first_value
|
|
else:
|
|
kwargs["return_dict"] = True
|
|
return orig_fwd(*args, **kwargs)
|
|
|
|
module.forward = forward
|
|
return module
|
|
|
|
|
|
@dataclass
|
|
class CausalLMRequest:
|
|
idx: int
|
|
data: generate_pb2.Request
|
|
input_length: int
|
|
prefix_offset: int
|
|
read_offset: int
|
|
stopping_criteria: StoppingCriteria
|
|
|
|
all_input_ids: torch.Tensor
|
|
|
|
@classmethod
|
|
def from_pb(cls, idx: int, data: generate_pb2.Request, tokenizer: PreTrainedTokenizerBase):
|
|
return cls(
|
|
idx=idx,
|
|
data=data,
|
|
input_length=None,
|
|
prefix_offset=None,
|
|
read_offset=None,
|
|
stopping_criteria=StoppingCriteria.from_pb(data.stopping_parameters, tokenizer),
|
|
all_input_ids=None,)
|
|
|
|
def update_idx(self, new_idx):
|
|
prev = self.idx
|
|
self.idx = new_idx
|
|
return (new_idx, prev)
|
|
|
|
|
|
@dataclass
|
|
class CausalLMBatch(Batch):
|
|
batch_id: int
|
|
requests: List[CausalLMRequest]
|
|
|
|
# Decoder values
|
|
input_ids: torch.Tensor
|
|
attention_mask: torch.Tensor
|
|
position_ids: torch.Tensor
|
|
past_key_values: Optional[List[Tuple]]
|
|
merged_kv_cache: bool
|
|
|
|
# Generation helpers
|
|
next_token_chooser: HeterogeneousNextTokenChooser
|
|
top_n_tokens: List[int]
|
|
top_n_tokens_tensor: torch.Tensor
|
|
|
|
input_length: int
|
|
|
|
logits = None
|
|
past = None
|
|
|
|
keys_head_dim_last: bool = True
|
|
|
|
def to_pb(self) -> generate_pb2.CachedBatch:
|
|
return generate_pb2.CachedBatch(
|
|
id=self.batch_id,
|
|
request_ids=[r.data.id for r in self.requests],
|
|
size=len(self),
|
|
max_tokens=self.max_tokens,
|
|
)
|
|
|
|
def detach_kv_cache(self):
|
|
past_keys = [past[0] for past in self.past_key_values]
|
|
past_values = [past[1] for past in self.past_key_values]
|
|
del self.past_key_values
|
|
return past_keys, past_values
|
|
|
|
def attach_kv_cache(self, past_keys, past_values):
|
|
# TODO: Add support for models that don't store kv_cache in a list
|
|
self.past_key_values = list(zip(past_keys, past_values))
|
|
|
|
def merge_kv_cache_if_needed(self, target_bs, offset):
|
|
pad_needed = self.seq_length < MAX_TOTAL_TOKENS
|
|
shift_needed = offset != 0
|
|
expand_needed = target_bs > self.batch_size
|
|
# Very simple heuristic to determine whether we should merge tensors
|
|
# this needs tuning for other models/scenarios
|
|
small_bs = len(self.past_key_values) > self.batch_size
|
|
if not self.merged_kv_cache and small_bs and (pad_needed or shift_needed or expand_needed):
|
|
past_keys, past_values = self.detach_kv_cache()
|
|
past_keys = merge(past_keys)
|
|
past_values = merge(past_values)
|
|
self.attach_kv_cache(past_keys, past_values)
|
|
self.merged_kv_cache = True
|
|
|
|
def split_kv_cache_if_needed(self, clone_data):
|
|
if self.merged_kv_cache:
|
|
past_keys, past_values = self.detach_kv_cache()
|
|
past_keys = split(past_keys, clone_data)
|
|
past_values = split(past_values, clone_data)
|
|
self.attach_kv_cache(past_keys, past_values)
|
|
self.merged_kv_cache = False
|
|
|
|
def get_tensor_groups(self):
|
|
past_keys, past_values = self.detach_kv_cache()
|
|
seq_dim = -1
|
|
key_dim = -2 if self.keys_head_dim_last else -1
|
|
value_dim = -2
|
|
tensors = [[self.input_ids], [self.attention_mask], [self.position_ids], past_keys, past_values]
|
|
# We don't need to align position_ids
|
|
seq_dims = [seq_dim, seq_dim, None, key_dim, value_dim]
|
|
bs_dims = [0, 0, 0] + ([1, 1] if self.merged_kv_cache else [0, 0])
|
|
return tensors, seq_dims, bs_dims
|
|
|
|
def set_tensor_groups(self, tensors):
|
|
self.input_ids = tensors.pop(0)[0]
|
|
self.attention_mask = tensors.pop(0)[0]
|
|
self.position_ids = tensors.pop(0)[0]
|
|
past_keys = tensors.pop(0)
|
|
past_values = tensors.pop(0)
|
|
self.attach_kv_cache(past_keys, past_values)
|
|
|
|
def realign(self, target_bs, offset, pad_token_id):
|
|
tensors, seq_dims, _ = self.get_tensor_groups()
|
|
tensors = grouped_pad(tensors, seq_dims, [pad_token_id, 0, 0, 0, 0])
|
|
tensors = grouped_shift(tensors, seq_dims, offset, self.merged_kv_cache)
|
|
self.set_tensor_groups(tensors)
|
|
|
|
def expand_bs(self, target_bs):
|
|
tensors, _, bs_dims = self.get_tensor_groups()
|
|
tensors = grouped_extend_batch(tensors, target_bs, bs_dims)
|
|
self.set_tensor_groups(tensors)
|
|
|
|
def used_indices(self):
|
|
return [req.idx for req in self.requests]
|
|
|
|
def update_indices(self, new_indices):
|
|
for req, new_idx in zip(self.requests, new_indices):
|
|
req.idx = new_idx
|
|
return self.used_indices()
|
|
|
|
def free_indices_generator(self):
|
|
used = set(req.idx for req in self.requests)
|
|
return (i for i in range(self.batch_size) if i not in used)
|
|
|
|
def move_data(self, src_batches):
|
|
dst_tensors, _, dst_dims = self.get_tensor_groups()
|
|
free_indices_gen = self.free_indices_generator()
|
|
for src_b in src_batches:
|
|
dst_indices = to_tensor_indices(src_b.update_indices(free_indices_gen), self.input_ids.device)
|
|
src_tensors, _, src_dims = src_b.get_tensor_groups()
|
|
grouped_move(dst_tensors, dst_indices, src_tensors)
|
|
self.set_tensor_groups(dst_tensors)
|
|
|
|
@classmethod
|
|
def recombine(cls, batches: List["CausalLMBatch"], pad_token_id: int) -> "CausalLMBatch":
|
|
if not all(b.past_key_values is not None for b in batches):
|
|
raise ValueError("KV cache not allocated! Cannot recombine before prefill!")
|
|
|
|
total_requests = sum(len(b) for b in batches)
|
|
new_bs = round_up(total_requests, BATCH_BUCKET_SIZE)
|
|
batch_id = batches[0].batch_id
|
|
device = batches[0].input_ids.device
|
|
|
|
input_lengths = [b.input_length for b in batches]
|
|
max_input_length = max(input_lengths)
|
|
offsets = [max_input_length - b.input_length for b in batches]
|
|
cur_padding = [b.right_padding for b in batches]
|
|
# For prefill there is a space allocated only for first token
|
|
# Need to add padding to the max total tokens before first decode
|
|
|
|
moves_needed = [total_requests - len(b) if b.batch_size == new_bs else total_requests for b in batches]
|
|
dst_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0]
|
|
reshape = (batches[dst_batch_idx].batch_size != new_bs)
|
|
|
|
# TODO: Add support for changing max seq len, i.e. due to output length bucketing
|
|
# FIXME: max_seq_len for non optimized code
|
|
if len(batches) > 1:
|
|
scenario = 'CONCAT'
|
|
elif reshape:
|
|
scenario = 'RESHAPE'
|
|
elif cur_padding[dst_batch_idx] <= 0:
|
|
scenario = 'SHIFT'
|
|
offsets = [biggest_single_chunk(b.max_input_length - max_input_length) for b in batches]
|
|
max_input_length = max_input_length + offsets[dst_batch_idx]
|
|
else:
|
|
# Nothing to do
|
|
return batches[0]
|
|
|
|
dbg_trace(
|
|
scenario, f'bs:{[b.batch_size for b in batches]}->{new_bs}'
|
|
f' reqs:{[len(b) for b in batches]}'
|
|
f' offsets:{offsets}'
|
|
f' input_lengths:{input_lengths}'
|
|
f' cur_padding:{cur_padding}'
|
|
f' dst_batch:{dst_batch_idx}')
|
|
|
|
grouped_requests = [[req for req in batch.requests] for batch in batches]
|
|
flat_requests = list(itertools.chain(*grouped_requests))
|
|
|
|
for i in range(len(batches)):
|
|
target_bs = new_bs if i == dst_batch_idx else batches[i].batch_size
|
|
batches[i].merge_kv_cache_if_needed(target_bs, offsets[i])
|
|
batches[i].realign(target_bs, offsets[i], pad_token_id)
|
|
batches[i].split_kv_cache_if_needed(i == dst_batch_idx)
|
|
batches[dst_batch_idx].expand_bs(new_bs)
|
|
batches[dst_batch_idx].move_data([batches[i] for i in range(len(batches)) if i != dst_batch_idx])
|
|
|
|
top_n_tokens = [r.data.top_n_tokens for r in flat_requests]
|
|
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
|
|
|
|
parameters = [r.data.parameters for r in flat_requests]
|
|
if len(flat_requests) < new_bs:
|
|
for i in range(new_bs-len(flat_requests)) :
|
|
# append the dummy parameters for dummy request
|
|
parameters.append(parameters[0])
|
|
|
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
|
parameters,
|
|
batches[dst_batch_idx].next_token_chooser.dtype,
|
|
batches[dst_batch_idx].next_token_chooser.device,
|
|
hq_env.is_quantization_enabled
|
|
)
|
|
|
|
input_ids = batches[dst_batch_idx].input_ids
|
|
attention_mask = batches[dst_batch_idx].attention_mask
|
|
position_ids = batches[dst_batch_idx].position_ids
|
|
past_key_values = batches[dst_batch_idx].past_key_values
|
|
input_length = max_input_length
|
|
|
|
htorch.core.mark_step()
|
|
|
|
return cls(
|
|
batch_id=batch_id,
|
|
requests=flat_requests,
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
merged_kv_cache=False,
|
|
next_token_chooser=next_token_chooser,
|
|
top_n_tokens=top_n_tokens,
|
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
|
input_length=input_length,
|
|
)
|
|
|
|
@classmethod
|
|
def from_pb(
|
|
cls,
|
|
pb: generate_pb2.Batch,
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
) -> "CausalLMBatch":
|
|
dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}')
|
|
requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)]
|
|
|
|
max_input_length = max(r.data.truncate for r in requests)
|
|
max_new_tokens = max(r.stopping_criteria.max_new_tokens for r in requests)
|
|
|
|
# TODO: Add support for sparse batches
|
|
top_n_tokens = [r.top_n_tokens for r in pb.requests]
|
|
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
|
|
|
|
# TODO: by tokenizing all inputs at once we loose information on actual input lengths
|
|
# this means that we cannot shift inputs to the left after a long input sequence
|
|
# was filtered out
|
|
new_bs = round_up(len(requests), PREFILL_BATCH_BUCKET_SIZE)
|
|
dummy_inputs = ["?"] * (new_bs - len(requests))
|
|
parameters = [r.parameters for r in pb.requests]
|
|
if len(pb.requests) < new_bs:
|
|
for i in range(new_bs-len(pb.requests)) :
|
|
#append the dummy parameters for dummy request
|
|
parameters.append(parameters[0])
|
|
|
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
|
parameters, dtype, device, hq_env.is_quantization_enabled
|
|
)
|
|
tokenized_inputs = tokenizer(
|
|
[r.data.inputs for r in requests] + dummy_inputs,
|
|
return_tensors="pt",
|
|
padding="longest",
|
|
return_token_type_ids=False,
|
|
truncation=True,
|
|
max_length=max_input_length,
|
|
)
|
|
|
|
input_len = tokenized_inputs["input_ids"].shape[1]
|
|
|
|
bucket_size = max_input_length
|
|
left_padding = max_input_length - input_len
|
|
if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0:
|
|
assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
|
|
rounded_seq_len = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF)
|
|
if rounded_seq_len <= max_input_length:
|
|
bucket_size = rounded_seq_len - 1
|
|
else:
|
|
bucket_size = max_input_length - 1
|
|
left_padding = bucket_size - input_len
|
|
|
|
input_ids = tokenized_inputs["input_ids"]
|
|
attention_mask = tokenized_inputs["attention_mask"]
|
|
|
|
# Allocate space for first token
|
|
input_ids = torch.nn.functional.pad(
|
|
input_ids, (left_padding, 1), value=tokenizer.pad_token_id
|
|
)
|
|
attention_mask = torch.nn.functional.pad(
|
|
attention_mask, (left_padding, 1), value=0
|
|
)
|
|
all_input_ids = torch.nn.functional.pad(
|
|
input_ids, (0, max_new_tokens), value=tokenizer.pad_token_id
|
|
).T.split(1, dim=1)
|
|
|
|
# New input length after left padding
|
|
input_len = bucket_size
|
|
for r in requests:
|
|
r.input_length = input_len
|
|
r.prefix_offset = input_len - 5
|
|
r.read_offset = input_len
|
|
r.all_input_ids = all_input_ids[r.idx]
|
|
|
|
input_ids = input_ids.to(device)
|
|
attention_mask = attention_mask.to(device)
|
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
|
|
|
htorch.core.mark_step()
|
|
|
|
return cls(
|
|
batch_id=pb.id,
|
|
requests=requests,
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=None,
|
|
merged_kv_cache=False,
|
|
next_token_chooser=next_token_chooser,
|
|
top_n_tokens=top_n_tokens,
|
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
|
input_length=input_len,
|
|
)
|
|
|
|
@tracer.start_as_current_span("filter")
|
|
def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]:
|
|
dbg_trace('FILTER', f'num_reqs:{len(self.requests)} -> {len(request_ids)}')
|
|
request_ids = set(request_ids)
|
|
self.requests = [req for req in self.requests if req.data.id in request_ids]
|
|
return self
|
|
|
|
@classmethod
|
|
@tracer.start_as_current_span("concatenate")
|
|
def concatenate(cls, batches: List["CausalLMBatch"], pad_token_id: int = 0) -> "CausalLMBatch":
|
|
return cls.recombine(batches, pad_token_id)
|
|
|
|
def __len__(self):
|
|
return len(self.requests)
|
|
|
|
@property
|
|
def max_input_length(self):
|
|
return max(req.input_length for req in self.requests)
|
|
|
|
@property
|
|
def batch_size(self):
|
|
return self.attention_mask.size(0)
|
|
|
|
@property
|
|
def seq_length(self):
|
|
return self.attention_mask.size(1)
|
|
|
|
@property
|
|
def right_padding(self):
|
|
return self.seq_length - self.input_length
|
|
|
|
# Maximum number of tokens this batch will grow to
|
|
@property
|
|
def max_tokens(self):
|
|
max_total_tokens = self.attention_mask.size(1)
|
|
return len(self.requests) * max_total_tokens
|
|
|
|
|
|
class CausalLM(Model):
|
|
def __init__(
|
|
self,
|
|
model_id: str,
|
|
revision: Optional[str] = None,
|
|
dtype: Optional[torch.dtype] = None,
|
|
):
|
|
adapt_transformers_to_gaudi()
|
|
|
|
# Create tokenizer
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_id,
|
|
revision=revision,
|
|
padding_side="left",
|
|
truncation_side="left",
|
|
)
|
|
make_tokenizer_optional(tokenizer)
|
|
|
|
# Create model
|
|
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
|
rank = int(os.getenv("RANK", "0"))
|
|
dtype = torch.bfloat16 if dtype is None else dtype
|
|
device = torch.device("hpu")
|
|
|
|
if hq_env.is_quantization_enabled:
|
|
htorch.core.hpu_set_env()
|
|
|
|
if world_size > 1:
|
|
model = self.get_deepspeed_model(
|
|
model_id, dtype, revision
|
|
)
|
|
model = self.prepare_model_for_quantization(model)
|
|
else:
|
|
get_repo_root(model_id)
|
|
|
|
# Check support for rope scaling
|
|
model_kwargs = {}
|
|
config = AutoConfig.from_pretrained(
|
|
model_id
|
|
)
|
|
if hasattr(config, "rope_scaling"):
|
|
model_kwargs["rope_scaling"] = self.get_rope_scaling()
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_id,
|
|
revision=revision,
|
|
torch_dtype=dtype,
|
|
**model_kwargs
|
|
)
|
|
model = self.prepare_model_for_quantization(model)
|
|
model = model.eval().to(device)
|
|
|
|
self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true"
|
|
self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
|
|
model = remove_kv_cache_from_output(model)
|
|
if self.enable_hpu_graph:
|
|
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
|
|
|
model = self.setup_quantization(model)
|
|
|
|
if model.config.model_type not in MODELS_OPTIMIZED_WITH_STATIC_SHAPES:
|
|
raise ValueError(f"Model type {model.config.model_type} is not supported!")
|
|
|
|
if tokenizer.pad_token_id is None:
|
|
if model.config.pad_token_id is not None:
|
|
tokenizer.pad_token_id = model.config.pad_token_id
|
|
elif model.config.eos_token_id is not None:
|
|
tokenizer.pad_token_id = model.config.eos_token_id
|
|
elif tokenizer.eos_token_id is not None:
|
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
|
else:
|
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
|
|
|
kwargs = {
|
|
"use_cache": True,
|
|
"return_dict": True,
|
|
}
|
|
|
|
if model.config.model_type == "llama":
|
|
kwargs["attn_softmax_bf16"] = True
|
|
kwargs["trim_logits"] = True
|
|
|
|
self.speculate = get_speculate()
|
|
|
|
super(CausalLM, self).__init__(
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
requires_padding=True,
|
|
dtype=dtype,
|
|
device=device,
|
|
rank=rank,
|
|
kwargs=kwargs,
|
|
)
|
|
|
|
# Create profiler
|
|
ranks_to_profile = [int(val) for val in os.getenv("PROF_RANKS", "0").split(',')]
|
|
record_shapes = os.getenv("PROF_RECORD_SHAPES", "false").lower() == "true"
|
|
output_dir = os.getenv("PROF_PATH", "/tmp/hpu_profile")
|
|
self.profiling_warmup_steps = int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in ranks_to_profile else 0
|
|
self.profiling_steps = int(os.getenv("PROF_STEP", "0")) if rank in ranks_to_profile else 0
|
|
self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0"))
|
|
if self.profiling_steps > 0:
|
|
self.hb_profiler = HabanaProfile(
|
|
wait=self.profiling_wait_steps,
|
|
warmup=self.profiling_warmup_steps,
|
|
active=self.profiling_steps,
|
|
output_dir=output_dir,
|
|
record_shapes=record_shapes
|
|
)
|
|
self.hb_profiler.start()
|
|
else:
|
|
self.hb_profiler = None
|
|
self.step = 0
|
|
|
|
def get_deepspeed_model(
|
|
self,
|
|
model_id: str,
|
|
dtype: torch.dtype,
|
|
revision: Optional[str] = None
|
|
) -> torch.nn.Module:
|
|
import deepspeed
|
|
from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu
|
|
|
|
world_size, rank, local_rank = initialize_distributed_hpu()
|
|
model_kwargs = {
|
|
"revision": revision
|
|
}
|
|
|
|
# Initialize process(es) for DeepSpeed
|
|
deepspeed.init_distributed(dist_backend="hccl")
|
|
logger.info(
|
|
"DeepSpeed is enabled. world_size {} rank {} local_rank {}".format(world_size, rank, local_rank)
|
|
)
|
|
config = AutoConfig.from_pretrained(model_id, **model_kwargs)
|
|
load_to_meta = model_on_meta(config)
|
|
|
|
# Check support for rope scaling
|
|
if hasattr(config, "rope_scaling"):
|
|
config.rope_scaling = self.get_rope_scaling()
|
|
model_kwargs["rope_scaling"] = self.get_rope_scaling()
|
|
|
|
if load_to_meta:
|
|
# Construct model with fake meta tensors, later will be replaced on devices during ds-inference ckpt load
|
|
with deepspeed.OnDevice(dtype=dtype, device="meta"):
|
|
model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype)
|
|
else:
|
|
get_repo_root(model_id, local_rank=os.getenv("LOCAL_RANK"))
|
|
# TODO: revisit placement on CPU when auto-injection is possible
|
|
with deepspeed.OnDevice(dtype=dtype, device="cpu"):
|
|
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, **model_kwargs)
|
|
model = model.eval()
|
|
|
|
# Initialize the model
|
|
ds_inference_kwargs = {"dtype": dtype}
|
|
ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size}
|
|
ds_inference_kwargs["enable_cuda_graph"] = False
|
|
|
|
if load_to_meta:
|
|
# model loaded to meta is managed differently
|
|
checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w")
|
|
write_checkpoints_json(model_id, local_rank, checkpoints_json)
|
|
ds_inference_kwargs["checkpoint"] = checkpoints_json.name
|
|
model = deepspeed.init_inference(model, **ds_inference_kwargs)
|
|
|
|
return model.module
|
|
|
|
def get_rope_scaling(self) -> Optional[Dict]:
|
|
rope_scaling = os.getenv("ROPE_SCALING", None)
|
|
if rope_scaling is None:
|
|
return None
|
|
|
|
rope_factor = float(os.getenv("ROPE_FACTOR", 1.0))
|
|
return {
|
|
'type': rope_scaling, 'factor': float(rope_factor)
|
|
}
|
|
|
|
def setup_quantization(self, model):
|
|
if hq_env.is_quantization_enabled:
|
|
htorch.core.quantization._mark_params_as_const(model)
|
|
htorch.core.quantization._check_params_as_const(model)
|
|
htorch.core.hpu_initialize(model)
|
|
return model
|
|
|
|
def prepare_model_for_quantization(self, model):
|
|
if hq_env.is_quantization_enabled:
|
|
if model.config.model_type == "llama":
|
|
self.patch_scoped_linear_all_reduce(model)
|
|
import habana_quantization_toolkit
|
|
habana_quantization_toolkit.prep_model(model)
|
|
return model
|
|
|
|
def finish_quantization_measurements(self, model):
|
|
if hq_env.is_quantization_enabled:
|
|
import habana_quantization_toolkit
|
|
habana_quantization_toolkit.finish_measurements(self.model)
|
|
return model
|
|
|
|
def patch_scoped_linear_all_reduce(self, model):
|
|
from deepspeed.module_inject.layers import LinearAllreduce
|
|
from optimum.habana.transformers.models.modeling_all_models import ScopedLinearAllReduce
|
|
for name, module in model.named_children():
|
|
if type(module) is LinearAllreduce:
|
|
SL = ScopedLinearAllReduce(mod=module)
|
|
setattr(model, name, SL)
|
|
self.patch_scoped_linear_all_reduce(module)
|
|
|
|
@property
|
|
def batch_type(self) -> Type[CausalLMBatch]:
|
|
return CausalLMBatch
|
|
|
|
def decode(self, generated_ids: List[int]) -> str:
|
|
return self.tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
|
|
|
def decode_token(
|
|
self,
|
|
all_input_ids: List[int],
|
|
prefix_offset: int = 0,
|
|
read_offset: int = 0,
|
|
) -> Tuple[str, int, int]:
|
|
if is_tokenizer_transparent(self.tokenizer):
|
|
new_text = self.tokenizer.decode(all_input_ids[read_offset:], skip_special_tokens=False)
|
|
return new_text, read_offset, len(all_input_ids)
|
|
else:
|
|
return super().decode_token(all_input_ids, prefix_offset, read_offset)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids,
|
|
attention_mask,
|
|
position_ids,
|
|
token_idx,
|
|
past_key_values: Optional = None,
|
|
bypass_hpu_graph: Optional = None,
|
|
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
|
# Model Forward
|
|
kwargs = {
|
|
"input_ids": input_ids,
|
|
"attention_mask": attention_mask,
|
|
"past_key_values": past_key_values,
|
|
"token_idx": token_idx
|
|
}
|
|
|
|
if self.has_position_ids:
|
|
kwargs["position_ids"] = position_ids
|
|
|
|
if bypass_hpu_graph != None:
|
|
kwargs["bypass_hpu_graphs"] = bypass_hpu_graph
|
|
|
|
kwargs.update(self.kwargs)
|
|
if past_key_values is not None:
|
|
return self.model.forward(**kwargs)
|
|
else:
|
|
outputs = self.model.forward(**kwargs)
|
|
return outputs.logits, outputs.past_key_values
|
|
|
|
@tracer.start_as_current_span("generate_token")
|
|
def generate_token(
|
|
self, batches: List[CausalLMBatch]
|
|
) -> Tuple[List[Generation], Optional[CausalLMBatch], Tuple[int, int]]:
|
|
start = time.time_ns()
|
|
# Results
|
|
generations: List[Generation] = []
|
|
prev_batches = []
|
|
requests_to_generate = []
|
|
# In order to pipeline any actions on CPU we perform the operation in 3 main stages:
|
|
# Stage 1. Collect next token ids of any previously started generations
|
|
for batch_id, batch in enumerate(batches):
|
|
if batch.logits is not None:
|
|
logits = batch.logits
|
|
past = batch.past
|
|
prefill = batch.past_key_values is None
|
|
if prefill:
|
|
# no right padding for prefill
|
|
token_idx_scalar = batch.attention_mask.shape[-1] - 1
|
|
token_idx = torch.tensor(token_idx_scalar).to(self.device)
|
|
else:
|
|
token_idx_scalar = batch.attention_mask.shape[-1] - batch.right_padding
|
|
token_idx = torch.tensor(token_idx_scalar).to(self.device)
|
|
|
|
# Select next token
|
|
input_length = batch.input_length
|
|
if logits.shape[-2] > 1:
|
|
next_token_ids, next_token_logprobs, logprobs, _, _ = batch.next_token_chooser(
|
|
batch.input_ids, logits[:, input_length - 1: input_length, :].squeeze(-2), self.speculate
|
|
)
|
|
else:
|
|
next_token_ids, next_token_logprobs, logprobs, _, _ = batch.next_token_chooser(
|
|
batch.input_ids, logits.squeeze(-2), self.speculate
|
|
)
|
|
# Speculation is not active for causal
|
|
accepted_ids = torch.ones_like(batch.input_ids)[:, 0]
|
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
|
batch.top_n_tokens,
|
|
batch.top_n_tokens_tensor,
|
|
logprobs,
|
|
accepted_ids,
|
|
)
|
|
|
|
prev_batches.append({
|
|
'next_token_ids': next_token_ids,
|
|
'next_token_logprobs': next_token_logprobs,
|
|
})
|
|
|
|
for req_idx, req in enumerate(batch.requests):
|
|
requests_to_generate.append({
|
|
'req': req,
|
|
'prev_req_idx': req.idx,
|
|
'batch_id': batch_id,
|
|
'seed': batch.next_token_chooser.seeds[req_idx],
|
|
'do_sample': batch.next_token_chooser.do_sample[req_idx],
|
|
'top_n_tokens': batch.top_n_tokens[req_idx],
|
|
'top_token_ids': batch_top_token_ids[req_idx],
|
|
'top_token_logprobs': batch_top_token_logprobs[req_idx],
|
|
})
|
|
|
|
htorch.core.mark_step()
|
|
|
|
# Add new token into input_ids
|
|
batch.input_ids.index_copy_(1, token_idx, next_token_ids.unsqueeze(1))
|
|
|
|
# Update attention_mask as we added a new token to input_ids
|
|
batch.attention_mask.index_fill_(1, token_idx, 1)
|
|
|
|
# Adjust lengths
|
|
batch.input_length += 1
|
|
|
|
# Update position_ids
|
|
if prefill:
|
|
batch.position_ids = torch.index_select(batch.position_ids, 1, token_idx - 1) + 1
|
|
else:
|
|
batch.position_ids += 1
|
|
# Update past key values
|
|
if prefill:
|
|
batch.past_key_values = past
|
|
|
|
htorch.core.mark_step()
|
|
|
|
# Stage 2. Prepare new batch for speculative scheduling
|
|
if len(batches) > 1:
|
|
batch = self.batch_type.concatenate(batches, self.tokenizer.pad_token_id)
|
|
else:
|
|
batch = batches[0]
|
|
|
|
prefill = batch.past_key_values is None
|
|
|
|
# Check if we need to do any bookkeeping first
|
|
if not prefill:
|
|
batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id)
|
|
|
|
scenario = 'PREFILL' if prefill else 'GENERATE'
|
|
dbg_trace(
|
|
scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}')
|
|
assert batch.right_padding > 0, 'No more room for next token!'
|
|
|
|
# Execute batch
|
|
if prefill:
|
|
# no right padding for prefill
|
|
token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device)
|
|
batch.logits, batch.past = self.forward(
|
|
batch.input_ids,
|
|
batch.attention_mask,
|
|
batch.position_ids,
|
|
token_idx,
|
|
batch.past_key_values,
|
|
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None
|
|
)
|
|
else:
|
|
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
|
|
input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1)
|
|
batch.logits = self.forward(
|
|
input_ids,
|
|
batch.attention_mask,
|
|
batch.position_ids,
|
|
token_idx,
|
|
batch.past_key_values,
|
|
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None
|
|
)
|
|
|
|
htorch.core.mark_step()
|
|
|
|
start_decode = time.time_ns()
|
|
|
|
# Stage 3. Finish and return previous generations
|
|
stopped = len(requests_to_generate) > 0
|
|
for prev_batch in prev_batches:
|
|
prev_batch['next_token_logprobs'] = prev_batch['next_token_logprobs'].tolist()
|
|
prev_batch['next_token_ids_cpu'] = prev_batch['next_token_ids'].cpu()
|
|
htorch.core.mark_step()
|
|
|
|
for req_data in requests_to_generate:
|
|
req = req_data['req']
|
|
i = req_data['prev_req_idx']
|
|
prev_batch_id = req_data['batch_id']
|
|
assert len(prev_batches) > prev_batch_id
|
|
next_token_ids_cpu = prev_batches[prev_batch_id]['next_token_ids_cpu']
|
|
next_token_logprobs = prev_batches[prev_batch_id]['next_token_logprobs']
|
|
|
|
request = req.data
|
|
input_length = req.input_length
|
|
prefix_offset = req.prefix_offset
|
|
read_offset = req.read_offset
|
|
do_sample = req_data['do_sample']
|
|
seed = req_data['seed']
|
|
stopping_criteria = req.stopping_criteria
|
|
all_input_ids = req.all_input_ids
|
|
next_token_id = next_token_ids_cpu[i]
|
|
next_token_logprob = next_token_logprobs[i]
|
|
top_n_tokens = req_data['top_n_tokens']
|
|
top_token_ids = req_data['top_token_ids']
|
|
top_token_logprobs = req_data['top_token_logprobs']
|
|
|
|
# Append next token to all tokens
|
|
all_input_ids[input_length] = next_token_id
|
|
new_input_length = input_length + 1
|
|
|
|
# Generated token
|
|
if is_tokenizer_transparent(self.tokenizer) and len(stopping_criteria.stop_sequence_criterias) == 0:
|
|
next_token_text = ''
|
|
else:
|
|
next_token_text, prefix_offset, read_offset = self.decode_token(
|
|
all_input_ids[0:new_input_length, 0], prefix_offset, read_offset
|
|
)
|
|
|
|
# Evaluate stopping criteria
|
|
stop, reason = stopping_criteria(
|
|
next_token_id,
|
|
next_token_text,
|
|
)
|
|
|
|
if not stop:
|
|
stopped = False
|
|
|
|
# Shard generations
|
|
# All generations will be appended in the rust sharded client
|
|
if i % self.world_size == self.rank:
|
|
if stop:
|
|
# Decode generated tokens
|
|
if is_tokenizer_transparent(self.tokenizer):
|
|
output_text = None
|
|
else:
|
|
output_text = self.decode(
|
|
all_input_ids[new_input_length - stopping_criteria.current_tokens: new_input_length, 0]
|
|
)
|
|
generated_text = GeneratedText(
|
|
output_text,
|
|
stopping_criteria.current_tokens,
|
|
reason,
|
|
seed if do_sample else None,
|
|
)
|
|
else:
|
|
generated_text = None
|
|
|
|
# Prefill
|
|
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
|
|
# Remove generated token to only have prefill and add nan for first prompt token
|
|
prefill_logprobs = [float("nan")] + next_token_logprobs
|
|
prefill_token_ids = all_input_ids[0: new_input_length - 1]
|
|
prefill_texts = self.tokenizer.batch_decode(
|
|
prefill_token_ids,
|
|
clean_up_tokenization_spaces=False,
|
|
skip_special_tokens=False,
|
|
)
|
|
prefill_tokens = Tokens(
|
|
prefill_token_ids,
|
|
prefill_logprobs,
|
|
prefill_texts,
|
|
is_special=[],
|
|
)
|
|
else:
|
|
prefill_tokens = None
|
|
|
|
if top_n_tokens > 0:
|
|
all_top_tokens = []
|
|
for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs):
|
|
toptoken_texts = self.tokenizer.batch_decode(
|
|
top_token_ids,
|
|
clean_up_tokenization_spaces=False,
|
|
skip_special_tokens=False,
|
|
)
|
|
special_toptokens = [
|
|
token_id in self.all_special_ids for token_id in top_token_ids
|
|
]
|
|
top_tokens = Tokens(
|
|
top_token_ids,
|
|
top_token_logprobs,
|
|
toptoken_texts,
|
|
special_toptokens,
|
|
)
|
|
all_top_tokens.append(top_tokens)
|
|
top_tokens = all_top_tokens
|
|
else:
|
|
top_tokens = None
|
|
|
|
generation = Generation(
|
|
request.id,
|
|
prefill_tokens,
|
|
Tokens(
|
|
[next_token_id],
|
|
[next_token_logprob],
|
|
[next_token_text],
|
|
[next_token_id in self.all_special_ids],
|
|
),
|
|
generated_text,
|
|
top_tokens,
|
|
)
|
|
|
|
generations.append(generation)
|
|
|
|
req.all_input_ids = all_input_ids
|
|
req.input_length = new_input_length
|
|
req.prefix_offset = prefix_offset
|
|
req.read_offset = read_offset
|
|
|
|
htorch.core.mark_step()
|
|
self.step = self.step + 1
|
|
if self.hb_profiler is not None:
|
|
if self.step > self.profiling_wait_steps + self.profiling_warmup_steps + self.profiling_steps:
|
|
self.hb_profiler.stop()
|
|
else:
|
|
self.hb_profiler.step()
|
|
|
|
forward_ns = start_decode - start
|
|
decode_ns = time.time_ns() - start_decode
|
|
return generations, batch if not stopped else None, (forward_ns, decode_ns)
|
|
|
|
def warmup(self, batches: List[CausalLMBatch]) -> None:
|
|
# prefill
|
|
_, prefill_batch, _ = self.generate_token([batches.pop(0)])
|
|
# decode
|
|
_, decode_batch, _ = self.generate_token([prefill_batch])
|
|
# shifts
|
|
self.shifting_warmup(decode_batch)
|
|
|
|
# if decode bs is 1 warmup ends here
|
|
if len(batches) == 0:
|
|
return
|
|
|
|
# prefill
|
|
_, prefill_batch, _ = self.generate_token([batches.pop(0)])
|
|
# concatenate and decode
|
|
_, decode_batch, _ = self.generate_token([decode_batch, prefill_batch])
|
|
# decodes
|
|
while decode_batch is not None:
|
|
_, decode_batch, _ = self.generate_token([decode_batch])
|
|
|
|
def shifting_warmup(self, batch: CausalLMBatch) -> None:
|
|
chunk_sizes = CHUNK_SIZES.copy()
|
|
chunk_sizes.extend([-chunk for chunk in chunk_sizes])
|
|
|
|
for chunk in chunk_sizes:
|
|
batch.merge_kv_cache_if_needed(batch.batch_size, chunk)
|
|
batch.realign(batch.batch_size, chunk, 0)
|
|
batch.split_kv_cache_if_needed(True)
|