From a172430d8b3bbacc6a0772688a7ed77c900cd292 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 29 Nov 2022 13:22:25 -0800 Subject: [PATCH] fix: Some small fixes - Avoid theoretical hang in batcher loop - Avoid a couple of clones in py server generate method - Keep attention mask tensors as integers --- router/src/batcher.rs | 2 +- router/src/server.rs | 6 +----- server/text_generation/models/causal_lm.py | 2 +- server/text_generation/models/seq2seq_lm.py | 4 ++-- 4 files changed, 5 insertions(+), 9 deletions(-) diff --git a/router/src/batcher.rs b/router/src/batcher.rs index aacc9634..0b7e0d5f 100644 --- a/router/src/batcher.rs +++ b/router/src/batcher.rs @@ -105,7 +105,7 @@ async fn batching_task( // This batch might be smaller than the maximum batch size if there are not enough requests // waiting in the DB let mut waiting_tokens = 0; - if let Some((request_ids, batch)) = db.next_batch(None, max_batch_size) { + while let Some((request_ids, batch)) = db.next_batch(None, max_batch_size) { let mut cached_batch = wrap_future(client.generate(batch), request_ids, &db).await; waiting_tokens += 1; diff --git a/router/src/server.rs b/router/src/server.rs index 72b720ef..9f4a75c9 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -90,11 +90,7 @@ async fn generate( // Validate request let (input_length, validated_request) = state .validation - // FIXME: can't we get rid of the cloning here?? - .validate(GenerateRequest { - inputs: req.inputs.clone(), - parameters: req.parameters.clone(), - }) + .validate(req.0) .await .map_err(|err| { tracing::error!("{}", err.to_string()); diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index ca8ea575..2c55508b 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -381,7 +381,7 @@ class CausalLM(Model): next_batch_attention_mask = torch.cat( [ next_batch_attention_mask, - torch.ones((next_batch_size, 1)).to(self.device), + next_batch_attention_mask.new_ones(next_batch_size, 1), ], dim=1, ) diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index cb1291ab..f63a8849 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -87,7 +87,7 @@ class Seq2SeqLMBatch: inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8 ).to(device) # Convert decoder_input_ids to torch tensor of size [batch_size, 1] - decoder_input_ids = torch.tensor(decoder_input_ids).to(device).unsqueeze(-1) + decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1) return cls( batch_id=pb.id, @@ -499,7 +499,7 @@ class Seq2SeqLM(Model): next_batch_decoder_attention_mask = torch.cat( [ next_batch_decoder_attention_mask, - torch.ones((next_batch_size, 1)).to(self.device), + next_batch_decoder_attention_mask.new_ones(next_batch_size, 1), ], dim=1, )