diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 07b7604d..5d376990 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1304,6 +1304,7 @@ class FlashCausalLM(Model): self.num_layers = config.num_hidden_layers self.num_heads = config.num_attention_heads // self.process_group.size() + self.config = config # Validation is done in the model itself if num_kv_heads is None: num_kv_heads = getattr(config, "num_key_value_heads", None) @@ -1594,7 +1595,10 @@ class FlashCausalLM(Model): if max_total_tokens is None: if get_support_chunking(): model_max_length = self.tokenizer.model_max_length - max_total_tokens = min(num_blocks * BLOCK_SIZE, model_max_length) + max_position_embeddings = self.config.max_position_embeddings + max_total_tokens = min( + num_blocks * BLOCK_SIZE, model_max_length, max_position_embeddings + ) else: max_total_tokens = sum(batch.cache_lengths)