mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Add a stupid batch mechanism
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
parent
e07835c5b5
commit
d6ded897a8
@ -202,7 +202,7 @@ impl Llamacpp {
|
||||
Ok(Llamacpp{model, ctx, vocab, logprobs, n_ctx, batch})
|
||||
}
|
||||
|
||||
fn batch_clear_logits(&mut self) {
|
||||
fn _batch_clear_logits(&mut self) {
|
||||
for n in 0..self.batch.n_tokens as usize{
|
||||
unsafe {
|
||||
*self.batch.logits.add(n) = 0 as i8;
|
||||
@ -214,24 +214,15 @@ impl Llamacpp {
|
||||
&mut self,
|
||||
token: bindings::llama_token,
|
||||
pos: bindings::llama_pos,
|
||||
seq_ids: &[bindings::llama_seq_id],
|
||||
seq_id: bindings::llama_seq_id,
|
||||
logits: bool,
|
||||
) {
|
||||
debug!("push {token} {pos} {logits}");
|
||||
// TODO check evertyhing..
|
||||
let n = self.batch.n_tokens as usize;
|
||||
|
||||
unsafe {
|
||||
*self.batch.token.add(n) = token;
|
||||
*self.batch.pos.add(n) = pos;
|
||||
*self.batch.n_seq_id.add(n) = seq_ids.len() as i32;
|
||||
}
|
||||
for (i, &seq_id) in seq_ids.iter().enumerate() {
|
||||
unsafe {
|
||||
*(*self.batch.seq_id.add(n)).add(i) = seq_id;
|
||||
}
|
||||
}
|
||||
unsafe {
|
||||
*self.batch.n_seq_id.add(n) = 1;
|
||||
*(*self.batch.seq_id.add(n)).add(0) = seq_id;
|
||||
*self.batch.logits.add(n) = logits as i8;
|
||||
}
|
||||
self.batch.n_tokens += 1;
|
||||
@ -375,6 +366,17 @@ impl Drop for LlamacppSampler {
|
||||
}
|
||||
}
|
||||
|
||||
struct LlamacppSeq {
|
||||
id: usize,
|
||||
batch_pos: usize,
|
||||
token: bindings::llama_token,
|
||||
pos: bindings::llama_pos,
|
||||
sampler: LlamacppSampler,
|
||||
text: String,
|
||||
n_new_tokens: usize,
|
||||
running: bool,
|
||||
}
|
||||
|
||||
static INIT: Once = Once::new();
|
||||
|
||||
impl LlamacppBackend {
|
||||
@ -397,7 +399,7 @@ impl LlamacppBackend {
|
||||
|
||||
spawn(async move {
|
||||
let mut n_tokens = 0;
|
||||
let mut requests = Vec::new();
|
||||
let mut requests = Vec::with_capacity(conf.max_batch_size);
|
||||
|
||||
loop {
|
||||
match timeout(conf.batch_timeout, rx.recv()).await {
|
||||
@ -442,120 +444,131 @@ impl LlamacppBackend {
|
||||
let _ = status_tx.send(true);
|
||||
|
||||
while let Ok(requests) = sync_rx.recv() {
|
||||
let start_time = Instant::now();
|
||||
let mut seqs: Vec<LlamacppSeq> = Vec::with_capacity(requests.len());
|
||||
llamacpp.batch.n_tokens = 0;
|
||||
|
||||
// TODO: do a real batch
|
||||
for (_seq_id, request) in requests.iter().enumerate() {
|
||||
|
||||
for (seq_id, request) in requests.iter().enumerate() {
|
||||
debug!("Request: {:?}", request);
|
||||
let start_time = Instant::now();
|
||||
llamacpp.batch.n_tokens = 0;
|
||||
|
||||
let sampler = match LlamacppSampler::new(&request) {
|
||||
Some(sampler) => sampler,
|
||||
_ => {
|
||||
let _ = request.tx.send(Err(InferError::IncompleteGeneration));
|
||||
continue;
|
||||
},
|
||||
};
|
||||
for (pos, &token_id) in request.input_ids.iter().enumerate() {
|
||||
llamacpp.batch_push(
|
||||
token_id as bindings::llama_token,
|
||||
pos as bindings::llama_pos,
|
||||
&[/* seq_id */ 0 as bindings::llama_seq_id],
|
||||
true,
|
||||
seq_id as bindings::llama_seq_id,
|
||||
true, // TODO
|
||||
);
|
||||
}
|
||||
let mut pos = request.input_ids.len();
|
||||
|
||||
// TODO: close this loop :)
|
||||
|
||||
// TODO: move up for perf ?
|
||||
let sampler = match LlamacppSampler::new(&request) {
|
||||
Some(sampler) => sampler,
|
||||
_ => {
|
||||
let _ = request.tx.send(Err(InferError::IncompleteGeneration));
|
||||
continue;
|
||||
},
|
||||
};
|
||||
let mut text = String::with_capacity(1024);
|
||||
let mut n_tokens: usize = 0;
|
||||
let mut n_new_tokens: usize = 0;
|
||||
|
||||
seqs.push(LlamacppSeq {
|
||||
id: seq_id,
|
||||
batch_pos: llamacpp.batch.n_tokens as usize - 1,
|
||||
token: -1,
|
||||
pos: request.input_ids.len() as _,
|
||||
sampler: sampler,
|
||||
text: String::with_capacity(1024),
|
||||
n_new_tokens: 0,
|
||||
running: true,
|
||||
});
|
||||
}
|
||||
loop {
|
||||
match unsafe { bindings::llama_decode(llamacpp.ctx, llamacpp.batch) } {
|
||||
0 => { },
|
||||
1 => {
|
||||
unsafe {
|
||||
// TODO: seq_rm & seq_add if model is compatible
|
||||
bindings::llama_kv_cache_clear(llamacpp.ctx);
|
||||
}
|
||||
let _ = request.tx.send(Err(InferError::IncompleteGeneration));
|
||||
break;
|
||||
},
|
||||
_ => {
|
||||
debug!("decode return <0");
|
||||
let _ = request.tx.send(Err(InferError::IncompleteGeneration));
|
||||
break;
|
||||
},
|
||||
if llamacpp.batch.n_tokens == 0 {
|
||||
break;
|
||||
}
|
||||
let decode = unsafe {
|
||||
bindings::llama_decode(llamacpp.ctx, llamacpp.batch)
|
||||
};
|
||||
let idx = llamacpp.batch.n_tokens as usize - 1;
|
||||
let (next, logprob) = sampler.sample(&mut llamacpp, idx);
|
||||
n_new_tokens += 1;
|
||||
|
||||
debug!("tokens: {n_tokens} new: {n_new_tokens}");
|
||||
if decode != 0 {
|
||||
error!("Failed to decode batch: {decode}");
|
||||
|
||||
if decode == 1 {
|
||||
unsafe {
|
||||
bindings::llama_kv_cache_clear(llamacpp.ctx); // TODO
|
||||
}
|
||||
}
|
||||
for seq in seqs.iter_mut() {
|
||||
let _ = requests[seq.id].tx.send(Err(InferError::IncompleteGeneration));
|
||||
seq.running = false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
let kv_cache_used_cells = unsafe {
|
||||
bindings::llama_get_kv_cache_used_cells(llamacpp.ctx)
|
||||
};
|
||||
let piece = match tokenizer.decode(&[next as u32], false) {
|
||||
Ok(piece) => piece,
|
||||
Err(e) => {
|
||||
error!("Failed to decode token: {e}");
|
||||
let _ = request.tx.send(Err(InferError::IncompleteGeneration));
|
||||
break;
|
||||
},
|
||||
};
|
||||
let special = vocab.is_special_token(&piece);
|
||||
for seq in seqs.iter_mut() {
|
||||
let (next, logprob) = seq.sampler.sample(&mut llamacpp, seq.batch_pos);
|
||||
seq.n_new_tokens += 1;
|
||||
seq.token = next;
|
||||
|
||||
if !special {
|
||||
text.push_str(&piece);
|
||||
}
|
||||
let token = Token {
|
||||
id: next as _,
|
||||
text: piece,
|
||||
logprob: logprob,
|
||||
special: special,
|
||||
};
|
||||
let finish: Option<FinishReason> = {
|
||||
if unsafe { bindings::llama_vocab_is_eog(llamacpp.vocab, next) } {
|
||||
Some(FinishReason::EndOfSequenceToken)
|
||||
} else if n_new_tokens == request.max_new_tokens {
|
||||
Some(FinishReason::Length)
|
||||
} else if kv_cache_used_cells == llamacpp.n_ctx as i32 {
|
||||
Some(FinishReason::Length) // TODO: check
|
||||
} else {
|
||||
None
|
||||
let piece = match tokenizer.decode(&[next as u32], false) {
|
||||
Ok(piece) => piece,
|
||||
Err(e) => {
|
||||
error!("Failed to decode token: {e}");
|
||||
let _ = requests[seq.id].tx.send(Err(InferError::IncompleteGeneration));
|
||||
seq.running = false;
|
||||
break;
|
||||
},
|
||||
};
|
||||
let special = vocab.is_special_token(&piece);
|
||||
|
||||
if !special {
|
||||
seq.text.push_str(&piece);
|
||||
}
|
||||
};
|
||||
if let Some(reason) = finish {
|
||||
let _ = request.tx.send(Ok(InferStreamResponse::End {
|
||||
let token = Token {
|
||||
id: next as _,
|
||||
text: piece,
|
||||
logprob: logprob,
|
||||
special: special,
|
||||
};
|
||||
let finish: Option<FinishReason> = {
|
||||
if unsafe { bindings::llama_vocab_is_eog(llamacpp.vocab, next) } {
|
||||
Some(FinishReason::EndOfSequenceToken)
|
||||
} else if seq.n_new_tokens == requests[seq.id].max_new_tokens {
|
||||
Some(FinishReason::Length)
|
||||
} else if kv_cache_used_cells == llamacpp.n_ctx as i32 {
|
||||
Some(FinishReason::Length) // TODO: check
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
if let Some(reason) = finish {
|
||||
let _ = requests[seq.id].tx.send(Ok(InferStreamResponse::End {
|
||||
token: token,
|
||||
top_tokens: vec![],
|
||||
generated_text: GeneratedText {
|
||||
text: seq.text.clone(),
|
||||
generated_tokens: seq.n_new_tokens as _,
|
||||
finish_reason: reason,
|
||||
seed: Some(requests[seq.id].seed as _),
|
||||
},
|
||||
start: start_time,
|
||||
queued: requests[seq.id].time,
|
||||
}));
|
||||
seq.running = false;
|
||||
break;
|
||||
}
|
||||
let _ = requests[seq.id].tx.send(Ok(InferStreamResponse::Intermediate {
|
||||
token: token,
|
||||
top_tokens: vec![],
|
||||
generated_text: GeneratedText {
|
||||
text: text,
|
||||
generated_tokens: n_new_tokens as _,
|
||||
finish_reason: reason,
|
||||
seed: Some(request.seed as _),
|
||||
},
|
||||
start: start_time,
|
||||
queued: request.time,
|
||||
}));
|
||||
break;
|
||||
}
|
||||
let _ = request.tx.send(Ok(InferStreamResponse::Intermediate {
|
||||
token: token,
|
||||
top_tokens: vec![],
|
||||
}));
|
||||
// generate a new batch
|
||||
llamacpp.batch.n_tokens = 0;
|
||||
// llamacpp.batch_clear_logits();
|
||||
llamacpp.batch_push(next, pos as _, &[0], true);
|
||||
pos += 1;
|
||||
|
||||
for seq in seqs.iter_mut() {
|
||||
if seq.running {
|
||||
llamacpp.batch_push(seq.token, seq.pos, seq.id as _, true);
|
||||
seq.batch_pos = 0;
|
||||
seq.pos += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // TODO remove this
|
||||
});
|
||||
(Self{tx, status: status_rx}, ok_rx)
|
||||
}
|
||||
|
@ -198,7 +198,7 @@ async fn main() -> Result<(), RouterError> {
|
||||
flash_attention: args.flash_attention,
|
||||
max_batch_total_tokens: args.max_batch_total_tokens,
|
||||
max_batch_size: args.max_batch_size,
|
||||
batch_timeout: tokio::time::Duration::from_millis(100),
|
||||
batch_timeout: tokio::time::Duration::from_millis(5),
|
||||
},
|
||||
tokenizer,
|
||||
);
|
||||
|
Loading…
Reference in New Issue
Block a user