Fix seq iterations

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
Adrien Gallouët 2025-02-01 17:55:00 +00:00
parent 96434a1e7e
commit 27534d8ee4
No known key found for this signature in database

View File

@ -470,6 +470,7 @@ impl LlamacppBackend {
for (seq_id, request) in requests.iter().enumerate() { for (seq_id, request) in requests.iter().enumerate() {
debug!("Request: {:?}", request); debug!("Request: {:?}", request);
// TODO remove this
let sampler = match LlamacppSampler::new(&request) { let sampler = match LlamacppSampler::new(&request) {
Some(sampler) => sampler, Some(sampler) => sampler,
_ => { _ => {
@ -506,11 +507,9 @@ impl LlamacppBackend {
bindings::llama_decode(llamacpp.ctx, llamacpp.batch) bindings::llama_decode(llamacpp.ctx, llamacpp.batch)
}; };
if decode != 0 { if decode != 0 {
error!("Failed to decode batch: {decode}");
if decode == 1 { if decode == 1 {
unsafe { unsafe {
bindings::llama_kv_cache_clear(llamacpp.ctx); // TODO bindings::llama_kv_cache_clear(llamacpp.ctx); // TODO: remove this ?
} }
} }
for seq in seqs.iter_mut() { for seq in seqs.iter_mut() {
@ -523,6 +522,9 @@ impl LlamacppBackend {
bindings::llama_get_kv_cache_used_cells(llamacpp.ctx) bindings::llama_get_kv_cache_used_cells(llamacpp.ctx)
}; };
for seq in seqs.iter_mut() { for seq in seqs.iter_mut() {
if !seq.running {
continue;
}
let (next, logprob) = seq.sampler.sample(&mut llamacpp, seq.batch_pos); let (next, logprob) = seq.sampler.sample(&mut llamacpp, seq.batch_pos);
seq.n_new_tokens += 1; seq.n_new_tokens += 1;
seq.token = next; seq.token = next;
@ -533,7 +535,7 @@ impl LlamacppBackend {
error!("Failed to decode token: {e}"); error!("Failed to decode token: {e}");
let _ = requests[seq.id].tx.send(Err(InferError::IncompleteGeneration)); let _ = requests[seq.id].tx.send(Err(InferError::IncompleteGeneration));
seq.running = false; seq.running = false;
break; continue;
}, },
}; };
let special = vocab.is_special_token(&piece); let special = vocab.is_special_token(&piece);
@ -572,7 +574,7 @@ impl LlamacppBackend {
queued: requests[seq.id].time, queued: requests[seq.id].time,
})); }));
seq.running = false; seq.running = false;
break; continue;
} }
let _ = requests[seq.id].tx.send(Ok(InferStreamResponse::Intermediate { let _ = requests[seq.id].tx.send(Ok(InferStreamResponse::Intermediate {
token: token, token: token,