From 1dd0cf63df8741d8aeae3436cdd293bc417fd7de Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Fri, 2 Jun 2023 15:30:35 +0200 Subject: [PATCH] feat(server): only compute prefill logprobs when asked --- benchmark/src/generation.rs | 1 + clients/python/README.md | 40 +++++- clients/python/pyproject.toml | 2 +- clients/python/tests/test_client.py | 14 +- clients/python/text_generation/client.py | 10 ++ clients/python/text_generation/types.py | 6 +- proto/generate.proto | 2 + router/src/health.rs | 1 + router/src/lib.rs | 4 + router/src/queue.rs | 1 + router/src/server.rs | 17 ++- router/src/validation.rs | 5 + .../models/causal_lm.py | 4 +- .../custom_modeling/flash_llama_modeling.py | 3 + .../custom_modeling/flash_neox_modeling.py | 3 + .../custom_modeling/flash_rw_modeling.py | 3 + .../flash_santacoder_modeling.py | 3 + .../models/flash_causal_lm.py | 132 +++++++++++++----- .../models/seq2seq_lm.py | 2 +- 19 files changed, 199 insertions(+), 54 deletions(-) diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index 17c72d26..b57c652b 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -136,6 +136,7 @@ async fn prefill( let requests = (0..batch_size) .map(|id| Request { id: id.into(), + prefill_logprobs: false, inputs: sequence.clone(), truncate: sequence_length, parameters: Some(parameters.clone()), diff --git a/clients/python/README.md b/clients/python/README.md index 99ff185a..49a5182d 100644 --- a/clients/python/README.md +++ b/clients/python/README.md @@ -107,6 +107,40 @@ print(text) ### Types ```python +# Request Parameters +class Parameters: + # Activate logits sampling + do_sample: bool + # Maximum number of generated tokens + max_new_tokens: int + # The parameter for repetition penalty. 1.0 means no penalty. + # See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + repetition_penalty: Optional[float] + # Whether to prepend the prompt to the generated text + return_full_text: bool + # Stop generating tokens if a member of `stop_sequences` is generated + stop: List[str] + # Random sampling seed + seed: Optional[int] + # The value used to module the logits distribution. + temperature: Optional[float] + # The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_k: Optional[int] + # If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + # higher are kept for generation. + top_p: Optional[float] + # truncate inputs tokens to the given size + truncate: Optional[int] + # Typical Decoding mass + # See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information + typical_p: Optional[float] + # Generate best_of sequences and return the one if the highest token logprobs + best_of: Optional[int] + # Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) + watermark: bool + # Get prompt token logprobs and ids + prefill_details: bool + # Prompt tokens class PrefillToken: # Token ID from the model tokenizer @@ -151,8 +185,8 @@ class BestOfSequence: generated_tokens: int # Sampling seed if sampling was activated seed: Optional[int] - # Prompt tokens - prefill: List[PrefillToken] + # Prompt tokens, empty if prefill_details is False + prefill: Optional[List[PrefillToken]] # Generated tokens tokens: List[Token] @@ -165,7 +199,7 @@ class Details: generated_tokens: int # Sampling seed if sampling was activated seed: Optional[int] - # Prompt tokens + # Prompt tokens, empty if prefill_details is False prefill: List[PrefillToken] # Generated tokens tokens: List[Token] diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 06d5f9cb..a52bdd81 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "text-generation" -version = "0.5.2" +version = "0.6.0" description = "Hugging Face Text Generation Python Client" license = "Apache-2.0" authors = ["Olivier Dehaene "] diff --git a/clients/python/tests/test_client.py b/clients/python/tests/test_client.py index 32462f14..10f0a825 100644 --- a/clients/python/tests/test_client.py +++ b/clients/python/tests/test_client.py @@ -7,7 +7,7 @@ from text_generation.types import FinishReason, PrefillToken, Token def test_generate(flan_t5_xxl_url, hf_headers): client = Client(flan_t5_xxl_url, hf_headers) - response = client.generate("test", max_new_tokens=1) + response = client.generate("test", max_new_tokens=1, prefill_details=True) assert response.generated_text == "" assert response.details.finish_reason == FinishReason.Length @@ -17,13 +17,15 @@ def test_generate(flan_t5_xxl_url, hf_headers): assert response.details.prefill[0] == PrefillToken(id=0, text="", logprob=None) assert len(response.details.tokens) == 1 assert response.details.tokens[0].id == 3 - assert response.details.tokens[0].text == "" + assert response.details.tokens[0].text == " " assert not response.details.tokens[0].special def test_generate_best_of(flan_t5_xxl_url, hf_headers): client = Client(flan_t5_xxl_url, hf_headers) - response = client.generate("test", max_new_tokens=1, best_of=2, do_sample=True) + response = client.generate( + "test", max_new_tokens=1, best_of=2, do_sample=True, prefill_details=True + ) assert response.details.seed is not None assert response.details.best_of_sequences is not None @@ -73,7 +75,7 @@ def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers): @pytest.mark.asyncio async def test_generate_async(flan_t5_xxl_url, hf_headers): client = AsyncClient(flan_t5_xxl_url, hf_headers) - response = await client.generate("test", max_new_tokens=1) + response = await client.generate("test", max_new_tokens=1, prefill_details=True) assert response.generated_text == "" assert response.details.finish_reason == FinishReason.Length @@ -83,7 +85,7 @@ async def test_generate_async(flan_t5_xxl_url, hf_headers): assert response.details.prefill[0] == PrefillToken(id=0, text="", logprob=None) assert len(response.details.tokens) == 1 assert response.details.tokens[0].id == 3 - assert response.details.tokens[0].text == "" + assert response.details.tokens[0].text == " " assert not response.details.tokens[0].special @@ -91,7 +93,7 @@ async def test_generate_async(flan_t5_xxl_url, hf_headers): async def test_generate_async_best_of(flan_t5_xxl_url, hf_headers): client = AsyncClient(flan_t5_xxl_url, hf_headers) response = await client.generate( - "test", max_new_tokens=1, best_of=2, do_sample=True + "test", max_new_tokens=1, best_of=2, do_sample=True, prefill_details=True ) assert response.details.seed is not None diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index 8b8742fc..be85d26b 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -74,6 +74,7 @@ class Client: truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: bool = False, + prefill_details: bool = False, ) -> Response: """ Given a prompt, generate the following text @@ -110,6 +111,8 @@ class Client: See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) + prefill_details (`bool`): + Return the prefill token log probabilities Returns: Response: generated response @@ -130,6 +133,7 @@ class Client: truncate=truncate, typical_p=typical_p, watermark=watermark, + prefill_details=prefill_details, ) request = Request(inputs=prompt, stream=False, parameters=parameters) @@ -202,6 +206,7 @@ class Client: parameters = Parameters( best_of=None, details=True, + prefill_details=False, do_sample=do_sample, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, @@ -311,6 +316,7 @@ class AsyncClient: truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: bool = False, + prefill_details: bool = False, ) -> Response: """ Given a prompt, generate the following text asynchronously @@ -347,6 +353,8 @@ class AsyncClient: See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) + prefill_details (`bool`): + Return the prefill token log probabilities Returns: Response: generated response @@ -355,6 +363,7 @@ class AsyncClient: parameters = Parameters( best_of=best_of, details=True, + prefill_details=prefill_details, do_sample=do_sample, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, @@ -437,6 +446,7 @@ class AsyncClient: parameters = Parameters( best_of=None, details=True, + prefill_details=False, do_sample=do_sample, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index ad3cd09b..2f78c033 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -37,6 +37,8 @@ class Parameters(BaseModel): watermark: bool = False # Get generation details details: bool = False + # Get prefill details + prefill_details: bool = False @validator("best_of") def valid_best_of(cls, field_value, values): @@ -173,7 +175,7 @@ class BestOfSequence(BaseModel): generated_tokens: int # Sampling seed if sampling was activated seed: Optional[int] - # Prompt tokens + # Prompt tokens, empty if prefill_details is False prefill: List[PrefillToken] # Generated tokens tokens: List[Token] @@ -187,7 +189,7 @@ class Details(BaseModel): generated_tokens: int # Sampling seed if sampling was activated seed: Optional[int] - # Prompt tokens + # Prompt tokens, empty if prefill_details is False prefill: List[PrefillToken] # Generated tokens tokens: List[Token] diff --git a/proto/generate.proto b/proto/generate.proto index 0c40e5bb..a0f5a75e 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -87,6 +87,8 @@ message Request { NextTokenChooserParameters parameters = 4; /// Stopping Criteria Parameters StoppingCriteriaParameters stopping_parameters = 5; + /// Return prefill logprobs + bool prefill_logprobs = 6; } message Batch { diff --git a/router/src/health.rs b/router/src/health.rs index 45f50e9d..a3cacdcd 100644 --- a/router/src/health.rs +++ b/router/src/health.rs @@ -34,6 +34,7 @@ impl Health { id: LIVENESS_ID, inputs: "liveness".to_string(), truncate: 10, + prefill_logprobs: false, parameters: Some(NextTokenChooserParameters { temperature: 1.0, top_k: 0, diff --git a/router/src/lib.rs b/router/src/lib.rs index 080dc4f4..4efe66ce 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -125,6 +125,9 @@ pub(crate) struct GenerateParameters { #[schema(default = "true")] pub details: bool, #[serde(default)] + #[schema(default = "true")] + pub prefill_details: bool, + #[serde(default)] #[schema( exclusive_minimum = 0, nullable = true, @@ -153,6 +156,7 @@ fn default_parameters() -> GenerateParameters { truncate: None, watermark: false, details: false, + prefill_details: false, seed: None, } } diff --git a/router/src/queue.rs b/router/src/queue.rs index 94851e1c..b8470ebe 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -201,6 +201,7 @@ impl State { batch_requests.push(Request { id, + prefill_logprobs: entry.request.prefill_details, inputs: entry.request.inputs.clone(), truncate: entry.request.truncate, parameters: Some(entry.request.parameters.clone()), diff --git a/router/src/server.rs b/router/src/server.rs index fd6a66bb..f0f205c5 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -364,7 +364,17 @@ async fn generate_stream( let details = req.0.parameters.details; let best_of = req.0.parameters.best_of.unwrap_or(1); - if best_of == 1 { + if best_of != 1 { + let err = InferError::from(ValidationError::BestOfStream); + metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + tracing::error!("{err}"); + yield Ok(Event::from(err)); + } else if req.0.parameters.prefill_details { + let err = InferError::from(ValidationError::PrefillDetailsStream); + metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + tracing::error!("{err}"); + yield Ok(Event::from(err)); + } else { match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await { // Keep permit as long as generate_stream lives Ok((_permit, mut response_stream)) => { @@ -474,11 +484,6 @@ async fn generate_stream( tracing::error!("{err}"); yield Ok(Event::from(err)); } - } else { - let err = InferError::from(ValidationError::BestOfStream); - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); - tracing::error!("{err}"); - yield Ok(Event::from(err)); } }; diff --git a/router/src/validation.rs b/router/src/validation.rs index cbb0d9cd..63bf78a3 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -145,6 +145,7 @@ impl Validation { truncate, seed, watermark, + prefill_details, .. } = request.parameters; @@ -261,6 +262,7 @@ impl Validation { Ok(ValidGenerateRequest { inputs, + prefill_details, input_length: input_length as u32, truncate: truncate.unwrap_or(self.max_input_length) as u32, parameters, @@ -335,6 +337,7 @@ pub(crate) struct ValidGenerateRequest { pub inputs: String, pub input_length: u32, pub truncate: u32, + pub prefill_details: bool, pub parameters: NextTokenChooserParameters, pub stopping_parameters: StoppingCriteriaParameters, } @@ -351,6 +354,8 @@ pub enum ValidationError { BestOfSeed, #[error("`best_of` != 1 is not supported when streaming tokens")] BestOfStream, + #[error("`prefill_details` == true is not supported when streaming tokens")] + PrefillDetailsStream, #[error("`temperature` must be strictly positive")] Temperature, #[error("`repetition_penalty` must be strictly positive")] diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 92622350..ba0853f5 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -104,7 +104,7 @@ class CausalLMBatch(Batch): ).to(device) for _ in pb.requests: input_len = tokenized_inputs["input_ids"].shape[1] - prefix_offsets.append(0) + prefix_offsets.append(input_len - 5) read_offsets.append(input_len) input_lengths = tokenized_inputs["attention_mask"].sum(1) @@ -617,7 +617,7 @@ class CausalLM(Model): generated_text = None # Prefill - if stopping_criteria.current_tokens == 1: + if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: # Remove generated token to only have prefill and add nan for first prompt token prefill_logprobs = [float("nan")] + torch.log_softmax( logits, -1 diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 2dcb6ed8..f4116937 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -443,6 +443,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): max_s, past_key_values: Optional[torch.Tensor] = None, pre_allocate_past_size: Optional[int] = None, + lm_head_indices: Optional[torch.Tensor] = None, ): hidden_states, present = self.model( input_ids, @@ -453,6 +454,8 @@ class FlashLlamaForCausalLM(torch.nn.Module): past_key_values, pre_allocate_past_size, ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] logits = self.lm_head(hidden_states) if self.model.tp_embeddings: diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 26e21753..b798750a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -481,6 +481,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): max_s, past_key_values: Optional[torch.Tensor] = None, pre_allocate_past_size: Optional[int] = None, + lm_head_indices: Optional[torch.Tensor] = None, ): hidden_states, present = self.gpt_neox( input_ids, @@ -491,6 +492,8 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): past_key_values, pre_allocate_past_size, ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] logits = self.embed_out(hidden_states) if self.gpt_neox.tp_embeddings: diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 545da26a..03487703 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -752,6 +752,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): max_s, past_key_values: Optional[torch.Tensor] = None, pre_allocate_past_size: Optional[int] = None, + lm_head_indices: Optional[torch.Tensor] = None, ): hidden_states, present = self.transformer( input_ids, @@ -762,6 +763,8 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): past_key_values, pre_allocate_past_size, ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] logits = self.lm_head(hidden_states) if self.transformer.tp_embeddings: diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 9bded805..b61ec873 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -358,6 +358,7 @@ class FlashSantacoderForCausalLM(nn.Module): max_s, past_key_values: Optional[torch.Tensor] = None, pre_allocate_past_size: Optional[int] = None, + lm_head_indices: Optional[torch.Tensor] = None, ): hidden_states, present = self.transformer( input_ids, @@ -368,6 +369,8 @@ class FlashSantacoderForCausalLM(nn.Module): past_key_values, pre_allocate_past_size, ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] logits = self.lm_head(hidden_states) if self.transformer.tp_embeddings: diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 35cbe174..5ff951b3 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -42,6 +42,11 @@ class FlashCausalLMBatch(Batch): past_key_values: Optional[torch.Tensor] max_seqlen: int + # Prefill metadata tensors to efficiently compute logprobs + prefill_head_indices: Optional[torch.Tensor] + prefill_next_token_indices: Optional[torch.tensor] + prefill_cu_outlens: Optional[List[int]] + # All tokens all_input_ids: List[List[int]] all_input_ids_tensor: torch.Tensor @@ -84,11 +89,18 @@ class FlashCausalLMBatch(Batch): all_input_ids = [] requests_idx_mapping = {} + all_prefill_logprobs = True + no_prefill_logprobs = True + prefill_head_indices = [] + prefill_next_token_indices = [] + prefill_cu_outlens = [0] + next_token_chooser_parameters = [] stopping_criterias = [] # Cumulative length cumulative_length = 0 + prefill_out_cumulative_length = 0 max_tokens = 0 max_length = 0 @@ -106,13 +118,14 @@ class FlashCausalLMBatch(Batch): max_seqlen = max(max_seqlen, input_length) input_lengths.append(input_length) - prefix_offsets.append(0) + prefix_offsets.append(input_length - 5) read_offsets.append(input_length) all_input_ids.append(tokenized_input) # Position ids - position_ids.append(np.arange(0, input_length)) + request_position_ids = torch.arange(0, input_length, dtype=torch.int32) + position_ids.append(request_position_ids) # Add cumulative lengths of all previous inputs cu_seqlens.append(cumulative_length + input_length) @@ -125,6 +138,26 @@ class FlashCausalLMBatch(Batch): max_new_tokens = stopping_criteria.max_new_tokens stopping_criterias.append(stopping_criteria) + all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs + no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs + + if r.prefill_logprobs: + prefill_head_indices.append(request_position_ids + cumulative_length) + prefill_next_token_indices.append( + prefill_out_cumulative_length + input_length - 1 + ) + prefill_cu_outlens.append(prefill_out_cumulative_length + input_length) + prefill_out_cumulative_length += input_length + else: + prefill_head_indices.append( + torch.tensor( + [cumulative_length + input_length - 1], dtype=torch.int32 + ) + ) + prefill_next_token_indices.append(prefill_out_cumulative_length) + prefill_cu_outlens.append(prefill_out_cumulative_length + 1) + prefill_out_cumulative_length += 1 + # Update cumulative_length += input_length max_tokens += input_length + max_new_tokens @@ -141,18 +174,35 @@ class FlashCausalLMBatch(Batch): for i, input_ids in enumerate(all_input_ids): all_input_ids_tensor[i, : len(input_ids)] = input_ids + if len(pb.requests) > 1: + input_ids = np.concatenate(all_input_ids, dtype=np.int64) + position_ids = torch.cat(position_ids) + else: + input_ids = all_input_ids[0] + position_ids = position_ids[0] + # Create tensors on device - input_ids = torch.tensor( - np.concatenate(all_input_ids), dtype=torch.int64, device=device - ) + input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) all_input_ids_tensor = torch.tensor( all_input_ids_tensor, dtype=torch.int64, device=device ) - position_ids = torch.tensor( - np.concatenate(position_ids), dtype=torch.int32, device=device - ) + position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device) cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32) + if all_prefill_logprobs: + prefill_head_indices = None + prefill_next_token_indices = cu_seqlens[1:] - 1 + elif no_prefill_logprobs: + prefill_head_indices = cu_seqlens[1:] - 1 + prefill_next_token_indices = None + else: + prefill_head_indices = torch.tensor( + torch.cat(prefill_head_indices), dtype=torch.int64, device=device + ) + prefill_next_token_indices = torch.tensor( + prefill_next_token_indices, dtype=torch.int64, device=device + ) + return cls( batch_id=pb.id, requests=pb.requests, @@ -162,6 +212,9 @@ class FlashCausalLMBatch(Batch): cu_seqlens=cu_seqlens, cu_seqlens_q=None, max_seqlen=max_seqlen, + prefill_head_indices=prefill_head_indices, + prefill_next_token_indices=prefill_next_token_indices, + prefill_cu_outlens=prefill_cu_outlens, past_key_values=None, input_lengths=input_lengths, prefix_offsets=prefix_offsets, @@ -280,6 +333,9 @@ class FlashCausalLMBatch(Batch): cu_seqlens=cu_seqlens, cu_seqlens_q=cu_seqlens_q, max_seqlen=max_seqlen, + prefill_head_indices=None, + prefill_next_token_indices=None, + prefill_cu_outlens=None, past_key_values=past_key_values, input_lengths=input_lengths, prefix_offsets=prefix_offsets, @@ -415,6 +471,9 @@ class FlashCausalLMBatch(Batch): cu_seqlens=cu_seqlens, cu_seqlens_q=cu_seqlens_q, max_seqlen=max_seqlen, + prefill_head_indices=None, + prefill_next_token_indices=None, + prefill_cu_outlens=None, past_key_values=past_key_values, input_lengths=input_lengths, prefix_offsets=prefix_offsets, @@ -486,6 +545,7 @@ class FlashCausalLM(Model): max_s: int, past_key_values: Optional = None, pre_allocate_past_size: Optional[int] = None, + lm_head_indices: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # Model Forward return self.model.forward( @@ -496,6 +556,7 @@ class FlashCausalLM(Model): max_s=max_s, past_key_values=past_key_values, pre_allocate_past_size=pre_allocate_past_size, + lm_head_indices=lm_head_indices, ) @tracer.start_as_current_span("generate_token") @@ -503,9 +564,10 @@ class FlashCausalLM(Model): self, batch: FlashCausalLMBatch ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: prefill = batch.past_key_values is None + prefill_logprobs = batch.prefill_next_token_indices is not None single_request = len(batch) == 1 - if prefill and len(batch) == 1: + if prefill and single_request: # Ask to pre-allocate kv to its max size # == number of tokens + max_new_tokens pre_allocate_past_size = ( @@ -522,11 +584,12 @@ class FlashCausalLM(Model): batch.max_seqlen, batch.past_key_values, pre_allocate_past_size, + batch.prefill_head_indices, ) if prefill: next_token_logits = ( - out[-1:] if single_request else out[batch.cu_seqlens[1:] - 1] + out[batch.prefill_next_token_indices] if prefill_logprobs else out ) else: next_token_logits = out @@ -536,10 +599,10 @@ class FlashCausalLM(Model): ) if prefill: - if len(batch) > 1: + if len(batch) > 1 and prefill_logprobs: # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs # When batch == 1, we will just use the batch.input_ids values directly - prefill_tokens_indices = batch.input_ids.new_zeros(len(batch.input_ids)) + prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) # Create batch.cu_seqlens_q for decode batch.cu_seqlens_q = torch.arange( @@ -600,7 +663,6 @@ class FlashCausalLM(Model): # Zipped iterator iterator = zip( batch.input_lengths, - batch.stopping_criterias, batch.all_input_ids, ) @@ -611,29 +673,33 @@ class FlashCausalLM(Model): # For each member of the batch for i, ( input_length, - stopping_criteria, all_input_ids, ) in enumerate(iterator): - # Indexing metadata start_index = cumulative_length end_index = cumulative_length + input_length if prefill: + # Indexing metadata + out_start_index = batch.prefill_cu_outlens[i] + out_end_index = batch.prefill_cu_outlens[i + 1] + out_length = out_end_index - out_start_index + # Initialize position_ids # In decode, we do not need this as we can just increment position ids next_position_ids[i] = batch.position_ids[end_index - 1] # Used to gather prefill logprobs # Copy batch.input_ids to prefill_token_indices - if len(batch) > 1: - prefill_tokens_indices[ - start_index : end_index - 1 - ] = batch.input_ids[start_index + 1 : end_index] - else: - # Set prefill_tokens_indices to the correct slice - prefill_tokens_indices = batch.input_ids[ - start_index + 1 : end_index - ] + if prefill_logprobs: + if len(batch) > 1: + prefill_tokens_indices[ + out_start_index : out_end_index - 1 + ] = batch.input_ids[start_index + 1 : start_index + out_length] + else: + # Set prefill_tokens_indices to the correct slice + prefill_tokens_indices = batch.input_ids[ + start_index + 1 : start_index + out_length + ] batch.all_input_ids_tensor[i, input_length] = next_input_ids[i] @@ -644,7 +710,7 @@ class FlashCausalLM(Model): batch.position_ids = next_position_ids + 1 batch.cu_seqlens = batch.cu_seqlens + batch.cu_seqlens_q - if prefill: + if prefill and prefill_logprobs: # Get prefill logprobs prefill_logprobs_tensor = torch.log_softmax(out, -1) prefill_logprobs = torch.gather( @@ -657,8 +723,6 @@ class FlashCausalLM(Model): next_token_logprobs = next_token_logprobs.tolist() next_token_ids = batch.input_ids.tolist() - cumulative_length = 0 - # Zipped iterator iterator = zip( batch.requests, @@ -688,9 +752,6 @@ class FlashCausalLM(Model): next_token_id, next_token_logprob, ) in enumerate(iterator): - start_index = cumulative_length - end_index = cumulative_length + input_length - # Append next token to all tokens all_input_ids.append(next_token_id) @@ -728,10 +789,13 @@ class FlashCausalLM(Model): generated_text = None # Prefill - if prefill: + if prefill and request.prefill_logprobs: + out_start_index = batch.prefill_cu_outlens[i] + out_end_index = batch.prefill_cu_outlens[i + 1] + # Remove generated token to only have prefill and add nan for first prompt token request_prefill_logprobs = [float("nan")] + prefill_logprobs[ - start_index : end_index - 1 + out_start_index : out_end_index - 1 ] prefill_token_ids = all_input_ids[:-1] prefill_texts = self.tokenizer.batch_decode( @@ -764,8 +828,10 @@ class FlashCausalLM(Model): batch.prefix_offsets[i] = prefix_offset batch.read_offsets[i] = read_offset batch.all_input_ids[i] = all_input_ids - cumulative_length += input_length + batch.prefill_cu_outlens = None + batch.prefill_head_indices = None + batch.prefill_next_token_indices = None batch.max_seqlen = batch.max_seqlen + 1 # No need to return a batch if we know that all requests stopped diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 68e59dc3..3ad5698c 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -688,7 +688,7 @@ class Seq2SeqLM(Model): generated_text = None # Prefill - if stopping_criteria.current_tokens == 1: + if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: prefill_tokens = PrefillTokens( [self.tokenizer.bos_token_id], [float("nan")],