mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
feat(server): only compute prefill logprobs when asked
This commit is contained in:
parent
e7248fe90e
commit
1dd0cf63df
@ -136,6 +136,7 @@ async fn prefill(
|
||||
let requests = (0..batch_size)
|
||||
.map(|id| Request {
|
||||
id: id.into(),
|
||||
prefill_logprobs: false,
|
||||
inputs: sequence.clone(),
|
||||
truncate: sequence_length,
|
||||
parameters: Some(parameters.clone()),
|
||||
|
@ -107,6 +107,40 @@ print(text)
|
||||
### Types
|
||||
|
||||
```python
|
||||
# Request Parameters
|
||||
class Parameters:
|
||||
# Activate logits sampling
|
||||
do_sample: bool
|
||||
# Maximum number of generated tokens
|
||||
max_new_tokens: int
|
||||
# The parameter for repetition penalty. 1.0 means no penalty.
|
||||
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||
repetition_penalty: Optional[float]
|
||||
# Whether to prepend the prompt to the generated text
|
||||
return_full_text: bool
|
||||
# Stop generating tokens if a member of `stop_sequences` is generated
|
||||
stop: List[str]
|
||||
# Random sampling seed
|
||||
seed: Optional[int]
|
||||
# The value used to module the logits distribution.
|
||||
temperature: Optional[float]
|
||||
# The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
||||
top_k: Optional[int]
|
||||
# If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||
# higher are kept for generation.
|
||||
top_p: Optional[float]
|
||||
# truncate inputs tokens to the given size
|
||||
truncate: Optional[int]
|
||||
# Typical Decoding mass
|
||||
# See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
||||
typical_p: Optional[float]
|
||||
# Generate best_of sequences and return the one if the highest token logprobs
|
||||
best_of: Optional[int]
|
||||
# Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||
watermark: bool
|
||||
# Get prompt token logprobs and ids
|
||||
prefill_details: bool
|
||||
|
||||
# Prompt tokens
|
||||
class PrefillToken:
|
||||
# Token ID from the model tokenizer
|
||||
@ -151,8 +185,8 @@ class BestOfSequence:
|
||||
generated_tokens: int
|
||||
# Sampling seed if sampling was activated
|
||||
seed: Optional[int]
|
||||
# Prompt tokens
|
||||
prefill: List[PrefillToken]
|
||||
# Prompt tokens, empty if prefill_details is False
|
||||
prefill: Optional[List[PrefillToken]]
|
||||
# Generated tokens
|
||||
tokens: List[Token]
|
||||
|
||||
@ -165,7 +199,7 @@ class Details:
|
||||
generated_tokens: int
|
||||
# Sampling seed if sampling was activated
|
||||
seed: Optional[int]
|
||||
# Prompt tokens
|
||||
# Prompt tokens, empty if prefill_details is False
|
||||
prefill: List[PrefillToken]
|
||||
# Generated tokens
|
||||
tokens: List[Token]
|
||||
|
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "text-generation"
|
||||
version = "0.5.2"
|
||||
version = "0.6.0"
|
||||
description = "Hugging Face Text Generation Python Client"
|
||||
license = "Apache-2.0"
|
||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||
|
@ -7,7 +7,7 @@ from text_generation.types import FinishReason, PrefillToken, Token
|
||||
|
||||
def test_generate(flan_t5_xxl_url, hf_headers):
|
||||
client = Client(flan_t5_xxl_url, hf_headers)
|
||||
response = client.generate("test", max_new_tokens=1)
|
||||
response = client.generate("test", max_new_tokens=1, prefill_details=True)
|
||||
|
||||
assert response.generated_text == ""
|
||||
assert response.details.finish_reason == FinishReason.Length
|
||||
@ -17,13 +17,15 @@ def test_generate(flan_t5_xxl_url, hf_headers):
|
||||
assert response.details.prefill[0] == PrefillToken(id=0, text="<pad>", logprob=None)
|
||||
assert len(response.details.tokens) == 1
|
||||
assert response.details.tokens[0].id == 3
|
||||
assert response.details.tokens[0].text == ""
|
||||
assert response.details.tokens[0].text == " "
|
||||
assert not response.details.tokens[0].special
|
||||
|
||||
|
||||
def test_generate_best_of(flan_t5_xxl_url, hf_headers):
|
||||
client = Client(flan_t5_xxl_url, hf_headers)
|
||||
response = client.generate("test", max_new_tokens=1, best_of=2, do_sample=True)
|
||||
response = client.generate(
|
||||
"test", max_new_tokens=1, best_of=2, do_sample=True, prefill_details=True
|
||||
)
|
||||
|
||||
assert response.details.seed is not None
|
||||
assert response.details.best_of_sequences is not None
|
||||
@ -73,7 +75,7 @@ def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers):
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_async(flan_t5_xxl_url, hf_headers):
|
||||
client = AsyncClient(flan_t5_xxl_url, hf_headers)
|
||||
response = await client.generate("test", max_new_tokens=1)
|
||||
response = await client.generate("test", max_new_tokens=1, prefill_details=True)
|
||||
|
||||
assert response.generated_text == ""
|
||||
assert response.details.finish_reason == FinishReason.Length
|
||||
@ -83,7 +85,7 @@ async def test_generate_async(flan_t5_xxl_url, hf_headers):
|
||||
assert response.details.prefill[0] == PrefillToken(id=0, text="<pad>", logprob=None)
|
||||
assert len(response.details.tokens) == 1
|
||||
assert response.details.tokens[0].id == 3
|
||||
assert response.details.tokens[0].text == ""
|
||||
assert response.details.tokens[0].text == " "
|
||||
assert not response.details.tokens[0].special
|
||||
|
||||
|
||||
@ -91,7 +93,7 @@ async def test_generate_async(flan_t5_xxl_url, hf_headers):
|
||||
async def test_generate_async_best_of(flan_t5_xxl_url, hf_headers):
|
||||
client = AsyncClient(flan_t5_xxl_url, hf_headers)
|
||||
response = await client.generate(
|
||||
"test", max_new_tokens=1, best_of=2, do_sample=True
|
||||
"test", max_new_tokens=1, best_of=2, do_sample=True, prefill_details=True
|
||||
)
|
||||
|
||||
assert response.details.seed is not None
|
||||
|
@ -74,6 +74,7 @@ class Client:
|
||||
truncate: Optional[int] = None,
|
||||
typical_p: Optional[float] = None,
|
||||
watermark: bool = False,
|
||||
prefill_details: bool = False,
|
||||
) -> Response:
|
||||
"""
|
||||
Given a prompt, generate the following text
|
||||
@ -110,6 +111,8 @@ class Client:
|
||||
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
||||
watermark (`bool`):
|
||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||
prefill_details (`bool`):
|
||||
Return the prefill token log probabilities
|
||||
|
||||
Returns:
|
||||
Response: generated response
|
||||
@ -130,6 +133,7 @@ class Client:
|
||||
truncate=truncate,
|
||||
typical_p=typical_p,
|
||||
watermark=watermark,
|
||||
prefill_details=prefill_details,
|
||||
)
|
||||
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
||||
|
||||
@ -202,6 +206,7 @@ class Client:
|
||||
parameters = Parameters(
|
||||
best_of=None,
|
||||
details=True,
|
||||
prefill_details=False,
|
||||
do_sample=do_sample,
|
||||
max_new_tokens=max_new_tokens,
|
||||
repetition_penalty=repetition_penalty,
|
||||
@ -311,6 +316,7 @@ class AsyncClient:
|
||||
truncate: Optional[int] = None,
|
||||
typical_p: Optional[float] = None,
|
||||
watermark: bool = False,
|
||||
prefill_details: bool = False,
|
||||
) -> Response:
|
||||
"""
|
||||
Given a prompt, generate the following text asynchronously
|
||||
@ -347,6 +353,8 @@ class AsyncClient:
|
||||
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
||||
watermark (`bool`):
|
||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||
prefill_details (`bool`):
|
||||
Return the prefill token log probabilities
|
||||
|
||||
Returns:
|
||||
Response: generated response
|
||||
@ -355,6 +363,7 @@ class AsyncClient:
|
||||
parameters = Parameters(
|
||||
best_of=best_of,
|
||||
details=True,
|
||||
prefill_details=prefill_details,
|
||||
do_sample=do_sample,
|
||||
max_new_tokens=max_new_tokens,
|
||||
repetition_penalty=repetition_penalty,
|
||||
@ -437,6 +446,7 @@ class AsyncClient:
|
||||
parameters = Parameters(
|
||||
best_of=None,
|
||||
details=True,
|
||||
prefill_details=False,
|
||||
do_sample=do_sample,
|
||||
max_new_tokens=max_new_tokens,
|
||||
repetition_penalty=repetition_penalty,
|
||||
|
@ -37,6 +37,8 @@ class Parameters(BaseModel):
|
||||
watermark: bool = False
|
||||
# Get generation details
|
||||
details: bool = False
|
||||
# Get prefill details
|
||||
prefill_details: bool = False
|
||||
|
||||
@validator("best_of")
|
||||
def valid_best_of(cls, field_value, values):
|
||||
@ -173,7 +175,7 @@ class BestOfSequence(BaseModel):
|
||||
generated_tokens: int
|
||||
# Sampling seed if sampling was activated
|
||||
seed: Optional[int]
|
||||
# Prompt tokens
|
||||
# Prompt tokens, empty if prefill_details is False
|
||||
prefill: List[PrefillToken]
|
||||
# Generated tokens
|
||||
tokens: List[Token]
|
||||
@ -187,7 +189,7 @@ class Details(BaseModel):
|
||||
generated_tokens: int
|
||||
# Sampling seed if sampling was activated
|
||||
seed: Optional[int]
|
||||
# Prompt tokens
|
||||
# Prompt tokens, empty if prefill_details is False
|
||||
prefill: List[PrefillToken]
|
||||
# Generated tokens
|
||||
tokens: List[Token]
|
||||
|
@ -87,6 +87,8 @@ message Request {
|
||||
NextTokenChooserParameters parameters = 4;
|
||||
/// Stopping Criteria Parameters
|
||||
StoppingCriteriaParameters stopping_parameters = 5;
|
||||
/// Return prefill logprobs
|
||||
bool prefill_logprobs = 6;
|
||||
}
|
||||
|
||||
message Batch {
|
||||
|
@ -34,6 +34,7 @@ impl Health {
|
||||
id: LIVENESS_ID,
|
||||
inputs: "liveness".to_string(),
|
||||
truncate: 10,
|
||||
prefill_logprobs: false,
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 1.0,
|
||||
top_k: 0,
|
||||
|
@ -125,6 +125,9 @@ pub(crate) struct GenerateParameters {
|
||||
#[schema(default = "true")]
|
||||
pub details: bool,
|
||||
#[serde(default)]
|
||||
#[schema(default = "true")]
|
||||
pub prefill_details: bool,
|
||||
#[serde(default)]
|
||||
#[schema(
|
||||
exclusive_minimum = 0,
|
||||
nullable = true,
|
||||
@ -153,6 +156,7 @@ fn default_parameters() -> GenerateParameters {
|
||||
truncate: None,
|
||||
watermark: false,
|
||||
details: false,
|
||||
prefill_details: false,
|
||||
seed: None,
|
||||
}
|
||||
}
|
||||
|
@ -201,6 +201,7 @@ impl State {
|
||||
|
||||
batch_requests.push(Request {
|
||||
id,
|
||||
prefill_logprobs: entry.request.prefill_details,
|
||||
inputs: entry.request.inputs.clone(),
|
||||
truncate: entry.request.truncate,
|
||||
parameters: Some(entry.request.parameters.clone()),
|
||||
|
@ -364,7 +364,17 @@ async fn generate_stream(
|
||||
let details = req.0.parameters.details;
|
||||
|
||||
let best_of = req.0.parameters.best_of.unwrap_or(1);
|
||||
if best_of == 1 {
|
||||
if best_of != 1 {
|
||||
let err = InferError::from(ValidationError::BestOfStream);
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
tracing::error!("{err}");
|
||||
yield Ok(Event::from(err));
|
||||
} else if req.0.parameters.prefill_details {
|
||||
let err = InferError::from(ValidationError::PrefillDetailsStream);
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
tracing::error!("{err}");
|
||||
yield Ok(Event::from(err));
|
||||
} else {
|
||||
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||
// Keep permit as long as generate_stream lives
|
||||
Ok((_permit, mut response_stream)) => {
|
||||
@ -474,11 +484,6 @@ async fn generate_stream(
|
||||
tracing::error!("{err}");
|
||||
yield Ok(Event::from(err));
|
||||
}
|
||||
} else {
|
||||
let err = InferError::from(ValidationError::BestOfStream);
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
tracing::error!("{err}");
|
||||
yield Ok(Event::from(err));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -145,6 +145,7 @@ impl Validation {
|
||||
truncate,
|
||||
seed,
|
||||
watermark,
|
||||
prefill_details,
|
||||
..
|
||||
} = request.parameters;
|
||||
|
||||
@ -261,6 +262,7 @@ impl Validation {
|
||||
|
||||
Ok(ValidGenerateRequest {
|
||||
inputs,
|
||||
prefill_details,
|
||||
input_length: input_length as u32,
|
||||
truncate: truncate.unwrap_or(self.max_input_length) as u32,
|
||||
parameters,
|
||||
@ -335,6 +337,7 @@ pub(crate) struct ValidGenerateRequest {
|
||||
pub inputs: String,
|
||||
pub input_length: u32,
|
||||
pub truncate: u32,
|
||||
pub prefill_details: bool,
|
||||
pub parameters: NextTokenChooserParameters,
|
||||
pub stopping_parameters: StoppingCriteriaParameters,
|
||||
}
|
||||
@ -351,6 +354,8 @@ pub enum ValidationError {
|
||||
BestOfSeed,
|
||||
#[error("`best_of` != 1 is not supported when streaming tokens")]
|
||||
BestOfStream,
|
||||
#[error("`prefill_details` == true is not supported when streaming tokens")]
|
||||
PrefillDetailsStream,
|
||||
#[error("`temperature` must be strictly positive")]
|
||||
Temperature,
|
||||
#[error("`repetition_penalty` must be strictly positive")]
|
||||
|
@ -104,7 +104,7 @@ class CausalLMBatch(Batch):
|
||||
).to(device)
|
||||
for _ in pb.requests:
|
||||
input_len = tokenized_inputs["input_ids"].shape[1]
|
||||
prefix_offsets.append(0)
|
||||
prefix_offsets.append(input_len - 5)
|
||||
read_offsets.append(input_len)
|
||||
|
||||
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
||||
@ -617,7 +617,7 @@ class CausalLM(Model):
|
||||
generated_text = None
|
||||
|
||||
# Prefill
|
||||
if stopping_criteria.current_tokens == 1:
|
||||
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
|
||||
# Remove generated token to only have prefill and add nan for first prompt token
|
||||
prefill_logprobs = [float("nan")] + torch.log_softmax(
|
||||
logits, -1
|
||||
|
@ -443,6 +443,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
max_s,
|
||||
past_key_values: Optional[torch.Tensor] = None,
|
||||
pre_allocate_past_size: Optional[int] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
):
|
||||
hidden_states, present = self.model(
|
||||
input_ids,
|
||||
@ -453,6 +454,8 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
past_key_values,
|
||||
pre_allocate_past_size,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
if self.model.tp_embeddings:
|
||||
|
@ -481,6 +481,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
||||
max_s,
|
||||
past_key_values: Optional[torch.Tensor] = None,
|
||||
pre_allocate_past_size: Optional[int] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
):
|
||||
hidden_states, present = self.gpt_neox(
|
||||
input_ids,
|
||||
@ -491,6 +492,8 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
||||
past_key_values,
|
||||
pre_allocate_past_size,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits = self.embed_out(hidden_states)
|
||||
|
||||
if self.gpt_neox.tp_embeddings:
|
||||
|
@ -752,6 +752,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
||||
max_s,
|
||||
past_key_values: Optional[torch.Tensor] = None,
|
||||
pre_allocate_past_size: Optional[int] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
):
|
||||
hidden_states, present = self.transformer(
|
||||
input_ids,
|
||||
@ -762,6 +763,8 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
||||
past_key_values,
|
||||
pre_allocate_past_size,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
if self.transformer.tp_embeddings:
|
||||
|
@ -358,6 +358,7 @@ class FlashSantacoderForCausalLM(nn.Module):
|
||||
max_s,
|
||||
past_key_values: Optional[torch.Tensor] = None,
|
||||
pre_allocate_past_size: Optional[int] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
):
|
||||
hidden_states, present = self.transformer(
|
||||
input_ids,
|
||||
@ -368,6 +369,8 @@ class FlashSantacoderForCausalLM(nn.Module):
|
||||
past_key_values,
|
||||
pre_allocate_past_size,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
if self.transformer.tp_embeddings:
|
||||
|
@ -42,6 +42,11 @@ class FlashCausalLMBatch(Batch):
|
||||
past_key_values: Optional[torch.Tensor]
|
||||
max_seqlen: int
|
||||
|
||||
# Prefill metadata tensors to efficiently compute logprobs
|
||||
prefill_head_indices: Optional[torch.Tensor]
|
||||
prefill_next_token_indices: Optional[torch.tensor]
|
||||
prefill_cu_outlens: Optional[List[int]]
|
||||
|
||||
# All tokens
|
||||
all_input_ids: List[List[int]]
|
||||
all_input_ids_tensor: torch.Tensor
|
||||
@ -84,11 +89,18 @@ class FlashCausalLMBatch(Batch):
|
||||
all_input_ids = []
|
||||
requests_idx_mapping = {}
|
||||
|
||||
all_prefill_logprobs = True
|
||||
no_prefill_logprobs = True
|
||||
prefill_head_indices = []
|
||||
prefill_next_token_indices = []
|
||||
prefill_cu_outlens = [0]
|
||||
|
||||
next_token_chooser_parameters = []
|
||||
stopping_criterias = []
|
||||
|
||||
# Cumulative length
|
||||
cumulative_length = 0
|
||||
prefill_out_cumulative_length = 0
|
||||
|
||||
max_tokens = 0
|
||||
max_length = 0
|
||||
@ -106,13 +118,14 @@ class FlashCausalLMBatch(Batch):
|
||||
max_seqlen = max(max_seqlen, input_length)
|
||||
input_lengths.append(input_length)
|
||||
|
||||
prefix_offsets.append(0)
|
||||
prefix_offsets.append(input_length - 5)
|
||||
read_offsets.append(input_length)
|
||||
|
||||
all_input_ids.append(tokenized_input)
|
||||
|
||||
# Position ids
|
||||
position_ids.append(np.arange(0, input_length))
|
||||
request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
|
||||
position_ids.append(request_position_ids)
|
||||
|
||||
# Add cumulative lengths of all previous inputs
|
||||
cu_seqlens.append(cumulative_length + input_length)
|
||||
@ -125,6 +138,26 @@ class FlashCausalLMBatch(Batch):
|
||||
max_new_tokens = stopping_criteria.max_new_tokens
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
|
||||
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
|
||||
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
|
||||
|
||||
if r.prefill_logprobs:
|
||||
prefill_head_indices.append(request_position_ids + cumulative_length)
|
||||
prefill_next_token_indices.append(
|
||||
prefill_out_cumulative_length + input_length - 1
|
||||
)
|
||||
prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
|
||||
prefill_out_cumulative_length += input_length
|
||||
else:
|
||||
prefill_head_indices.append(
|
||||
torch.tensor(
|
||||
[cumulative_length + input_length - 1], dtype=torch.int32
|
||||
)
|
||||
)
|
||||
prefill_next_token_indices.append(prefill_out_cumulative_length)
|
||||
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
|
||||
prefill_out_cumulative_length += 1
|
||||
|
||||
# Update
|
||||
cumulative_length += input_length
|
||||
max_tokens += input_length + max_new_tokens
|
||||
@ -141,18 +174,35 @@ class FlashCausalLMBatch(Batch):
|
||||
for i, input_ids in enumerate(all_input_ids):
|
||||
all_input_ids_tensor[i, : len(input_ids)] = input_ids
|
||||
|
||||
if len(pb.requests) > 1:
|
||||
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
|
||||
position_ids = torch.cat(position_ids)
|
||||
else:
|
||||
input_ids = all_input_ids[0]
|
||||
position_ids = position_ids[0]
|
||||
|
||||
# Create tensors on device
|
||||
input_ids = torch.tensor(
|
||||
np.concatenate(all_input_ids), dtype=torch.int64, device=device
|
||||
)
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||
all_input_ids_tensor = torch.tensor(
|
||||
all_input_ids_tensor, dtype=torch.int64, device=device
|
||||
)
|
||||
position_ids = torch.tensor(
|
||||
np.concatenate(position_ids), dtype=torch.int32, device=device
|
||||
)
|
||||
position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device)
|
||||
cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32)
|
||||
|
||||
if all_prefill_logprobs:
|
||||
prefill_head_indices = None
|
||||
prefill_next_token_indices = cu_seqlens[1:] - 1
|
||||
elif no_prefill_logprobs:
|
||||
prefill_head_indices = cu_seqlens[1:] - 1
|
||||
prefill_next_token_indices = None
|
||||
else:
|
||||
prefill_head_indices = torch.tensor(
|
||||
torch.cat(prefill_head_indices), dtype=torch.int64, device=device
|
||||
)
|
||||
prefill_next_token_indices = torch.tensor(
|
||||
prefill_next_token_indices, dtype=torch.int64, device=device
|
||||
)
|
||||
|
||||
return cls(
|
||||
batch_id=pb.id,
|
||||
requests=pb.requests,
|
||||
@ -162,6 +212,9 @@ class FlashCausalLMBatch(Batch):
|
||||
cu_seqlens=cu_seqlens,
|
||||
cu_seqlens_q=None,
|
||||
max_seqlen=max_seqlen,
|
||||
prefill_head_indices=prefill_head_indices,
|
||||
prefill_next_token_indices=prefill_next_token_indices,
|
||||
prefill_cu_outlens=prefill_cu_outlens,
|
||||
past_key_values=None,
|
||||
input_lengths=input_lengths,
|
||||
prefix_offsets=prefix_offsets,
|
||||
@ -280,6 +333,9 @@ class FlashCausalLMBatch(Batch):
|
||||
cu_seqlens=cu_seqlens,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
max_seqlen=max_seqlen,
|
||||
prefill_head_indices=None,
|
||||
prefill_next_token_indices=None,
|
||||
prefill_cu_outlens=None,
|
||||
past_key_values=past_key_values,
|
||||
input_lengths=input_lengths,
|
||||
prefix_offsets=prefix_offsets,
|
||||
@ -415,6 +471,9 @@ class FlashCausalLMBatch(Batch):
|
||||
cu_seqlens=cu_seqlens,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
max_seqlen=max_seqlen,
|
||||
prefill_head_indices=None,
|
||||
prefill_next_token_indices=None,
|
||||
prefill_cu_outlens=None,
|
||||
past_key_values=past_key_values,
|
||||
input_lengths=input_lengths,
|
||||
prefix_offsets=prefix_offsets,
|
||||
@ -486,6 +545,7 @@ class FlashCausalLM(Model):
|
||||
max_s: int,
|
||||
past_key_values: Optional = None,
|
||||
pre_allocate_past_size: Optional[int] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Model Forward
|
||||
return self.model.forward(
|
||||
@ -496,6 +556,7 @@ class FlashCausalLM(Model):
|
||||
max_s=max_s,
|
||||
past_key_values=past_key_values,
|
||||
pre_allocate_past_size=pre_allocate_past_size,
|
||||
lm_head_indices=lm_head_indices,
|
||||
)
|
||||
|
||||
@tracer.start_as_current_span("generate_token")
|
||||
@ -503,9 +564,10 @@ class FlashCausalLM(Model):
|
||||
self, batch: FlashCausalLMBatch
|
||||
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
|
||||
prefill = batch.past_key_values is None
|
||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||
single_request = len(batch) == 1
|
||||
|
||||
if prefill and len(batch) == 1:
|
||||
if prefill and single_request:
|
||||
# Ask to pre-allocate kv to its max size
|
||||
# == number of tokens + max_new_tokens
|
||||
pre_allocate_past_size = (
|
||||
@ -522,11 +584,12 @@ class FlashCausalLM(Model):
|
||||
batch.max_seqlen,
|
||||
batch.past_key_values,
|
||||
pre_allocate_past_size,
|
||||
batch.prefill_head_indices,
|
||||
)
|
||||
|
||||
if prefill:
|
||||
next_token_logits = (
|
||||
out[-1:] if single_request else out[batch.cu_seqlens[1:] - 1]
|
||||
out[batch.prefill_next_token_indices] if prefill_logprobs else out
|
||||
)
|
||||
else:
|
||||
next_token_logits = out
|
||||
@ -536,10 +599,10 @@ class FlashCausalLM(Model):
|
||||
)
|
||||
|
||||
if prefill:
|
||||
if len(batch) > 1:
|
||||
if len(batch) > 1 and prefill_logprobs:
|
||||
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
||||
# When batch == 1, we will just use the batch.input_ids values directly
|
||||
prefill_tokens_indices = batch.input_ids.new_zeros(len(batch.input_ids))
|
||||
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
|
||||
|
||||
# Create batch.cu_seqlens_q for decode
|
||||
batch.cu_seqlens_q = torch.arange(
|
||||
@ -600,7 +663,6 @@ class FlashCausalLM(Model):
|
||||
# Zipped iterator
|
||||
iterator = zip(
|
||||
batch.input_lengths,
|
||||
batch.stopping_criterias,
|
||||
batch.all_input_ids,
|
||||
)
|
||||
|
||||
@ -611,29 +673,33 @@ class FlashCausalLM(Model):
|
||||
# For each member of the batch
|
||||
for i, (
|
||||
input_length,
|
||||
stopping_criteria,
|
||||
all_input_ids,
|
||||
) in enumerate(iterator):
|
||||
# Indexing metadata
|
||||
start_index = cumulative_length
|
||||
end_index = cumulative_length + input_length
|
||||
|
||||
if prefill:
|
||||
# Indexing metadata
|
||||
out_start_index = batch.prefill_cu_outlens[i]
|
||||
out_end_index = batch.prefill_cu_outlens[i + 1]
|
||||
out_length = out_end_index - out_start_index
|
||||
|
||||
# Initialize position_ids
|
||||
# In decode, we do not need this as we can just increment position ids
|
||||
next_position_ids[i] = batch.position_ids[end_index - 1]
|
||||
|
||||
# Used to gather prefill logprobs
|
||||
# Copy batch.input_ids to prefill_token_indices
|
||||
if len(batch) > 1:
|
||||
prefill_tokens_indices[
|
||||
start_index : end_index - 1
|
||||
] = batch.input_ids[start_index + 1 : end_index]
|
||||
else:
|
||||
# Set prefill_tokens_indices to the correct slice
|
||||
prefill_tokens_indices = batch.input_ids[
|
||||
start_index + 1 : end_index
|
||||
]
|
||||
if prefill_logprobs:
|
||||
if len(batch) > 1:
|
||||
prefill_tokens_indices[
|
||||
out_start_index : out_end_index - 1
|
||||
] = batch.input_ids[start_index + 1 : start_index + out_length]
|
||||
else:
|
||||
# Set prefill_tokens_indices to the correct slice
|
||||
prefill_tokens_indices = batch.input_ids[
|
||||
start_index + 1 : start_index + out_length
|
||||
]
|
||||
|
||||
batch.all_input_ids_tensor[i, input_length] = next_input_ids[i]
|
||||
|
||||
@ -644,7 +710,7 @@ class FlashCausalLM(Model):
|
||||
batch.position_ids = next_position_ids + 1
|
||||
batch.cu_seqlens = batch.cu_seqlens + batch.cu_seqlens_q
|
||||
|
||||
if prefill:
|
||||
if prefill and prefill_logprobs:
|
||||
# Get prefill logprobs
|
||||
prefill_logprobs_tensor = torch.log_softmax(out, -1)
|
||||
prefill_logprobs = torch.gather(
|
||||
@ -657,8 +723,6 @@ class FlashCausalLM(Model):
|
||||
next_token_logprobs = next_token_logprobs.tolist()
|
||||
next_token_ids = batch.input_ids.tolist()
|
||||
|
||||
cumulative_length = 0
|
||||
|
||||
# Zipped iterator
|
||||
iterator = zip(
|
||||
batch.requests,
|
||||
@ -688,9 +752,6 @@ class FlashCausalLM(Model):
|
||||
next_token_id,
|
||||
next_token_logprob,
|
||||
) in enumerate(iterator):
|
||||
start_index = cumulative_length
|
||||
end_index = cumulative_length + input_length
|
||||
|
||||
# Append next token to all tokens
|
||||
all_input_ids.append(next_token_id)
|
||||
|
||||
@ -728,10 +789,13 @@ class FlashCausalLM(Model):
|
||||
generated_text = None
|
||||
|
||||
# Prefill
|
||||
if prefill:
|
||||
if prefill and request.prefill_logprobs:
|
||||
out_start_index = batch.prefill_cu_outlens[i]
|
||||
out_end_index = batch.prefill_cu_outlens[i + 1]
|
||||
|
||||
# Remove generated token to only have prefill and add nan for first prompt token
|
||||
request_prefill_logprobs = [float("nan")] + prefill_logprobs[
|
||||
start_index : end_index - 1
|
||||
out_start_index : out_end_index - 1
|
||||
]
|
||||
prefill_token_ids = all_input_ids[:-1]
|
||||
prefill_texts = self.tokenizer.batch_decode(
|
||||
@ -764,8 +828,10 @@ class FlashCausalLM(Model):
|
||||
batch.prefix_offsets[i] = prefix_offset
|
||||
batch.read_offsets[i] = read_offset
|
||||
batch.all_input_ids[i] = all_input_ids
|
||||
cumulative_length += input_length
|
||||
|
||||
batch.prefill_cu_outlens = None
|
||||
batch.prefill_head_indices = None
|
||||
batch.prefill_next_token_indices = None
|
||||
batch.max_seqlen = batch.max_seqlen + 1
|
||||
|
||||
# No need to return a batch if we know that all requests stopped
|
||||
|
@ -688,7 +688,7 @@ class Seq2SeqLM(Model):
|
||||
generated_text = None
|
||||
|
||||
# Prefill
|
||||
if stopping_criteria.current_tokens == 1:
|
||||
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
|
||||
prefill_tokens = PrefillTokens(
|
||||
[self.tokenizer.bos_token_id],
|
||||
[float("nan")],
|
||||
|
Loading…
Reference in New Issue
Block a user