mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
refactor to use prefix/postfix namming + fix all_input_ids_tensor
This commit is contained in:
parent
de043b53c4
commit
838756eb18
@ -9,7 +9,7 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Seqlen:
|
class Seqlen:
|
||||||
input_lengths: torch.Tensor
|
postfix_lengths: torch.Tensor
|
||||||
prefix_lengths: torch.Tensor
|
prefix_lengths: torch.Tensor
|
||||||
cu_seqlen_q: Optional[torch.Tensor]
|
cu_seqlen_q: Optional[torch.Tensor]
|
||||||
cu_seqlen_k: Optional[torch.Tensor]
|
cu_seqlen_k: Optional[torch.Tensor]
|
||||||
@ -18,16 +18,16 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
input_lengths,
|
postfix_lengths,
|
||||||
prefix_lengths,
|
prefix_lengths,
|
||||||
cu_seqlen_q=None,
|
cu_seqlen_q=None,
|
||||||
max_q=None,
|
max_q=None,
|
||||||
max_k=None,
|
max_k=None,
|
||||||
):
|
):
|
||||||
self.input_lengths = input_lengths
|
self.postfix_lengths = postfix_lengths
|
||||||
self.prefix_lengths = prefix_lengths
|
self.prefix_lengths = prefix_lengths
|
||||||
device = self.input_lengths.device
|
device = self.postfix_lengths.device
|
||||||
shape = self.input_lengths.shape
|
shape = self.postfix_lengths.shape
|
||||||
if cu_seqlen_q is None:
|
if cu_seqlen_q is None:
|
||||||
cu_seqlen_q = torch.arange(
|
cu_seqlen_q = torch.arange(
|
||||||
shape[0] + 1,
|
shape[0] + 1,
|
||||||
@ -43,7 +43,7 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
|
|||||||
# cuda graphs don't like this and this is necessary to clamp within mistral
|
# cuda graphs don't like this and this is necessary to clamp within mistral
|
||||||
# Although FA2 might not want the clamping
|
# Although FA2 might not want the clamping
|
||||||
# cu_seqlen_k[0] = 0
|
# cu_seqlen_k[0] = 0
|
||||||
total = self.input_lengths + self.prefix_lengths
|
total = self.postfix_lengths + self.prefix_lengths
|
||||||
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
|
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
|
||||||
|
|
||||||
self.cu_seqlen_q = cu_seqlen_q
|
self.cu_seqlen_q = cu_seqlen_q
|
||||||
@ -59,7 +59,7 @@ else:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Seqlen:
|
class Seqlen:
|
||||||
input_lengths: torch.Tensor
|
postfix_lengths: torch.Tensor
|
||||||
prefix_lengths: torch.Tensor
|
prefix_lengths: torch.Tensor
|
||||||
cu_seqlen_q: torch.Tensor
|
cu_seqlen_q: torch.Tensor
|
||||||
max_q: int
|
max_q: int
|
||||||
|
@ -143,9 +143,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
block_tables_tensor: torch.Tensor
|
block_tables_tensor: torch.Tensor
|
||||||
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
|
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
|
||||||
slots: torch.Tensor
|
slots: torch.Tensor
|
||||||
# size [b], containing the number of blocks that can be retrieved from the cache
|
|
||||||
prefix_lens: List[int]
|
|
||||||
prefix_lens_tensor: torch.Tensor
|
|
||||||
|
|
||||||
max_seqlen: int
|
max_seqlen: int
|
||||||
|
|
||||||
@ -162,8 +159,14 @@ class FlashCausalLMBatch(Batch):
|
|||||||
all_input_ids_tensor: torch.Tensor
|
all_input_ids_tensor: torch.Tensor
|
||||||
|
|
||||||
# Lengths of all generations present in the batch
|
# Lengths of all generations present in the batch
|
||||||
input_lengths: List[int]
|
postfix_lengths: List[int]
|
||||||
input_lengths_tensor: torch.Tensor
|
postfix_lengths_tensor: torch.Tensor
|
||||||
|
# size [b], containing the number of blocks that can be retrieved from the cache
|
||||||
|
prefix_lengths: List[int]
|
||||||
|
prefix_lengths_tensor: torch.Tensor
|
||||||
|
prompt_lengths: List[int]
|
||||||
|
prompt_lengths_tensor: torch.Tensor
|
||||||
|
|
||||||
prefix_offsets: List[Optional[int]]
|
prefix_offsets: List[Optional[int]]
|
||||||
read_offsets: List[Optional[int]]
|
read_offsets: List[Optional[int]]
|
||||||
|
|
||||||
@ -225,10 +228,13 @@ class FlashCausalLMBatch(Batch):
|
|||||||
slot_indices = []
|
slot_indices = []
|
||||||
prefill_cache_indices = []
|
prefill_cache_indices = []
|
||||||
|
|
||||||
input_lengths = []
|
prefix_lengths = []
|
||||||
|
postfix_lengths = []
|
||||||
|
prompt_lengths = []
|
||||||
prefix_offsets = []
|
prefix_offsets = []
|
||||||
read_offsets = []
|
read_offsets = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
|
all_postfix_ids = []
|
||||||
prefix_ids = []
|
prefix_ids = []
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
|
|
||||||
@ -257,7 +263,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
block_tables = []
|
block_tables = []
|
||||||
slots = []
|
slots = []
|
||||||
prefix_lens = []
|
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
for i, (r, tokenized_input) in enumerate(
|
for i, (r, tokenized_input) in enumerate(
|
||||||
@ -266,37 +271,39 @@ class FlashCausalLMBatch(Batch):
|
|||||||
# request id -> idx in list mapping
|
# request id -> idx in list mapping
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
|
|
||||||
orig_input_length = len(tokenized_input)
|
prompt_length = len(tokenized_input)
|
||||||
|
prompt_lengths.append(prompt_length)
|
||||||
|
|
||||||
prefix_len = r.prefix_len
|
prefix_length = r.prefix_len
|
||||||
assert (
|
assert (
|
||||||
prefix_len <= orig_input_length
|
prefix_length <= prompt_length
|
||||||
), f"Prefix {prefix_len} vs input {orig_input_length}"
|
), f"Prefix {prefix_length} vs input {prompt_length}"
|
||||||
if prefix_len == orig_input_length:
|
if prefix_length == prompt_length:
|
||||||
assert prefix_len > 0
|
assert prefix_length > 0
|
||||||
prefix_len -= 1
|
prefix_length -= 1
|
||||||
|
|
||||||
# Commented as it's costly.
|
# Commented as it's costly.
|
||||||
# log_master(logger.debug, "Tokenized input ids {tokenized_input}")
|
# log_master(logger.debug, "Tokenized input ids {tokenized_input}")
|
||||||
prefix_ids.append(tokenized_input[:prefix_len])
|
prefix_ids.append(tokenized_input[:prefix_length])
|
||||||
tokenized_input = tokenized_input[prefix_len:]
|
postfix_ids = tokenized_input[prefix_length:]
|
||||||
|
|
||||||
input_length = len(tokenized_input)
|
postfix_length = len(postfix_ids)
|
||||||
input_lengths.append(input_length)
|
postfix_lengths.append(postfix_length)
|
||||||
|
|
||||||
prefix_offsets.append(input_length - 5)
|
prefix_offsets.append(postfix_length - 5)
|
||||||
read_offsets.append(input_length)
|
read_offsets.append(postfix_length)
|
||||||
|
|
||||||
|
all_postfix_ids.append(postfix_ids)
|
||||||
all_input_ids.append(tokenized_input)
|
all_input_ids.append(tokenized_input)
|
||||||
|
|
||||||
# Position ids
|
# Position ids
|
||||||
request_position_ids = torch.arange(
|
request_position_ids = torch.arange(
|
||||||
prefix_len, orig_input_length, dtype=torch.int32
|
prefix_length, prompt_length, dtype=torch.int32
|
||||||
)
|
)
|
||||||
position_ids.append(request_position_ids)
|
position_ids.append(request_position_ids)
|
||||||
|
|
||||||
# Add cumulative lengths of all previous inputs
|
# Add cumulative lengths of all previous inputs
|
||||||
cu_seqlen_prefill.append(cumulative_length + input_length)
|
cu_seqlen_prefill.append(cumulative_length + postfix_length)
|
||||||
|
|
||||||
next_token_chooser_parameters.append(r.parameters)
|
next_token_chooser_parameters.append(r.parameters)
|
||||||
|
|
||||||
@ -309,7 +316,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
ADAPTER_TO_INDEX = get_adapter_to_index()
|
ADAPTER_TO_INDEX = get_adapter_to_index()
|
||||||
adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
|
adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
|
||||||
adapter_indices_list.append(torch.full((input_length,), adapter_index))
|
adapter_indices_list.append(torch.full((postfix_length,), adapter_index))
|
||||||
adapter_set.add(adapter_index)
|
adapter_set.add(adapter_index)
|
||||||
|
|
||||||
# Paged attention
|
# Paged attention
|
||||||
@ -318,11 +325,11 @@ class FlashCausalLMBatch(Batch):
|
|||||||
speculative_length = 0 if speculative_length is None else speculative_length
|
speculative_length = 0 if speculative_length is None else speculative_length
|
||||||
|
|
||||||
# Tokens that need to be mapped to blocks.
|
# Tokens that need to be mapped to blocks.
|
||||||
block_tokens = orig_input_length + max_new_tokens - 1 + speculative_length
|
block_tokens = prompt_length + max_new_tokens - 1 + speculative_length
|
||||||
|
|
||||||
# Tokens that need to be mapped to slots. We don't need slots for the
|
# Tokens that need to be mapped to slots. We don't need slots for the
|
||||||
# cached prefix (if present).
|
# cached prefix (if present).
|
||||||
slot_tokens = input_length + max_new_tokens - 1 + speculative_length
|
slot_tokens = postfix_length + max_new_tokens - 1 + speculative_length
|
||||||
|
|
||||||
# blocks and slots can be empty (for example in warmup)
|
# blocks and slots can be empty (for example in warmup)
|
||||||
if not r.blocks:
|
if not r.blocks:
|
||||||
@ -338,19 +345,19 @@ class FlashCausalLMBatch(Batch):
|
|||||||
else:
|
else:
|
||||||
request_blocks = r.blocks
|
request_blocks = r.blocks
|
||||||
request_slots = r.slots[
|
request_slots = r.slots[
|
||||||
prefix_len: #: orig_input_length + max_new_tokens + speculative_length
|
prefix_length: #: orig_input_length + max_new_tokens + speculative_length
|
||||||
]
|
]
|
||||||
|
|
||||||
block_tables.append(request_blocks)
|
block_tables.append(request_blocks)
|
||||||
|
|
||||||
slots.extend(request_slots)
|
slots.extend(request_slots)
|
||||||
prefix_lens.append(prefix_len)
|
prefix_lengths.append(prefix_length)
|
||||||
num_blocks += len(request_blocks)
|
num_blocks += len(request_blocks)
|
||||||
start_slots.append(cumulative_slot_tokens)
|
start_slots.append(cumulative_slot_tokens)
|
||||||
|
|
||||||
request_slot_indices = torch.arange(
|
request_slot_indices = torch.arange(
|
||||||
cumulative_slot_tokens,
|
cumulative_slot_tokens,
|
||||||
cumulative_slot_tokens + input_length,
|
cumulative_slot_tokens + postfix_length,
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
)
|
)
|
||||||
slot_indices.append(request_slot_indices)
|
slot_indices.append(request_slot_indices)
|
||||||
@ -358,8 +365,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
# Create tensor to slice into the kv tensor in prefill
|
# Create tensor to slice into the kv tensor in prefill
|
||||||
if sliding_window is not None:
|
if sliding_window is not None:
|
||||||
request_prefill_cache_indices = torch.arange(
|
request_prefill_cache_indices = torch.arange(
|
||||||
cumulative_length + max(0, input_length - sliding_window),
|
cumulative_length + max(0, postfix_length - sliding_window),
|
||||||
cumulative_length + input_length,
|
cumulative_length + postfix_length,
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
)
|
)
|
||||||
prefill_cache_indices.append(request_prefill_cache_indices)
|
prefill_cache_indices.append(request_prefill_cache_indices)
|
||||||
@ -370,14 +377,16 @@ class FlashCausalLMBatch(Batch):
|
|||||||
if r.prefill_logprobs:
|
if r.prefill_logprobs:
|
||||||
prefill_head_indices.append(request_position_ids + cumulative_length)
|
prefill_head_indices.append(request_position_ids + cumulative_length)
|
||||||
prefill_next_token_indices.append(
|
prefill_next_token_indices.append(
|
||||||
prefill_out_cumulative_length + input_length - 1
|
prefill_out_cumulative_length + postfix_length - 1
|
||||||
)
|
)
|
||||||
prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
|
prefill_cu_outlens.append(
|
||||||
prefill_out_cumulative_length += input_length
|
prefill_out_cumulative_length + postfix_length
|
||||||
|
)
|
||||||
|
prefill_out_cumulative_length += postfix_length
|
||||||
else:
|
else:
|
||||||
prefill_head_indices.append(
|
prefill_head_indices.append(
|
||||||
torch.tensor(
|
torch.tensor(
|
||||||
[cumulative_length + input_length - 1], dtype=torch.int32
|
[cumulative_length + postfix_length - 1], dtype=torch.int32
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
prefill_next_token_indices.append(prefill_out_cumulative_length)
|
prefill_next_token_indices.append(prefill_out_cumulative_length)
|
||||||
@ -385,12 +394,13 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefill_out_cumulative_length += 1
|
prefill_out_cumulative_length += 1
|
||||||
|
|
||||||
# Update
|
# Update
|
||||||
cumulative_length += input_length
|
cumulative_length += postfix_length
|
||||||
cumulative_slot_tokens += slot_tokens
|
cumulative_slot_tokens += slot_tokens
|
||||||
max_seqlen = max(max_seqlen, input_length)
|
max_seqlen = max(max_seqlen, postfix_length)
|
||||||
max_blocks = max(max_blocks, len(request_blocks))
|
max_blocks = max(max_blocks, len(request_blocks))
|
||||||
max_length = max(
|
max_length = max(
|
||||||
max_length, input_length + max_new_tokens + speculative_length
|
max_length,
|
||||||
|
prefix_length + postfix_length + max_new_tokens + speculative_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
adapter_indices = torch.cat(adapter_indices_list).to(
|
adapter_indices = torch.cat(adapter_indices_list).to(
|
||||||
@ -415,13 +425,13 @@ class FlashCausalLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if len(pb.requests) > 1:
|
if len(pb.requests) > 1:
|
||||||
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
|
input_ids = np.concatenate(all_postfix_ids, dtype=np.int64)
|
||||||
position_ids = torch.cat(position_ids)
|
position_ids = torch.cat(position_ids)
|
||||||
slot_indices = torch.cat(slot_indices)
|
slot_indices = torch.cat(slot_indices)
|
||||||
if sliding_window is not None:
|
if sliding_window is not None:
|
||||||
prefill_cache_indices = torch.cat(prefill_cache_indices)
|
prefill_cache_indices = torch.cat(prefill_cache_indices)
|
||||||
else:
|
else:
|
||||||
input_ids = all_input_ids[0]
|
input_ids = all_postfix_ids[0]
|
||||||
position_ids = position_ids[0]
|
position_ids = position_ids[0]
|
||||||
slot_indices = slot_indices[0]
|
slot_indices = slot_indices[0]
|
||||||
if sliding_window is not None:
|
if sliding_window is not None:
|
||||||
@ -436,8 +446,11 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefill_cache_indices.to(device) if sliding_window is not None else None
|
prefill_cache_indices.to(device) if sliding_window is not None else None
|
||||||
)
|
)
|
||||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||||
input_lengths_tensor = torch.tensor(
|
postfix_lengths_tensor = torch.tensor(
|
||||||
input_lengths, dtype=torch.int32, device=device
|
postfix_lengths, dtype=torch.int32, device=device
|
||||||
|
)
|
||||||
|
prompt_lengths_tensor = torch.tensor(
|
||||||
|
prompt_lengths, dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
|
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
|
||||||
@ -470,7 +483,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
for i, request_blocks in enumerate(block_tables):
|
for i, request_blocks in enumerate(block_tables):
|
||||||
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
|
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
|
||||||
block_tables_tensor = block_tables_tensor.to(device)
|
block_tables_tensor = block_tables_tensor.to(device)
|
||||||
prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
|
prefix_lengths_tensor = torch.tensor(
|
||||||
|
prefix_lengths, dtype=torch.int32, device=device
|
||||||
|
)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=pb.id,
|
batch_id=pb.id,
|
||||||
@ -485,14 +500,16 @@ class FlashCausalLMBatch(Batch):
|
|||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
block_tables_tensor=block_tables_tensor,
|
block_tables_tensor=block_tables_tensor,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
prefix_lens=prefix_lens,
|
prefix_lengths=prefix_lengths,
|
||||||
prefix_lens_tensor=prefix_lens_tensor,
|
prefix_lengths_tensor=prefix_lengths_tensor,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
prefill_head_indices=prefill_head_indices,
|
prefill_head_indices=prefill_head_indices,
|
||||||
prefill_next_token_indices=prefill_next_token_indices,
|
prefill_next_token_indices=prefill_next_token_indices,
|
||||||
prefill_cu_outlens=prefill_cu_outlens,
|
prefill_cu_outlens=prefill_cu_outlens,
|
||||||
input_lengths=input_lengths,
|
postfix_lengths=postfix_lengths,
|
||||||
input_lengths_tensor=input_lengths_tensor,
|
postfix_lengths_tensor=postfix_lengths_tensor,
|
||||||
|
prompt_lengths=prompt_lengths,
|
||||||
|
prompt_lengths_tensor=prompt_lengths_tensor,
|
||||||
prefix_offsets=prefix_offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
@ -556,8 +573,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
prefix_ids = []
|
prefix_ids = []
|
||||||
|
|
||||||
input_lengths = []
|
postfix_lengths = []
|
||||||
prefix_lens = []
|
prefix_lengths = []
|
||||||
prefix_offsets = []
|
prefix_offsets = []
|
||||||
read_offsets = []
|
read_offsets = []
|
||||||
|
|
||||||
@ -578,15 +595,15 @@ class FlashCausalLMBatch(Batch):
|
|||||||
requests.append(self.requests[idx])
|
requests.append(self.requests[idx])
|
||||||
|
|
||||||
# Get length
|
# Get length
|
||||||
request_input_length = self.input_lengths[idx]
|
request_input_length = self.postfix_lengths[idx]
|
||||||
prefix_len = self.prefix_lens[idx]
|
prefix_length = self.prefix_lengths[idx]
|
||||||
max_seqlen = max(max_seqlen, request_input_length)
|
max_seqlen = max(max_seqlen, request_input_length)
|
||||||
|
|
||||||
all_input_ids.append(self.all_input_ids[idx])
|
all_input_ids.append(self.all_input_ids[idx])
|
||||||
prefix_ids.append(self.prefix_ids[idx])
|
prefix_ids.append(self.prefix_ids[idx])
|
||||||
|
|
||||||
input_lengths.append(request_input_length)
|
postfix_lengths.append(request_input_length)
|
||||||
prefix_lens.append(prefix_len)
|
prefix_lengths.append(prefix_length)
|
||||||
prefix_offsets.append(self.prefix_offsets[idx])
|
prefix_offsets.append(self.prefix_offsets[idx])
|
||||||
read_offsets.append(self.read_offsets[idx])
|
read_offsets.append(self.read_offsets[idx])
|
||||||
|
|
||||||
@ -629,9 +646,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
adapter_indices = self.adapter_meta.adapter_indices[indices]
|
adapter_indices = self.adapter_meta.adapter_indices[indices]
|
||||||
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
||||||
block_tables_tensor = self.block_tables_tensor[indices]
|
block_tables_tensor = self.block_tables_tensor[indices]
|
||||||
input_lengths_tensor = self.input_lengths_tensor[indices]
|
postfix_lengths_tensor = self.postfix_lengths_tensor[indices]
|
||||||
slots = self.slots[slot_filtering_indices]
|
slots = self.slots[slot_filtering_indices]
|
||||||
prefix_lens_tensor = self.prefix_lens_tensor[indices]
|
prefix_lengths_tensor = self.prefix_lengths_tensor[indices]
|
||||||
next_token_chooser = self.next_token_chooser.filter(indices)
|
next_token_chooser = self.next_token_chooser.filter(indices)
|
||||||
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
|
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
|
||||||
speculative_ids = (
|
speculative_ids = (
|
||||||
@ -666,10 +683,10 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefill_head_indices=None,
|
prefill_head_indices=None,
|
||||||
prefill_next_token_indices=None,
|
prefill_next_token_indices=None,
|
||||||
prefill_cu_outlens=None,
|
prefill_cu_outlens=None,
|
||||||
input_lengths=input_lengths,
|
postfix_lengths=postfix_lengths,
|
||||||
input_lengths_tensor=input_lengths_tensor,
|
postfix_lengths_tensor=postfix_lengths_tensor,
|
||||||
prefix_lens=prefix_lens,
|
prefix_lengths=prefix_lengths,
|
||||||
prefix_lens_tensor=prefix_lens_tensor,
|
prefix_lengths_tensor=prefix_lengths_tensor,
|
||||||
prefix_offsets=prefix_offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
@ -720,7 +737,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
+ speculative_length
|
+ speculative_length
|
||||||
- stopping_criteria.current_tokens
|
- stopping_criteria.current_tokens
|
||||||
for input_length, stopping_criteria in zip(
|
for input_length, stopping_criteria in zip(
|
||||||
b.input_lengths, b.stopping_criterias
|
b.postfix_lengths, b.stopping_criterias
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -729,13 +746,15 @@ class FlashCausalLMBatch(Batch):
|
|||||||
position_ids = batches[0].position_ids.new_empty(total_batch_size)
|
position_ids = batches[0].position_ids.new_empty(total_batch_size)
|
||||||
slots = batches[0].slots.new_empty(total_slots)
|
slots = batches[0].slots.new_empty(total_slots)
|
||||||
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
|
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
|
||||||
input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
|
postfix_lengths_tensor = batches[0].postfix_lengths_tensor.new_empty(
|
||||||
total_batch_size
|
total_batch_size
|
||||||
)
|
)
|
||||||
block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
|
block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
|
||||||
(total_batch_size, max_blocks)
|
(total_batch_size, max_blocks)
|
||||||
)
|
)
|
||||||
prefix_lens_tensor = batches[0].prefix_lens_tensor.new_empty(total_batch_size)
|
prefix_lengths_tensor = batches[0].prefix_lengths_tensor.new_empty(
|
||||||
|
total_batch_size
|
||||||
|
)
|
||||||
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
|
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
|
||||||
(total_batch_size, max_length)
|
(total_batch_size, max_length)
|
||||||
)
|
)
|
||||||
@ -753,11 +772,11 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
start_slots = []
|
start_slots = []
|
||||||
block_tables = []
|
block_tables = []
|
||||||
prefix_lens = []
|
prefix_lengths = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
prefix_ids = []
|
prefix_ids = []
|
||||||
|
|
||||||
input_lengths = []
|
postfix_lengths = []
|
||||||
prefix_offsets = []
|
prefix_offsets = []
|
||||||
read_offsets = []
|
read_offsets = []
|
||||||
|
|
||||||
@ -790,7 +809,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
input_ids[start_index:end_index] = batch.input_ids
|
input_ids[start_index:end_index] = batch.input_ids
|
||||||
position_ids[start_index:end_index] = batch.position_ids
|
position_ids[start_index:end_index] = batch.position_ids
|
||||||
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
|
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
|
||||||
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
|
postfix_lengths_tensor[start_index:end_index] = batch.postfix_lengths_tensor
|
||||||
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
||||||
slots[slots_start_index:slots_end_index] = batch.slots
|
slots[slots_start_index:slots_end_index] = batch.slots
|
||||||
|
|
||||||
@ -817,16 +836,16 @@ class FlashCausalLMBatch(Batch):
|
|||||||
start_index:end_index, : batch.block_tables_tensor.shape[1]
|
start_index:end_index, : batch.block_tables_tensor.shape[1]
|
||||||
] = batch.block_tables_tensor[:, :max_blocks]
|
] = batch.block_tables_tensor[:, :max_blocks]
|
||||||
|
|
||||||
prefix_lens_tensor[start_index:end_index] = batch.prefix_lens_tensor
|
prefix_lengths_tensor[start_index:end_index] = batch.prefix_lengths_tensor
|
||||||
|
|
||||||
start_slots.append(batch.start_slots + cumulative_slots)
|
start_slots.append(batch.start_slots + cumulative_slots)
|
||||||
|
|
||||||
block_tables.extend(batch.block_tables)
|
block_tables.extend(batch.block_tables)
|
||||||
prefix_lens.extend(batch.prefix_lens)
|
prefix_lengths.extend(batch.prefix_lengths)
|
||||||
all_input_ids.extend(batch.all_input_ids)
|
all_input_ids.extend(batch.all_input_ids)
|
||||||
prefix_ids.extend(batch.prefix_ids)
|
prefix_ids.extend(batch.prefix_ids)
|
||||||
|
|
||||||
input_lengths.extend(batch.input_lengths)
|
postfix_lengths.extend(batch.postfix_lengths)
|
||||||
prefix_offsets.extend(batch.prefix_offsets)
|
prefix_offsets.extend(batch.prefix_offsets)
|
||||||
read_offsets.extend(batch.read_offsets)
|
read_offsets.extend(batch.read_offsets)
|
||||||
|
|
||||||
@ -872,15 +891,15 @@ class FlashCausalLMBatch(Batch):
|
|||||||
slot_indices=slot_indices,
|
slot_indices=slot_indices,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
block_tables_tensor=block_tables_tensor,
|
block_tables_tensor=block_tables_tensor,
|
||||||
prefix_lens=prefix_lens,
|
prefix_lengths=prefix_lengths,
|
||||||
prefix_lens_tensor=prefix_lens_tensor,
|
prefix_lengths_tensor=prefix_lengths_tensor,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
prefill_head_indices=None,
|
prefill_head_indices=None,
|
||||||
prefill_next_token_indices=None,
|
prefill_next_token_indices=None,
|
||||||
prefill_cu_outlens=None,
|
prefill_cu_outlens=None,
|
||||||
input_lengths=input_lengths,
|
postfix_lengths=postfix_lengths,
|
||||||
input_lengths_tensor=input_lengths_tensor,
|
postfix_lengths_tensor=postfix_lengths_tensor,
|
||||||
prefix_offsets=prefix_offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
@ -1100,9 +1119,9 @@ class FlashCausalLM(Model):
|
|||||||
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||||
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||||
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
|
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
|
||||||
input_lengths = [max_s] * bs
|
postfix_lengths = [max_s] * bs
|
||||||
prefix_lengths = [0] * bs
|
prefix_lengths = [0] * bs
|
||||||
input_lengths_tensor = (
|
postfix_lengths_tensor = (
|
||||||
torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
||||||
)
|
)
|
||||||
prefix_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
prefix_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||||
@ -1114,8 +1133,8 @@ class FlashCausalLM(Model):
|
|||||||
if ATTENTION == "flashinfer":
|
if ATTENTION == "flashinfer":
|
||||||
block_tables = block_tables_to_ragged(
|
block_tables = block_tables_to_ragged(
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
input_lengths=input_lengths,
|
postfix_lengths=postfix_lengths,
|
||||||
prefix_lens=prefix_lengths,
|
prefix_lengths=prefix_lengths,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention.flashinfer import (
|
from text_generation_server.layers.attention.flashinfer import (
|
||||||
create_decode_state_cuda_graphs,
|
create_decode_state_cuda_graphs,
|
||||||
@ -1143,7 +1162,7 @@ class FlashCausalLM(Model):
|
|||||||
"kv_cache": self.kv_cache,
|
"kv_cache": self.kv_cache,
|
||||||
"block_tables": block_tables,
|
"block_tables": block_tables,
|
||||||
"slots": slots,
|
"slots": slots,
|
||||||
"input_lengths": input_lengths_tensor,
|
"postfix_lengths": postfix_lengths_tensor,
|
||||||
"prefix_lengths": prefix_lengths_tensor,
|
"prefix_lengths": prefix_lengths_tensor,
|
||||||
"state": state,
|
"state": state,
|
||||||
"graph": graph,
|
"graph": graph,
|
||||||
@ -1154,12 +1173,12 @@ class FlashCausalLM(Model):
|
|||||||
with self._forward_context(
|
with self._forward_context(
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
cu_seqlen_prefill=None,
|
cu_seqlen_prefill=None,
|
||||||
input_lengths_tensor=input_lengths_tensor,
|
postfix_lengths_tensor=postfix_lengths_tensor,
|
||||||
state=state,
|
state=state,
|
||||||
prefix_lens_tensor=prefix_lengths_tensor,
|
prefix_lengths_tensor=prefix_lengths_tensor,
|
||||||
):
|
):
|
||||||
seqlen = Seqlen(
|
seqlen = Seqlen(
|
||||||
input_lengths=input_lengths_tensor,
|
postfix_lengths=postfix_lengths_tensor,
|
||||||
prefix_lengths=prefix_lengths_tensor,
|
prefix_lengths=prefix_lengths_tensor,
|
||||||
cu_seqlen_q=None,
|
cu_seqlen_q=None,
|
||||||
max_q=1,
|
max_q=1,
|
||||||
@ -1183,7 +1202,7 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||||
seqlen = Seqlen(
|
seqlen = Seqlen(
|
||||||
input_lengths=input_lengths_tensor,
|
postfix_lengths=postfix_lengths_tensor,
|
||||||
prefix_lengths=prefix_lengths_tensor,
|
prefix_lengths=prefix_lengths_tensor,
|
||||||
cu_seqlen_q=None,
|
cu_seqlen_q=None,
|
||||||
max_q=1,
|
max_q=1,
|
||||||
@ -1340,15 +1359,17 @@ class FlashCausalLM(Model):
|
|||||||
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
|
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
|
||||||
|
|
||||||
# Dummy value, some models (starcoder2) don't accept `None`.
|
# Dummy value, some models (starcoder2) don't accept `None`.
|
||||||
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
|
postfix_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
|
||||||
prefix_lens_tensor = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
|
prefix_lengths_tensor = torch.zeros(
|
||||||
|
seqlen, dtype=torch.int32, device=self.device
|
||||||
|
)
|
||||||
cu_seqlen_prefill = torch.tensor(
|
cu_seqlen_prefill = torch.tensor(
|
||||||
[0, seqlen], device=self.device, dtype=torch.int32
|
[0, seqlen], device=self.device, dtype=torch.int32
|
||||||
)
|
)
|
||||||
max_s = seqlen
|
max_s = seqlen
|
||||||
seqlen = Seqlen(
|
seqlen = Seqlen(
|
||||||
input_lengths=input_lengths,
|
postfix_lengths=postfix_lengths,
|
||||||
prefix_lengths=prefix_lens_tensor,
|
prefix_lengths=prefix_lengths_tensor,
|
||||||
cu_seqlen_q=cu_seqlen_prefill,
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
max_q=1,
|
max_q=1,
|
||||||
max_k=seqlen,
|
max_k=seqlen,
|
||||||
@ -1379,7 +1400,7 @@ class FlashCausalLM(Model):
|
|||||||
kv_cache = self.kv_cache
|
kv_cache = self.kv_cache
|
||||||
block_tables = batch.block_tables_tensor
|
block_tables = batch.block_tables_tensor
|
||||||
slots = batch.slots[batch.slot_indices]
|
slots = batch.slots[batch.slot_indices]
|
||||||
input_lengths = batch.input_lengths_tensor
|
postfix_lengths = batch.postfix_lengths_tensor
|
||||||
max_s = batch.max_seqlen
|
max_s = batch.max_seqlen
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
@ -1396,11 +1417,11 @@ class FlashCausalLM(Model):
|
|||||||
position_ids.unsqueeze(-1).expand(B, new_length) + arange
|
position_ids.unsqueeze(-1).expand(B, new_length) + arange
|
||||||
).view(-1)
|
).view(-1)
|
||||||
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
||||||
input_lengths = (
|
postfix_lengths = (
|
||||||
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
postfix_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||||
).view(-1)
|
).view(-1)
|
||||||
prefix_lens_tensor = (
|
prefix_lengths_tensor = (
|
||||||
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
|
batch.prefix_lengths_tensor.unsqueeze(-1).expand(B, new_length)
|
||||||
).reshape(-1)
|
).reshape(-1)
|
||||||
|
|
||||||
# Add Copy the block tables for all members
|
# Add Copy the block tables for all members
|
||||||
@ -1421,8 +1442,8 @@ class FlashCausalLM(Model):
|
|||||||
kv_cache = self.kv_cache
|
kv_cache = self.kv_cache
|
||||||
block_tables = batch.block_tables_tensor
|
block_tables = batch.block_tables_tensor
|
||||||
slots = batch.slots[batch.slot_indices]
|
slots = batch.slots[batch.slot_indices]
|
||||||
input_lengths = batch.input_lengths_tensor
|
postfix_lengths = batch.postfix_lengths_tensor
|
||||||
prefix_lens_tensor = batch.prefix_lens_tensor
|
prefix_lengths_tensor = batch.prefix_lengths_tensor
|
||||||
max_s = batch.max_seqlen
|
max_s = batch.max_seqlen
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
@ -1444,19 +1465,19 @@ class FlashCausalLM(Model):
|
|||||||
if ATTENTION == "flashinfer":
|
if ATTENTION == "flashinfer":
|
||||||
block_tables = block_tables_to_ragged(
|
block_tables = block_tables_to_ragged(
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
input_lengths=batch.input_lengths,
|
postfix_lengths=batch.postfix_lengths,
|
||||||
prefix_lens=batch.prefix_lens,
|
prefix_lengths=batch.prefix_lengths,
|
||||||
)
|
)
|
||||||
with self._forward_context(
|
with self._forward_context(
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
input_lengths_tensor=input_lengths,
|
postfix_lengths_tensor=postfix_lengths,
|
||||||
prefix_lens_tensor=prefix_lens_tensor,
|
prefix_lengths_tensor=prefix_lengths_tensor,
|
||||||
):
|
):
|
||||||
max_k = (input_lengths + prefix_lens_tensor).max().item()
|
max_k = (postfix_lengths + prefix_lengths_tensor).max().item()
|
||||||
seqlen = Seqlen(
|
seqlen = Seqlen(
|
||||||
input_lengths=input_lengths,
|
postfix_lengths=postfix_lengths,
|
||||||
prefix_lengths=prefix_lens_tensor,
|
prefix_lengths=prefix_lengths_tensor,
|
||||||
cu_seqlen_q=cu_seqlen_prefill,
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
max_q=max_s,
|
max_q=max_s,
|
||||||
max_k=max_k,
|
max_k=max_k,
|
||||||
@ -1485,8 +1506,8 @@ class FlashCausalLM(Model):
|
|||||||
if ATTENTION == "flashinfer":
|
if ATTENTION == "flashinfer":
|
||||||
block_tables = block_tables_to_ragged(
|
block_tables = block_tables_to_ragged(
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
input_lengths=batch.input_lengths,
|
postfix_lengths=batch.postfix_lengths,
|
||||||
prefix_lens=batch.prefix_lens,
|
prefix_lengths=batch.prefix_lengths,
|
||||||
)
|
)
|
||||||
# assert block_tables.shape[0] >= slots.shape[0]
|
# assert block_tables.shape[0] >= slots.shape[0]
|
||||||
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
||||||
@ -1499,16 +1520,18 @@ class FlashCausalLM(Model):
|
|||||||
# so it doesn't matter if we override it with bogus values.
|
# so it doesn't matter if we override it with bogus values.
|
||||||
cuda_graph["slots"].fill_(0)
|
cuda_graph["slots"].fill_(0)
|
||||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||||
cuda_graph["input_lengths"].zero_()
|
cuda_graph["postfix_lengths"].zero_()
|
||||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
cuda_graph["postfix_lengths"][: postfix_lengths.shape[0]] = postfix_lengths
|
||||||
cuda_graph["prefix_lengths"].zero_()
|
cuda_graph["prefix_lengths"].zero_()
|
||||||
cuda_graph["prefix_lengths"][: prefix_lens_tensor.shape[0]] = prefix_lens_tensor
|
cuda_graph["prefix_lengths"][
|
||||||
|
: prefix_lengths_tensor.shape[0]
|
||||||
|
] = prefix_lengths_tensor
|
||||||
|
|
||||||
with self._forward_context(
|
with self._forward_context(
|
||||||
block_tables=cuda_graph["block_tables"],
|
block_tables=cuda_graph["block_tables"],
|
||||||
cu_seqlen_prefill=None,
|
cu_seqlen_prefill=None,
|
||||||
input_lengths_tensor=cuda_graph["input_lengths"],
|
postfix_lengths_tensor=cuda_graph["postfix_lengths"],
|
||||||
prefix_lens_tensor=cuda_graph["prefix_lengths"],
|
prefix_lengths_tensor=cuda_graph["prefix_lengths"],
|
||||||
state=cuda_graph["state"],
|
state=cuda_graph["state"],
|
||||||
):
|
):
|
||||||
# Replay the graph
|
# Replay the graph
|
||||||
@ -1586,7 +1609,7 @@ class FlashCausalLM(Model):
|
|||||||
accepted_ids,
|
accepted_ids,
|
||||||
speculative_ids,
|
speculative_ids,
|
||||||
) = batch.next_token_chooser(
|
) = batch.next_token_chooser(
|
||||||
batch.all_input_ids_tensor[:, : batch.max_seqlen],
|
batch.all_input_ids_tensor[:, : max(batch.postfix_lengths)],
|
||||||
next_token_logits,
|
next_token_logits,
|
||||||
speculate,
|
speculate,
|
||||||
batch.speculative_ids,
|
batch.speculative_ids,
|
||||||
@ -1619,7 +1642,12 @@ class FlashCausalLM(Model):
|
|||||||
stopped = True
|
stopped = True
|
||||||
|
|
||||||
# Zipped iterator
|
# Zipped iterator
|
||||||
iterator = zip(batch.input_lengths, batch.all_input_ids, accepted_ids)
|
iterator = zip(
|
||||||
|
batch.prefix_lengths,
|
||||||
|
batch.postfix_lengths,
|
||||||
|
batch.all_input_ids,
|
||||||
|
accepted_ids,
|
||||||
|
)
|
||||||
|
|
||||||
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second
|
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second
|
||||||
# one, we need to first do a GPU <-> CPU sync
|
# one, we need to first do a GPU <-> CPU sync
|
||||||
@ -1627,10 +1655,15 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
# For each member of the batch
|
# For each member of the batch
|
||||||
index = 0
|
index = 0
|
||||||
for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator):
|
for i, (
|
||||||
|
prefix_length,
|
||||||
|
postfix_length,
|
||||||
|
all_input_ids,
|
||||||
|
n_accepted_ids,
|
||||||
|
) in enumerate(iterator):
|
||||||
# Indexing metadata
|
# Indexing metadata
|
||||||
start_index = cumulative_length
|
start_index = cumulative_length
|
||||||
end_index = cumulative_length + input_length
|
end_index = cumulative_length + postfix_length
|
||||||
|
|
||||||
if prefill:
|
if prefill:
|
||||||
# Indexing metadata
|
# Indexing metadata
|
||||||
@ -1662,16 +1695,18 @@ class FlashCausalLM(Model):
|
|||||||
]
|
]
|
||||||
|
|
||||||
for j in range(n_accepted_ids):
|
for j in range(n_accepted_ids):
|
||||||
batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index]
|
batch.all_input_ids_tensor[i, prefix_length + postfix_length + j] = (
|
||||||
|
next_input_ids[index]
|
||||||
|
)
|
||||||
index += 1
|
index += 1
|
||||||
|
|
||||||
cumulative_length += input_length
|
cumulative_length += postfix_length
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
|
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
|
||||||
batch.speculative_ids = speculative_ids
|
batch.speculative_ids = speculative_ids
|
||||||
batch.position_ids = next_position_ids + accepted_ids
|
batch.position_ids = next_position_ids + accepted_ids
|
||||||
batch.input_lengths_tensor += accepted_ids
|
batch.postfix_lengths_tensor += accepted_ids
|
||||||
batch.slot_indices += accepted_ids
|
batch.slot_indices += accepted_ids
|
||||||
batch.adapter_meta.adapter_indices = next_adapter_indices
|
batch.adapter_meta.adapter_indices = next_adapter_indices
|
||||||
|
|
||||||
@ -1702,7 +1737,7 @@ class FlashCausalLM(Model):
|
|||||||
# Zipped iterator
|
# Zipped iterator
|
||||||
iterator = zip(
|
iterator = zip(
|
||||||
batch.requests,
|
batch.requests,
|
||||||
batch.input_lengths,
|
batch.postfix_lengths,
|
||||||
batch.prefix_offsets,
|
batch.prefix_offsets,
|
||||||
batch.read_offsets,
|
batch.read_offsets,
|
||||||
batch.stopping_criterias,
|
batch.stopping_criterias,
|
||||||
@ -1867,9 +1902,9 @@ class FlashCausalLM(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
batch.input_lengths[i] = input_length + n_accepted_ids
|
batch.postfix_lengths[i] = input_length + n_accepted_ids
|
||||||
if batch.input_lengths[i] > batch.max_seqlen:
|
if batch.postfix_lengths[i] > batch.max_seqlen:
|
||||||
batch.max_seqlen = batch.input_lengths[i]
|
batch.max_seqlen = batch.postfix_lengths[i]
|
||||||
batch.prefix_offsets[i] = prefix_offset
|
batch.prefix_offsets[i] = prefix_offset
|
||||||
batch.read_offsets[i] = read_offset
|
batch.read_offsets[i] = read_offset
|
||||||
batch.all_input_ids[i] = all_input_ids
|
batch.all_input_ids[i] = all_input_ids
|
||||||
@ -1893,8 +1928,8 @@ class FlashCausalLM(Model):
|
|||||||
*,
|
*,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
input_lengths_tensor: torch.Tensor,
|
postfix_lengths_tensor: torch.Tensor,
|
||||||
prefix_lens_tensor: torch.Tensor,
|
prefix_lengths_tensor: torch.Tensor,
|
||||||
state: Optional[Any] = None,
|
state: Optional[Any] = None,
|
||||||
) -> ContextManager:
|
) -> ContextManager:
|
||||||
if ATTENTION != "flashinfer":
|
if ATTENTION != "flashinfer":
|
||||||
@ -1905,7 +1940,7 @@ class FlashCausalLM(Model):
|
|||||||
use_prefill_with_paged_kv_state,
|
use_prefill_with_paged_kv_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
# has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens)
|
# has_prefix_lengths = any(prefix_length > 0 for prefix_length in prefix_lengths)
|
||||||
|
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
return use_prefill_with_paged_kv_state(
|
return use_prefill_with_paged_kv_state(
|
||||||
@ -1914,12 +1949,12 @@ class FlashCausalLM(Model):
|
|||||||
),
|
),
|
||||||
# block_tables=block_tables_to_ragged(
|
# block_tables=block_tables_to_ragged(
|
||||||
# block_tables=block_tables,
|
# block_tables=block_tables,
|
||||||
# input_lengths=input_lengths,
|
# postfix_lengths=postfix_lengths,
|
||||||
# prefix_lens=prefix_lens,
|
# prefix_lengths=prefix_lengths,
|
||||||
# ),
|
# ),
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
cu_seqlens=cu_seqlen_prefill,
|
cu_seqlens=cu_seqlen_prefill,
|
||||||
input_lengths=input_lengths_tensor + prefix_lens_tensor,
|
input_lengths=postfix_lengths_tensor + prefix_lengths_tensor,
|
||||||
num_heads=self.num_heads,
|
num_heads=self.num_heads,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
head_size=self.head_size,
|
head_size=self.head_size,
|
||||||
@ -1928,10 +1963,10 @@ class FlashCausalLM(Model):
|
|||||||
window_left=self.sliding_window,
|
window_left=self.sliding_window,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert input_lengths_tensor is not None
|
assert postfix_lengths_tensor is not None
|
||||||
return use_decode_state(
|
return use_decode_state(
|
||||||
state=state if state is not None else self.decode_state,
|
state=state if state is not None else self.decode_state,
|
||||||
input_lengths=input_lengths_tensor + prefix_lens_tensor,
|
input_lengths=postfix_lengths_tensor + prefix_lengths_tensor,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
num_heads=self.num_heads,
|
num_heads=self.num_heads,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
@ -1943,19 +1978,21 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
|
|
||||||
def block_tables_to_ragged(
|
def block_tables_to_ragged(
|
||||||
*, block_tables: torch.Tensor, input_lengths: List[int], prefix_lens: List[int]
|
*, block_tables: torch.Tensor, postfix_lengths: List[int], prefix_lengths: List[int]
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Convert block table to ragged format compatible with FlashInfer."""
|
"""Convert block table to ragged format compatible with FlashInfer."""
|
||||||
assert len(input_lengths) == len(prefix_lens)
|
assert len(postfix_lengths) == len(prefix_lengths)
|
||||||
|
|
||||||
total_len = sum(input_lengths) + sum(prefix_lens)
|
total_len = sum(postfix_lengths) + sum(prefix_lengths)
|
||||||
block_tables_ragged = torch.empty(
|
block_tables_ragged = torch.empty(
|
||||||
total_len, dtype=torch.int32, device=block_tables.device
|
total_len, dtype=torch.int32, device=block_tables.device
|
||||||
)
|
)
|
||||||
|
|
||||||
offset = 0
|
offset = 0
|
||||||
for i, (input_length, prefix_len) in enumerate(zip(input_lengths, prefix_lens)):
|
for i, (input_length, prefix_length) in enumerate(
|
||||||
seq_len = prefix_len + input_length
|
zip(postfix_lengths, prefix_lengths)
|
||||||
|
):
|
||||||
|
seq_len = prefix_length + input_length
|
||||||
block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len]
|
block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len]
|
||||||
offset += seq_len
|
offset += seq_len
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user