mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Add a stupid batch mechanism
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
parent
e07835c5b5
commit
d6ded897a8
@ -202,7 +202,7 @@ impl Llamacpp {
|
|||||||
Ok(Llamacpp{model, ctx, vocab, logprobs, n_ctx, batch})
|
Ok(Llamacpp{model, ctx, vocab, logprobs, n_ctx, batch})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn batch_clear_logits(&mut self) {
|
fn _batch_clear_logits(&mut self) {
|
||||||
for n in 0..self.batch.n_tokens as usize{
|
for n in 0..self.batch.n_tokens as usize{
|
||||||
unsafe {
|
unsafe {
|
||||||
*self.batch.logits.add(n) = 0 as i8;
|
*self.batch.logits.add(n) = 0 as i8;
|
||||||
@ -214,24 +214,15 @@ impl Llamacpp {
|
|||||||
&mut self,
|
&mut self,
|
||||||
token: bindings::llama_token,
|
token: bindings::llama_token,
|
||||||
pos: bindings::llama_pos,
|
pos: bindings::llama_pos,
|
||||||
seq_ids: &[bindings::llama_seq_id],
|
seq_id: bindings::llama_seq_id,
|
||||||
logits: bool,
|
logits: bool,
|
||||||
) {
|
) {
|
||||||
debug!("push {token} {pos} {logits}");
|
|
||||||
// TODO check evertyhing..
|
|
||||||
let n = self.batch.n_tokens as usize;
|
let n = self.batch.n_tokens as usize;
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
*self.batch.token.add(n) = token;
|
*self.batch.token.add(n) = token;
|
||||||
*self.batch.pos.add(n) = pos;
|
*self.batch.pos.add(n) = pos;
|
||||||
*self.batch.n_seq_id.add(n) = seq_ids.len() as i32;
|
*self.batch.n_seq_id.add(n) = 1;
|
||||||
}
|
*(*self.batch.seq_id.add(n)).add(0) = seq_id;
|
||||||
for (i, &seq_id) in seq_ids.iter().enumerate() {
|
|
||||||
unsafe {
|
|
||||||
*(*self.batch.seq_id.add(n)).add(i) = seq_id;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
unsafe {
|
|
||||||
*self.batch.logits.add(n) = logits as i8;
|
*self.batch.logits.add(n) = logits as i8;
|
||||||
}
|
}
|
||||||
self.batch.n_tokens += 1;
|
self.batch.n_tokens += 1;
|
||||||
@ -375,6 +366,17 @@ impl Drop for LlamacppSampler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct LlamacppSeq {
|
||||||
|
id: usize,
|
||||||
|
batch_pos: usize,
|
||||||
|
token: bindings::llama_token,
|
||||||
|
pos: bindings::llama_pos,
|
||||||
|
sampler: LlamacppSampler,
|
||||||
|
text: String,
|
||||||
|
n_new_tokens: usize,
|
||||||
|
running: bool,
|
||||||
|
}
|
||||||
|
|
||||||
static INIT: Once = Once::new();
|
static INIT: Once = Once::new();
|
||||||
|
|
||||||
impl LlamacppBackend {
|
impl LlamacppBackend {
|
||||||
@ -397,7 +399,7 @@ impl LlamacppBackend {
|
|||||||
|
|
||||||
spawn(async move {
|
spawn(async move {
|
||||||
let mut n_tokens = 0;
|
let mut n_tokens = 0;
|
||||||
let mut requests = Vec::new();
|
let mut requests = Vec::with_capacity(conf.max_batch_size);
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
match timeout(conf.batch_timeout, rx.recv()).await {
|
match timeout(conf.batch_timeout, rx.recv()).await {
|
||||||
@ -442,120 +444,131 @@ impl LlamacppBackend {
|
|||||||
let _ = status_tx.send(true);
|
let _ = status_tx.send(true);
|
||||||
|
|
||||||
while let Ok(requests) = sync_rx.recv() {
|
while let Ok(requests) = sync_rx.recv() {
|
||||||
|
let start_time = Instant::now();
|
||||||
|
let mut seqs: Vec<LlamacppSeq> = Vec::with_capacity(requests.len());
|
||||||
|
llamacpp.batch.n_tokens = 0;
|
||||||
|
|
||||||
// TODO: do a real batch
|
for (seq_id, request) in requests.iter().enumerate() {
|
||||||
for (_seq_id, request) in requests.iter().enumerate() {
|
|
||||||
|
|
||||||
debug!("Request: {:?}", request);
|
debug!("Request: {:?}", request);
|
||||||
let start_time = Instant::now();
|
let sampler = match LlamacppSampler::new(&request) {
|
||||||
llamacpp.batch.n_tokens = 0;
|
Some(sampler) => sampler,
|
||||||
|
_ => {
|
||||||
|
let _ = request.tx.send(Err(InferError::IncompleteGeneration));
|
||||||
|
continue;
|
||||||
|
},
|
||||||
|
};
|
||||||
for (pos, &token_id) in request.input_ids.iter().enumerate() {
|
for (pos, &token_id) in request.input_ids.iter().enumerate() {
|
||||||
llamacpp.batch_push(
|
llamacpp.batch_push(
|
||||||
token_id as bindings::llama_token,
|
token_id as bindings::llama_token,
|
||||||
pos as bindings::llama_pos,
|
pos as bindings::llama_pos,
|
||||||
&[/* seq_id */ 0 as bindings::llama_seq_id],
|
seq_id as bindings::llama_seq_id,
|
||||||
true,
|
true, // TODO
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
let mut pos = request.input_ids.len();
|
seqs.push(LlamacppSeq {
|
||||||
|
id: seq_id,
|
||||||
// TODO: close this loop :)
|
batch_pos: llamacpp.batch.n_tokens as usize - 1,
|
||||||
|
token: -1,
|
||||||
// TODO: move up for perf ?
|
pos: request.input_ids.len() as _,
|
||||||
let sampler = match LlamacppSampler::new(&request) {
|
sampler: sampler,
|
||||||
Some(sampler) => sampler,
|
text: String::with_capacity(1024),
|
||||||
_ => {
|
n_new_tokens: 0,
|
||||||
let _ = request.tx.send(Err(InferError::IncompleteGeneration));
|
running: true,
|
||||||
continue;
|
});
|
||||||
},
|
}
|
||||||
};
|
|
||||||
let mut text = String::with_capacity(1024);
|
|
||||||
let mut n_tokens: usize = 0;
|
|
||||||
let mut n_new_tokens: usize = 0;
|
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
match unsafe { bindings::llama_decode(llamacpp.ctx, llamacpp.batch) } {
|
if llamacpp.batch.n_tokens == 0 {
|
||||||
0 => { },
|
break;
|
||||||
1 => {
|
}
|
||||||
unsafe {
|
let decode = unsafe {
|
||||||
// TODO: seq_rm & seq_add if model is compatible
|
bindings::llama_decode(llamacpp.ctx, llamacpp.batch)
|
||||||
bindings::llama_kv_cache_clear(llamacpp.ctx);
|
|
||||||
}
|
|
||||||
let _ = request.tx.send(Err(InferError::IncompleteGeneration));
|
|
||||||
break;
|
|
||||||
},
|
|
||||||
_ => {
|
|
||||||
debug!("decode return <0");
|
|
||||||
let _ = request.tx.send(Err(InferError::IncompleteGeneration));
|
|
||||||
break;
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
let idx = llamacpp.batch.n_tokens as usize - 1;
|
if decode != 0 {
|
||||||
let (next, logprob) = sampler.sample(&mut llamacpp, idx);
|
error!("Failed to decode batch: {decode}");
|
||||||
n_new_tokens += 1;
|
|
||||||
|
|
||||||
debug!("tokens: {n_tokens} new: {n_new_tokens}");
|
|
||||||
|
|
||||||
|
if decode == 1 {
|
||||||
|
unsafe {
|
||||||
|
bindings::llama_kv_cache_clear(llamacpp.ctx); // TODO
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for seq in seqs.iter_mut() {
|
||||||
|
let _ = requests[seq.id].tx.send(Err(InferError::IncompleteGeneration));
|
||||||
|
seq.running = false;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
let kv_cache_used_cells = unsafe {
|
let kv_cache_used_cells = unsafe {
|
||||||
bindings::llama_get_kv_cache_used_cells(llamacpp.ctx)
|
bindings::llama_get_kv_cache_used_cells(llamacpp.ctx)
|
||||||
};
|
};
|
||||||
let piece = match tokenizer.decode(&[next as u32], false) {
|
for seq in seqs.iter_mut() {
|
||||||
Ok(piece) => piece,
|
let (next, logprob) = seq.sampler.sample(&mut llamacpp, seq.batch_pos);
|
||||||
Err(e) => {
|
seq.n_new_tokens += 1;
|
||||||
error!("Failed to decode token: {e}");
|
seq.token = next;
|
||||||
let _ = request.tx.send(Err(InferError::IncompleteGeneration));
|
|
||||||
break;
|
|
||||||
},
|
|
||||||
};
|
|
||||||
let special = vocab.is_special_token(&piece);
|
|
||||||
|
|
||||||
if !special {
|
let piece = match tokenizer.decode(&[next as u32], false) {
|
||||||
text.push_str(&piece);
|
Ok(piece) => piece,
|
||||||
}
|
Err(e) => {
|
||||||
let token = Token {
|
error!("Failed to decode token: {e}");
|
||||||
id: next as _,
|
let _ = requests[seq.id].tx.send(Err(InferError::IncompleteGeneration));
|
||||||
text: piece,
|
seq.running = false;
|
||||||
logprob: logprob,
|
break;
|
||||||
special: special,
|
},
|
||||||
};
|
};
|
||||||
let finish: Option<FinishReason> = {
|
let special = vocab.is_special_token(&piece);
|
||||||
if unsafe { bindings::llama_vocab_is_eog(llamacpp.vocab, next) } {
|
|
||||||
Some(FinishReason::EndOfSequenceToken)
|
if !special {
|
||||||
} else if n_new_tokens == request.max_new_tokens {
|
seq.text.push_str(&piece);
|
||||||
Some(FinishReason::Length)
|
|
||||||
} else if kv_cache_used_cells == llamacpp.n_ctx as i32 {
|
|
||||||
Some(FinishReason::Length) // TODO: check
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
}
|
||||||
};
|
let token = Token {
|
||||||
if let Some(reason) = finish {
|
id: next as _,
|
||||||
let _ = request.tx.send(Ok(InferStreamResponse::End {
|
text: piece,
|
||||||
|
logprob: logprob,
|
||||||
|
special: special,
|
||||||
|
};
|
||||||
|
let finish: Option<FinishReason> = {
|
||||||
|
if unsafe { bindings::llama_vocab_is_eog(llamacpp.vocab, next) } {
|
||||||
|
Some(FinishReason::EndOfSequenceToken)
|
||||||
|
} else if seq.n_new_tokens == requests[seq.id].max_new_tokens {
|
||||||
|
Some(FinishReason::Length)
|
||||||
|
} else if kv_cache_used_cells == llamacpp.n_ctx as i32 {
|
||||||
|
Some(FinishReason::Length) // TODO: check
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if let Some(reason) = finish {
|
||||||
|
let _ = requests[seq.id].tx.send(Ok(InferStreamResponse::End {
|
||||||
|
token: token,
|
||||||
|
top_tokens: vec![],
|
||||||
|
generated_text: GeneratedText {
|
||||||
|
text: seq.text.clone(),
|
||||||
|
generated_tokens: seq.n_new_tokens as _,
|
||||||
|
finish_reason: reason,
|
||||||
|
seed: Some(requests[seq.id].seed as _),
|
||||||
|
},
|
||||||
|
start: start_time,
|
||||||
|
queued: requests[seq.id].time,
|
||||||
|
}));
|
||||||
|
seq.running = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let _ = requests[seq.id].tx.send(Ok(InferStreamResponse::Intermediate {
|
||||||
token: token,
|
token: token,
|
||||||
top_tokens: vec![],
|
top_tokens: vec![],
|
||||||
generated_text: GeneratedText {
|
|
||||||
text: text,
|
|
||||||
generated_tokens: n_new_tokens as _,
|
|
||||||
finish_reason: reason,
|
|
||||||
seed: Some(request.seed as _),
|
|
||||||
},
|
|
||||||
start: start_time,
|
|
||||||
queued: request.time,
|
|
||||||
}));
|
}));
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
let _ = request.tx.send(Ok(InferStreamResponse::Intermediate {
|
// generate a new batch
|
||||||
token: token,
|
|
||||||
top_tokens: vec![],
|
|
||||||
}));
|
|
||||||
llamacpp.batch.n_tokens = 0;
|
llamacpp.batch.n_tokens = 0;
|
||||||
// llamacpp.batch_clear_logits();
|
|
||||||
llamacpp.batch_push(next, pos as _, &[0], true);
|
for seq in seqs.iter_mut() {
|
||||||
pos += 1;
|
if seq.running {
|
||||||
|
llamacpp.batch_push(seq.token, seq.pos, seq.id as _, true);
|
||||||
|
seq.batch_pos = 0;
|
||||||
|
seq.pos += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // TODO remove this
|
|
||||||
});
|
});
|
||||||
(Self{tx, status: status_rx}, ok_rx)
|
(Self{tx, status: status_rx}, ok_rx)
|
||||||
}
|
}
|
||||||
|
@ -198,7 +198,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
flash_attention: args.flash_attention,
|
flash_attention: args.flash_attention,
|
||||||
max_batch_total_tokens: args.max_batch_total_tokens,
|
max_batch_total_tokens: args.max_batch_total_tokens,
|
||||||
max_batch_size: args.max_batch_size,
|
max_batch_size: args.max_batch_size,
|
||||||
batch_timeout: tokio::time::Duration::from_millis(100),
|
batch_timeout: tokio::time::Duration::from_millis(5),
|
||||||
},
|
},
|
||||||
tokenizer,
|
tokenizer,
|
||||||
);
|
);
|
||||||
|
Loading…
Reference in New Issue
Block a user