From 47e93409f3c045e79a4de0b1e7aa43fe89f77bfe Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 4 Apr 2023 12:35:29 +0200 Subject: [PATCH] optional rust validation --- benchmark/Cargo.toml | 3 +- proto/generate.proto | 6 +- router/src/main.rs | 11 +- router/src/queue.rs | 2 + router/src/server.rs | 2 +- router/src/validation.rs | 73 ++++--- .../models/causal_lm.py | 5 + .../custom_modeling/flash_llama_modeling.py | 191 ++++++++++++++++- .../custom_modeling/flash_neox_modeling.py | 198 ++++++++++++++++-- .../models/custom_modeling/linear.py | 22 -- .../models/custom_modeling/rotary.py | 42 ---- .../models/custom_modeling/tensor_parallel.py | 124 ----------- .../models/flash_causal_lm.py | 5 +- .../models/flash_llama.py | 2 - .../models/flash_neox.py | 6 +- .../models/galactica.py | 4 + server/text_generation_server/models/model.py | 22 +- .../models/seq2seq_lm.py | 5 + 18 files changed, 453 insertions(+), 270 deletions(-) delete mode 100644 server/text_generation_server/models/custom_modeling/linear.py delete mode 100644 server/text_generation_server/models/custom_modeling/rotary.py delete mode 100644 server/text_generation_server/models/custom_modeling/tensor_parallel.py diff --git a/benchmark/Cargo.toml b/benchmark/Cargo.toml index d3badcd8..f2a82935 100644 --- a/benchmark/Cargo.toml +++ b/benchmark/Cargo.toml @@ -27,7 +27,8 @@ serde = {version = "1.0.142", features = ["derive"]} serde_json = "1.0" text-generation-client = { path = "../router/client" } thiserror = "1.0.38" -tokenizers = "0.13.2" +#tokenizers = "0.13.2" +tokenizers = { git = "https://github.com/huggingface/tokenizers.git" } tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tui = {package = "ratatui", version = "0.20", default-features = false, features = ["crossterm"]} tracing = "0.1.37" diff --git a/proto/generate.proto b/proto/generate.proto index 86393026..cc14cbf8 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -63,10 +63,12 @@ message Request { uint64 id = 1; /// The generation context string inputs = 2; + /// Context truncation + uint32 truncate = 3; /// Next Token Chooser Parameters - NextTokenChooserParameters parameters = 3; + NextTokenChooserParameters parameters = 4; /// Stopping Criteria Parameters - StoppingCriteriaParameters stopping_parameters = 4; + StoppingCriteriaParameters stopping_parameters = 5; } message Batch { diff --git a/router/src/main.rs b/router/src/main.rs index bad3df93..3ff72cde 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -94,11 +94,11 @@ fn main() -> Result<(), std::io::Error> { if local_path.exists() && local_path.is_dir() && local_path.join("tokenizer.json").exists() { // Load local tokenizer - Tokenizer::from_file(local_path.join("tokenizer.json")).unwrap() + Tokenizer::from_file(local_path.join("tokenizer.json")).ok() } else { // Download and instantiate tokenizer // We need to download it outside of the Tokio runtime - Tokenizer::from_pretrained(tokenizer_name.clone(), None).unwrap() + Tokenizer::from_pretrained(tokenizer_name.clone(), None).ok() }; // Launch Tokio runtime @@ -109,6 +109,13 @@ fn main() -> Result<(), std::io::Error> { .block_on(async { init_logging(otlp_endpoint, json_output); + if tokenizer.is_none() { + tracing::warn!( + "Could not find a fast tokenizer implementation for {tokenizer_name}" + ); + tracing::warn!("Rust input length validation and truncation is disabled"); + } + // Get pipeline tag let model_info = reqwest::get(format!( "https://huggingface.co/api/models/{tokenizer_name}" diff --git a/router/src/queue.rs b/router/src/queue.rs index 2899ccd4..77f8461b 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -174,6 +174,7 @@ impl State { batch_requests.push(Request { id, inputs: entry.request.inputs.clone(), + truncate: entry.request.truncate, parameters: Some(entry.request.parameters.clone()), stopping_parameters: Some(entry.request.stopping_parameters.clone()), }); @@ -226,6 +227,7 @@ mod tests { Entry { request: ValidGenerateRequest { inputs: "".to_string(), + truncate: 0, parameters: NextTokenChooserParameters { temperature: 0.0, top_k: 0, diff --git a/router/src/server.rs b/router/src/server.rs index f7850053..88c40565 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -455,7 +455,7 @@ pub async fn run( max_batch_size: usize, max_waiting_tokens: usize, client: ShardedClient, - tokenizer: Tokenizer, + tokenizer: Option, validation_workers: usize, addr: SocketAddr, allow_origin: Option, diff --git a/router/src/validation.rs b/router/src/validation.rs index ec67cefd..a0b8b98e 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -23,7 +23,7 @@ pub struct Validation { impl Validation { pub(crate) fn new( workers: usize, - tokenizer: Tokenizer, + tokenizer: Option, max_best_of: usize, max_stop_sequences: usize, max_input_length: usize, @@ -85,7 +85,7 @@ impl Validation { /// Load balance the validation requests between multiple validation workers async fn validation_task( workers: usize, - tokenizer: Tokenizer, + tokenizer: Option, max_stop_sequences: usize, max_input_length: usize, max_total_tokens: usize, @@ -95,7 +95,7 @@ async fn validation_task( // Create workers for _ in 0..workers { - let tokenizer_clone: Tokenizer = tokenizer.clone().into(); + let tokenizer_clone: Option = tokenizer.clone().into(); // Create channel to communicate with worker let (worker_sender, worker_receiver) = mpsc::channel(workers); workers_senders.push(worker_sender); @@ -127,7 +127,7 @@ async fn validation_task( /// Check the parameters inside the payload and get the number of tokens inside the input using /// the tokenizer fn validation_worker( - tokenizer: Tokenizer, + tokenizer: Option, max_stop_sequences: usize, max_input_length: usize, max_total_tokens: usize, @@ -143,7 +143,7 @@ fn validation_worker( .send( validate( request, - &tokenizer, + tokenizer.as_ref(), max_stop_sequences, max_input_length, max_total_tokens, @@ -162,7 +162,7 @@ fn validation_worker( fn validate( request: GenerateRequest, - tokenizer: &Tokenizer, + tokenizer: Option<&Tokenizer>, max_stop_sequences: usize, max_input_length: usize, max_total_tokens: usize, @@ -272,35 +272,43 @@ fn validate( }) .unwrap_or(Ok(None))?; - // Get the number of tokens in the input - let mut encoding = tokenizer - .encode(request.inputs.clone(), true) - .map_err(|err| ValidationError::Tokenizer(err.to_string()))?; - - let (inputs, input_length) = if let Some(truncate) = truncate { - // truncate encoding and decode new inputs - encoding.truncate(truncate, 0, TruncationDirection::Left); - let inputs = tokenizer - .decode(Vec::from(encoding.get_ids()), false) + // If we have a fast tokenizer + let inputs = if let Some(tokenizer) = tokenizer { + // Get the number of tokens in the input + let mut encoding = tokenizer + .encode(request.inputs.clone(), true) .map_err(|err| ValidationError::Tokenizer(err.to_string()))?; - (inputs, encoding.len()) + + let (inputs, input_length) = if let Some(truncate) = truncate { + // truncate encoding and decode new inputs + encoding.truncate(truncate, 0, TruncationDirection::Left); + let inputs = tokenizer + .decode(Vec::from(encoding.get_ids()), false) + .map_err(|err| ValidationError::Tokenizer(err.to_string()))?; + (inputs, encoding.len()) + } else { + (request.inputs, encoding.len()) + }; + + if input_length > max_input_length { + return Err(ValidationError::InputLength(max_input_length, input_length)); + } + + let total_tokens = input_length + max_new_tokens as usize; + if total_tokens > max_total_tokens { + return Err(ValidationError::MaxTotalTokens( + max_total_tokens, + input_length, + max_new_tokens, + )); + } + + metrics::histogram!("tgi_request_input_length", input_length as f64); + inputs } else { - (request.inputs, encoding.len()) + request.inputs }; - if input_length > max_input_length { - return Err(ValidationError::InputLength(max_input_length, input_length)); - } - - let total_tokens = input_length + max_new_tokens as usize; - if total_tokens > max_total_tokens { - return Err(ValidationError::MaxTotalTokens( - max_total_tokens, - input_length, - max_new_tokens, - )); - } - // Return ValidGenerateRequest let parameters = NextTokenChooserParameters { temperature, @@ -318,11 +326,11 @@ fn validate( ignore_eos_token: false, }; - metrics::histogram!("tgi_request_input_length", input_length as f64); metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64); Ok(ValidGenerateRequest { inputs, + truncate: truncate.unwrap_or(max_input_length) as u32, parameters, stopping_parameters, }) @@ -337,6 +345,7 @@ type ValidationRequest = ( #[derive(Debug)] pub(crate) struct ValidGenerateRequest { pub inputs: String, + pub truncate: u32, pub parameters: NextTokenChooserParameters, pub stopping_parameters: StoppingCriteriaParameters, } diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index c2ad0587..7cd49239 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -66,6 +66,7 @@ class CausalLMBatch(Batch): stopping_criterias = [] # Parse batch + max_truncation = 0 padding_right_offset = 0 for r in pb.requests: inputs.append(r.inputs) @@ -74,6 +75,7 @@ class CausalLMBatch(Batch): r.stopping_parameters, tokenizer ) stopping_criterias.append(stopping_criteria) + max_truncation = max(max_truncation, r.truncate) padding_right_offset = max( padding_right_offset, stopping_criteria.max_new_tokens ) @@ -83,6 +85,8 @@ class CausalLMBatch(Batch): return_tensors="pt", padding=True, return_token_type_ids=False, + truncation=True, + max_length=max_truncation, ).to(device) input_lengths = tokenized_inputs["attention_mask"].sum(1) @@ -388,6 +392,7 @@ class CausalLM(Model): next_token_logprob = logprobs[-1, next_token_id] next_token_id_squeezed = next_token_id.squeeze() next_token_text = self.decode_token( + all_input_ids[-2, 0], next_token_id_squeezed, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index ee977948..6e2fd4b7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -21,21 +21,18 @@ import torch import torch.distributed +from torch.nn import functional as F + from torch import nn from transformers.activations import ACT2FN -from text_generation_server.models.custom_modeling.tensor_parallel import ( - TensorParallelEmbedding, - TensorParallelRowLinear, - TensorParallelColumnLinear, -) -from text_generation_server.models.custom_modeling.linear import FastLinear -from text_generation_server.models.custom_modeling.rotary import PositionRotaryEmbedding - # Flash attention imports +import rotary_emb import flash_attn_cuda import dropout_layer_norm +from flash_attn.layers.rotary import RotaryEmbedding + class LlamaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): @@ -87,6 +84,184 @@ class LlamaRMSNorm(nn.Module): return normed_hidden_states, res +class FastLinear(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype) + + def transpose_weight(self): + self.weight = nn.Parameter(self.weight.T) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.bias is not None: + return torch.addmm(self.bias, input, self.weight) + return torch.matmul(input, self.weight) + + +class TensorParallelColumnLinear(FastLinear): + def __init__( + self, + in_features, + out_features, + process_group: torch.distributed.ProcessGroup, + bias=True, + device=None, + dtype=None, + ): + self.process_group = process_group + self.tp_world_size = process_group.size() + assert out_features % self.tp_world_size == 0 + out_features = out_features // self.tp_world_size + + super().__init__( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + dtype=dtype, + ) + + +class TensorParallelRowLinear(FastLinear): + def __init__( + self, + in_features, + out_features, + process_group: torch.distributed.ProcessGroup, + reduce=True, + bias=True, + device=None, + dtype=None, + ): + self.process_group = process_group + self.tp_world_size = process_group.size() + self.reduce = reduce + assert in_features % self.tp_world_size == 0 + in_features = in_features // self.tp_world_size + + super().__init__( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + dtype=dtype, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + out = super(TensorParallelRowLinear, self).forward(input) + if self.reduce: + torch.distributed.all_reduce(out, group=self.process_group) + + return out + + +class TensorParallelEmbedding(nn.Embedding): + def __init__( + self, + num_embeddings, + embedding_dim, + process_group: torch.distributed.ProcessGroup, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + _weight=None, + device=None, + dtype=None, + ): + self.process_group = process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + + self.original_num_embeddings = num_embeddings + + assert num_embeddings % self.tp_world_size == 0 + block_size = num_embeddings // self.tp_world_size + # inputs in `[min_id, max_id[` are handled by `self` to get embeddings + self.min_id = self.tp_rank * block_size + self.max_id = (self.tp_rank + 1) * block_size + + # Additional entry that will map to zero + # Used for masking + self.null_idx = block_size + + super().__init__( + block_size, + embedding_dim, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + _weight=_weight, + device=device, + dtype=dtype, + ) + + def add_null_idx(self): + """Additional 0 entry used for masking""" + self.weight = nn.Parameter(F.pad(self.weight, (0, 0, 0, 1))) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + # default all out of bounds values to `self.null_idx` that will then be mapped to 0 + # translate for [0, self.max_id - self.min_id[ + input = torch.where( + (self.min_id > input) | (input >= self.max_id), + self.null_idx, + input - self.min_id, + ) + out = super().forward(input) + torch.distributed.all_reduce(out, group=self.process_group) + return out + + +class PositionRotaryEmbedding(RotaryEmbedding): + def _update_cos_sin_cache(self, dtype, device, seqlen): + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if ( + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + ): + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + # Don't do einsum, it converts fp32 to fp16 + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + + def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype): + """ + Return cos and sin for the asked position ids + """ + + self._update_cos_sin_cache(dtype, position_ids.device, max_s) + + cos = torch.index_select(self._cos_cached, 0, position_ids) + sin = torch.index_select(self._sin_cached, 0, position_ids) + return cos.unsqueeze(1), sin.unsqueeze(1) + + def forward(self, qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): + rotary_dim = cos.shape[-1] + q1 = qkv[:, 0, :, :rotary_dim] + q2 = qkv[:, 0, :, rotary_dim : 2 * rotary_dim] + k1 = qkv[:, 1, :, :rotary_dim] + k2 = qkv[:, 1, :, rotary_dim : 2 * rotary_dim] + + rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) + rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) + return qkv + + class FlashLlamaAttention(torch.nn.Module): def __init__( self, diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 93f1b0ca..f3517c47 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -21,23 +21,20 @@ import torch import torch.distributed +from torch.nn import functional as F + from torch import nn from transformers.activations import ACT2FN from transformers.modeling_utils import PreTrainedModel from transformers.models.gpt_neox import GPTNeoXConfig -from text_generation_server.models.custom_modeling.tensor_parallel import ( - TensorParallelEmbedding, - TensorParallelRowLinear, - TensorParallelColumnLinear, -) -from text_generation_server.models.custom_modeling.linear import FastLinear -from text_generation_server.models.custom_modeling.rotary import PositionRotaryEmbedding - # Flash attention imports +import rotary_emb import flash_attn_cuda import dropout_layer_norm +from flash_attn.layers.rotary import RotaryEmbedding + class FastLayerNorm(nn.LayerNorm): def forward(self, hidden_states, residual=None): @@ -75,6 +72,184 @@ class FastLayerNorm(nn.LayerNorm): return normed_hidden_states, residual +class FastLinear(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype) + + def transpose_weight(self): + self.weight = nn.Parameter(self.weight.T) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.bias is not None: + return torch.addmm(self.bias, input, self.weight) + return torch.matmul(input, self.weight) + + +class TensorParallelColumnLinear(FastLinear): + def __init__( + self, + in_features, + out_features, + process_group: torch.distributed.ProcessGroup, + bias=True, + device=None, + dtype=None, + ): + self.process_group = process_group + self.tp_world_size = process_group.size() + assert out_features % self.tp_world_size == 0 + out_features = out_features // self.tp_world_size + + super().__init__( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + dtype=dtype, + ) + + +class TensorParallelRowLinear(FastLinear): + def __init__( + self, + in_features, + out_features, + process_group: torch.distributed.ProcessGroup, + reduce=True, + bias=True, + device=None, + dtype=None, + ): + self.process_group = process_group + self.tp_world_size = process_group.size() + self.reduce = reduce + assert in_features % self.tp_world_size == 0 + in_features = in_features // self.tp_world_size + + super().__init__( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + dtype=dtype, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + out = super(TensorParallelRowLinear, self).forward(input) + if self.reduce: + torch.distributed.all_reduce(out, group=self.process_group) + + return out + + +class TensorParallelEmbedding(nn.Embedding): + def __init__( + self, + num_embeddings, + embedding_dim, + process_group: torch.distributed.ProcessGroup, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + _weight=None, + device=None, + dtype=None, + ): + self.process_group = process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + + self.original_num_embeddings = num_embeddings + + assert num_embeddings % self.tp_world_size == 0 + block_size = num_embeddings // self.tp_world_size + # inputs in `[min_id, max_id[` are handled by `self` to get embeddings + self.min_id = self.tp_rank * block_size + self.max_id = (self.tp_rank + 1) * block_size + + # Additional entry that will map to zero + # Used for masking + self.null_idx = block_size + + super().__init__( + block_size, + embedding_dim, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + _weight=_weight, + device=device, + dtype=dtype, + ) + + def add_null_idx(self): + """Additional 0 entry used for masking""" + self.weight = nn.Parameter(F.pad(self.weight, (0, 0, 0, 1))) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + # default all out of bounds values to `self.null_idx` that will then be mapped to 0 + # translate for [0, self.max_id - self.min_id[ + input = torch.where( + (self.min_id > input) | (input >= self.max_id), + self.null_idx, + input - self.min_id, + ) + out = super().forward(input) + torch.distributed.all_reduce(out, group=self.process_group) + return out + + +class PositionRotaryEmbedding(RotaryEmbedding): + def _update_cos_sin_cache(self, dtype, device, seqlen): + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if ( + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + ): + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + # Don't do einsum, it converts fp32 to fp16 + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + + def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype): + """ + Return cos and sin for the asked position ids + """ + + self._update_cos_sin_cache(dtype, position_ids.device, max_s) + + cos = torch.index_select(self._cos_cached, 0, position_ids) + sin = torch.index_select(self._sin_cached, 0, position_ids) + return cos.unsqueeze(1), sin.unsqueeze(1) + + def forward(self, qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): + rotary_dim = cos.shape[-1] + q1 = qkv[:, 0, :, :rotary_dim] + q2 = qkv[:, 0, :, rotary_dim : 2 * rotary_dim] + k1 = qkv[:, 1, :, :rotary_dim] + k2 = qkv[:, 1, :, rotary_dim : 2 * rotary_dim] + + rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) + rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) + return qkv + + class FlashNeoxAttention(torch.nn.Module): def __init__( self, @@ -201,12 +376,7 @@ class FlashMLP(nn.Module): self.act = ( ACT2FN[act] if "gelu" not in act - else lambda x: torch.nn.functional.gelu( - x, - approximate="tanh" - if act in ["gelu_fast", "gelu_pytorch_tanh"] - else None, - ) + else lambda x: torch.nn.functional.gelu(x, approximate="tanh") ) if process_group is None: diff --git a/server/text_generation_server/models/custom_modeling/linear.py b/server/text_generation_server/models/custom_modeling/linear.py deleted file mode 100644 index 1ca8e3a9..00000000 --- a/server/text_generation_server/models/custom_modeling/linear.py +++ /dev/null @@ -1,22 +0,0 @@ -import torch -from torch import nn - - -class FastLinear(nn.Linear): - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - device=None, - dtype=None, - ) -> None: - super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype) - - def transpose_weight(self): - self.weight = nn.Parameter(self.weight.T) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - if self.bias is not None: - return torch.addmm(self.bias, input, self.weight) - return torch.matmul(input, self.weight) diff --git a/server/text_generation_server/models/custom_modeling/rotary.py b/server/text_generation_server/models/custom_modeling/rotary.py deleted file mode 100644 index 69f97558..00000000 --- a/server/text_generation_server/models/custom_modeling/rotary.py +++ /dev/null @@ -1,42 +0,0 @@ -import torch -import rotary_emb - -from flash_attn.layers.rotary import RotaryEmbedding - - -class PositionRotaryEmbedding(RotaryEmbedding): - def _update_cos_sin_cache(self, dtype, device, seqlen): - # Reset the tables if the sequence length has changed, - # or if we're on a new device (possibly due to tracing for instance) - if ( - seqlen > self._seq_len_cached - or self._cos_cached.device != device - or self._cos_cached.dtype != dtype - ): - self._seq_len_cached = seqlen - t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device=t.device)) - self._cos_cached = torch.cos(freqs).to(dtype) - self._sin_cached = torch.sin(freqs).to(dtype) - - def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype): - """ - Return cos and sin for the asked position ids - """ - - self._update_cos_sin_cache(dtype, position_ids.device, max_s) - - cos = torch.index_select(self._cos_cached, 0, position_ids) - sin = torch.index_select(self._sin_cached, 0, position_ids) - return cos.unsqueeze(1), sin.unsqueeze(1) - - def forward(self, qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): - rotary_dim = cos.shape[-1] - q1 = qkv[:, 0, :, :rotary_dim] - q2 = qkv[:, 0, :, rotary_dim : 2 * rotary_dim] - k1 = qkv[:, 1, :, :rotary_dim] - k2 = qkv[:, 1, :, rotary_dim : 2 * rotary_dim] - - rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) - rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) - return qkv diff --git a/server/text_generation_server/models/custom_modeling/tensor_parallel.py b/server/text_generation_server/models/custom_modeling/tensor_parallel.py deleted file mode 100644 index e01b8835..00000000 --- a/server/text_generation_server/models/custom_modeling/tensor_parallel.py +++ /dev/null @@ -1,124 +0,0 @@ -import torch -import torch.distributed -from torch import nn -from torch.nn import functional as F - -from text_generation_server.models.custom_modeling.linear import FastLinear - - -class TensorParallelColumnLinear(FastLinear): - def __init__( - self, - in_features, - out_features, - process_group: torch.distributed.ProcessGroup, - bias=True, - device=None, - dtype=None, - ): - self.process_group = process_group - self.tp_world_size = process_group.size() - assert out_features % self.tp_world_size == 0 - out_features = out_features // self.tp_world_size - - super().__init__( - in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - dtype=dtype, - ) - - -class TensorParallelRowLinear(FastLinear): - def __init__( - self, - in_features, - out_features, - process_group: torch.distributed.ProcessGroup, - reduce=True, - bias=True, - device=None, - dtype=None, - ): - self.process_group = process_group - self.tp_world_size = process_group.size() - self.reduce = reduce - assert in_features % self.tp_world_size == 0 - in_features = in_features // self.tp_world_size - - super().__init__( - in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - dtype=dtype, - ) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - out = super(TensorParallelRowLinear, self).forward(input) - if self.reduce: - torch.distributed.all_reduce(out, group=self.process_group) - - return out - - -class TensorParallelEmbedding(nn.Embedding): - def __init__( - self, - num_embeddings, - embedding_dim, - process_group: torch.distributed.ProcessGroup, - padding_idx=None, - max_norm=None, - norm_type=2.0, - scale_grad_by_freq=False, - sparse=False, - _weight=None, - device=None, - dtype=None, - ): - self.process_group = process_group - self.tp_rank = process_group.rank() - self.tp_world_size = process_group.size() - - self.original_num_embeddings = num_embeddings - - assert num_embeddings % self.tp_world_size == 0 - block_size = num_embeddings // self.tp_world_size - # inputs in `[min_id, max_id[` are handled by `self` to get embeddings - self.min_id = self.tp_rank * block_size - self.max_id = (self.tp_rank + 1) * block_size - - # Additional entry that will map to zero - # Used for masking - self.null_idx = block_size - - super().__init__( - block_size, - embedding_dim, - padding_idx=padding_idx, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse, - _weight=_weight, - device=device, - dtype=dtype, - ) - - def add_null_idx(self): - """Additional 0 entry used for masking""" - self.weight = nn.Parameter(F.pad(self.weight, (0, 0, 0, 1))) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - # default all out of bounds values to `self.null_idx` that will then be mapped to 0 - # translate for [0, self.max_id - self.min_id[ - input = torch.where( - (self.min_id > input) | (input >= self.max_id), - self.null_idx, - input - self.min_id, - ) - out = super().forward(input) - torch.distributed.all_reduce(out, group=self.process_group) - return out diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index e1a10cbf..2aeac7b5 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -78,7 +78,9 @@ class FlashCausalLMBatch(Batch): # Parse batch for r in pb.requests: - tokenized_input = tokenizer(r.inputs)["input_ids"] + tokenized_input = tokenizer( + r.inputs, truncation=True, max_length=r.truncate + )["input_ids"] input_length = len(tokenized_input) max_seqlen = max(max_seqlen, input_length) input_lengths.append(input_length) @@ -333,6 +335,7 @@ class FlashCausalLM(Model): # Generated token next_token_logprob = logprobs[-1, next_token_id_item] next_token_text = self.decode_token( + all_input_ids[-2], next_token_id_item, ) diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 7d7e2cf5..3029ab89 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -11,8 +11,6 @@ from typing import Optional, Tuple, List from text_generation_server.models import FlashCausalLM from text_generation_server.models.custom_modeling.flash_llama_modeling import ( FlashLlamaForCausalLM, -) -from text_generation_server.models.custom_modeling.tensor_parallel import ( TensorParallelEmbedding, TensorParallelRowLinear, TensorParallelColumnLinear, diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 655b664e..e415a725 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -8,14 +8,12 @@ from transformers import AutoTokenizer, AutoConfig from typing import Optional, Tuple, List from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.tensor_parallel import ( +from text_generation_server.models.custom_modeling.flash_neox_modeling import ( + FlashGPTNeoXForCausalLM, TensorParallelEmbedding, TensorParallelRowLinear, TensorParallelColumnLinear, ) -from text_generation_server.models.custom_modeling.flash_neox_modeling import ( - FlashGPTNeoXForCausalLM, -) from text_generation_server.utils import ( initialize_torch_distributed, weight_files, diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index a90a299e..f997ab1a 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -96,6 +96,7 @@ class GalacticaCausalLMBatch(CausalLMBatch): input_lengths = [] # Parse batch + max_truncation = 0 max_sequence_length = 0 padding_right_offset = 0 for r in pb.requests: @@ -107,6 +108,7 @@ class GalacticaCausalLMBatch(CausalLMBatch): r.stopping_parameters, tokenizer ) stopping_criterias.append(stopping_criteria) + max_truncation = max(max_truncation, r.truncate) max_sequence_length = max(max_sequence_length, r.input_length) padding_right_offset = max( padding_right_offset, stopping_criteria.max_new_tokens @@ -118,6 +120,8 @@ class GalacticaCausalLMBatch(CausalLMBatch): return_tensors="pt", padding=True, return_token_type_ids=False, + truncation=True, + max_length=max_truncation, ).to(device) input_ids = tokenized_inputs["input_ids"] # Allocate maximum attention_mask diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index e0ce6686..9e519779 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -15,15 +15,6 @@ class Model(ABC): self.all_special_ids = set(tokenizer.all_special_ids) self.device = device - # see `decode_token` method - self.tokenizer.add_special_tokens( - {"additional_special_tokens": [""]} - ) - self.special_decode_token_id = self.tokenizer.convert_tokens_to_ids( - "" - ) - self.special_decode_token_length = len("") - @property @abstractmethod def batch_type(self) -> Type[B]: @@ -33,11 +24,12 @@ class Model(ABC): def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]: raise NotImplementedError - def decode_token(self, token_id: int) -> str: + def decode_token(self, previous_token_id: int, token_id: int) -> str: """Hack to hopefully support generate_stream for the maximum number of tokenizers""" - # append token to special decode token and decode both - result = self.tokenizer.decode( - [self.special_decode_token_id, token_id], skip_special_tokens=False + # Decode previous token and previous token + token + results = self.tokenizer.batch_decode( + [[previous_token_id], [previous_token_id, token_id]], + skip_special_tokens=False, ) - # slice to remove special decode token - return result[self.special_decode_token_length :] + # slice to remove previous token + return results[1][len(results[0]) :] diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 0fe5c03f..72f694c3 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -73,6 +73,7 @@ class Seq2SeqLMBatch(Batch): decoder_input_lengths = [] # Parse batch + max_truncation = 0 padding_right_offset = 0 for r in pb.requests: inputs.append(r.inputs) @@ -84,6 +85,7 @@ class Seq2SeqLMBatch(Batch): r.stopping_parameters, tokenizer ) stopping_criterias.append(stopping_criteria) + max_truncation = max(max_truncation, r.truncate) padding_right_offset = max( padding_right_offset, stopping_criteria.max_new_tokens ) @@ -94,6 +96,8 @@ class Seq2SeqLMBatch(Batch): return_tensors="pt", padding=True, return_token_type_ids=False, + truncation=True, + max_length=max_truncation, ).to(device) input_lengths = tokenized_inputs["attention_mask"].sum(1) @@ -463,6 +467,7 @@ class Seq2SeqLM(Model): next_token_logprob = logprobs[-1, next_token_id] next_token_id_squeezed = next_token_id.squeeze() next_token_text = self.decode_token( + decoder_input_ids[-2], next_token_id_squeezed, )