diff --git a/router/src/batcher.rs b/router/src/batcher.rs
index aacc9634..0b7e0d5f 100644
--- a/router/src/batcher.rs
+++ b/router/src/batcher.rs
@@ -105,7 +105,7 @@ async fn batching_task(
         // This batch might be smaller than the maximum batch size if there are not enough requests
         // waiting in the DB
         let mut waiting_tokens = 0;
-        if let Some((request_ids, batch)) = db.next_batch(None, max_batch_size) {
+        while let Some((request_ids, batch)) = db.next_batch(None, max_batch_size) {
             let mut cached_batch = wrap_future(client.generate(batch), request_ids, &db).await;
             waiting_tokens += 1;
 
diff --git a/router/src/server.rs b/router/src/server.rs
index 72b720ef..9f4a75c9 100644
--- a/router/src/server.rs
+++ b/router/src/server.rs
@@ -90,11 +90,7 @@ async fn generate(
     // Validate request
     let (input_length, validated_request) = state
         .validation
-        // FIXME: can't we get rid of the cloning here??
-        .validate(GenerateRequest {
-            inputs: req.inputs.clone(),
-            parameters: req.parameters.clone(),
-        })
+        .validate(req.0)
         .await
         .map_err(|err| {
             tracing::error!("{}", err.to_string());
diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py
index ca8ea575..2c55508b 100644
--- a/server/text_generation/models/causal_lm.py
+++ b/server/text_generation/models/causal_lm.py
@@ -381,7 +381,7 @@ class CausalLM(Model):
         next_batch_attention_mask = torch.cat(
             [
                 next_batch_attention_mask,
-                torch.ones((next_batch_size, 1)).to(self.device),
+                next_batch_attention_mask.new_ones(next_batch_size, 1),
             ],
             dim=1,
         )
diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py
index cb1291ab..f63a8849 100644
--- a/server/text_generation/models/seq2seq_lm.py
+++ b/server/text_generation/models/seq2seq_lm.py
@@ -87,7 +87,7 @@ class Seq2SeqLMBatch:
             inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
         ).to(device)
         # Convert decoder_input_ids to torch tensor of size [batch_size, 1]
-        decoder_input_ids = torch.tensor(decoder_input_ids).to(device).unsqueeze(-1)
+        decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
 
         return cls(
             batch_id=pb.id,
@@ -499,7 +499,7 @@ class Seq2SeqLM(Model):
             next_batch_decoder_attention_mask = torch.cat(
                 [
                     next_batch_decoder_attention_mask,
-                    torch.ones((next_batch_size, 1)).to(self.device),
+                    next_batch_decoder_attention_mask.new_ones(next_batch_size, 1),
                 ],
                 dim=1,
             )