diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index f1879c5c..45e4313a 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -718,7 +718,7 @@ class FlashCausalLM(Model): torch.cuda.empty_cache() try: - if max_total_tokens is None: + if not max_total_tokens: num_blocks = batch.blocks else: num_blocks = math.ceil(max_total_tokens / BLOCK_SIZE) @@ -739,8 +739,8 @@ class FlashCausalLM(Model): torch.cuda.synchronize(self.device) - if max_total_tokens is not None: - return num_blocks + if max_total_tokens: + return int(num_blocks * BLOCK_SIZE) # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) # Calculate the number of blocks that can be allocated with the free memory diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index f696ac41..8d25d719 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -58,7 +58,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): batch = self.model.batch_type.from_pb( request.batch, self.model.tokenizer, self.model.dtype, self.model.device ) - logger.info(f"Warmup {(batch, request.max_total_tokens)}") max_supported_total_tokens = self.model.warmup(batch, request.max_total_tokens) return generate_pb2.WarmupResponse(