diff --git a/backends/client/src/v3/sharded_client.rs b/backends/client/src/v3/sharded_client.rs index b8a9182c..dc3bcdde 100644 --- a/backends/client/src/v3/sharded_client.rs +++ b/backends/client/src/v3/sharded_client.rs @@ -123,9 +123,15 @@ impl ShardedClient { .await .into_iter() .collect::, u32, u32)>>>()?; - let first = results.first().expect("Expect at least 1 warmup result"); - assert!(results.iter().all(|&item| item == *first)); - Ok(*first) + + // Take the minimum value + // Different shards hold different parts of vocab, might yield + // different available block size. + let min = results + .iter() + .min() + .expect("Expect at least 1 warmup result"); + Ok(*min) } /// Generate one token for each request in the given batch diff --git a/backends/v3/src/client/sharded_client.rs b/backends/v3/src/client/sharded_client.rs index ac916d94..6d4e207b 100644 --- a/backends/v3/src/client/sharded_client.rs +++ b/backends/v3/src/client/sharded_client.rs @@ -119,15 +119,19 @@ impl ShardedClient { )) }) .collect(); - // Take the minimum value let results = join_all(futures) .await .into_iter() .collect::, u32, u32)>>>()?; - let first = results.first().expect("Expect at least 1 warmup result"); - assert!(results.iter().all(|&item| item == *first)); - Ok(*first) + // Take the minimum value + // Different shards hold different parts of vocab, might yield + // different available block size. + let min = results + .iter() + .min() + .expect("Expect at least 1 warmup result"); + Ok(*min) } /// Generate one token for each request in the given batch diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index dfc61fb8..3bba1cf2 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -1,7 +1,7 @@ import torch import torch.distributed from transformers import AutoTokenizer, PreTrainedTokenizerBase -from typing import Optional +from typing import Optional, Union from text_generation_server.models.custom_modeling.mamba_modeling import ( MambaConfig, ) @@ -475,7 +475,9 @@ class Mamba(Model): def batch_type(self) -> Type[MambaBatch]: return MambaBatch - def warmup(self, batch) -> Optional[int]: + def warmup( + self, batch, max_input_tokens: Optional[int], max_total_tokens: Optional[int] + ) -> Union[Optional[int], Optional[int], Optional[int]]: # TODO: implement warmup for Mamba if needed if CUDA_GRAPHS: if self.speculate is None or self.speculate == 0: @@ -489,7 +491,12 @@ class Mamba(Model): else: logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).") - return None + if max_total_tokens is None: + max_total_tokens = min(self.tokenizer.model_max_length, 4096) + + if max_input_tokens is None: + max_input_tokens = max_total_tokens - 1 + return None, max_input_tokens, max_total_tokens def cuda_graph_warmup(self, batch_size: int): input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device)