mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
rollback
This commit is contained in:
parent
7169cbae6d
commit
de043b53c4
@ -36,9 +36,9 @@ impl BackendV3 {
|
||||
speculate: u32,
|
||||
) -> Self {
|
||||
let prefix_caching =
|
||||
std::env::var("USE_PREFIX_CACHING").expect("Expect prefix caching env var");
|
||||
std::env::var("USE_PREFIX_CACHING").unwrap_or("1".to_string());
|
||||
let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1");
|
||||
let attention: String = std::env::var("ATTENTION").expect("attention env var");
|
||||
let attention: String = std::env::var("ATTENTION").unwrap_or("flashinfer".to_string());
|
||||
|
||||
let attention: Attention = attention
|
||||
.parse()
|
||||
|
@ -2,10 +2,6 @@ import pytest
|
||||
import os
|
||||
from text_generation_server.pb import generate_pb2
|
||||
|
||||
os.environ["USE_PREFIX_CACHING"] = "1"
|
||||
os.environ["ATTENTION"] = "flashinfer"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_pb_parameters():
|
||||
return generate_pb2.NextTokenChooserParameters(
|
||||
|
@ -149,26 +149,11 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
max_seqlen: int
|
||||
|
||||
# Prefill metadata tensors
|
||||
# Prefill metadata tensors to efficiently compute logprobs
|
||||
prefill_head_indices: Optional[torch.Tensor]
|
||||
prefill_next_token_indices: Optional[torch.Tensor]
|
||||
prefill_next_token_indices: Optional[torch.tensor]
|
||||
prefill_cu_outlens: Optional[List[int]]
|
||||
|
||||
# Whether at least one request is prefilling/chunking
|
||||
# == any(prefilling_mask)
|
||||
prefilling: bool
|
||||
# For each request, whether they are still prefilling/chunking
|
||||
prefilling_mask: List[bool]
|
||||
# For each request, whether the model output should be used or discarded
|
||||
# If we are chunking, we don't care about the output as it might be different
|
||||
# from the token in the prompt
|
||||
use_output_token: List[bool]
|
||||
|
||||
# If the request is decoding, `next_chunk_length = 1`
|
||||
# `None if not batch.prefilling`
|
||||
next_chunk_lengths: Optional[List[int]]
|
||||
next_chunk_lengths_tensor: Optional[torch.Tensor]
|
||||
|
||||
# Prefixes
|
||||
prefix_ids: List[List[int]]
|
||||
|
||||
@ -247,14 +232,11 @@ class FlashCausalLMBatch(Batch):
|
||||
prefix_ids = []
|
||||
requests_idx_mapping = {}
|
||||
|
||||
chunking = False
|
||||
all_prefill_logprobs = True
|
||||
no_prefill_logprobs = True
|
||||
prefill_head_indices = []
|
||||
prefill_next_token_indices = []
|
||||
prefill_cu_outlens = [0]
|
||||
next_chunk_lengths = []
|
||||
use_output_token = []
|
||||
|
||||
next_token_chooser_parameters = []
|
||||
stopping_criterias = []
|
||||
@ -294,7 +276,6 @@ class FlashCausalLMBatch(Batch):
|
||||
assert prefix_len > 0
|
||||
prefix_len -= 1
|
||||
|
||||
|
||||
# Commented as it's costly.
|
||||
# log_master(logger.debug, "Tokenized input ids {tokenized_input}")
|
||||
prefix_ids.append(tokenized_input[:prefix_len])
|
||||
@ -303,18 +284,9 @@ class FlashCausalLMBatch(Batch):
|
||||
input_length = len(tokenized_input)
|
||||
input_lengths.append(input_length)
|
||||
|
||||
if True:
|
||||
# This request only requires one prefill and no chunking
|
||||
use_output_token.append(True)
|
||||
next_chunk_lengths.append(1)
|
||||
else:
|
||||
chunking = True
|
||||
raise NotImplementedError
|
||||
|
||||
prefix_offsets.append(input_length - 5)
|
||||
read_offsets.append(input_length)
|
||||
|
||||
# FIXME: use all input tokens not just postfix ones
|
||||
all_input_ids.append(tokenized_input)
|
||||
|
||||
# Position ids
|
||||
@ -385,7 +357,6 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
# Create tensor to slice into the kv tensor in prefill
|
||||
if sliding_window is not None:
|
||||
raise NotImplementedError
|
||||
request_prefill_cache_indices = torch.arange(
|
||||
cumulative_length + max(0, input_length - sliding_window),
|
||||
cumulative_length + input_length,
|
||||
@ -397,7 +368,6 @@ class FlashCausalLMBatch(Batch):
|
||||
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
|
||||
|
||||
if r.prefill_logprobs:
|
||||
raise NotImplementedError
|
||||
prefill_head_indices.append(request_position_ids + cumulative_length)
|
||||
prefill_next_token_indices.append(
|
||||
prefill_out_cumulative_length + input_length - 1
|
||||
@ -475,12 +445,6 @@ class FlashCausalLMBatch(Batch):
|
||||
adapter_segments, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
if chunking:
|
||||
next_chunk_lengths_tensor = torch.tensor(next_chunk_lengths, dtype=torch.int64, device=device)
|
||||
else:
|
||||
next_chunk_lengths = None
|
||||
next_chunk_lengths_tensor = None
|
||||
|
||||
if all_prefill_logprobs:
|
||||
prefill_head_indices = None
|
||||
prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
|
||||
@ -527,11 +491,6 @@ class FlashCausalLMBatch(Batch):
|
||||
prefill_head_indices=prefill_head_indices,
|
||||
prefill_next_token_indices=prefill_next_token_indices,
|
||||
prefill_cu_outlens=prefill_cu_outlens,
|
||||
prefilling=True,
|
||||
prefilling_mask=[True] * pb.requests.len(),
|
||||
use_output_token=use_output_token,
|
||||
next_chunk_lengths=next_chunk_lengths,
|
||||
next_chunk_lengths_tensor=next_chunk_lengths_tensor,
|
||||
input_lengths=input_lengths,
|
||||
input_lengths_tensor=input_lengths_tensor,
|
||||
prefix_offsets=prefix_offsets,
|
||||
@ -1467,7 +1426,7 @@ class FlashCausalLM(Model):
|
||||
max_s = batch.max_seqlen
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
if not batch.prefilling and self.max_past() is not None:
|
||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||
# In decode, not prefill, we're actually overwriting the KV-cache
|
||||
# in a circular buffer mode.
|
||||
# This makes sure the max_s for the decode pass is correct.
|
||||
@ -1481,7 +1440,7 @@ class FlashCausalLM(Model):
|
||||
else:
|
||||
cuda_graph = None
|
||||
|
||||
if batch.prefilling or cuda_graph is None:
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
if ATTENTION == "flashinfer":
|
||||
block_tables = block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
@ -1516,7 +1475,6 @@ class FlashCausalLM(Model):
|
||||
adapter_data=adapter_data,
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
raise NotImplementedError
|
||||
batch.prefill_cache_indices = None
|
||||
return logits, speculative_logits
|
||||
|
||||
@ -1570,6 +1528,7 @@ class FlashCausalLM(Model):
|
||||
self, batch: FlashCausalLMBatch
|
||||
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
|
||||
start = time.time_ns()
|
||||
prefill = batch.cu_seqlen_prefill is not None
|
||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||
|
||||
# Update adapter indices for speculative tokens (if present)
|
||||
@ -1595,13 +1554,13 @@ class FlashCausalLM(Model):
|
||||
adapter_data = AdapterBatchData.from_meta(
|
||||
adapter_meta,
|
||||
self.layer_to_adapter_weights,
|
||||
batch.prefilling,
|
||||
prefill,
|
||||
batch.prefill_head_indices,
|
||||
)
|
||||
|
||||
out, speculative_logits = self.forward(batch, adapter_data)
|
||||
|
||||
if batch.prefilling:
|
||||
if prefill:
|
||||
next_token_logits = (
|
||||
out[batch.prefill_next_token_indices] if prefill_logprobs else out
|
||||
)
|
||||
@ -1638,31 +1597,22 @@ class FlashCausalLM(Model):
|
||||
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
|
||||
)
|
||||
|
||||
if batch.prefilling:
|
||||
if prefill:
|
||||
if len(batch) > 1 and prefill_logprobs:
|
||||
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
||||
# When batch == 1, we will just use the batch.input_ids values directly
|
||||
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
|
||||
|
||||
if batch.next_chunk_lengths is None:
|
||||
# We are done prefilling after this forward
|
||||
next_position_ids = batch.position_ids.new_empty(len(batch))
|
||||
# [BATCH_SIZE]
|
||||
# Last slot for each request, will be incremented later
|
||||
batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
|
||||
else:
|
||||
# We still have prefill chunks to go through
|
||||
next_forward_size = sum(batch.next_chunk_lengths)
|
||||
next_position_ids = batch.position_ids.new_empty(next_forward_size)
|
||||
batch.slot_indices = batch.slot_indices.new_empty(next_forward_size)
|
||||
batch.cu_seqlen_prefill[1:] = torch.cumsum(batch.next_chunk_lengths_tensor, dim=0)
|
||||
# We do not need cu_seqlen_prefill anymore
|
||||
batch.cu_seqlen_prefill = None
|
||||
else:
|
||||
prefill_logprobs = None
|
||||
next_position_ids = batch.position_ids
|
||||
|
||||
# Cumulative length
|
||||
cumulative_length = 0
|
||||
cumulative_chunk_lengths = 0
|
||||
|
||||
# Results
|
||||
generations: List[Generation] = []
|
||||
@ -1675,32 +1625,21 @@ class FlashCausalLM(Model):
|
||||
# one, we need to first do a GPU <-> CPU sync
|
||||
# It is faster if we delay this sync for the maximum amount of time
|
||||
|
||||
index = 0
|
||||
# For each member of the batch
|
||||
index = 0
|
||||
for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator):
|
||||
# Indexing metadata
|
||||
start_index = cumulative_length
|
||||
end_index = cumulative_length + input_length
|
||||
|
||||
if batch.prefilling:
|
||||
if batch.next_chunk_lengths is not None:
|
||||
next_chunk_length = batch.next_chunk_lengths[i]
|
||||
else:
|
||||
next_chunk_length = 1
|
||||
|
||||
if prefill:
|
||||
# Indexing metadata
|
||||
out_start_index = batch.prefill_cu_outlens[i]
|
||||
out_end_index = batch.prefill_cu_outlens[i + 1]
|
||||
out_length = out_end_index - out_start_index
|
||||
|
||||
position_start_index =
|
||||
|
||||
# Initialize position_ids
|
||||
# In decode, we do not need this as we can just increment position ids
|
||||
|
||||
|
||||
next_position_ids
|
||||
|
||||
next_position_ids[i] = batch.position_ids[end_index - 1]
|
||||
|
||||
# Initialize adapter indices
|
||||
@ -1712,7 +1651,6 @@ class FlashCausalLM(Model):
|
||||
# Used to gather prefill logprobs
|
||||
# Copy batch.input_ids to prefill_token_indices
|
||||
if prefill_logprobs:
|
||||
raise NotImplementedError
|
||||
if len(batch) > 1:
|
||||
prefill_tokens_indices[out_start_index : out_end_index - 1] = (
|
||||
batch.input_ids[start_index + 1 : start_index + out_length]
|
||||
@ -1730,15 +1668,6 @@ class FlashCausalLM(Model):
|
||||
cumulative_length += input_length
|
||||
|
||||
# Update values
|
||||
if batch.next_prefilling_chunk_lengths is None:
|
||||
# We are done prefilling
|
||||
batch.prefilling = False
|
||||
batch.next_prefilling_chunk_lengths = None
|
||||
batch.next_prefilling_chunk_lengths_tensor = None
|
||||
# We do not need cu_seqlen_prefill anymore
|
||||
batch.cu_seqlen_prefill = None
|
||||
|
||||
if not batch.prefilling:
|
||||
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
|
||||
batch.speculative_ids = speculative_ids
|
||||
batch.position_ids = next_position_ids + accepted_ids
|
||||
@ -1746,7 +1675,7 @@ class FlashCausalLM(Model):
|
||||
batch.slot_indices += accepted_ids
|
||||
batch.adapter_meta.adapter_indices = next_adapter_indices
|
||||
|
||||
if batch.prefilling:
|
||||
if prefill:
|
||||
# adjust segment lengths to account for all request lengths being 1 during decoding
|
||||
adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
|
||||
batch.adapter_meta.adapter_segments = torch.tensor(
|
||||
@ -1755,8 +1684,7 @@ class FlashCausalLM(Model):
|
||||
device=batch.adapter_meta.adapter_segments.device,
|
||||
)
|
||||
|
||||
if batch.prefilling and prefill_logprobs:
|
||||
raise NotImplementedError
|
||||
if prefill and prefill_logprobs:
|
||||
# Get prefill logprobs
|
||||
prefill_logprobs_tensor = torch.log_softmax(out, -1)
|
||||
prefill_logprobs = torch.gather(
|
||||
@ -1867,8 +1795,7 @@ class FlashCausalLM(Model):
|
||||
generated_text = None
|
||||
|
||||
# Prefill
|
||||
if batch.prefilling and request.prefill_logprobs:
|
||||
raise NotImplementedError
|
||||
if prefill and request.prefill_logprobs:
|
||||
out_start_index = batch.prefill_cu_outlens[i]
|
||||
out_end_index = batch.prefill_cu_outlens[i + 1]
|
||||
|
||||
|
@ -5,9 +5,9 @@ from typing import Dict, Optional
|
||||
|
||||
from text_generation_server.utils.log import log_master
|
||||
|
||||
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"}
|
||||
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", "1").lower() in {"1", "true"}
|
||||
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
||||
ATTENTION = os.getenv("ATTENTION")
|
||||
ATTENTION = os.getenv("ATTENTION", "flashinfer")
|
||||
_expected = {"paged", "flashdecoding", "flashinfer"}
|
||||
assert (
|
||||
ATTENTION in _expected
|
||||
|
Loading…
Reference in New Issue
Block a user