This commit is contained in:
OlivierDehaene 2024-09-25 13:57:18 +02:00
parent 7169cbae6d
commit de043b53c4
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C
4 changed files with 27 additions and 104 deletions

View File

@ -36,9 +36,9 @@ impl BackendV3 {
speculate: u32, speculate: u32,
) -> Self { ) -> Self {
let prefix_caching = let prefix_caching =
std::env::var("USE_PREFIX_CACHING").expect("Expect prefix caching env var"); std::env::var("USE_PREFIX_CACHING").unwrap_or("1".to_string());
let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1"); let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1");
let attention: String = std::env::var("ATTENTION").expect("attention env var"); let attention: String = std::env::var("ATTENTION").unwrap_or("flashinfer".to_string());
let attention: Attention = attention let attention: Attention = attention
.parse() .parse()

View File

@ -2,10 +2,6 @@ import pytest
import os import os
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
os.environ["USE_PREFIX_CACHING"] = "1"
os.environ["ATTENTION"] = "flashinfer"
@pytest.fixture @pytest.fixture
def default_pb_parameters(): def default_pb_parameters():
return generate_pb2.NextTokenChooserParameters( return generate_pb2.NextTokenChooserParameters(

View File

@ -149,26 +149,11 @@ class FlashCausalLMBatch(Batch):
max_seqlen: int max_seqlen: int
# Prefill metadata tensors # Prefill metadata tensors to efficiently compute logprobs
prefill_head_indices: Optional[torch.Tensor] prefill_head_indices: Optional[torch.Tensor]
prefill_next_token_indices: Optional[torch.Tensor] prefill_next_token_indices: Optional[torch.tensor]
prefill_cu_outlens: Optional[List[int]] prefill_cu_outlens: Optional[List[int]]
# Whether at least one request is prefilling/chunking
# == any(prefilling_mask)
prefilling: bool
# For each request, whether they are still prefilling/chunking
prefilling_mask: List[bool]
# For each request, whether the model output should be used or discarded
# If we are chunking, we don't care about the output as it might be different
# from the token in the prompt
use_output_token: List[bool]
# If the request is decoding, `next_chunk_length = 1`
# `None if not batch.prefilling`
next_chunk_lengths: Optional[List[int]]
next_chunk_lengths_tensor: Optional[torch.Tensor]
# Prefixes # Prefixes
prefix_ids: List[List[int]] prefix_ids: List[List[int]]
@ -247,14 +232,11 @@ class FlashCausalLMBatch(Batch):
prefix_ids = [] prefix_ids = []
requests_idx_mapping = {} requests_idx_mapping = {}
chunking = False
all_prefill_logprobs = True all_prefill_logprobs = True
no_prefill_logprobs = True no_prefill_logprobs = True
prefill_head_indices = [] prefill_head_indices = []
prefill_next_token_indices = [] prefill_next_token_indices = []
prefill_cu_outlens = [0] prefill_cu_outlens = [0]
next_chunk_lengths = []
use_output_token = []
next_token_chooser_parameters = [] next_token_chooser_parameters = []
stopping_criterias = [] stopping_criterias = []
@ -294,7 +276,6 @@ class FlashCausalLMBatch(Batch):
assert prefix_len > 0 assert prefix_len > 0
prefix_len -= 1 prefix_len -= 1
# Commented as it's costly. # Commented as it's costly.
# log_master(logger.debug, "Tokenized input ids {tokenized_input}") # log_master(logger.debug, "Tokenized input ids {tokenized_input}")
prefix_ids.append(tokenized_input[:prefix_len]) prefix_ids.append(tokenized_input[:prefix_len])
@ -303,18 +284,9 @@ class FlashCausalLMBatch(Batch):
input_length = len(tokenized_input) input_length = len(tokenized_input)
input_lengths.append(input_length) input_lengths.append(input_length)
if True:
# This request only requires one prefill and no chunking
use_output_token.append(True)
next_chunk_lengths.append(1)
else:
chunking = True
raise NotImplementedError
prefix_offsets.append(input_length - 5) prefix_offsets.append(input_length - 5)
read_offsets.append(input_length) read_offsets.append(input_length)
# FIXME: use all input tokens not just postfix ones
all_input_ids.append(tokenized_input) all_input_ids.append(tokenized_input)
# Position ids # Position ids
@ -385,7 +357,6 @@ class FlashCausalLMBatch(Batch):
# Create tensor to slice into the kv tensor in prefill # Create tensor to slice into the kv tensor in prefill
if sliding_window is not None: if sliding_window is not None:
raise NotImplementedError
request_prefill_cache_indices = torch.arange( request_prefill_cache_indices = torch.arange(
cumulative_length + max(0, input_length - sliding_window), cumulative_length + max(0, input_length - sliding_window),
cumulative_length + input_length, cumulative_length + input_length,
@ -397,7 +368,6 @@ class FlashCausalLMBatch(Batch):
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
if r.prefill_logprobs: if r.prefill_logprobs:
raise NotImplementedError
prefill_head_indices.append(request_position_ids + cumulative_length) prefill_head_indices.append(request_position_ids + cumulative_length)
prefill_next_token_indices.append( prefill_next_token_indices.append(
prefill_out_cumulative_length + input_length - 1 prefill_out_cumulative_length + input_length - 1
@ -475,12 +445,6 @@ class FlashCausalLMBatch(Batch):
adapter_segments, dtype=torch.int32, device=device adapter_segments, dtype=torch.int32, device=device
) )
if chunking:
next_chunk_lengths_tensor = torch.tensor(next_chunk_lengths, dtype=torch.int64, device=device)
else:
next_chunk_lengths = None
next_chunk_lengths_tensor = None
if all_prefill_logprobs: if all_prefill_logprobs:
prefill_head_indices = None prefill_head_indices = None
prefill_next_token_indices = cu_seqlen_prefill[1:] - 1 prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
@ -527,11 +491,6 @@ class FlashCausalLMBatch(Batch):
prefill_head_indices=prefill_head_indices, prefill_head_indices=prefill_head_indices,
prefill_next_token_indices=prefill_next_token_indices, prefill_next_token_indices=prefill_next_token_indices,
prefill_cu_outlens=prefill_cu_outlens, prefill_cu_outlens=prefill_cu_outlens,
prefilling=True,
prefilling_mask=[True] * pb.requests.len(),
use_output_token=use_output_token,
next_chunk_lengths=next_chunk_lengths,
next_chunk_lengths_tensor=next_chunk_lengths_tensor,
input_lengths=input_lengths, input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor, input_lengths_tensor=input_lengths_tensor,
prefix_offsets=prefix_offsets, prefix_offsets=prefix_offsets,
@ -1467,7 +1426,7 @@ class FlashCausalLM(Model):
max_s = batch.max_seqlen max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
if not batch.prefilling and self.max_past() is not None: if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache # In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode. # in a circular buffer mode.
# This makes sure the max_s for the decode pass is correct. # This makes sure the max_s for the decode pass is correct.
@ -1481,7 +1440,7 @@ class FlashCausalLM(Model):
else: else:
cuda_graph = None cuda_graph = None
if batch.prefilling or cuda_graph is None: if cu_seqlen_prefill is not None or cuda_graph is None:
if ATTENTION == "flashinfer": if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged( block_tables = block_tables_to_ragged(
block_tables=block_tables, block_tables=block_tables,
@ -1516,7 +1475,6 @@ class FlashCausalLM(Model):
adapter_data=adapter_data, adapter_data=adapter_data,
) )
if batch.prefill_cache_indices is not None: if batch.prefill_cache_indices is not None:
raise NotImplementedError
batch.prefill_cache_indices = None batch.prefill_cache_indices = None
return logits, speculative_logits return logits, speculative_logits
@ -1570,6 +1528,7 @@ class FlashCausalLM(Model):
self, batch: FlashCausalLMBatch self, batch: FlashCausalLMBatch
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]: ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
start = time.time_ns() start = time.time_ns()
prefill = batch.cu_seqlen_prefill is not None
prefill_logprobs = batch.prefill_next_token_indices is not None prefill_logprobs = batch.prefill_next_token_indices is not None
# Update adapter indices for speculative tokens (if present) # Update adapter indices for speculative tokens (if present)
@ -1595,13 +1554,13 @@ class FlashCausalLM(Model):
adapter_data = AdapterBatchData.from_meta( adapter_data = AdapterBatchData.from_meta(
adapter_meta, adapter_meta,
self.layer_to_adapter_weights, self.layer_to_adapter_weights,
batch.prefilling, prefill,
batch.prefill_head_indices, batch.prefill_head_indices,
) )
out, speculative_logits = self.forward(batch, adapter_data) out, speculative_logits = self.forward(batch, adapter_data)
if batch.prefilling: if prefill:
next_token_logits = ( next_token_logits = (
out[batch.prefill_next_token_indices] if prefill_logprobs else out out[batch.prefill_next_token_indices] if prefill_logprobs else out
) )
@ -1638,31 +1597,22 @@ class FlashCausalLM(Model):
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
) )
if batch.prefilling: if prefill:
if len(batch) > 1 and prefill_logprobs: if len(batch) > 1 and prefill_logprobs:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
# When batch == 1, we will just use the batch.input_ids values directly # When batch == 1, we will just use the batch.input_ids values directly
prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
if batch.next_chunk_lengths is None:
# We are done prefilling after this forward
next_position_ids = batch.position_ids.new_empty(len(batch)) next_position_ids = batch.position_ids.new_empty(len(batch))
# [BATCH_SIZE]
# Last slot for each request, will be incremented later
batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1] batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
else: # We do not need cu_seqlen_prefill anymore
# We still have prefill chunks to go through batch.cu_seqlen_prefill = None
next_forward_size = sum(batch.next_chunk_lengths)
next_position_ids = batch.position_ids.new_empty(next_forward_size)
batch.slot_indices = batch.slot_indices.new_empty(next_forward_size)
batch.cu_seqlen_prefill[1:] = torch.cumsum(batch.next_chunk_lengths_tensor, dim=0)
else: else:
prefill_logprobs = None prefill_logprobs = None
next_position_ids = batch.position_ids next_position_ids = batch.position_ids
# Cumulative length # Cumulative length
cumulative_length = 0 cumulative_length = 0
cumulative_chunk_lengths = 0
# Results # Results
generations: List[Generation] = [] generations: List[Generation] = []
@ -1675,32 +1625,21 @@ class FlashCausalLM(Model):
# one, we need to first do a GPU <-> CPU sync # one, we need to first do a GPU <-> CPU sync
# It is faster if we delay this sync for the maximum amount of time # It is faster if we delay this sync for the maximum amount of time
index = 0
# For each member of the batch # For each member of the batch
index = 0
for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator): for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator):
# Indexing metadata # Indexing metadata
start_index = cumulative_length start_index = cumulative_length
end_index = cumulative_length + input_length end_index = cumulative_length + input_length
if batch.prefilling: if prefill:
if batch.next_chunk_lengths is not None:
next_chunk_length = batch.next_chunk_lengths[i]
else:
next_chunk_length = 1
# Indexing metadata # Indexing metadata
out_start_index = batch.prefill_cu_outlens[i] out_start_index = batch.prefill_cu_outlens[i]
out_end_index = batch.prefill_cu_outlens[i + 1] out_end_index = batch.prefill_cu_outlens[i + 1]
out_length = out_end_index - out_start_index out_length = out_end_index - out_start_index
position_start_index =
# Initialize position_ids # Initialize position_ids
# In decode, we do not need this as we can just increment position ids # In decode, we do not need this as we can just increment position ids
next_position_ids
next_position_ids[i] = batch.position_ids[end_index - 1] next_position_ids[i] = batch.position_ids[end_index - 1]
# Initialize adapter indices # Initialize adapter indices
@ -1712,7 +1651,6 @@ class FlashCausalLM(Model):
# Used to gather prefill logprobs # Used to gather prefill logprobs
# Copy batch.input_ids to prefill_token_indices # Copy batch.input_ids to prefill_token_indices
if prefill_logprobs: if prefill_logprobs:
raise NotImplementedError
if len(batch) > 1: if len(batch) > 1:
prefill_tokens_indices[out_start_index : out_end_index - 1] = ( prefill_tokens_indices[out_start_index : out_end_index - 1] = (
batch.input_ids[start_index + 1 : start_index + out_length] batch.input_ids[start_index + 1 : start_index + out_length]
@ -1730,15 +1668,6 @@ class FlashCausalLM(Model):
cumulative_length += input_length cumulative_length += input_length
# Update values # Update values
if batch.next_prefilling_chunk_lengths is None:
# We are done prefilling
batch.prefilling = False
batch.next_prefilling_chunk_lengths = None
batch.next_prefilling_chunk_lengths_tensor = None
# We do not need cu_seqlen_prefill anymore
batch.cu_seqlen_prefill = None
if not batch.prefilling:
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
batch.speculative_ids = speculative_ids batch.speculative_ids = speculative_ids
batch.position_ids = next_position_ids + accepted_ids batch.position_ids = next_position_ids + accepted_ids
@ -1746,7 +1675,7 @@ class FlashCausalLM(Model):
batch.slot_indices += accepted_ids batch.slot_indices += accepted_ids
batch.adapter_meta.adapter_indices = next_adapter_indices batch.adapter_meta.adapter_indices = next_adapter_indices
if batch.prefilling: if prefill:
# adjust segment lengths to account for all request lengths being 1 during decoding # adjust segment lengths to account for all request lengths being 1 during decoding
adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices) adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
batch.adapter_meta.adapter_segments = torch.tensor( batch.adapter_meta.adapter_segments = torch.tensor(
@ -1755,8 +1684,7 @@ class FlashCausalLM(Model):
device=batch.adapter_meta.adapter_segments.device, device=batch.adapter_meta.adapter_segments.device,
) )
if batch.prefilling and prefill_logprobs: if prefill and prefill_logprobs:
raise NotImplementedError
# Get prefill logprobs # Get prefill logprobs
prefill_logprobs_tensor = torch.log_softmax(out, -1) prefill_logprobs_tensor = torch.log_softmax(out, -1)
prefill_logprobs = torch.gather( prefill_logprobs = torch.gather(
@ -1867,8 +1795,7 @@ class FlashCausalLM(Model):
generated_text = None generated_text = None
# Prefill # Prefill
if batch.prefilling and request.prefill_logprobs: if prefill and request.prefill_logprobs:
raise NotImplementedError
out_start_index = batch.prefill_cu_outlens[i] out_start_index = batch.prefill_cu_outlens[i]
out_end_index = batch.prefill_cu_outlens[i + 1] out_end_index = batch.prefill_cu_outlens[i + 1]

View File

@ -5,9 +5,9 @@ from typing import Dict, Optional
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"} PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", "1").lower() in {"1", "true"}
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
ATTENTION = os.getenv("ATTENTION") ATTENTION = os.getenv("ATTENTION", "flashinfer")
_expected = {"paged", "flashdecoding", "flashinfer"} _expected = {"paged", "flashdecoding", "flashinfer"}
assert ( assert (
ATTENTION in _expected ATTENTION in _expected