mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Flash causal LM case
This commit is contained in:
parent
2ba396c4c1
commit
564c9e1cc0
@ -1702,6 +1702,7 @@ class FlashCausalLM(Model):
|
||||
f"{dim}:{seq_len} "
|
||||
f"bypass:{bypass} "
|
||||
f"free_mem:{free_mem}"
|
||||
", this may take a while..."
|
||||
)
|
||||
log_master(logger.info, msg)
|
||||
|
||||
@ -1753,6 +1754,10 @@ class FlashCausalLM(Model):
|
||||
total_batch_seq = 0.001
|
||||
total_mem = 0
|
||||
available_mem = prompt_available_memory
|
||||
logger.info(
|
||||
f"Prefill batch size list:{[bsz[0] for bsz in buckets]}\n"
|
||||
f"Prefill sequence length list:{[seq[1] for seq in buckets]}\n"
|
||||
)
|
||||
for i, (batch_size, seq_len) in enumerate(buckets):
|
||||
if batch_size * seq_len > self.max_batch_prefill_tokens:
|
||||
continue
|
||||
@ -1779,6 +1784,8 @@ class FlashCausalLM(Model):
|
||||
total_mem += used_mem
|
||||
total_batch_seq += batch_seq
|
||||
|
||||
logger.info("Prefill warmup successful.\n")
|
||||
|
||||
def ordering_function_max_bs(b):
|
||||
return (-b[0], b[1])
|
||||
|
||||
@ -1790,6 +1797,7 @@ class FlashCausalLM(Model):
|
||||
total_batch_seq = 0.001
|
||||
total_mem = 0
|
||||
available_mem = free_mem - self.mem_reserved
|
||||
logger.info(f"Decode batch size list:{[bsz[0] for bsz in buckets]}\n")
|
||||
for i, (batch_size, block_num) in enumerate(buckets):
|
||||
if batch_size > block_num:
|
||||
continue
|
||||
@ -1814,6 +1822,8 @@ class FlashCausalLM(Model):
|
||||
total_mem += used_mem
|
||||
total_batch_seq += batch_seq
|
||||
|
||||
logger.info("Decode warmup successful.\n")
|
||||
|
||||
log_master(
|
||||
logger.info,
|
||||
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
|
||||
|
Loading…
Reference in New Issue
Block a user