diff --git a/benchmark/Cargo.toml b/benchmark/Cargo.toml
index d3badcd8..f2a82935 100644
--- a/benchmark/Cargo.toml
+++ b/benchmark/Cargo.toml
@@ -27,7 +27,8 @@ serde = {version = "1.0.142", features = ["derive"]}
 serde_json = "1.0"
 text-generation-client = { path = "../router/client" }
 thiserror = "1.0.38"
-tokenizers = "0.13.2"
+#tokenizers = "0.13.2"
+tokenizers = { git = "https://github.com/huggingface/tokenizers.git" }
 tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
 tui = {package = "ratatui", version = "0.20", default-features = false, features = ["crossterm"]}
 tracing = "0.1.37"
diff --git a/proto/generate.proto b/proto/generate.proto
index 86393026..cc14cbf8 100644
--- a/proto/generate.proto
+++ b/proto/generate.proto
@@ -63,10 +63,12 @@ message Request {
     uint64 id = 1;
     /// The generation context
     string inputs = 2;
+    /// Context truncation
+    uint32 truncate = 3;
     /// Next Token Chooser Parameters
-    NextTokenChooserParameters parameters = 3;
+    NextTokenChooserParameters parameters = 4;
     /// Stopping Criteria Parameters
-    StoppingCriteriaParameters stopping_parameters = 4;
+    StoppingCriteriaParameters stopping_parameters = 5;
 }
 
 message Batch {
diff --git a/router/src/main.rs b/router/src/main.rs
index bad3df93..3ff72cde 100644
--- a/router/src/main.rs
+++ b/router/src/main.rs
@@ -94,11 +94,11 @@ fn main() -> Result<(), std::io::Error> {
         if local_path.exists() && local_path.is_dir() && local_path.join("tokenizer.json").exists()
         {
             // Load local tokenizer
-            Tokenizer::from_file(local_path.join("tokenizer.json")).unwrap()
+            Tokenizer::from_file(local_path.join("tokenizer.json")).ok()
         } else {
             // Download and instantiate tokenizer
             // We need to download it outside of the Tokio runtime
-            Tokenizer::from_pretrained(tokenizer_name.clone(), None).unwrap()
+            Tokenizer::from_pretrained(tokenizer_name.clone(), None).ok()
         };
 
     // Launch Tokio runtime
@@ -109,6 +109,13 @@ fn main() -> Result<(), std::io::Error> {
         .block_on(async {
             init_logging(otlp_endpoint, json_output);
 
+            if tokenizer.is_none() {
+                tracing::warn!(
+                    "Could not find a fast tokenizer implementation for {tokenizer_name}"
+                );
+                tracing::warn!("Rust input length validation and truncation is disabled");
+            }
+
             // Get pipeline tag
             let model_info = reqwest::get(format!(
                 "https://huggingface.co/api/models/{tokenizer_name}"
diff --git a/router/src/queue.rs b/router/src/queue.rs
index 2899ccd4..77f8461b 100644
--- a/router/src/queue.rs
+++ b/router/src/queue.rs
@@ -174,6 +174,7 @@ impl State {
                 batch_requests.push(Request {
                     id,
                     inputs: entry.request.inputs.clone(),
+                    truncate: entry.request.truncate,
                     parameters: Some(entry.request.parameters.clone()),
                     stopping_parameters: Some(entry.request.stopping_parameters.clone()),
                 });
@@ -226,6 +227,7 @@ mod tests {
         Entry {
             request: ValidGenerateRequest {
                 inputs: "".to_string(),
+                truncate: 0,
                 parameters: NextTokenChooserParameters {
                     temperature: 0.0,
                     top_k: 0,
diff --git a/router/src/server.rs b/router/src/server.rs
index f7850053..88c40565 100644
--- a/router/src/server.rs
+++ b/router/src/server.rs
@@ -455,7 +455,7 @@ pub async fn run(
     max_batch_size: usize,
     max_waiting_tokens: usize,
     client: ShardedClient,
-    tokenizer: Tokenizer,
+    tokenizer: Option<Tokenizer>,
     validation_workers: usize,
     addr: SocketAddr,
     allow_origin: Option<AllowOrigin>,
diff --git a/router/src/validation.rs b/router/src/validation.rs
index ec67cefd..a0b8b98e 100644
--- a/router/src/validation.rs
+++ b/router/src/validation.rs
@@ -23,7 +23,7 @@ pub struct Validation {
 impl Validation {
     pub(crate) fn new(
         workers: usize,
-        tokenizer: Tokenizer,
+        tokenizer: Option<Tokenizer>,
         max_best_of: usize,
         max_stop_sequences: usize,
         max_input_length: usize,
@@ -85,7 +85,7 @@ impl Validation {
 /// Load balance the validation requests between multiple validation workers
 async fn validation_task(
     workers: usize,
-    tokenizer: Tokenizer,
+    tokenizer: Option<Tokenizer>,
     max_stop_sequences: usize,
     max_input_length: usize,
     max_total_tokens: usize,
@@ -95,7 +95,7 @@ async fn validation_task(
 
     // Create workers
     for _ in 0..workers {
-        let tokenizer_clone: Tokenizer = tokenizer.clone().into();
+        let tokenizer_clone: Option<Tokenizer> = tokenizer.clone().into();
         // Create channel to communicate with worker
         let (worker_sender, worker_receiver) = mpsc::channel(workers);
         workers_senders.push(worker_sender);
@@ -127,7 +127,7 @@ async fn validation_task(
 /// Check the parameters inside the payload and get the number of tokens inside the input using
 /// the tokenizer
 fn validation_worker(
-    tokenizer: Tokenizer,
+    tokenizer: Option<Tokenizer>,
     max_stop_sequences: usize,
     max_input_length: usize,
     max_total_tokens: usize,
@@ -143,7 +143,7 @@ fn validation_worker(
                 .send(
                     validate(
                         request,
-                        &tokenizer,
+                        tokenizer.as_ref(),
                         max_stop_sequences,
                         max_input_length,
                         max_total_tokens,
@@ -162,7 +162,7 @@ fn validation_worker(
 
 fn validate(
     request: GenerateRequest,
-    tokenizer: &Tokenizer,
+    tokenizer: Option<&Tokenizer>,
     max_stop_sequences: usize,
     max_input_length: usize,
     max_total_tokens: usize,
@@ -272,35 +272,43 @@ fn validate(
         })
         .unwrap_or(Ok(None))?;
 
-    // Get the number of tokens in the input
-    let mut encoding = tokenizer
-        .encode(request.inputs.clone(), true)
-        .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
-
-    let (inputs, input_length) = if let Some(truncate) = truncate {
-        // truncate encoding and decode new inputs
-        encoding.truncate(truncate, 0, TruncationDirection::Left);
-        let inputs = tokenizer
-            .decode(Vec::from(encoding.get_ids()), false)
+    // If we have a fast tokenizer
+    let inputs = if let Some(tokenizer) = tokenizer {
+        // Get the number of tokens in the input
+        let mut encoding = tokenizer
+            .encode(request.inputs.clone(), true)
             .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
-        (inputs, encoding.len())
+
+        let (inputs, input_length) = if let Some(truncate) = truncate {
+            // truncate encoding and decode new inputs
+            encoding.truncate(truncate, 0, TruncationDirection::Left);
+            let inputs = tokenizer
+                .decode(Vec::from(encoding.get_ids()), false)
+                .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
+            (inputs, encoding.len())
+        } else {
+            (request.inputs, encoding.len())
+        };
+
+        if input_length > max_input_length {
+            return Err(ValidationError::InputLength(max_input_length, input_length));
+        }
+
+        let total_tokens = input_length + max_new_tokens as usize;
+        if total_tokens > max_total_tokens {
+            return Err(ValidationError::MaxTotalTokens(
+                max_total_tokens,
+                input_length,
+                max_new_tokens,
+            ));
+        }
+
+        metrics::histogram!("tgi_request_input_length", input_length as f64);
+        inputs
     } else {
-        (request.inputs, encoding.len())
+        request.inputs
     };
 
-    if input_length > max_input_length {
-        return Err(ValidationError::InputLength(max_input_length, input_length));
-    }
-
-    let total_tokens = input_length + max_new_tokens as usize;
-    if total_tokens > max_total_tokens {
-        return Err(ValidationError::MaxTotalTokens(
-            max_total_tokens,
-            input_length,
-            max_new_tokens,
-        ));
-    }
-
     // Return ValidGenerateRequest
     let parameters = NextTokenChooserParameters {
         temperature,
@@ -318,11 +326,11 @@ fn validate(
         ignore_eos_token: false,
     };
 
-    metrics::histogram!("tgi_request_input_length", input_length as f64);
     metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64);
 
     Ok(ValidGenerateRequest {
         inputs,
+        truncate: truncate.unwrap_or(max_input_length) as u32,
         parameters,
         stopping_parameters,
     })
@@ -337,6 +345,7 @@ type ValidationRequest = (
 #[derive(Debug)]
 pub(crate) struct ValidGenerateRequest {
     pub inputs: String,
+    pub truncate: u32,
     pub parameters: NextTokenChooserParameters,
     pub stopping_parameters: StoppingCriteriaParameters,
 }
diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py
index c2ad0587..7cd49239 100644
--- a/server/text_generation_server/models/causal_lm.py
+++ b/server/text_generation_server/models/causal_lm.py
@@ -66,6 +66,7 @@ class CausalLMBatch(Batch):
         stopping_criterias = []
 
         # Parse batch
+        max_truncation = 0
         padding_right_offset = 0
         for r in pb.requests:
             inputs.append(r.inputs)
@@ -74,6 +75,7 @@ class CausalLMBatch(Batch):
                 r.stopping_parameters, tokenizer
             )
             stopping_criterias.append(stopping_criteria)
+            max_truncation = max(max_truncation, r.truncate)
             padding_right_offset = max(
                 padding_right_offset, stopping_criteria.max_new_tokens
             )
@@ -83,6 +85,8 @@ class CausalLMBatch(Batch):
             return_tensors="pt",
             padding=True,
             return_token_type_ids=False,
+            truncation=True,
+            max_length=max_truncation,
         ).to(device)
 
         input_lengths = tokenized_inputs["attention_mask"].sum(1)
@@ -388,6 +392,7 @@ class CausalLM(Model):
             next_token_logprob = logprobs[-1, next_token_id]
             next_token_id_squeezed = next_token_id.squeeze()
             next_token_text = self.decode_token(
+                all_input_ids[-2, 0],
                 next_token_id_squeezed,
             )
 
diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
index ee977948..6e2fd4b7 100644
--- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
+++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
@@ -21,21 +21,18 @@
 import torch
 import torch.distributed
 
+from torch.nn import functional as F
+
 from torch import nn
 from transformers.activations import ACT2FN
 
-from text_generation_server.models.custom_modeling.tensor_parallel import (
-    TensorParallelEmbedding,
-    TensorParallelRowLinear,
-    TensorParallelColumnLinear,
-)
-from text_generation_server.models.custom_modeling.linear import FastLinear
-from text_generation_server.models.custom_modeling.rotary import PositionRotaryEmbedding
-
 # Flash attention imports
+import rotary_emb
 import flash_attn_cuda
 import dropout_layer_norm
 
+from flash_attn.layers.rotary import RotaryEmbedding
+
 
 class LlamaRMSNorm(nn.Module):
     def __init__(self, hidden_size, eps=1e-6):
@@ -87,6 +84,184 @@ class LlamaRMSNorm(nn.Module):
             return normed_hidden_states, res
 
 
+class FastLinear(nn.Linear):
+    def __init__(
+        self,
+        in_features: int,
+        out_features: int,
+        bias: bool = True,
+        device=None,
+        dtype=None,
+    ) -> None:
+        super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
+
+    def transpose_weight(self):
+        self.weight = nn.Parameter(self.weight.T)
+
+    def forward(self, input: torch.Tensor) -> torch.Tensor:
+        if self.bias is not None:
+            return torch.addmm(self.bias, input, self.weight)
+        return torch.matmul(input, self.weight)
+
+
+class TensorParallelColumnLinear(FastLinear):
+    def __init__(
+        self,
+        in_features,
+        out_features,
+        process_group: torch.distributed.ProcessGroup,
+        bias=True,
+        device=None,
+        dtype=None,
+    ):
+        self.process_group = process_group
+        self.tp_world_size = process_group.size()
+        assert out_features % self.tp_world_size == 0
+        out_features = out_features // self.tp_world_size
+
+        super().__init__(
+            in_features=in_features,
+            out_features=out_features,
+            bias=bias,
+            device=device,
+            dtype=dtype,
+        )
+
+
+class TensorParallelRowLinear(FastLinear):
+    def __init__(
+        self,
+        in_features,
+        out_features,
+        process_group: torch.distributed.ProcessGroup,
+        reduce=True,
+        bias=True,
+        device=None,
+        dtype=None,
+    ):
+        self.process_group = process_group
+        self.tp_world_size = process_group.size()
+        self.reduce = reduce
+        assert in_features % self.tp_world_size == 0
+        in_features = in_features // self.tp_world_size
+
+        super().__init__(
+            in_features=in_features,
+            out_features=out_features,
+            bias=bias,
+            device=device,
+            dtype=dtype,
+        )
+
+    def forward(self, input: torch.Tensor) -> torch.Tensor:
+        out = super(TensorParallelRowLinear, self).forward(input)
+        if self.reduce:
+            torch.distributed.all_reduce(out, group=self.process_group)
+
+        return out
+
+
+class TensorParallelEmbedding(nn.Embedding):
+    def __init__(
+        self,
+        num_embeddings,
+        embedding_dim,
+        process_group: torch.distributed.ProcessGroup,
+        padding_idx=None,
+        max_norm=None,
+        norm_type=2.0,
+        scale_grad_by_freq=False,
+        sparse=False,
+        _weight=None,
+        device=None,
+        dtype=None,
+    ):
+        self.process_group = process_group
+        self.tp_rank = process_group.rank()
+        self.tp_world_size = process_group.size()
+
+        self.original_num_embeddings = num_embeddings
+
+        assert num_embeddings % self.tp_world_size == 0
+        block_size = num_embeddings // self.tp_world_size
+        # inputs in `[min_id, max_id[` are handled by `self` to get embeddings
+        self.min_id = self.tp_rank * block_size
+        self.max_id = (self.tp_rank + 1) * block_size
+
+        # Additional entry that will map to zero
+        # Used for masking
+        self.null_idx = block_size
+
+        super().__init__(
+            block_size,
+            embedding_dim,
+            padding_idx=padding_idx,
+            max_norm=max_norm,
+            norm_type=norm_type,
+            scale_grad_by_freq=scale_grad_by_freq,
+            sparse=sparse,
+            _weight=_weight,
+            device=device,
+            dtype=dtype,
+        )
+
+    def add_null_idx(self):
+        """Additional 0 entry used for masking"""
+        self.weight = nn.Parameter(F.pad(self.weight, (0, 0, 0, 1)))
+
+    def forward(self, input: torch.Tensor) -> torch.Tensor:
+        # default all out of bounds values to `self.null_idx` that will then be mapped to 0
+        # translate for [0, self.max_id - self.min_id[
+        input = torch.where(
+            (self.min_id > input) | (input >= self.max_id),
+            self.null_idx,
+            input - self.min_id,
+        )
+        out = super().forward(input)
+        torch.distributed.all_reduce(out, group=self.process_group)
+        return out
+
+
+class PositionRotaryEmbedding(RotaryEmbedding):
+    def _update_cos_sin_cache(self, dtype, device, seqlen):
+        # Reset the tables if the sequence length has changed,
+        # or if we're on a new device (possibly due to tracing for instance)
+        if (
+            seqlen > self._seq_len_cached
+            or self._cos_cached.device != device
+            or self._cos_cached.dtype != dtype
+        ):
+            self._seq_len_cached = seqlen
+            t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
+            # Don't do einsum, it converts fp32 to fp16
+            # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+            freqs = torch.outer(t, self.inv_freq.to(device=t.device))
+            self._cos_cached = torch.cos(freqs).to(dtype)
+            self._sin_cached = torch.sin(freqs).to(dtype)
+
+    def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype):
+        """
+        Return cos and sin for the asked position ids
+        """
+
+        self._update_cos_sin_cache(dtype, position_ids.device, max_s)
+
+        cos = torch.index_select(self._cos_cached, 0, position_ids)
+        sin = torch.index_select(self._sin_cached, 0, position_ids)
+        return cos.unsqueeze(1), sin.unsqueeze(1)
+
+    def forward(self, qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
+        rotary_dim = cos.shape[-1]
+        q1 = qkv[:, 0, :, :rotary_dim]
+        q2 = qkv[:, 0, :, rotary_dim : 2 * rotary_dim]
+        k1 = qkv[:, 1, :, :rotary_dim]
+        k2 = qkv[:, 1, :, rotary_dim : 2 * rotary_dim]
+
+        rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
+        rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
+        return qkv
+
+
 class FlashLlamaAttention(torch.nn.Module):
     def __init__(
         self,
diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
index 93f1b0ca..f3517c47 100644
--- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
+++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
@@ -21,23 +21,20 @@
 import torch
 import torch.distributed
 
+from torch.nn import functional as F
+
 from torch import nn
 from transformers.activations import ACT2FN
 from transformers.modeling_utils import PreTrainedModel
 from transformers.models.gpt_neox import GPTNeoXConfig
 
-from text_generation_server.models.custom_modeling.tensor_parallel import (
-    TensorParallelEmbedding,
-    TensorParallelRowLinear,
-    TensorParallelColumnLinear,
-)
-from text_generation_server.models.custom_modeling.linear import FastLinear
-from text_generation_server.models.custom_modeling.rotary import PositionRotaryEmbedding
-
 # Flash attention imports
+import rotary_emb
 import flash_attn_cuda
 import dropout_layer_norm
 
+from flash_attn.layers.rotary import RotaryEmbedding
+
 
 class FastLayerNorm(nn.LayerNorm):
     def forward(self, hidden_states, residual=None):
@@ -75,6 +72,184 @@ class FastLayerNorm(nn.LayerNorm):
             return normed_hidden_states, residual
 
 
+class FastLinear(nn.Linear):
+    def __init__(
+        self,
+        in_features: int,
+        out_features: int,
+        bias: bool = True,
+        device=None,
+        dtype=None,
+    ) -> None:
+        super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
+
+    def transpose_weight(self):
+        self.weight = nn.Parameter(self.weight.T)
+
+    def forward(self, input: torch.Tensor) -> torch.Tensor:
+        if self.bias is not None:
+            return torch.addmm(self.bias, input, self.weight)
+        return torch.matmul(input, self.weight)
+
+
+class TensorParallelColumnLinear(FastLinear):
+    def __init__(
+        self,
+        in_features,
+        out_features,
+        process_group: torch.distributed.ProcessGroup,
+        bias=True,
+        device=None,
+        dtype=None,
+    ):
+        self.process_group = process_group
+        self.tp_world_size = process_group.size()
+        assert out_features % self.tp_world_size == 0
+        out_features = out_features // self.tp_world_size
+
+        super().__init__(
+            in_features=in_features,
+            out_features=out_features,
+            bias=bias,
+            device=device,
+            dtype=dtype,
+        )
+
+
+class TensorParallelRowLinear(FastLinear):
+    def __init__(
+        self,
+        in_features,
+        out_features,
+        process_group: torch.distributed.ProcessGroup,
+        reduce=True,
+        bias=True,
+        device=None,
+        dtype=None,
+    ):
+        self.process_group = process_group
+        self.tp_world_size = process_group.size()
+        self.reduce = reduce
+        assert in_features % self.tp_world_size == 0
+        in_features = in_features // self.tp_world_size
+
+        super().__init__(
+            in_features=in_features,
+            out_features=out_features,
+            bias=bias,
+            device=device,
+            dtype=dtype,
+        )
+
+    def forward(self, input: torch.Tensor) -> torch.Tensor:
+        out = super(TensorParallelRowLinear, self).forward(input)
+        if self.reduce:
+            torch.distributed.all_reduce(out, group=self.process_group)
+
+        return out
+
+
+class TensorParallelEmbedding(nn.Embedding):
+    def __init__(
+        self,
+        num_embeddings,
+        embedding_dim,
+        process_group: torch.distributed.ProcessGroup,
+        padding_idx=None,
+        max_norm=None,
+        norm_type=2.0,
+        scale_grad_by_freq=False,
+        sparse=False,
+        _weight=None,
+        device=None,
+        dtype=None,
+    ):
+        self.process_group = process_group
+        self.tp_rank = process_group.rank()
+        self.tp_world_size = process_group.size()
+
+        self.original_num_embeddings = num_embeddings
+
+        assert num_embeddings % self.tp_world_size == 0
+        block_size = num_embeddings // self.tp_world_size
+        # inputs in `[min_id, max_id[` are handled by `self` to get embeddings
+        self.min_id = self.tp_rank * block_size
+        self.max_id = (self.tp_rank + 1) * block_size
+
+        # Additional entry that will map to zero
+        # Used for masking
+        self.null_idx = block_size
+
+        super().__init__(
+            block_size,
+            embedding_dim,
+            padding_idx=padding_idx,
+            max_norm=max_norm,
+            norm_type=norm_type,
+            scale_grad_by_freq=scale_grad_by_freq,
+            sparse=sparse,
+            _weight=_weight,
+            device=device,
+            dtype=dtype,
+        )
+
+    def add_null_idx(self):
+        """Additional 0 entry used for masking"""
+        self.weight = nn.Parameter(F.pad(self.weight, (0, 0, 0, 1)))
+
+    def forward(self, input: torch.Tensor) -> torch.Tensor:
+        # default all out of bounds values to `self.null_idx` that will then be mapped to 0
+        # translate for [0, self.max_id - self.min_id[
+        input = torch.where(
+            (self.min_id > input) | (input >= self.max_id),
+            self.null_idx,
+            input - self.min_id,
+        )
+        out = super().forward(input)
+        torch.distributed.all_reduce(out, group=self.process_group)
+        return out
+
+
+class PositionRotaryEmbedding(RotaryEmbedding):
+    def _update_cos_sin_cache(self, dtype, device, seqlen):
+        # Reset the tables if the sequence length has changed,
+        # or if we're on a new device (possibly due to tracing for instance)
+        if (
+            seqlen > self._seq_len_cached
+            or self._cos_cached.device != device
+            or self._cos_cached.dtype != dtype
+        ):
+            self._seq_len_cached = seqlen
+            t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
+            # Don't do einsum, it converts fp32 to fp16
+            # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+            freqs = torch.outer(t, self.inv_freq.to(device=t.device))
+            self._cos_cached = torch.cos(freqs).to(dtype)
+            self._sin_cached = torch.sin(freqs).to(dtype)
+
+    def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype):
+        """
+        Return cos and sin for the asked position ids
+        """
+
+        self._update_cos_sin_cache(dtype, position_ids.device, max_s)
+
+        cos = torch.index_select(self._cos_cached, 0, position_ids)
+        sin = torch.index_select(self._sin_cached, 0, position_ids)
+        return cos.unsqueeze(1), sin.unsqueeze(1)
+
+    def forward(self, qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
+        rotary_dim = cos.shape[-1]
+        q1 = qkv[:, 0, :, :rotary_dim]
+        q2 = qkv[:, 0, :, rotary_dim : 2 * rotary_dim]
+        k1 = qkv[:, 1, :, :rotary_dim]
+        k2 = qkv[:, 1, :, rotary_dim : 2 * rotary_dim]
+
+        rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
+        rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
+        return qkv
+
+
 class FlashNeoxAttention(torch.nn.Module):
     def __init__(
         self,
@@ -201,12 +376,7 @@ class FlashMLP(nn.Module):
         self.act = (
             ACT2FN[act]
             if "gelu" not in act
-            else lambda x: torch.nn.functional.gelu(
-                x,
-                approximate="tanh"
-                if act in ["gelu_fast", "gelu_pytorch_tanh"]
-                else None,
-            )
+            else lambda x: torch.nn.functional.gelu(x, approximate="tanh")
         )
 
         if process_group is None:
diff --git a/server/text_generation_server/models/custom_modeling/linear.py b/server/text_generation_server/models/custom_modeling/linear.py
deleted file mode 100644
index 1ca8e3a9..00000000
--- a/server/text_generation_server/models/custom_modeling/linear.py
+++ /dev/null
@@ -1,22 +0,0 @@
-import torch
-from torch import nn
-
-
-class FastLinear(nn.Linear):
-    def __init__(
-        self,
-        in_features: int,
-        out_features: int,
-        bias: bool = True,
-        device=None,
-        dtype=None,
-    ) -> None:
-        super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
-
-    def transpose_weight(self):
-        self.weight = nn.Parameter(self.weight.T)
-
-    def forward(self, input: torch.Tensor) -> torch.Tensor:
-        if self.bias is not None:
-            return torch.addmm(self.bias, input, self.weight)
-        return torch.matmul(input, self.weight)
diff --git a/server/text_generation_server/models/custom_modeling/rotary.py b/server/text_generation_server/models/custom_modeling/rotary.py
deleted file mode 100644
index 69f97558..00000000
--- a/server/text_generation_server/models/custom_modeling/rotary.py
+++ /dev/null
@@ -1,42 +0,0 @@
-import torch
-import rotary_emb
-
-from flash_attn.layers.rotary import RotaryEmbedding
-
-
-class PositionRotaryEmbedding(RotaryEmbedding):
-    def _update_cos_sin_cache(self, dtype, device, seqlen):
-        # Reset the tables if the sequence length has changed,
-        # or if we're on a new device (possibly due to tracing for instance)
-        if (
-            seqlen > self._seq_len_cached
-            or self._cos_cached.device != device
-            or self._cos_cached.dtype != dtype
-        ):
-            self._seq_len_cached = seqlen
-            t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
-            freqs = torch.outer(t, self.inv_freq.to(device=t.device))
-            self._cos_cached = torch.cos(freqs).to(dtype)
-            self._sin_cached = torch.sin(freqs).to(dtype)
-
-    def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype):
-        """
-        Return cos and sin for the asked position ids
-        """
-
-        self._update_cos_sin_cache(dtype, position_ids.device, max_s)
-
-        cos = torch.index_select(self._cos_cached, 0, position_ids)
-        sin = torch.index_select(self._sin_cached, 0, position_ids)
-        return cos.unsqueeze(1), sin.unsqueeze(1)
-
-    def forward(self, qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
-        rotary_dim = cos.shape[-1]
-        q1 = qkv[:, 0, :, :rotary_dim]
-        q2 = qkv[:, 0, :, rotary_dim : 2 * rotary_dim]
-        k1 = qkv[:, 1, :, :rotary_dim]
-        k2 = qkv[:, 1, :, rotary_dim : 2 * rotary_dim]
-
-        rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
-        rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
-        return qkv
diff --git a/server/text_generation_server/models/custom_modeling/tensor_parallel.py b/server/text_generation_server/models/custom_modeling/tensor_parallel.py
deleted file mode 100644
index e01b8835..00000000
--- a/server/text_generation_server/models/custom_modeling/tensor_parallel.py
+++ /dev/null
@@ -1,124 +0,0 @@
-import torch
-import torch.distributed
-from torch import nn
-from torch.nn import functional as F
-
-from text_generation_server.models.custom_modeling.linear import FastLinear
-
-
-class TensorParallelColumnLinear(FastLinear):
-    def __init__(
-        self,
-        in_features,
-        out_features,
-        process_group: torch.distributed.ProcessGroup,
-        bias=True,
-        device=None,
-        dtype=None,
-    ):
-        self.process_group = process_group
-        self.tp_world_size = process_group.size()
-        assert out_features % self.tp_world_size == 0
-        out_features = out_features // self.tp_world_size
-
-        super().__init__(
-            in_features=in_features,
-            out_features=out_features,
-            bias=bias,
-            device=device,
-            dtype=dtype,
-        )
-
-
-class TensorParallelRowLinear(FastLinear):
-    def __init__(
-        self,
-        in_features,
-        out_features,
-        process_group: torch.distributed.ProcessGroup,
-        reduce=True,
-        bias=True,
-        device=None,
-        dtype=None,
-    ):
-        self.process_group = process_group
-        self.tp_world_size = process_group.size()
-        self.reduce = reduce
-        assert in_features % self.tp_world_size == 0
-        in_features = in_features // self.tp_world_size
-
-        super().__init__(
-            in_features=in_features,
-            out_features=out_features,
-            bias=bias,
-            device=device,
-            dtype=dtype,
-        )
-
-    def forward(self, input: torch.Tensor) -> torch.Tensor:
-        out = super(TensorParallelRowLinear, self).forward(input)
-        if self.reduce:
-            torch.distributed.all_reduce(out, group=self.process_group)
-
-        return out
-
-
-class TensorParallelEmbedding(nn.Embedding):
-    def __init__(
-        self,
-        num_embeddings,
-        embedding_dim,
-        process_group: torch.distributed.ProcessGroup,
-        padding_idx=None,
-        max_norm=None,
-        norm_type=2.0,
-        scale_grad_by_freq=False,
-        sparse=False,
-        _weight=None,
-        device=None,
-        dtype=None,
-    ):
-        self.process_group = process_group
-        self.tp_rank = process_group.rank()
-        self.tp_world_size = process_group.size()
-
-        self.original_num_embeddings = num_embeddings
-
-        assert num_embeddings % self.tp_world_size == 0
-        block_size = num_embeddings // self.tp_world_size
-        # inputs in `[min_id, max_id[` are handled by `self` to get embeddings
-        self.min_id = self.tp_rank * block_size
-        self.max_id = (self.tp_rank + 1) * block_size
-
-        # Additional entry that will map to zero
-        # Used for masking
-        self.null_idx = block_size
-
-        super().__init__(
-            block_size,
-            embedding_dim,
-            padding_idx=padding_idx,
-            max_norm=max_norm,
-            norm_type=norm_type,
-            scale_grad_by_freq=scale_grad_by_freq,
-            sparse=sparse,
-            _weight=_weight,
-            device=device,
-            dtype=dtype,
-        )
-
-    def add_null_idx(self):
-        """Additional 0 entry used for masking"""
-        self.weight = nn.Parameter(F.pad(self.weight, (0, 0, 0, 1)))
-
-    def forward(self, input: torch.Tensor) -> torch.Tensor:
-        # default all out of bounds values to `self.null_idx` that will then be mapped to 0
-        # translate for [0, self.max_id - self.min_id[
-        input = torch.where(
-            (self.min_id > input) | (input >= self.max_id),
-            self.null_idx,
-            input - self.min_id,
-        )
-        out = super().forward(input)
-        torch.distributed.all_reduce(out, group=self.process_group)
-        return out
diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py
index e1a10cbf..2aeac7b5 100644
--- a/server/text_generation_server/models/flash_causal_lm.py
+++ b/server/text_generation_server/models/flash_causal_lm.py
@@ -78,7 +78,9 @@ class FlashCausalLMBatch(Batch):
 
         # Parse batch
         for r in pb.requests:
-            tokenized_input = tokenizer(r.inputs)["input_ids"]
+            tokenized_input = tokenizer(
+                r.inputs, truncation=True, max_length=r.truncate
+            )["input_ids"]
             input_length = len(tokenized_input)
             max_seqlen = max(max_seqlen, input_length)
             input_lengths.append(input_length)
@@ -333,6 +335,7 @@ class FlashCausalLM(Model):
             # Generated token
             next_token_logprob = logprobs[-1, next_token_id_item]
             next_token_text = self.decode_token(
+                all_input_ids[-2],
                 next_token_id_item,
             )
 
diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py
index 7d7e2cf5..3029ab89 100644
--- a/server/text_generation_server/models/flash_llama.py
+++ b/server/text_generation_server/models/flash_llama.py
@@ -11,8 +11,6 @@ from typing import Optional, Tuple, List
 from text_generation_server.models import FlashCausalLM
 from text_generation_server.models.custom_modeling.flash_llama_modeling import (
     FlashLlamaForCausalLM,
-)
-from text_generation_server.models.custom_modeling.tensor_parallel import (
     TensorParallelEmbedding,
     TensorParallelRowLinear,
     TensorParallelColumnLinear,
diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py
index 655b664e..e415a725 100644
--- a/server/text_generation_server/models/flash_neox.py
+++ b/server/text_generation_server/models/flash_neox.py
@@ -8,14 +8,12 @@ from transformers import AutoTokenizer, AutoConfig
 from typing import Optional, Tuple, List
 
 from text_generation_server.models import FlashCausalLM
-from text_generation_server.models.custom_modeling.tensor_parallel import (
+from text_generation_server.models.custom_modeling.flash_neox_modeling import (
+    FlashGPTNeoXForCausalLM,
     TensorParallelEmbedding,
     TensorParallelRowLinear,
     TensorParallelColumnLinear,
 )
-from text_generation_server.models.custom_modeling.flash_neox_modeling import (
-    FlashGPTNeoXForCausalLM,
-)
 from text_generation_server.utils import (
     initialize_torch_distributed,
     weight_files,
diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py
index a90a299e..f997ab1a 100644
--- a/server/text_generation_server/models/galactica.py
+++ b/server/text_generation_server/models/galactica.py
@@ -96,6 +96,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
         input_lengths = []
 
         # Parse batch
+        max_truncation = 0
         max_sequence_length = 0
         padding_right_offset = 0
         for r in pb.requests:
@@ -107,6 +108,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
                 r.stopping_parameters, tokenizer
             )
             stopping_criterias.append(stopping_criteria)
+            max_truncation = max(max_truncation, r.truncate)
             max_sequence_length = max(max_sequence_length, r.input_length)
             padding_right_offset = max(
                 padding_right_offset, stopping_criteria.max_new_tokens
@@ -118,6 +120,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
             return_tensors="pt",
             padding=True,
             return_token_type_ids=False,
+            truncation=True,
+            max_length=max_truncation,
         ).to(device)
         input_ids = tokenized_inputs["input_ids"]
         # Allocate maximum attention_mask
diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py
index e0ce6686..9e519779 100644
--- a/server/text_generation_server/models/model.py
+++ b/server/text_generation_server/models/model.py
@@ -15,15 +15,6 @@ class Model(ABC):
         self.all_special_ids = set(tokenizer.all_special_ids)
         self.device = device
 
-        # see `decode_token` method
-        self.tokenizer.add_special_tokens(
-            {"additional_special_tokens": ["<decode-token>"]}
-        )
-        self.special_decode_token_id = self.tokenizer.convert_tokens_to_ids(
-            "<decode-token>"
-        )
-        self.special_decode_token_length = len("<decode-token>")
-
     @property
     @abstractmethod
     def batch_type(self) -> Type[B]:
@@ -33,11 +24,12 @@ class Model(ABC):
     def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
         raise NotImplementedError
 
-    def decode_token(self, token_id: int) -> str:
+    def decode_token(self, previous_token_id: int, token_id: int) -> str:
         """Hack to hopefully support generate_stream for the maximum number of tokenizers"""
-        # append token to special decode token and decode both
-        result = self.tokenizer.decode(
-            [self.special_decode_token_id, token_id], skip_special_tokens=False
+        # Decode previous token and previous token + token
+        results = self.tokenizer.batch_decode(
+            [[previous_token_id], [previous_token_id, token_id]],
+            skip_special_tokens=False,
         )
-        # slice to remove special decode token
-        return result[self.special_decode_token_length :]
+        # slice to remove previous token
+        return results[1][len(results[0]) :]
diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py
index 0fe5c03f..72f694c3 100644
--- a/server/text_generation_server/models/seq2seq_lm.py
+++ b/server/text_generation_server/models/seq2seq_lm.py
@@ -73,6 +73,7 @@ class Seq2SeqLMBatch(Batch):
         decoder_input_lengths = []
 
         # Parse batch
+        max_truncation = 0
         padding_right_offset = 0
         for r in pb.requests:
             inputs.append(r.inputs)
@@ -84,6 +85,7 @@ class Seq2SeqLMBatch(Batch):
                 r.stopping_parameters, tokenizer
             )
             stopping_criterias.append(stopping_criteria)
+            max_truncation = max(max_truncation, r.truncate)
             padding_right_offset = max(
                 padding_right_offset, stopping_criteria.max_new_tokens
             )
@@ -94,6 +96,8 @@ class Seq2SeqLMBatch(Batch):
             return_tensors="pt",
             padding=True,
             return_token_type_ids=False,
+            truncation=True,
+            max_length=max_truncation,
         ).to(device)
 
         input_lengths = tokenized_inputs["attention_mask"].sum(1)
@@ -463,6 +467,7 @@ class Seq2SeqLM(Model):
             next_token_logprob = logprobs[-1, next_token_id]
             next_token_id_squeezed = next_token_id.squeeze()
             next_token_text = self.decode_token(
+                decoder_input_ids[-2],
                 next_token_id_squeezed,
             )