diff --git a/backends/v2/src/backend.rs b/backends/v2/src/backend.rs index bc264138d..cfe87f98f 100644 --- a/backends/v2/src/backend.rs +++ b/backends/v2/src/backend.rs @@ -104,6 +104,10 @@ impl Backend for BackendV2 { } .is_ok() } + + fn start_health(&self) -> bool { + true + } } /// Batching logic diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index 7ae794a09..736301b33 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -111,6 +111,10 @@ impl Backend for BackendV3 { } .is_ok() } + + fn start_health(&self) -> bool { + true + } } /// Batching logic diff --git a/backends/v3/src/client/sharded_client.rs b/backends/v3/src/client/sharded_client.rs index 6d4e207bb..4701c5600 100644 --- a/backends/v3/src/client/sharded_client.rs +++ b/backends/v3/src/client/sharded_client.rs @@ -217,8 +217,8 @@ impl Health for ShardedClient { input_chunks: Some(Input { chunks: vec![Chunk::Text("liveness".into()).into()], }), - truncate: 10, - add_special_tokens: true, + truncate: 1, + add_special_tokens: false, prefill_logprobs: false, parameters: Some(NextTokenChooserParameters { temperature: 1.0, @@ -241,7 +241,7 @@ impl Health for ShardedClient { top_n_tokens: 0, // Block 0 is reserved for health checks blocks: vec![0], - slots: (0..16).collect(), + slots: vec![0], cache_len: 0, adapter_id: None, chunk_len: None, diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 866066438..6497d8578 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -33,6 +33,13 @@ pub trait Backend { ) -> Result>, InferError>; async fn health(&self, current_health: bool) -> bool; + + /// The state of the health on startup + /// Typically false, or true if the backend includes + /// a warmup phase. + fn start_health(&self) -> bool { + false + } } /// Inference struct @@ -75,7 +82,7 @@ impl Infer { let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); // Backend health - let backend_health = Arc::new(AtomicBool::new(false)); + let backend_health = Arc::new(AtomicBool::new(backend.start_health())); Self { validation, diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 2dcd1bf30..fbf1a597c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -78,6 +78,7 @@ class RWConfig(PretrainedConfig): self.alibi = False self.rotary = True self.rope_theta = rope_theta + self.max_position_embeddings = 2048 self.vocab_size = vocab_size # Backward compatibility with n_embed kwarg diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 07b7604d6..5d3769907 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1304,6 +1304,7 @@ class FlashCausalLM(Model): self.num_layers = config.num_hidden_layers self.num_heads = config.num_attention_heads // self.process_group.size() + self.config = config # Validation is done in the model itself if num_kv_heads is None: num_kv_heads = getattr(config, "num_key_value_heads", None) @@ -1594,7 +1595,10 @@ class FlashCausalLM(Model): if max_total_tokens is None: if get_support_chunking(): model_max_length = self.tokenizer.model_max_length - max_total_tokens = min(num_blocks * BLOCK_SIZE, model_max_length) + max_position_embeddings = self.config.max_position_embeddings + max_total_tokens = min( + num_blocks * BLOCK_SIZE, model_max_length, max_position_embeddings + ) else: max_total_tokens = sum(batch.cache_lengths)