diff --git a/router/src/batcher.rs b/router/src/batcher.rs index aacc9634e..a9b892cc7 100644 --- a/router/src/batcher.rs +++ b/router/src/batcher.rs @@ -104,10 +104,9 @@ async fn batching_task( // Get the next batch from the DB // This batch might be smaller than the maximum batch size if there are not enough requests // waiting in the DB - let mut waiting_tokens = 0; - if let Some((request_ids, batch)) = db.next_batch(None, max_batch_size) { + while let Some((request_ids, batch)) = db.next_batch(None, max_batch_size) { let mut cached_batch = wrap_future(client.generate(batch), request_ids, &db).await; - waiting_tokens += 1; + let mut waiting_tokens = 1; // We loop until we do not receive any cached batch from the inference server (== until // all requests have met their stopping criteria) @@ -131,11 +130,11 @@ async fn batching_task( if let Some((new_request_ids, new_batch)) = db.next_batch(min_size, max_batch_size) { - // Reset waiting counter - waiting_tokens = 0; // Generate one token for this new batch to have the attention past in cache let new_cached_batch = wrap_future(client.generate(new_batch), new_request_ids, &db).await; + // Reset waiting counter + waiting_tokens = 1; // Extend current batch with the new batch if let Some(new_cached_batch) = new_cached_batch { request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id)); diff --git a/router/src/server.rs b/router/src/server.rs index 72b720efa..9f4a75c9e 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -90,11 +90,7 @@ async fn generate( // Validate request let (input_length, validated_request) = state .validation - // FIXME: can't we get rid of the cloning here?? - .validate(GenerateRequest { - inputs: req.inputs.clone(), - parameters: req.parameters.clone(), - }) + .validate(req.0) .await .map_err(|err| { tracing::error!("{}", err.to_string()); diff --git a/router/src/validation.rs b/router/src/validation.rs index a105ddf3e..0ddfe544c 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -155,7 +155,7 @@ type ValidationRequest = ( pub enum ValidationError { #[error("temperature must be strictly positive")] Temperature, - #[error("top_p must be >= 0.0 or < 1.0")] + #[error("top_p must be > 0.0 and <= 1.0")] TopP, #[error("top_k must be strictly positive")] TopK, diff --git a/server/text_generation/models/bloom.py b/server/text_generation/models/bloom.py index 008288f83..2a7405d39 100644 --- a/server/text_generation/models/bloom.py +++ b/server/text_generation/models/bloom.py @@ -82,7 +82,6 @@ class BLOOMSharded(CausalLM): torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( tokenizer=tokenizer, - num_heads=config.n_head // self.process_group.size(), device=device, ) diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index ca8ea5757..4e66ae3a3 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -251,7 +251,6 @@ class CausalLM(Model): super(CausalLM, self).__init__( tokenizer=tokenizer, - num_heads=self.model.config.num_attention_heads, device=device, ) @@ -358,7 +357,7 @@ class CausalLM(Model): # Force past to be of dim [batch_size, num_heads, ...] for easy indexing next_batch_past_key_values = [ [ - t.view(-1, self.num_heads, *t.shape[-2:])[next_batch_keep_indices] + t.view(batch.size, -1, *t.shape[-2:])[next_batch_keep_indices] for t in layer ] for layer in past @@ -381,7 +380,7 @@ class CausalLM(Model): next_batch_attention_mask = torch.cat( [ next_batch_attention_mask, - torch.ones((next_batch_size, 1)).to(self.device), + next_batch_attention_mask.new_ones(next_batch_size, 1), ], dim=1, ) diff --git a/server/text_generation/models/galactica.py b/server/text_generation/models/galactica.py index abc3c36c9..5de75ab45 100644 --- a/server/text_generation/models/galactica.py +++ b/server/text_generation/models/galactica.py @@ -185,7 +185,6 @@ class GalacticaSharded(Galactica): torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( tokenizer=tokenizer, - num_heads=config.num_attention_heads // self.process_group.size(), device=device, ) diff --git a/server/text_generation/models/model.py b/server/text_generation/models/model.py index 7fb8142c5..0331e1938 100644 --- a/server/text_generation/models/model.py +++ b/server/text_generation/models/model.py @@ -10,9 +10,8 @@ B = TypeVar("B", bound=Batch) class Model(ABC): - def __init__(self, tokenizer: Tokenizer, num_heads: int, device: torch.device): + def __init__(self, tokenizer: Tokenizer, device: torch.device): self.tokenizer = tokenizer - self.num_heads = num_heads self.device = device @property diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index cb1291abb..e9c65596f 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -87,7 +87,7 @@ class Seq2SeqLMBatch: inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8 ).to(device) # Convert decoder_input_ids to torch tensor of size [batch_size, 1] - decoder_input_ids = torch.tensor(decoder_input_ids).to(device).unsqueeze(-1) + decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1) return cls( batch_id=pb.id, @@ -319,7 +319,6 @@ class Seq2SeqLM(Model): super(Seq2SeqLM, self).__init__( tokenizer=tokenizer, - num_heads=self.model.config.num_attention_heads, device=device, ) @@ -499,7 +498,7 @@ class Seq2SeqLM(Model): next_batch_decoder_attention_mask = torch.cat( [ next_batch_decoder_attention_mask, - torch.ones((next_batch_size, 1)).to(self.device), + next_batch_decoder_attention_mask.new_ones(next_batch_size, 1), ], dim=1, )