From 2cf1f5c00e074b8b557bc1750e80581ebe23908c Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 27 Aug 2024 20:02:35 +0200 Subject: [PATCH] Fixing the issue with `add_special_tokens` not being passed around. --- backends/client/src/v3/client.rs | 2 ++ backends/client/src/v3/sharded_client.rs | 1 + backends/v3/src/client/grpc_client.rs | 1 + backends/v3/src/client/sharded_client.rs | 1 + backends/v3/src/queue.rs | 1 + benchmark/src/generation.rs | 1 + proto/v3/generate.proto | 2 ++ router/src/validation.rs | 2 ++ .../custom_modeling/flash_qwen2_modeling.py | 25 +++++++------- .../models/flash_causal_lm.py | 34 +++++++++++-------- 10 files changed, 42 insertions(+), 28 deletions(-) diff --git a/backends/client/src/v3/client.rs b/backends/client/src/v3/client.rs index b321278c1..479d31bf2 100644 --- a/backends/client/src/v3/client.rs +++ b/backends/client/src/v3/client.rs @@ -153,6 +153,8 @@ impl Client { }), // We truncate the input on the server side to be sure that it has the correct size truncate, + // Most request will have that + add_special_tokens: true, // Blocks and slots will be set on the server side if we use paged attention blocks: vec![], slots: vec![], diff --git a/backends/client/src/v3/sharded_client.rs b/backends/client/src/v3/sharded_client.rs index 1cc173e33..645c076a2 100644 --- a/backends/client/src/v3/sharded_client.rs +++ b/backends/client/src/v3/sharded_client.rs @@ -221,6 +221,7 @@ impl Health for ShardedClient { chunks: vec![Chunk::Text("liveness".into()).into()], }), truncate: 10, + add_special_tokens: true, prefill_logprobs: false, parameters: Some(NextTokenChooserParameters { temperature: 1.0, diff --git a/backends/v3/src/client/grpc_client.rs b/backends/v3/src/client/grpc_client.rs index 6282759e8..648662db3 100644 --- a/backends/v3/src/client/grpc_client.rs +++ b/backends/v3/src/client/grpc_client.rs @@ -149,6 +149,7 @@ impl Client { requests.push(Request { id: 0, inputs, + add_special_tokens: true, input_chunks: Some(Input { chunks: input_chunks, }), diff --git a/backends/v3/src/client/sharded_client.rs b/backends/v3/src/client/sharded_client.rs index 2f78da034..ea77a6966 100644 --- a/backends/v3/src/client/sharded_client.rs +++ b/backends/v3/src/client/sharded_client.rs @@ -222,6 +222,7 @@ impl Health for ShardedClient { chunks: vec![Chunk::Text("liveness".into()).into()], }), truncate: 10, + add_special_tokens: true, prefill_logprobs: false, parameters: Some(NextTokenChooserParameters { temperature: 1.0, diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index 4002b83f0..53439bf69 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -387,6 +387,7 @@ impl State { }), inputs: entry.request.inputs.chunks_to_string(), truncate: entry.request.truncate, + add_special_tokens: entry.request.add_special_tokens, parameters: Some(NextTokenChooserParameters::from( entry.request.parameters.clone(), )), diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index 7494d5b5d..789c7b514 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -148,6 +148,7 @@ async fn prefill( }), inputs: sequence.clone(), truncate: sequence_length, + add_special_tokens: true, parameters: Some(parameters.clone()), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: decode_length, diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index 68eea7ac9..34894bdab 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -137,6 +137,8 @@ message Request { optional string adapter_id = 11; /// Prefix length that can be retrieved from the KV cache. uint32 prefix_len = 12; + /// Context truncation + bool add_special_tokens = 13; } message Batch { diff --git a/router/src/validation.rs b/router/src/validation.rs index 3c2e706b1..054276c82 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -415,6 +415,7 @@ impl Validation { Ok(ValidGenerateRequest { inputs, input_ids: input_ids.map(Arc::new), + add_special_tokens: request.add_special_tokens, decoder_input_details, input_length: input_length as u32, truncate: truncate.unwrap_or(self.max_input_length) as u32, @@ -738,6 +739,7 @@ pub struct ValidGenerateRequest { pub input_ids: Option>>, pub input_length: u32, pub truncate: u32, + pub add_special_tokens: bool, pub decoder_input_details: bool, pub parameters: ValidParameters, pub stopping_parameters: ValidStoppingParameters, diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 879b8abd7..5aac28a30 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -9,6 +9,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -104,7 +105,7 @@ class Qwen2Attention(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, ): @@ -135,12 +136,10 @@ class Qwen2Attention(torch.nn.Module): # flash attention attn_output = attention( query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, window_size_left=self.max_past, ) @@ -153,7 +152,7 @@ class Qwen2Attention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -225,7 +224,7 @@ class Qwen2Layer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, ): @@ -240,7 +239,7 @@ class Qwen2Layer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, ) @@ -296,7 +295,7 @@ class Qwen2Model(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], @@ -320,7 +319,7 @@ class Qwen2Model(torch.nn.Module): kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, ) @@ -361,7 +360,7 @@ class Qwen2ForCausalLM(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, @@ -374,7 +373,7 @@ class Qwen2ForCausalLM(torch.nn.Module): elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values - input_lengths = input_lengths.clamp(max=self.max_past_tensor) + seqlen = seqlen.clamp(max=self.max_past_tensor) hidden_states = self.model( input_ids, @@ -383,7 +382,7 @@ class Qwen2ForCausalLM(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, true_max_s, prefill_cache_indices, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 22e0ada18..6b0e1e86f 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -189,15 +189,22 @@ class FlashCausalLMBatch(Batch): cls, requests: Iterable[generate_pb2.Request], tokenizer ): batch_inputs = [] - max_truncation = 0 + max_length = 0 + all_input_ids = [] + batch_size = 0 for r in requests: + batch_size += 1 batch_inputs.append(concat_text_chunks(r.input_chunks.chunks)) - max_truncation = max(max_truncation, r.truncate) - batch_tokenized_inputs = tokenizer( - batch_inputs, truncation=True, max_length=max_truncation - )["input_ids"] - return batch_tokenized_inputs + input_ids = tokenizer( + batch_inputs, + truncation=True, + max_length=r.truncate, + add_special_tokens=r.add_special_tokens, + )["input_ids"][0] + max_length = max(max_length, len(input_ids)) + all_input_ids.append(input_ids) + return all_input_ids @classmethod def from_tokenized( @@ -256,20 +263,17 @@ class FlashCausalLMBatch(Batch): # request id -> idx in list mapping requests_idx_mapping[r.id] = i - tokenized_input = tokenized_input[-r.truncate :] - if ( - tokenized_input[0] == tokenizer.bos_token_id - and tokenized_input[1] == tokenizer.bos_token_id - ): - tokenized_input = tokenized_input[1:] + # tokenized_input = tokenized_input[-r.truncate :] + # if ( + # tokenized_input[0] == tokenizer.bos_token_id + # and tokenized_input[1] == tokenizer.bos_token_id + # ): + # tokenized_input = tokenized_input[1:] orig_input_length = len(tokenized_input) prefix_len = r.prefix_len assert prefix_len <= orig_input_length - if prefix_len == orig_input_length: - assert prefix_len > 0 - prefix_len -= 1 prefix_ids.append(tokenized_input[:prefix_len]) tokenized_input = tokenized_input[prefix_len:]