mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
working
This commit is contained in:
parent
173bc99ab3
commit
34f5dc525e
@ -16,7 +16,7 @@ from transformers import (
|
|||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
GenerationConfig,
|
GenerationConfig,
|
||||||
)
|
)
|
||||||
from typing import Any, ContextManager, Iterable, Optional, Tuple, List, Type, Dict
|
from typing import Any, ContextManager, Iterable, Optional, Tuple, List, Type, Dict, Union
|
||||||
|
|
||||||
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
|
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
|
||||||
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
||||||
@ -119,7 +119,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
requests_idx_mapping: Dict[int, int]
|
requests_idx_mapping: Dict[int, int]
|
||||||
|
|
||||||
# Decoder values
|
# Decoder values
|
||||||
input_ids: torch.Tensor
|
# Can be a list for easy filtering
|
||||||
|
# If `input_ids` is a list, it needs to be materialized to a tensor first
|
||||||
|
input_ids: Union[torch.Tensor, List[List[int]]]
|
||||||
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode
|
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode
|
||||||
position_ids: Optional[torch.Tensor]
|
position_ids: Optional[torch.Tensor]
|
||||||
speculative_ids: Optional[torch.Tensor]
|
speculative_ids: Optional[torch.Tensor]
|
||||||
@ -178,7 +180,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode
|
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode
|
||||||
postfix_lengths_tensor: Optional[torch.Tensor]
|
postfix_lengths_tensor: Optional[torch.Tensor]
|
||||||
prefix_lengths_tensor: Optional[torch.Tensor]
|
prefix_lengths_tensor: Optional[torch.Tensor]
|
||||||
prompt_lengths_tensor: Optional[torch.Tensor]
|
prompt_lengths_tensor: torch.Tensor
|
||||||
|
|
||||||
prefix_offsets: List[Optional[int]]
|
prefix_offsets: List[Optional[int]]
|
||||||
read_offsets: List[Optional[int]]
|
read_offsets: List[Optional[int]]
|
||||||
@ -350,12 +352,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
all_input_ids_tensor, dtype=torch.int64, device=device
|
all_input_ids_tensor, dtype=torch.int64, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(pb.requests) > 1:
|
|
||||||
input_ids = np.concatenate(all_postfix_ids, dtype=np.int64)
|
|
||||||
else:
|
|
||||||
input_ids = all_postfix_ids[0]
|
|
||||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
|
||||||
|
|
||||||
top_n_tokens_tensor = torch.tensor(
|
top_n_tokens_tensor = torch.tensor(
|
||||||
top_n_tokens, device=device, dtype=torch.int64
|
top_n_tokens, device=device, dtype=torch.int64
|
||||||
)
|
)
|
||||||
@ -366,12 +362,15 @@ class FlashCausalLMBatch(Batch):
|
|||||||
for i, request_blocks in enumerate(block_tables):
|
for i, request_blocks in enumerate(block_tables):
|
||||||
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
|
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
|
||||||
block_tables_tensor = block_tables_tensor.to(device)
|
block_tables_tensor = block_tables_tensor.to(device)
|
||||||
|
prompt_lengths_tensor = torch.tensor(
|
||||||
|
prompt_lengths, dtype=torch.int32, device=device
|
||||||
|
)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=pb.id,
|
batch_id=pb.id,
|
||||||
requests=pb.requests,
|
requests=pb.requests,
|
||||||
requests_idx_mapping=requests_idx_mapping,
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
input_ids=input_ids,
|
input_ids=all_postfix_ids,
|
||||||
|
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
block_tables_tensor=block_tables_tensor,
|
block_tables_tensor=block_tables_tensor,
|
||||||
@ -395,6 +394,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
num_blocks=num_blocks,
|
num_blocks=num_blocks,
|
||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
speculative_ids=None,
|
speculative_ids=None,
|
||||||
|
prompt_lengths_tensor=prompt_lengths_tensor,
|
||||||
|
|
||||||
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
|
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
@ -408,7 +408,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefill_cu_outlens=None,
|
prefill_cu_outlens=None,
|
||||||
prefix_lengths_tensor=None,
|
prefix_lengths_tensor=None,
|
||||||
postfix_lengths_tensor=None,
|
postfix_lengths_tensor=None,
|
||||||
prompt_lengths_tensor=None,
|
|
||||||
adapter_meta=None,
|
adapter_meta=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -455,6 +454,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
block_tables = []
|
block_tables = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
prefix_ids = []
|
prefix_ids = []
|
||||||
|
input_ids = []
|
||||||
|
|
||||||
prompt_lengths = []
|
prompt_lengths = []
|
||||||
postfix_lengths = []
|
postfix_lengths = []
|
||||||
@ -473,7 +473,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
max_blocks = 0
|
max_blocks = 0
|
||||||
# Cumulative length
|
# Cumulative length
|
||||||
cumulative_max_length = 0
|
cumulative_max_length = 0
|
||||||
prefilling=False
|
|
||||||
|
|
||||||
for i, request_id in enumerate(request_ids):
|
for i, request_id in enumerate(request_ids):
|
||||||
idx = self.requests_idx_mapping[request_id]
|
idx = self.requests_idx_mapping[request_id]
|
||||||
@ -484,9 +483,13 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
# Prefilling
|
# Prefilling
|
||||||
request_prefilling = self.prefilling_mask[idx]
|
request_prefilling = self.prefilling_mask[idx]
|
||||||
prefilling = prefilling or request_prefilling
|
|
||||||
prefilling_mask.append(request_prefilling)
|
prefilling_mask.append(request_prefilling)
|
||||||
|
|
||||||
|
# Input ids if the request was part of a prefilling batch
|
||||||
|
# If the batch was decoding we can index into the tensor directly later
|
||||||
|
if self.prefilling:
|
||||||
|
input_ids.append(self.input_ids[idx])
|
||||||
|
|
||||||
# Get length
|
# Get length
|
||||||
request_postfix_length = self.postfix_lengths[idx]
|
request_postfix_length = self.postfix_lengths[idx]
|
||||||
request_prefix_length = self.prefix_lengths[idx]
|
request_prefix_length = self.prefix_lengths[idx]
|
||||||
@ -538,32 +541,48 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
max_blocks = max(max_blocks, len(request_block_table))
|
max_blocks = max(max_blocks, len(request_block_table))
|
||||||
|
|
||||||
# Index into tensors
|
|
||||||
input_ids = self.input_ids[indices]
|
|
||||||
position_ids = self.position_ids[indices]
|
|
||||||
adapter_indices = self.adapter_meta.adapter_indices[indices]
|
|
||||||
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
||||||
block_tables_tensor = self.block_tables_tensor[indices]
|
block_tables_tensor = self.block_tables_tensor[indices]
|
||||||
postfix_lengths_tensor = self.postfix_lengths_tensor[indices]
|
|
||||||
prompt_lengths_tensor = self.prompt_lengths_tensor[indices]
|
|
||||||
slots = self.slots[slot_filtering_indices]
|
|
||||||
prefix_lengths_tensor = self.prefix_lengths_tensor[indices]
|
|
||||||
next_token_chooser = self.next_token_chooser.filter(indices)
|
next_token_chooser = self.next_token_chooser.filter(indices)
|
||||||
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
|
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
|
||||||
speculative_ids = (
|
speculative_ids = (
|
||||||
self.speculative_ids[indices] if self.speculative_ids is not None else None
|
self.speculative_ids[indices] if self.speculative_ids is not None else None
|
||||||
)
|
)
|
||||||
|
prompt_lengths_tensor = self.prompt_lengths_tensor[indices]
|
||||||
|
|
||||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
if self.prefilling:
|
||||||
|
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
|
||||||
|
position_ids=None
|
||||||
|
start_slots=None
|
||||||
|
slot_indices=None
|
||||||
|
slots=None
|
||||||
|
prefix_lengths_tensor=None
|
||||||
|
postfix_lengths_tensor=None
|
||||||
|
adapter_meta=None
|
||||||
|
else:
|
||||||
|
# Index into tensors
|
||||||
|
input_ids = self.input_ids[indices]
|
||||||
|
position_ids = self.position_ids[indices]
|
||||||
|
adapter_indices = self.adapter_meta.adapter_indices[indices]
|
||||||
|
postfix_lengths_tensor = self.postfix_lengths_tensor[indices]
|
||||||
|
slots = self.slots[slot_filtering_indices]
|
||||||
|
prefix_lengths_tensor = self.prefix_lengths_tensor[indices]
|
||||||
|
|
||||||
# Move to GPU now that we have the whole tensor
|
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||||
slot_indices = slot_indices.to(device)
|
|
||||||
|
|
||||||
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
|
# Move to GPU now that we have the whole tensor
|
||||||
adapter_segments = torch.tensor(
|
slot_indices = slot_indices.to(device)
|
||||||
adapter_segments, dtype=torch.int32, device=device
|
|
||||||
)
|
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
|
||||||
# assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum()
|
adapter_segments = torch.tensor(
|
||||||
|
adapter_segments, dtype=torch.int32, device=device
|
||||||
|
)
|
||||||
|
adapter_meta = AdapterBatchMetadata(
|
||||||
|
adapter_indices=adapter_indices,
|
||||||
|
adapter_set=adapter_set,
|
||||||
|
adapter_segments=adapter_segments,
|
||||||
|
segment_indices=adapter_segment_indices,
|
||||||
|
)
|
||||||
|
|
||||||
return type(self)(
|
return type(self)(
|
||||||
batch_id=self.batch_id,
|
batch_id=self.batch_id,
|
||||||
@ -580,7 +599,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
slots=slots,
|
slots=slots,
|
||||||
max_postfix_length=max_postfix_length,
|
max_postfix_length=max_postfix_length,
|
||||||
max_current_length=max_current_length,
|
max_current_length=max_current_length,
|
||||||
prefilling=prefilling,
|
prefilling=self.prefilling,
|
||||||
prefilling_mask=prefilling_mask,
|
prefilling_mask=prefilling_mask,
|
||||||
prefill_head_indices=None,
|
prefill_head_indices=None,
|
||||||
prefill_next_token_indices=None,
|
prefill_next_token_indices=None,
|
||||||
@ -604,12 +623,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
num_blocks=num_blocks,
|
num_blocks=num_blocks,
|
||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
speculative_ids=speculative_ids,
|
speculative_ids=speculative_ids,
|
||||||
adapter_meta=AdapterBatchMetadata(
|
adapter_meta=adapter_meta,
|
||||||
adapter_indices=adapter_indices,
|
|
||||||
adapter_set=adapter_set,
|
|
||||||
adapter_segments=adapter_segments,
|
|
||||||
segment_indices=adapter_segment_indices,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -652,38 +666,51 @@ class FlashCausalLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
prefilling = prefilling or b.prefilling
|
prefilling = prefilling or b.prefilling
|
||||||
|
|
||||||
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
if prefilling:
|
||||||
position_ids = batches[0].position_ids.new_empty(total_batch_size)
|
input_ids = []
|
||||||
slots = batches[0].slots.new_empty(total_slots)
|
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
|
||||||
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
|
position_ids=None
|
||||||
|
start_slots=None
|
||||||
|
slots=None
|
||||||
|
slot_indices=None
|
||||||
|
prefix_lengths_tensor=None
|
||||||
|
postfix_lengths_tensor=None
|
||||||
|
adapter_meta=None
|
||||||
|
adapter_segment_builder=None
|
||||||
|
else:
|
||||||
|
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
||||||
|
position_ids = batches[0].position_ids.new_empty(total_batch_size)
|
||||||
|
start_slots = []
|
||||||
|
slots = batches[0].slots.new_empty(total_slots)
|
||||||
|
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
|
||||||
|
postfix_lengths_tensor = batches[0].postfix_lengths_tensor.new_empty(
|
||||||
|
total_batch_size
|
||||||
|
)
|
||||||
|
prefix_lengths_tensor = batches[0].prefix_lengths_tensor.new_empty(
|
||||||
|
total_batch_size
|
||||||
|
)
|
||||||
|
total_indices_size = sum(
|
||||||
|
b.adapter_meta.adapter_indices.shape[0] for b in batches
|
||||||
|
)
|
||||||
|
adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(
|
||||||
|
total_indices_size
|
||||||
|
)
|
||||||
|
adapter_segment_builder = SegmentConcatBuilder()
|
||||||
|
adapter_set = set()
|
||||||
|
|
||||||
prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty(
|
prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty(
|
||||||
total_batch_size
|
total_batch_size
|
||||||
)
|
)
|
||||||
postfix_lengths_tensor = batches[0].postfix_lengths_tensor.new_empty(
|
|
||||||
total_batch_size
|
|
||||||
)
|
|
||||||
block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
|
block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
|
||||||
(total_batch_size, max_blocks)
|
(total_batch_size, max_blocks)
|
||||||
)
|
)
|
||||||
prefix_lengths_tensor = batches[0].prefix_lengths_tensor.new_empty(
|
|
||||||
total_batch_size
|
|
||||||
)
|
|
||||||
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
|
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
|
||||||
(total_batch_size, max_length)
|
(total_batch_size, max_length)
|
||||||
)
|
)
|
||||||
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
|
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
|
||||||
total_batch_size,
|
total_batch_size,
|
||||||
)
|
)
|
||||||
total_indices_size = sum(
|
|
||||||
b.adapter_meta.adapter_indices.shape[0] for b in batches
|
|
||||||
)
|
|
||||||
adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(
|
|
||||||
total_indices_size
|
|
||||||
)
|
|
||||||
adapter_set = set()
|
|
||||||
adapter_segment_builder = SegmentConcatBuilder()
|
|
||||||
|
|
||||||
start_slots = []
|
|
||||||
block_tables = []
|
block_tables = []
|
||||||
prefix_lengths = []
|
prefix_lengths = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
@ -723,29 +750,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
slots_end_index = cumulative_slots + len(batch.slots)
|
slots_end_index = cumulative_slots + len(batch.slots)
|
||||||
|
|
||||||
# Copy tensors (GPU)
|
# Copy tensors (GPU)
|
||||||
input_ids[start_index:end_index] = batch.input_ids
|
|
||||||
position_ids[start_index:end_index] = batch.position_ids
|
|
||||||
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
|
|
||||||
prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor
|
|
||||||
postfix_lengths_tensor[start_index:end_index] = batch.postfix_lengths_tensor
|
|
||||||
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
||||||
slots[slots_start_index:slots_end_index] = batch.slots
|
|
||||||
|
|
||||||
# Copy over adapter indices
|
|
||||||
adapter_start_index = cumulative_adapter_indices_size
|
|
||||||
adapter_end_index = (
|
|
||||||
cumulative_adapter_indices_size
|
|
||||||
+ batch.adapter_meta.adapter_indices.shape[0]
|
|
||||||
)
|
|
||||||
adapter_indices[adapter_start_index:adapter_end_index] = (
|
|
||||||
batch.adapter_meta.adapter_indices
|
|
||||||
)
|
|
||||||
cumulative_adapter_indices_size = adapter_end_index
|
|
||||||
adapter_set.update(batch.adapter_meta.adapter_set)
|
|
||||||
adapter_segment_builder.concat(
|
|
||||||
batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices
|
|
||||||
)
|
|
||||||
|
|
||||||
all_input_ids_tensor[
|
all_input_ids_tensor[
|
||||||
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
|
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
|
||||||
] = batch.all_input_ids_tensor[:, :max_length]
|
] = batch.all_input_ids_tensor[:, :max_length]
|
||||||
@ -753,12 +758,38 @@ class FlashCausalLMBatch(Batch):
|
|||||||
block_tables_tensor[
|
block_tables_tensor[
|
||||||
start_index:end_index, : batch.block_tables_tensor.shape[1]
|
start_index:end_index, : batch.block_tables_tensor.shape[1]
|
||||||
] = batch.block_tables_tensor[:, :max_blocks]
|
] = batch.block_tables_tensor[:, :max_blocks]
|
||||||
|
prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor
|
||||||
|
|
||||||
prefix_lengths_tensor[start_index:end_index] = batch.prefix_lengths_tensor
|
if not prefilling:
|
||||||
|
input_ids[start_index:end_index] = batch.input_ids
|
||||||
|
position_ids[start_index:end_index] = batch.position_ids
|
||||||
|
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
|
||||||
|
postfix_lengths_tensor[start_index:end_index] = batch.postfix_lengths_tensor
|
||||||
|
slots[slots_start_index:slots_end_index] = batch.slots
|
||||||
|
|
||||||
start_slots.append(batch.start_slots + cumulative_slots)
|
# Copy over adapter indices
|
||||||
|
adapter_start_index = cumulative_adapter_indices_size
|
||||||
|
adapter_end_index = (
|
||||||
|
cumulative_adapter_indices_size
|
||||||
|
+ batch.adapter_meta.adapter_indices.shape[0]
|
||||||
|
)
|
||||||
|
adapter_indices[adapter_start_index:adapter_end_index] = (
|
||||||
|
batch.adapter_meta.adapter_indices
|
||||||
|
)
|
||||||
|
cumulative_adapter_indices_size = adapter_end_index
|
||||||
|
adapter_set.update(batch.adapter_meta.adapter_set)
|
||||||
|
adapter_segment_builder.concat(
|
||||||
|
batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices
|
||||||
|
)
|
||||||
|
prefix_lengths_tensor[start_index:end_index] = batch.prefix_lengths_tensor
|
||||||
|
|
||||||
prefilling_mask = prefilling_mask.extend(batch.prefilling_mask)
|
start_slots.append(batch.start_slots + cumulative_slots)
|
||||||
|
else:
|
||||||
|
if isinstance(batch.input_ids, torch.Tensor):
|
||||||
|
batch.input_ids = batch.input_ids.view(-1, 1).tolist()
|
||||||
|
input_ids.extend(batch.input_ids)
|
||||||
|
|
||||||
|
prefilling_mask.extend(batch.prefilling_mask)
|
||||||
block_tables.extend(batch.block_tables)
|
block_tables.extend(batch.block_tables)
|
||||||
prefix_lengths.extend(batch.prefix_lengths)
|
prefix_lengths.extend(batch.prefix_lengths)
|
||||||
all_input_ids.extend(batch.all_input_ids)
|
all_input_ids.extend(batch.all_input_ids)
|
||||||
@ -781,7 +812,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
cumulative_batch_size += len(batch)
|
cumulative_batch_size += len(batch)
|
||||||
cumulative_slots += len(batch.slots)
|
cumulative_slots += len(batch.slots)
|
||||||
|
|
||||||
start_slots = torch.concat(start_slots)
|
if start_slots is not None:
|
||||||
|
start_slots = torch.concat(start_slots)
|
||||||
|
|
||||||
# assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum()
|
# assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum()
|
||||||
|
|
||||||
@ -799,7 +831,14 @@ class FlashCausalLMBatch(Batch):
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
|
if adapter_segment_builder is not None:
|
||||||
|
adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
|
||||||
|
adapter_meta = AdapterBatchMetadata(
|
||||||
|
adapter_indices=adapter_indices,
|
||||||
|
adapter_set=adapter_set,
|
||||||
|
adapter_segments=adapter_segments,
|
||||||
|
segment_indices=adapter_segment_indices,
|
||||||
|
)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=batches[0].batch_id,
|
batch_id=batches[0].batch_id,
|
||||||
@ -840,12 +879,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
num_blocks=num_blocks,
|
num_blocks=num_blocks,
|
||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
speculative_ids=speculative_ids,
|
speculative_ids=speculative_ids,
|
||||||
adapter_meta=AdapterBatchMetadata(
|
adapter_meta=adapter_meta,
|
||||||
adapter_indices=adapter_indices,
|
|
||||||
adapter_set=adapter_set,
|
|
||||||
adapter_segments=adapter_segments,
|
|
||||||
segment_indices=adapter_segment_indices,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_for_prefill(self):
|
def prepare_for_prefill(self):
|
||||||
@ -973,9 +1007,16 @@ class FlashCausalLMBatch(Batch):
|
|||||||
cumulative_length += next_chunk_length
|
cumulative_length += next_chunk_length
|
||||||
cumulative_slot_tokens += len(request_slots)
|
cumulative_slot_tokens += len(request_slots)
|
||||||
|
|
||||||
device = self.input_ids.device
|
device = self.block_tables_tensor.device
|
||||||
self.start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
self.start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||||
|
|
||||||
|
if isinstance(self.input_ids, list):
|
||||||
|
if len(self) > 1:
|
||||||
|
input_ids = np.concatenate(self.input_ids, dtype=np.int64)
|
||||||
|
else:
|
||||||
|
input_ids = self.input_ids[0]
|
||||||
|
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||||
|
|
||||||
if len(self) > 1:
|
if len(self) > 1:
|
||||||
position_ids = torch.cat(position_ids)
|
position_ids = torch.cat(position_ids)
|
||||||
slot_indices = torch.cat(slot_indices)
|
slot_indices = torch.cat(slot_indices)
|
||||||
@ -1865,7 +1906,8 @@ class FlashCausalLM(Model):
|
|||||||
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
|
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
|
||||||
batch.speculative_ids = speculative_ids
|
batch.speculative_ids = speculative_ids
|
||||||
batch.position_ids = next_position_ids + accepted_ids
|
batch.position_ids = next_position_ids + accepted_ids
|
||||||
batch.postfix_lengths_tensor += accepted_ids
|
batch.prefix_lengths_tensor += batch.postfix_lengths_tensor
|
||||||
|
batch.postfix_lengths_tensor = accepted_ids
|
||||||
batch.slot_indices += accepted_ids
|
batch.slot_indices += accepted_ids
|
||||||
batch.adapter_meta.adapter_indices = next_adapter_indices
|
batch.adapter_meta.adapter_indices = next_adapter_indices
|
||||||
|
|
||||||
@ -1929,11 +1971,9 @@ class FlashCausalLM(Model):
|
|||||||
# This request is done prefilling, the new id is the one selected the sampling method
|
# This request is done prefilling, the new id is the one selected the sampling method
|
||||||
postfix_ids = [next_token_id]
|
postfix_ids = [next_token_id]
|
||||||
|
|
||||||
all_postfix_ids.extend(postfix_ids)
|
all_postfix_ids.append(postfix_ids)
|
||||||
|
|
||||||
batch.input_ids = batch.input_ids.new_tensor(
|
batch.input_ids = all_postfix_ids
|
||||||
all_postfix_ids, dtype=torch.int64
|
|
||||||
)
|
|
||||||
|
|
||||||
start_decode = time.time_ns()
|
start_decode = time.time_ns()
|
||||||
|
|
||||||
@ -2014,7 +2054,7 @@ class FlashCausalLM(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
prefill_tokens = Tokens(
|
prefill_tokens = Tokens(
|
||||||
prefix_ids + prefill_token_ids,
|
prefill_token_ids,
|
||||||
request_prefill_logprobs,
|
request_prefill_logprobs,
|
||||||
prefill_texts,
|
prefill_texts,
|
||||||
is_special=[],
|
is_special=[],
|
||||||
|
Loading…
Reference in New Issue
Block a user