From ccc7b7ab8f4fa930568582efe28b4d9bafbd88a1 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Sat, 22 Jul 2023 17:12:37 -0700 Subject: [PATCH] Cleanup Signed-off-by: Antoni Baum --- server/text_generation_server/models/flash_causal_lm.py | 6 +++--- server/text_generation_server/server.py | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index f1879c5c..45e4313a 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -718,7 +718,7 @@ class FlashCausalLM(Model): torch.cuda.empty_cache() try: - if max_total_tokens is None: + if not max_total_tokens: num_blocks = batch.blocks else: num_blocks = math.ceil(max_total_tokens / BLOCK_SIZE) @@ -739,8 +739,8 @@ class FlashCausalLM(Model): torch.cuda.synchronize(self.device) - if max_total_tokens is not None: - return num_blocks + if max_total_tokens: + return int(num_blocks * BLOCK_SIZE) # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) # Calculate the number of blocks that can be allocated with the free memory diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index f696ac41..8d25d719 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -58,7 +58,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): batch = self.model.batch_type.from_pb( request.batch, self.model.tokenizer, self.model.dtype, self.model.device ) - logger.info(f"Warmup {(batch, request.max_total_tokens)}") max_supported_total_tokens = self.model.warmup(batch, request.max_total_tokens) return generate_pb2.WarmupResponse(