mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
Cleanup
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
parent
d583f962f8
commit
ccc7b7ab8f
@ -718,7 +718,7 @@ class FlashCausalLM(Model):
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
try:
|
||||
if max_total_tokens is None:
|
||||
if not max_total_tokens:
|
||||
num_blocks = batch.blocks
|
||||
else:
|
||||
num_blocks = math.ceil(max_total_tokens / BLOCK_SIZE)
|
||||
@ -739,8 +739,8 @@ class FlashCausalLM(Model):
|
||||
|
||||
torch.cuda.synchronize(self.device)
|
||||
|
||||
if max_total_tokens is not None:
|
||||
return num_blocks
|
||||
if max_total_tokens:
|
||||
return int(num_blocks * BLOCK_SIZE)
|
||||
|
||||
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
||||
# Calculate the number of blocks that can be allocated with the free memory
|
||||
|
@ -58,7 +58,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
batch = self.model.batch_type.from_pb(
|
||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||
)
|
||||
logger.info(f"Warmup {(batch, request.max_total_tokens)}")
|
||||
max_supported_total_tokens = self.model.warmup(batch, request.max_total_tokens)
|
||||
|
||||
return generate_pb2.WarmupResponse(
|
||||
|
Loading…
Reference in New Issue
Block a user