max input length

This commit is contained in:
OlivierDehaene 2024-10-09 19:39:14 +02:00
parent 57f55fe834
commit d73c5c634d
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C

View File

@ -1618,7 +1618,7 @@ class FlashCausalLM(Model):
input_lengths=input_lengths, input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor, cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill, cu_seqlen_q=cu_seqlen_prefill,
max_q=max_s, max_q=batch.max_input_length,
max_k=batch.max_current_length, max_k=batch.max_current_length,
) )
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
@ -2236,8 +2236,6 @@ class FlashCausalLM(Model):
use_prefill_with_paged_kv_state, use_prefill_with_paged_kv_state,
) )
# has_cache_lengths = any(cache_length > 0 for cache_length in cache_lengths)
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
return use_prefill_with_paged_kv_state( return use_prefill_with_paged_kv_state(
state=( state=(