feat(server): only compute prefill logprobs when asked

This commit is contained in:
OlivierDehaene 2023-06-02 15:30:35 +02:00
parent e7248fe90e
commit 1dd0cf63df
19 changed files with 199 additions and 54 deletions

View File

@ -136,6 +136,7 @@ async fn prefill(
let requests = (0..batch_size) let requests = (0..batch_size)
.map(|id| Request { .map(|id| Request {
id: id.into(), id: id.into(),
prefill_logprobs: false,
inputs: sequence.clone(), inputs: sequence.clone(),
truncate: sequence_length, truncate: sequence_length,
parameters: Some(parameters.clone()), parameters: Some(parameters.clone()),

View File

@ -107,6 +107,40 @@ print(text)
### Types ### Types
```python ```python
# Request Parameters
class Parameters:
# Activate logits sampling
do_sample: bool
# Maximum number of generated tokens
max_new_tokens: int
# The parameter for repetition penalty. 1.0 means no penalty.
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
repetition_penalty: Optional[float]
# Whether to prepend the prompt to the generated text
return_full_text: bool
# Stop generating tokens if a member of `stop_sequences` is generated
stop: List[str]
# Random sampling seed
seed: Optional[int]
# The value used to module the logits distribution.
temperature: Optional[float]
# The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_k: Optional[int]
# If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
# higher are kept for generation.
top_p: Optional[float]
# truncate inputs tokens to the given size
truncate: Optional[int]
# Typical Decoding mass
# See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
typical_p: Optional[float]
# Generate best_of sequences and return the one if the highest token logprobs
best_of: Optional[int]
# Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
watermark: bool
# Get prompt token logprobs and ids
prefill_details: bool
# Prompt tokens # Prompt tokens
class PrefillToken: class PrefillToken:
# Token ID from the model tokenizer # Token ID from the model tokenizer
@ -151,8 +185,8 @@ class BestOfSequence:
generated_tokens: int generated_tokens: int
# Sampling seed if sampling was activated # Sampling seed if sampling was activated
seed: Optional[int] seed: Optional[int]
# Prompt tokens # Prompt tokens, empty if prefill_details is False
prefill: List[PrefillToken] prefill: Optional[List[PrefillToken]]
# Generated tokens # Generated tokens
tokens: List[Token] tokens: List[Token]
@ -165,7 +199,7 @@ class Details:
generated_tokens: int generated_tokens: int
# Sampling seed if sampling was activated # Sampling seed if sampling was activated
seed: Optional[int] seed: Optional[int]
# Prompt tokens # Prompt tokens, empty if prefill_details is False
prefill: List[PrefillToken] prefill: List[PrefillToken]
# Generated tokens # Generated tokens
tokens: List[Token] tokens: List[Token]

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "text-generation" name = "text-generation"
version = "0.5.2" version = "0.6.0"
description = "Hugging Face Text Generation Python Client" description = "Hugging Face Text Generation Python Client"
license = "Apache-2.0" license = "Apache-2.0"
authors = ["Olivier Dehaene <olivier@huggingface.co>"] authors = ["Olivier Dehaene <olivier@huggingface.co>"]

View File

@ -7,7 +7,7 @@ from text_generation.types import FinishReason, PrefillToken, Token
def test_generate(flan_t5_xxl_url, hf_headers): def test_generate(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers) client = Client(flan_t5_xxl_url, hf_headers)
response = client.generate("test", max_new_tokens=1) response = client.generate("test", max_new_tokens=1, prefill_details=True)
assert response.generated_text == "" assert response.generated_text == ""
assert response.details.finish_reason == FinishReason.Length assert response.details.finish_reason == FinishReason.Length
@ -17,13 +17,15 @@ def test_generate(flan_t5_xxl_url, hf_headers):
assert response.details.prefill[0] == PrefillToken(id=0, text="<pad>", logprob=None) assert response.details.prefill[0] == PrefillToken(id=0, text="<pad>", logprob=None)
assert len(response.details.tokens) == 1 assert len(response.details.tokens) == 1
assert response.details.tokens[0].id == 3 assert response.details.tokens[0].id == 3
assert response.details.tokens[0].text == "" assert response.details.tokens[0].text == " "
assert not response.details.tokens[0].special assert not response.details.tokens[0].special
def test_generate_best_of(flan_t5_xxl_url, hf_headers): def test_generate_best_of(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers) client = Client(flan_t5_xxl_url, hf_headers)
response = client.generate("test", max_new_tokens=1, best_of=2, do_sample=True) response = client.generate(
"test", max_new_tokens=1, best_of=2, do_sample=True, prefill_details=True
)
assert response.details.seed is not None assert response.details.seed is not None
assert response.details.best_of_sequences is not None assert response.details.best_of_sequences is not None
@ -73,7 +75,7 @@ def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_async(flan_t5_xxl_url, hf_headers): async def test_generate_async(flan_t5_xxl_url, hf_headers):
client = AsyncClient(flan_t5_xxl_url, hf_headers) client = AsyncClient(flan_t5_xxl_url, hf_headers)
response = await client.generate("test", max_new_tokens=1) response = await client.generate("test", max_new_tokens=1, prefill_details=True)
assert response.generated_text == "" assert response.generated_text == ""
assert response.details.finish_reason == FinishReason.Length assert response.details.finish_reason == FinishReason.Length
@ -83,7 +85,7 @@ async def test_generate_async(flan_t5_xxl_url, hf_headers):
assert response.details.prefill[0] == PrefillToken(id=0, text="<pad>", logprob=None) assert response.details.prefill[0] == PrefillToken(id=0, text="<pad>", logprob=None)
assert len(response.details.tokens) == 1 assert len(response.details.tokens) == 1
assert response.details.tokens[0].id == 3 assert response.details.tokens[0].id == 3
assert response.details.tokens[0].text == "" assert response.details.tokens[0].text == " "
assert not response.details.tokens[0].special assert not response.details.tokens[0].special
@ -91,7 +93,7 @@ async def test_generate_async(flan_t5_xxl_url, hf_headers):
async def test_generate_async_best_of(flan_t5_xxl_url, hf_headers): async def test_generate_async_best_of(flan_t5_xxl_url, hf_headers):
client = AsyncClient(flan_t5_xxl_url, hf_headers) client = AsyncClient(flan_t5_xxl_url, hf_headers)
response = await client.generate( response = await client.generate(
"test", max_new_tokens=1, best_of=2, do_sample=True "test", max_new_tokens=1, best_of=2, do_sample=True, prefill_details=True
) )
assert response.details.seed is not None assert response.details.seed is not None

View File

@ -74,6 +74,7 @@ class Client:
truncate: Optional[int] = None, truncate: Optional[int] = None,
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
watermark: bool = False, watermark: bool = False,
prefill_details: bool = False,
) -> Response: ) -> Response:
""" """
Given a prompt, generate the following text Given a prompt, generate the following text
@ -110,6 +111,8 @@ class Client:
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`): watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
prefill_details (`bool`):
Return the prefill token log probabilities
Returns: Returns:
Response: generated response Response: generated response
@ -130,6 +133,7 @@ class Client:
truncate=truncate, truncate=truncate,
typical_p=typical_p, typical_p=typical_p,
watermark=watermark, watermark=watermark,
prefill_details=prefill_details,
) )
request = Request(inputs=prompt, stream=False, parameters=parameters) request = Request(inputs=prompt, stream=False, parameters=parameters)
@ -202,6 +206,7 @@ class Client:
parameters = Parameters( parameters = Parameters(
best_of=None, best_of=None,
details=True, details=True,
prefill_details=False,
do_sample=do_sample, do_sample=do_sample,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
@ -311,6 +316,7 @@ class AsyncClient:
truncate: Optional[int] = None, truncate: Optional[int] = None,
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
watermark: bool = False, watermark: bool = False,
prefill_details: bool = False,
) -> Response: ) -> Response:
""" """
Given a prompt, generate the following text asynchronously Given a prompt, generate the following text asynchronously
@ -347,6 +353,8 @@ class AsyncClient:
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`): watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
prefill_details (`bool`):
Return the prefill token log probabilities
Returns: Returns:
Response: generated response Response: generated response
@ -355,6 +363,7 @@ class AsyncClient:
parameters = Parameters( parameters = Parameters(
best_of=best_of, best_of=best_of,
details=True, details=True,
prefill_details=prefill_details,
do_sample=do_sample, do_sample=do_sample,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
@ -437,6 +446,7 @@ class AsyncClient:
parameters = Parameters( parameters = Parameters(
best_of=None, best_of=None,
details=True, details=True,
prefill_details=False,
do_sample=do_sample, do_sample=do_sample,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,

View File

@ -37,6 +37,8 @@ class Parameters(BaseModel):
watermark: bool = False watermark: bool = False
# Get generation details # Get generation details
details: bool = False details: bool = False
# Get prefill details
prefill_details: bool = False
@validator("best_of") @validator("best_of")
def valid_best_of(cls, field_value, values): def valid_best_of(cls, field_value, values):
@ -173,7 +175,7 @@ class BestOfSequence(BaseModel):
generated_tokens: int generated_tokens: int
# Sampling seed if sampling was activated # Sampling seed if sampling was activated
seed: Optional[int] seed: Optional[int]
# Prompt tokens # Prompt tokens, empty if prefill_details is False
prefill: List[PrefillToken] prefill: List[PrefillToken]
# Generated tokens # Generated tokens
tokens: List[Token] tokens: List[Token]
@ -187,7 +189,7 @@ class Details(BaseModel):
generated_tokens: int generated_tokens: int
# Sampling seed if sampling was activated # Sampling seed if sampling was activated
seed: Optional[int] seed: Optional[int]
# Prompt tokens # Prompt tokens, empty if prefill_details is False
prefill: List[PrefillToken] prefill: List[PrefillToken]
# Generated tokens # Generated tokens
tokens: List[Token] tokens: List[Token]

View File

@ -87,6 +87,8 @@ message Request {
NextTokenChooserParameters parameters = 4; NextTokenChooserParameters parameters = 4;
/// Stopping Criteria Parameters /// Stopping Criteria Parameters
StoppingCriteriaParameters stopping_parameters = 5; StoppingCriteriaParameters stopping_parameters = 5;
/// Return prefill logprobs
bool prefill_logprobs = 6;
} }
message Batch { message Batch {

View File

@ -34,6 +34,7 @@ impl Health {
id: LIVENESS_ID, id: LIVENESS_ID,
inputs: "liveness".to_string(), inputs: "liveness".to_string(),
truncate: 10, truncate: 10,
prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters { parameters: Some(NextTokenChooserParameters {
temperature: 1.0, temperature: 1.0,
top_k: 0, top_k: 0,

View File

@ -125,6 +125,9 @@ pub(crate) struct GenerateParameters {
#[schema(default = "true")] #[schema(default = "true")]
pub details: bool, pub details: bool,
#[serde(default)] #[serde(default)]
#[schema(default = "true")]
pub prefill_details: bool,
#[serde(default)]
#[schema( #[schema(
exclusive_minimum = 0, exclusive_minimum = 0,
nullable = true, nullable = true,
@ -153,6 +156,7 @@ fn default_parameters() -> GenerateParameters {
truncate: None, truncate: None,
watermark: false, watermark: false,
details: false, details: false,
prefill_details: false,
seed: None, seed: None,
} }
} }

View File

@ -201,6 +201,7 @@ impl State {
batch_requests.push(Request { batch_requests.push(Request {
id, id,
prefill_logprobs: entry.request.prefill_details,
inputs: entry.request.inputs.clone(), inputs: entry.request.inputs.clone(),
truncate: entry.request.truncate, truncate: entry.request.truncate,
parameters: Some(entry.request.parameters.clone()), parameters: Some(entry.request.parameters.clone()),

View File

@ -364,7 +364,17 @@ async fn generate_stream(
let details = req.0.parameters.details; let details = req.0.parameters.details;
let best_of = req.0.parameters.best_of.unwrap_or(1); let best_of = req.0.parameters.best_of.unwrap_or(1);
if best_of == 1 { if best_of != 1 {
let err = InferError::from(ValidationError::BestOfStream);
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}");
yield Ok(Event::from(err));
} else if req.0.parameters.prefill_details {
let err = InferError::from(ValidationError::PrefillDetailsStream);
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}");
yield Ok(Event::from(err));
} else {
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await { match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
// Keep permit as long as generate_stream lives // Keep permit as long as generate_stream lives
Ok((_permit, mut response_stream)) => { Ok((_permit, mut response_stream)) => {
@ -474,11 +484,6 @@ async fn generate_stream(
tracing::error!("{err}"); tracing::error!("{err}");
yield Ok(Event::from(err)); yield Ok(Event::from(err));
} }
} else {
let err = InferError::from(ValidationError::BestOfStream);
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}");
yield Ok(Event::from(err));
} }
}; };

View File

@ -145,6 +145,7 @@ impl Validation {
truncate, truncate,
seed, seed,
watermark, watermark,
prefill_details,
.. ..
} = request.parameters; } = request.parameters;
@ -261,6 +262,7 @@ impl Validation {
Ok(ValidGenerateRequest { Ok(ValidGenerateRequest {
inputs, inputs,
prefill_details,
input_length: input_length as u32, input_length: input_length as u32,
truncate: truncate.unwrap_or(self.max_input_length) as u32, truncate: truncate.unwrap_or(self.max_input_length) as u32,
parameters, parameters,
@ -335,6 +337,7 @@ pub(crate) struct ValidGenerateRequest {
pub inputs: String, pub inputs: String,
pub input_length: u32, pub input_length: u32,
pub truncate: u32, pub truncate: u32,
pub prefill_details: bool,
pub parameters: NextTokenChooserParameters, pub parameters: NextTokenChooserParameters,
pub stopping_parameters: StoppingCriteriaParameters, pub stopping_parameters: StoppingCriteriaParameters,
} }
@ -351,6 +354,8 @@ pub enum ValidationError {
BestOfSeed, BestOfSeed,
#[error("`best_of` != 1 is not supported when streaming tokens")] #[error("`best_of` != 1 is not supported when streaming tokens")]
BestOfStream, BestOfStream,
#[error("`prefill_details` == true is not supported when streaming tokens")]
PrefillDetailsStream,
#[error("`temperature` must be strictly positive")] #[error("`temperature` must be strictly positive")]
Temperature, Temperature,
#[error("`repetition_penalty` must be strictly positive")] #[error("`repetition_penalty` must be strictly positive")]

View File

@ -104,7 +104,7 @@ class CausalLMBatch(Batch):
).to(device) ).to(device)
for _ in pb.requests: for _ in pb.requests:
input_len = tokenized_inputs["input_ids"].shape[1] input_len = tokenized_inputs["input_ids"].shape[1]
prefix_offsets.append(0) prefix_offsets.append(input_len - 5)
read_offsets.append(input_len) read_offsets.append(input_len)
input_lengths = tokenized_inputs["attention_mask"].sum(1) input_lengths = tokenized_inputs["attention_mask"].sum(1)
@ -617,7 +617,7 @@ class CausalLM(Model):
generated_text = None generated_text = None
# Prefill # Prefill
if stopping_criteria.current_tokens == 1: if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
# Remove generated token to only have prefill and add nan for first prompt token # Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + torch.log_softmax( prefill_logprobs = [float("nan")] + torch.log_softmax(
logits, -1 logits, -1

View File

@ -443,6 +443,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
max_s, max_s,
past_key_values: Optional[torch.Tensor] = None, past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None, pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None,
): ):
hidden_states, present = self.model( hidden_states, present = self.model(
input_ids, input_ids,
@ -453,6 +454,8 @@ class FlashLlamaForCausalLM(torch.nn.Module):
past_key_values, past_key_values,
pre_allocate_past_size, pre_allocate_past_size,
) )
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
if self.model.tp_embeddings: if self.model.tp_embeddings:

View File

@ -481,6 +481,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
max_s, max_s,
past_key_values: Optional[torch.Tensor] = None, past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None, pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None,
): ):
hidden_states, present = self.gpt_neox( hidden_states, present = self.gpt_neox(
input_ids, input_ids,
@ -491,6 +492,8 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
past_key_values, past_key_values,
pre_allocate_past_size, pre_allocate_past_size,
) )
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.embed_out(hidden_states) logits = self.embed_out(hidden_states)
if self.gpt_neox.tp_embeddings: if self.gpt_neox.tp_embeddings:

View File

@ -752,6 +752,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
max_s, max_s,
past_key_values: Optional[torch.Tensor] = None, past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None, pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None,
): ):
hidden_states, present = self.transformer( hidden_states, present = self.transformer(
input_ids, input_ids,
@ -762,6 +763,8 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
past_key_values, past_key_values,
pre_allocate_past_size, pre_allocate_past_size,
) )
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
if self.transformer.tp_embeddings: if self.transformer.tp_embeddings:

View File

@ -358,6 +358,7 @@ class FlashSantacoderForCausalLM(nn.Module):
max_s, max_s,
past_key_values: Optional[torch.Tensor] = None, past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None, pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None,
): ):
hidden_states, present = self.transformer( hidden_states, present = self.transformer(
input_ids, input_ids,
@ -368,6 +369,8 @@ class FlashSantacoderForCausalLM(nn.Module):
past_key_values, past_key_values,
pre_allocate_past_size, pre_allocate_past_size,
) )
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
if self.transformer.tp_embeddings: if self.transformer.tp_embeddings:

View File

@ -42,6 +42,11 @@ class FlashCausalLMBatch(Batch):
past_key_values: Optional[torch.Tensor] past_key_values: Optional[torch.Tensor]
max_seqlen: int max_seqlen: int
# Prefill metadata tensors to efficiently compute logprobs
prefill_head_indices: Optional[torch.Tensor]
prefill_next_token_indices: Optional[torch.tensor]
prefill_cu_outlens: Optional[List[int]]
# All tokens # All tokens
all_input_ids: List[List[int]] all_input_ids: List[List[int]]
all_input_ids_tensor: torch.Tensor all_input_ids_tensor: torch.Tensor
@ -84,11 +89,18 @@ class FlashCausalLMBatch(Batch):
all_input_ids = [] all_input_ids = []
requests_idx_mapping = {} requests_idx_mapping = {}
all_prefill_logprobs = True
no_prefill_logprobs = True
prefill_head_indices = []
prefill_next_token_indices = []
prefill_cu_outlens = [0]
next_token_chooser_parameters = [] next_token_chooser_parameters = []
stopping_criterias = [] stopping_criterias = []
# Cumulative length # Cumulative length
cumulative_length = 0 cumulative_length = 0
prefill_out_cumulative_length = 0
max_tokens = 0 max_tokens = 0
max_length = 0 max_length = 0
@ -106,13 +118,14 @@ class FlashCausalLMBatch(Batch):
max_seqlen = max(max_seqlen, input_length) max_seqlen = max(max_seqlen, input_length)
input_lengths.append(input_length) input_lengths.append(input_length)
prefix_offsets.append(0) prefix_offsets.append(input_length - 5)
read_offsets.append(input_length) read_offsets.append(input_length)
all_input_ids.append(tokenized_input) all_input_ids.append(tokenized_input)
# Position ids # Position ids
position_ids.append(np.arange(0, input_length)) request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
position_ids.append(request_position_ids)
# Add cumulative lengths of all previous inputs # Add cumulative lengths of all previous inputs
cu_seqlens.append(cumulative_length + input_length) cu_seqlens.append(cumulative_length + input_length)
@ -125,6 +138,26 @@ class FlashCausalLMBatch(Batch):
max_new_tokens = stopping_criteria.max_new_tokens max_new_tokens = stopping_criteria.max_new_tokens
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
if r.prefill_logprobs:
prefill_head_indices.append(request_position_ids + cumulative_length)
prefill_next_token_indices.append(
prefill_out_cumulative_length + input_length - 1
)
prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
prefill_out_cumulative_length += input_length
else:
prefill_head_indices.append(
torch.tensor(
[cumulative_length + input_length - 1], dtype=torch.int32
)
)
prefill_next_token_indices.append(prefill_out_cumulative_length)
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
prefill_out_cumulative_length += 1
# Update # Update
cumulative_length += input_length cumulative_length += input_length
max_tokens += input_length + max_new_tokens max_tokens += input_length + max_new_tokens
@ -141,18 +174,35 @@ class FlashCausalLMBatch(Batch):
for i, input_ids in enumerate(all_input_ids): for i, input_ids in enumerate(all_input_ids):
all_input_ids_tensor[i, : len(input_ids)] = input_ids all_input_ids_tensor[i, : len(input_ids)] = input_ids
if len(pb.requests) > 1:
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
position_ids = torch.cat(position_ids)
else:
input_ids = all_input_ids[0]
position_ids = position_ids[0]
# Create tensors on device # Create tensors on device
input_ids = torch.tensor( input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
np.concatenate(all_input_ids), dtype=torch.int64, device=device
)
all_input_ids_tensor = torch.tensor( all_input_ids_tensor = torch.tensor(
all_input_ids_tensor, dtype=torch.int64, device=device all_input_ids_tensor, dtype=torch.int64, device=device
) )
position_ids = torch.tensor( position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device)
np.concatenate(position_ids), dtype=torch.int32, device=device
)
cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32) cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32)
if all_prefill_logprobs:
prefill_head_indices = None
prefill_next_token_indices = cu_seqlens[1:] - 1
elif no_prefill_logprobs:
prefill_head_indices = cu_seqlens[1:] - 1
prefill_next_token_indices = None
else:
prefill_head_indices = torch.tensor(
torch.cat(prefill_head_indices), dtype=torch.int64, device=device
)
prefill_next_token_indices = torch.tensor(
prefill_next_token_indices, dtype=torch.int64, device=device
)
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
requests=pb.requests, requests=pb.requests,
@ -162,6 +212,9 @@ class FlashCausalLMBatch(Batch):
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
cu_seqlens_q=None, cu_seqlens_q=None,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
prefill_head_indices=prefill_head_indices,
prefill_next_token_indices=prefill_next_token_indices,
prefill_cu_outlens=prefill_cu_outlens,
past_key_values=None, past_key_values=None,
input_lengths=input_lengths, input_lengths=input_lengths,
prefix_offsets=prefix_offsets, prefix_offsets=prefix_offsets,
@ -280,6 +333,9 @@ class FlashCausalLMBatch(Batch):
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
prefill_head_indices=None,
prefill_next_token_indices=None,
prefill_cu_outlens=None,
past_key_values=past_key_values, past_key_values=past_key_values,
input_lengths=input_lengths, input_lengths=input_lengths,
prefix_offsets=prefix_offsets, prefix_offsets=prefix_offsets,
@ -415,6 +471,9 @@ class FlashCausalLMBatch(Batch):
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
prefill_head_indices=None,
prefill_next_token_indices=None,
prefill_cu_outlens=None,
past_key_values=past_key_values, past_key_values=past_key_values,
input_lengths=input_lengths, input_lengths=input_lengths,
prefix_offsets=prefix_offsets, prefix_offsets=prefix_offsets,
@ -486,6 +545,7 @@ class FlashCausalLM(Model):
max_s: int, max_s: int,
past_key_values: Optional = None, past_key_values: Optional = None,
pre_allocate_past_size: Optional[int] = None, pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Model Forward # Model Forward
return self.model.forward( return self.model.forward(
@ -496,6 +556,7 @@ class FlashCausalLM(Model):
max_s=max_s, max_s=max_s,
past_key_values=past_key_values, past_key_values=past_key_values,
pre_allocate_past_size=pre_allocate_past_size, pre_allocate_past_size=pre_allocate_past_size,
lm_head_indices=lm_head_indices,
) )
@tracer.start_as_current_span("generate_token") @tracer.start_as_current_span("generate_token")
@ -503,9 +564,10 @@ class FlashCausalLM(Model):
self, batch: FlashCausalLMBatch self, batch: FlashCausalLMBatch
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
prefill = batch.past_key_values is None prefill = batch.past_key_values is None
prefill_logprobs = batch.prefill_next_token_indices is not None
single_request = len(batch) == 1 single_request = len(batch) == 1
if prefill and len(batch) == 1: if prefill and single_request:
# Ask to pre-allocate kv to its max size # Ask to pre-allocate kv to its max size
# == number of tokens + max_new_tokens # == number of tokens + max_new_tokens
pre_allocate_past_size = ( pre_allocate_past_size = (
@ -522,11 +584,12 @@ class FlashCausalLM(Model):
batch.max_seqlen, batch.max_seqlen,
batch.past_key_values, batch.past_key_values,
pre_allocate_past_size, pre_allocate_past_size,
batch.prefill_head_indices,
) )
if prefill: if prefill:
next_token_logits = ( next_token_logits = (
out[-1:] if single_request else out[batch.cu_seqlens[1:] - 1] out[batch.prefill_next_token_indices] if prefill_logprobs else out
) )
else: else:
next_token_logits = out next_token_logits = out
@ -536,10 +599,10 @@ class FlashCausalLM(Model):
) )
if prefill: if prefill:
if len(batch) > 1: if len(batch) > 1 and prefill_logprobs:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
# When batch == 1, we will just use the batch.input_ids values directly # When batch == 1, we will just use the batch.input_ids values directly
prefill_tokens_indices = batch.input_ids.new_zeros(len(batch.input_ids)) prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
# Create batch.cu_seqlens_q for decode # Create batch.cu_seqlens_q for decode
batch.cu_seqlens_q = torch.arange( batch.cu_seqlens_q = torch.arange(
@ -600,7 +663,6 @@ class FlashCausalLM(Model):
# Zipped iterator # Zipped iterator
iterator = zip( iterator = zip(
batch.input_lengths, batch.input_lengths,
batch.stopping_criterias,
batch.all_input_ids, batch.all_input_ids,
) )
@ -611,28 +673,32 @@ class FlashCausalLM(Model):
# For each member of the batch # For each member of the batch
for i, ( for i, (
input_length, input_length,
stopping_criteria,
all_input_ids, all_input_ids,
) in enumerate(iterator): ) in enumerate(iterator):
# Indexing metadata
start_index = cumulative_length start_index = cumulative_length
end_index = cumulative_length + input_length end_index = cumulative_length + input_length
if prefill: if prefill:
# Indexing metadata
out_start_index = batch.prefill_cu_outlens[i]
out_end_index = batch.prefill_cu_outlens[i + 1]
out_length = out_end_index - out_start_index
# Initialize position_ids # Initialize position_ids
# In decode, we do not need this as we can just increment position ids # In decode, we do not need this as we can just increment position ids
next_position_ids[i] = batch.position_ids[end_index - 1] next_position_ids[i] = batch.position_ids[end_index - 1]
# Used to gather prefill logprobs # Used to gather prefill logprobs
# Copy batch.input_ids to prefill_token_indices # Copy batch.input_ids to prefill_token_indices
if prefill_logprobs:
if len(batch) > 1: if len(batch) > 1:
prefill_tokens_indices[ prefill_tokens_indices[
start_index : end_index - 1 out_start_index : out_end_index - 1
] = batch.input_ids[start_index + 1 : end_index] ] = batch.input_ids[start_index + 1 : start_index + out_length]
else: else:
# Set prefill_tokens_indices to the correct slice # Set prefill_tokens_indices to the correct slice
prefill_tokens_indices = batch.input_ids[ prefill_tokens_indices = batch.input_ids[
start_index + 1 : end_index start_index + 1 : start_index + out_length
] ]
batch.all_input_ids_tensor[i, input_length] = next_input_ids[i] batch.all_input_ids_tensor[i, input_length] = next_input_ids[i]
@ -644,7 +710,7 @@ class FlashCausalLM(Model):
batch.position_ids = next_position_ids + 1 batch.position_ids = next_position_ids + 1
batch.cu_seqlens = batch.cu_seqlens + batch.cu_seqlens_q batch.cu_seqlens = batch.cu_seqlens + batch.cu_seqlens_q
if prefill: if prefill and prefill_logprobs:
# Get prefill logprobs # Get prefill logprobs
prefill_logprobs_tensor = torch.log_softmax(out, -1) prefill_logprobs_tensor = torch.log_softmax(out, -1)
prefill_logprobs = torch.gather( prefill_logprobs = torch.gather(
@ -657,8 +723,6 @@ class FlashCausalLM(Model):
next_token_logprobs = next_token_logprobs.tolist() next_token_logprobs = next_token_logprobs.tolist()
next_token_ids = batch.input_ids.tolist() next_token_ids = batch.input_ids.tolist()
cumulative_length = 0
# Zipped iterator # Zipped iterator
iterator = zip( iterator = zip(
batch.requests, batch.requests,
@ -688,9 +752,6 @@ class FlashCausalLM(Model):
next_token_id, next_token_id,
next_token_logprob, next_token_logprob,
) in enumerate(iterator): ) in enumerate(iterator):
start_index = cumulative_length
end_index = cumulative_length + input_length
# Append next token to all tokens # Append next token to all tokens
all_input_ids.append(next_token_id) all_input_ids.append(next_token_id)
@ -728,10 +789,13 @@ class FlashCausalLM(Model):
generated_text = None generated_text = None
# Prefill # Prefill
if prefill: if prefill and request.prefill_logprobs:
out_start_index = batch.prefill_cu_outlens[i]
out_end_index = batch.prefill_cu_outlens[i + 1]
# Remove generated token to only have prefill and add nan for first prompt token # Remove generated token to only have prefill and add nan for first prompt token
request_prefill_logprobs = [float("nan")] + prefill_logprobs[ request_prefill_logprobs = [float("nan")] + prefill_logprobs[
start_index : end_index - 1 out_start_index : out_end_index - 1
] ]
prefill_token_ids = all_input_ids[:-1] prefill_token_ids = all_input_ids[:-1]
prefill_texts = self.tokenizer.batch_decode( prefill_texts = self.tokenizer.batch_decode(
@ -764,8 +828,10 @@ class FlashCausalLM(Model):
batch.prefix_offsets[i] = prefix_offset batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset batch.read_offsets[i] = read_offset
batch.all_input_ids[i] = all_input_ids batch.all_input_ids[i] = all_input_ids
cumulative_length += input_length
batch.prefill_cu_outlens = None
batch.prefill_head_indices = None
batch.prefill_next_token_indices = None
batch.max_seqlen = batch.max_seqlen + 1 batch.max_seqlen = batch.max_seqlen + 1
# No need to return a batch if we know that all requests stopped # No need to return a batch if we know that all requests stopped

View File

@ -688,7 +688,7 @@ class Seq2SeqLM(Model):
generated_text = None generated_text = None
# Prefill # Prefill
if stopping_criteria.current_tokens == 1: if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
prefill_tokens = PrefillTokens( prefill_tokens = PrefillTokens(
[self.tokenizer.bos_token_id], [self.tokenizer.bos_token_id],
[float("nan")], [float("nan")],