Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
Antoni Baum 2023-07-22 17:12:37 -07:00
parent d583f962f8
commit ccc7b7ab8f
2 changed files with 3 additions and 4 deletions

View File

@ -718,7 +718,7 @@ class FlashCausalLM(Model):
torch.cuda.empty_cache() torch.cuda.empty_cache()
try: try:
if max_total_tokens is None: if not max_total_tokens:
num_blocks = batch.blocks num_blocks = batch.blocks
else: else:
num_blocks = math.ceil(max_total_tokens / BLOCK_SIZE) num_blocks = math.ceil(max_total_tokens / BLOCK_SIZE)
@ -739,8 +739,8 @@ class FlashCausalLM(Model):
torch.cuda.synchronize(self.device) torch.cuda.synchronize(self.device)
if max_total_tokens is not None: if max_total_tokens:
return num_blocks return int(num_blocks * BLOCK_SIZE)
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the free memory # Calculate the number of blocks that can be allocated with the free memory

View File

@ -58,7 +58,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
batch = self.model.batch_type.from_pb( batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.dtype, self.model.device request.batch, self.model.tokenizer, self.model.dtype, self.model.device
) )
logger.info(f"Warmup {(batch, request.max_total_tokens)}")
max_supported_total_tokens = self.model.warmup(batch, request.max_total_tokens) max_supported_total_tokens = self.model.warmup(batch, request.max_total_tokens)
return generate_pb2.WarmupResponse( return generate_pb2.WarmupResponse(