mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
rollback
This commit is contained in:
parent
7169cbae6d
commit
de043b53c4
@ -36,9 +36,9 @@ impl BackendV3 {
|
|||||||
speculate: u32,
|
speculate: u32,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let prefix_caching =
|
let prefix_caching =
|
||||||
std::env::var("USE_PREFIX_CACHING").expect("Expect prefix caching env var");
|
std::env::var("USE_PREFIX_CACHING").unwrap_or("1".to_string());
|
||||||
let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1");
|
let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1");
|
||||||
let attention: String = std::env::var("ATTENTION").expect("attention env var");
|
let attention: String = std::env::var("ATTENTION").unwrap_or("flashinfer".to_string());
|
||||||
|
|
||||||
let attention: Attention = attention
|
let attention: Attention = attention
|
||||||
.parse()
|
.parse()
|
||||||
|
@ -2,10 +2,6 @@ import pytest
|
|||||||
import os
|
import os
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
|
|
||||||
os.environ["USE_PREFIX_CACHING"] = "1"
|
|
||||||
os.environ["ATTENTION"] = "flashinfer"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def default_pb_parameters():
|
def default_pb_parameters():
|
||||||
return generate_pb2.NextTokenChooserParameters(
|
return generate_pb2.NextTokenChooserParameters(
|
||||||
|
@ -149,26 +149,11 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
max_seqlen: int
|
max_seqlen: int
|
||||||
|
|
||||||
# Prefill metadata tensors
|
# Prefill metadata tensors to efficiently compute logprobs
|
||||||
prefill_head_indices: Optional[torch.Tensor]
|
prefill_head_indices: Optional[torch.Tensor]
|
||||||
prefill_next_token_indices: Optional[torch.Tensor]
|
prefill_next_token_indices: Optional[torch.tensor]
|
||||||
prefill_cu_outlens: Optional[List[int]]
|
prefill_cu_outlens: Optional[List[int]]
|
||||||
|
|
||||||
# Whether at least one request is prefilling/chunking
|
|
||||||
# == any(prefilling_mask)
|
|
||||||
prefilling: bool
|
|
||||||
# For each request, whether they are still prefilling/chunking
|
|
||||||
prefilling_mask: List[bool]
|
|
||||||
# For each request, whether the model output should be used or discarded
|
|
||||||
# If we are chunking, we don't care about the output as it might be different
|
|
||||||
# from the token in the prompt
|
|
||||||
use_output_token: List[bool]
|
|
||||||
|
|
||||||
# If the request is decoding, `next_chunk_length = 1`
|
|
||||||
# `None if not batch.prefilling`
|
|
||||||
next_chunk_lengths: Optional[List[int]]
|
|
||||||
next_chunk_lengths_tensor: Optional[torch.Tensor]
|
|
||||||
|
|
||||||
# Prefixes
|
# Prefixes
|
||||||
prefix_ids: List[List[int]]
|
prefix_ids: List[List[int]]
|
||||||
|
|
||||||
@ -247,14 +232,11 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefix_ids = []
|
prefix_ids = []
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
|
|
||||||
chunking = False
|
|
||||||
all_prefill_logprobs = True
|
all_prefill_logprobs = True
|
||||||
no_prefill_logprobs = True
|
no_prefill_logprobs = True
|
||||||
prefill_head_indices = []
|
prefill_head_indices = []
|
||||||
prefill_next_token_indices = []
|
prefill_next_token_indices = []
|
||||||
prefill_cu_outlens = [0]
|
prefill_cu_outlens = [0]
|
||||||
next_chunk_lengths = []
|
|
||||||
use_output_token = []
|
|
||||||
|
|
||||||
next_token_chooser_parameters = []
|
next_token_chooser_parameters = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
@ -294,7 +276,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
assert prefix_len > 0
|
assert prefix_len > 0
|
||||||
prefix_len -= 1
|
prefix_len -= 1
|
||||||
|
|
||||||
|
|
||||||
# Commented as it's costly.
|
# Commented as it's costly.
|
||||||
# log_master(logger.debug, "Tokenized input ids {tokenized_input}")
|
# log_master(logger.debug, "Tokenized input ids {tokenized_input}")
|
||||||
prefix_ids.append(tokenized_input[:prefix_len])
|
prefix_ids.append(tokenized_input[:prefix_len])
|
||||||
@ -303,18 +284,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
input_length = len(tokenized_input)
|
input_length = len(tokenized_input)
|
||||||
input_lengths.append(input_length)
|
input_lengths.append(input_length)
|
||||||
|
|
||||||
if True:
|
|
||||||
# This request only requires one prefill and no chunking
|
|
||||||
use_output_token.append(True)
|
|
||||||
next_chunk_lengths.append(1)
|
|
||||||
else:
|
|
||||||
chunking = True
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
prefix_offsets.append(input_length - 5)
|
prefix_offsets.append(input_length - 5)
|
||||||
read_offsets.append(input_length)
|
read_offsets.append(input_length)
|
||||||
|
|
||||||
# FIXME: use all input tokens not just postfix ones
|
|
||||||
all_input_ids.append(tokenized_input)
|
all_input_ids.append(tokenized_input)
|
||||||
|
|
||||||
# Position ids
|
# Position ids
|
||||||
@ -385,7 +357,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
# Create tensor to slice into the kv tensor in prefill
|
# Create tensor to slice into the kv tensor in prefill
|
||||||
if sliding_window is not None:
|
if sliding_window is not None:
|
||||||
raise NotImplementedError
|
|
||||||
request_prefill_cache_indices = torch.arange(
|
request_prefill_cache_indices = torch.arange(
|
||||||
cumulative_length + max(0, input_length - sliding_window),
|
cumulative_length + max(0, input_length - sliding_window),
|
||||||
cumulative_length + input_length,
|
cumulative_length + input_length,
|
||||||
@ -397,7 +368,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
|
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
|
||||||
|
|
||||||
if r.prefill_logprobs:
|
if r.prefill_logprobs:
|
||||||
raise NotImplementedError
|
|
||||||
prefill_head_indices.append(request_position_ids + cumulative_length)
|
prefill_head_indices.append(request_position_ids + cumulative_length)
|
||||||
prefill_next_token_indices.append(
|
prefill_next_token_indices.append(
|
||||||
prefill_out_cumulative_length + input_length - 1
|
prefill_out_cumulative_length + input_length - 1
|
||||||
@ -475,12 +445,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
adapter_segments, dtype=torch.int32, device=device
|
adapter_segments, dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
if chunking:
|
|
||||||
next_chunk_lengths_tensor = torch.tensor(next_chunk_lengths, dtype=torch.int64, device=device)
|
|
||||||
else:
|
|
||||||
next_chunk_lengths = None
|
|
||||||
next_chunk_lengths_tensor = None
|
|
||||||
|
|
||||||
if all_prefill_logprobs:
|
if all_prefill_logprobs:
|
||||||
prefill_head_indices = None
|
prefill_head_indices = None
|
||||||
prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
|
prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
|
||||||
@ -527,11 +491,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefill_head_indices=prefill_head_indices,
|
prefill_head_indices=prefill_head_indices,
|
||||||
prefill_next_token_indices=prefill_next_token_indices,
|
prefill_next_token_indices=prefill_next_token_indices,
|
||||||
prefill_cu_outlens=prefill_cu_outlens,
|
prefill_cu_outlens=prefill_cu_outlens,
|
||||||
prefilling=True,
|
|
||||||
prefilling_mask=[True] * pb.requests.len(),
|
|
||||||
use_output_token=use_output_token,
|
|
||||||
next_chunk_lengths=next_chunk_lengths,
|
|
||||||
next_chunk_lengths_tensor=next_chunk_lengths_tensor,
|
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
input_lengths_tensor=input_lengths_tensor,
|
input_lengths_tensor=input_lengths_tensor,
|
||||||
prefix_offsets=prefix_offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
@ -1467,7 +1426,7 @@ class FlashCausalLM(Model):
|
|||||||
max_s = batch.max_seqlen
|
max_s = batch.max_seqlen
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
if not batch.prefilling and self.max_past() is not None:
|
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||||
# In decode, not prefill, we're actually overwriting the KV-cache
|
# In decode, not prefill, we're actually overwriting the KV-cache
|
||||||
# in a circular buffer mode.
|
# in a circular buffer mode.
|
||||||
# This makes sure the max_s for the decode pass is correct.
|
# This makes sure the max_s for the decode pass is correct.
|
||||||
@ -1481,7 +1440,7 @@ class FlashCausalLM(Model):
|
|||||||
else:
|
else:
|
||||||
cuda_graph = None
|
cuda_graph = None
|
||||||
|
|
||||||
if batch.prefilling or cuda_graph is None:
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||||
if ATTENTION == "flashinfer":
|
if ATTENTION == "flashinfer":
|
||||||
block_tables = block_tables_to_ragged(
|
block_tables = block_tables_to_ragged(
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
@ -1516,7 +1475,6 @@ class FlashCausalLM(Model):
|
|||||||
adapter_data=adapter_data,
|
adapter_data=adapter_data,
|
||||||
)
|
)
|
||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
raise NotImplementedError
|
|
||||||
batch.prefill_cache_indices = None
|
batch.prefill_cache_indices = None
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
|
||||||
@ -1570,6 +1528,7 @@ class FlashCausalLM(Model):
|
|||||||
self, batch: FlashCausalLMBatch
|
self, batch: FlashCausalLMBatch
|
||||||
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
|
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
|
||||||
start = time.time_ns()
|
start = time.time_ns()
|
||||||
|
prefill = batch.cu_seqlen_prefill is not None
|
||||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||||
|
|
||||||
# Update adapter indices for speculative tokens (if present)
|
# Update adapter indices for speculative tokens (if present)
|
||||||
@ -1595,13 +1554,13 @@ class FlashCausalLM(Model):
|
|||||||
adapter_data = AdapterBatchData.from_meta(
|
adapter_data = AdapterBatchData.from_meta(
|
||||||
adapter_meta,
|
adapter_meta,
|
||||||
self.layer_to_adapter_weights,
|
self.layer_to_adapter_weights,
|
||||||
batch.prefilling,
|
prefill,
|
||||||
batch.prefill_head_indices,
|
batch.prefill_head_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
out, speculative_logits = self.forward(batch, adapter_data)
|
out, speculative_logits = self.forward(batch, adapter_data)
|
||||||
|
|
||||||
if batch.prefilling:
|
if prefill:
|
||||||
next_token_logits = (
|
next_token_logits = (
|
||||||
out[batch.prefill_next_token_indices] if prefill_logprobs else out
|
out[batch.prefill_next_token_indices] if prefill_logprobs else out
|
||||||
)
|
)
|
||||||
@ -1638,31 +1597,22 @@ class FlashCausalLM(Model):
|
|||||||
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
|
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch.prefilling:
|
if prefill:
|
||||||
if len(batch) > 1 and prefill_logprobs:
|
if len(batch) > 1 and prefill_logprobs:
|
||||||
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
||||||
# When batch == 1, we will just use the batch.input_ids values directly
|
# When batch == 1, we will just use the batch.input_ids values directly
|
||||||
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
|
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
|
||||||
|
|
||||||
if batch.next_chunk_lengths is None:
|
next_position_ids = batch.position_ids.new_empty(len(batch))
|
||||||
# We are done prefilling after this forward
|
batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
|
||||||
next_position_ids = batch.position_ids.new_empty(len(batch))
|
# We do not need cu_seqlen_prefill anymore
|
||||||
# [BATCH_SIZE]
|
batch.cu_seqlen_prefill = None
|
||||||
# Last slot for each request, will be incremented later
|
|
||||||
batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
|
|
||||||
else:
|
|
||||||
# We still have prefill chunks to go through
|
|
||||||
next_forward_size = sum(batch.next_chunk_lengths)
|
|
||||||
next_position_ids = batch.position_ids.new_empty(next_forward_size)
|
|
||||||
batch.slot_indices = batch.slot_indices.new_empty(next_forward_size)
|
|
||||||
batch.cu_seqlen_prefill[1:] = torch.cumsum(batch.next_chunk_lengths_tensor, dim=0)
|
|
||||||
else:
|
else:
|
||||||
prefill_logprobs = None
|
prefill_logprobs = None
|
||||||
next_position_ids = batch.position_ids
|
next_position_ids = batch.position_ids
|
||||||
|
|
||||||
# Cumulative length
|
# Cumulative length
|
||||||
cumulative_length = 0
|
cumulative_length = 0
|
||||||
cumulative_chunk_lengths = 0
|
|
||||||
|
|
||||||
# Results
|
# Results
|
||||||
generations: List[Generation] = []
|
generations: List[Generation] = []
|
||||||
@ -1675,32 +1625,21 @@ class FlashCausalLM(Model):
|
|||||||
# one, we need to first do a GPU <-> CPU sync
|
# one, we need to first do a GPU <-> CPU sync
|
||||||
# It is faster if we delay this sync for the maximum amount of time
|
# It is faster if we delay this sync for the maximum amount of time
|
||||||
|
|
||||||
index = 0
|
|
||||||
# For each member of the batch
|
# For each member of the batch
|
||||||
|
index = 0
|
||||||
for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator):
|
for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator):
|
||||||
# Indexing metadata
|
# Indexing metadata
|
||||||
start_index = cumulative_length
|
start_index = cumulative_length
|
||||||
end_index = cumulative_length + input_length
|
end_index = cumulative_length + input_length
|
||||||
|
|
||||||
if batch.prefilling:
|
if prefill:
|
||||||
if batch.next_chunk_lengths is not None:
|
|
||||||
next_chunk_length = batch.next_chunk_lengths[i]
|
|
||||||
else:
|
|
||||||
next_chunk_length = 1
|
|
||||||
|
|
||||||
# Indexing metadata
|
# Indexing metadata
|
||||||
out_start_index = batch.prefill_cu_outlens[i]
|
out_start_index = batch.prefill_cu_outlens[i]
|
||||||
out_end_index = batch.prefill_cu_outlens[i + 1]
|
out_end_index = batch.prefill_cu_outlens[i + 1]
|
||||||
out_length = out_end_index - out_start_index
|
out_length = out_end_index - out_start_index
|
||||||
|
|
||||||
position_start_index =
|
|
||||||
|
|
||||||
# Initialize position_ids
|
# Initialize position_ids
|
||||||
# In decode, we do not need this as we can just increment position ids
|
# In decode, we do not need this as we can just increment position ids
|
||||||
|
|
||||||
|
|
||||||
next_position_ids
|
|
||||||
|
|
||||||
next_position_ids[i] = batch.position_ids[end_index - 1]
|
next_position_ids[i] = batch.position_ids[end_index - 1]
|
||||||
|
|
||||||
# Initialize adapter indices
|
# Initialize adapter indices
|
||||||
@ -1712,7 +1651,6 @@ class FlashCausalLM(Model):
|
|||||||
# Used to gather prefill logprobs
|
# Used to gather prefill logprobs
|
||||||
# Copy batch.input_ids to prefill_token_indices
|
# Copy batch.input_ids to prefill_token_indices
|
||||||
if prefill_logprobs:
|
if prefill_logprobs:
|
||||||
raise NotImplementedError
|
|
||||||
if len(batch) > 1:
|
if len(batch) > 1:
|
||||||
prefill_tokens_indices[out_start_index : out_end_index - 1] = (
|
prefill_tokens_indices[out_start_index : out_end_index - 1] = (
|
||||||
batch.input_ids[start_index + 1 : start_index + out_length]
|
batch.input_ids[start_index + 1 : start_index + out_length]
|
||||||
@ -1730,23 +1668,14 @@ class FlashCausalLM(Model):
|
|||||||
cumulative_length += input_length
|
cumulative_length += input_length
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
if batch.next_prefilling_chunk_lengths is None:
|
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
|
||||||
# We are done prefilling
|
batch.speculative_ids = speculative_ids
|
||||||
batch.prefilling = False
|
batch.position_ids = next_position_ids + accepted_ids
|
||||||
batch.next_prefilling_chunk_lengths = None
|
batch.input_lengths_tensor += accepted_ids
|
||||||
batch.next_prefilling_chunk_lengths_tensor = None
|
batch.slot_indices += accepted_ids
|
||||||
# We do not need cu_seqlen_prefill anymore
|
batch.adapter_meta.adapter_indices = next_adapter_indices
|
||||||
batch.cu_seqlen_prefill = None
|
|
||||||
|
|
||||||
if not batch.prefilling:
|
if prefill:
|
||||||
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
|
|
||||||
batch.speculative_ids = speculative_ids
|
|
||||||
batch.position_ids = next_position_ids + accepted_ids
|
|
||||||
batch.input_lengths_tensor += accepted_ids
|
|
||||||
batch.slot_indices += accepted_ids
|
|
||||||
batch.adapter_meta.adapter_indices = next_adapter_indices
|
|
||||||
|
|
||||||
if batch.prefilling:
|
|
||||||
# adjust segment lengths to account for all request lengths being 1 during decoding
|
# adjust segment lengths to account for all request lengths being 1 during decoding
|
||||||
adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
|
adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
|
||||||
batch.adapter_meta.adapter_segments = torch.tensor(
|
batch.adapter_meta.adapter_segments = torch.tensor(
|
||||||
@ -1755,8 +1684,7 @@ class FlashCausalLM(Model):
|
|||||||
device=batch.adapter_meta.adapter_segments.device,
|
device=batch.adapter_meta.adapter_segments.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch.prefilling and prefill_logprobs:
|
if prefill and prefill_logprobs:
|
||||||
raise NotImplementedError
|
|
||||||
# Get prefill logprobs
|
# Get prefill logprobs
|
||||||
prefill_logprobs_tensor = torch.log_softmax(out, -1)
|
prefill_logprobs_tensor = torch.log_softmax(out, -1)
|
||||||
prefill_logprobs = torch.gather(
|
prefill_logprobs = torch.gather(
|
||||||
@ -1867,8 +1795,7 @@ class FlashCausalLM(Model):
|
|||||||
generated_text = None
|
generated_text = None
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if batch.prefilling and request.prefill_logprobs:
|
if prefill and request.prefill_logprobs:
|
||||||
raise NotImplementedError
|
|
||||||
out_start_index = batch.prefill_cu_outlens[i]
|
out_start_index = batch.prefill_cu_outlens[i]
|
||||||
out_end_index = batch.prefill_cu_outlens[i + 1]
|
out_end_index = batch.prefill_cu_outlens[i + 1]
|
||||||
|
|
||||||
|
@ -5,9 +5,9 @@ from typing import Dict, Optional
|
|||||||
|
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"}
|
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", "1").lower() in {"1", "true"}
|
||||||
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
||||||
ATTENTION = os.getenv("ATTENTION")
|
ATTENTION = os.getenv("ATTENTION", "flashinfer")
|
||||||
_expected = {"paged", "flashdecoding", "flashinfer"}
|
_expected = {"paged", "flashdecoding", "flashinfer"}
|
||||||
assert (
|
assert (
|
||||||
ATTENTION in _expected
|
ATTENTION in _expected
|
||||||
|
Loading…
Reference in New Issue
Block a user