diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 2b2dd940..6eb914f8 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1398,22 +1398,32 @@ class FlashCausalLM(Model): total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size if max_total_tokens is None: - model_max_length = self.tokenizer.model_max_length - free_memory = get_free_memory(self.device, MEMORY_FRACTION) - spare_blocks = ( - # Leave 5% for some wiggle room - int((free_memory * TGI_WIGGLE_ROOM) // total_cache_size) - + batch.num_blocks - ) - spare_blocks = small_power_of_2(spare_blocks) + if get_support_chunking(): + model_max_length = self.tokenizer.model_max_length + free_memory = get_free_memory(self.device, MEMORY_FRACTION) + spare_blocks = ( + # Leave 5% for some wiggle room + int((free_memory * TGI_WIGGLE_ROOM) // total_cache_size) + + batch.num_blocks + ) + spare_blocks = small_power_of_2(spare_blocks) - available_blocks = min(model_max_length, spare_blocks) - batch.num_blocks = available_blocks - batch.max_blocks = available_blocks - max_input_tokens = ( - available_blocks - 1 if max_input_tokens is None else max_input_tokens - ) - max_total_tokens = available_blocks + available_blocks = min(model_max_length, spare_blocks) + batch.num_blocks = available_blocks + batch.max_blocks = available_blocks + max_input_tokens = ( + available_blocks - 1 + if max_input_tokens is None + else max_input_tokens + ) + max_total_tokens = available_blocks + else: + max_total_tokens = batch.num_blocks + max_input_tokens = ( + batch.num_blocks - 1 + if max_input_tokens is None + else max_input_tokens + ) try: self.init_kv_cache(