mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Fix fmt
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
parent
5367d94f34
commit
809e288b5a
@ -262,15 +262,11 @@ impl Llamacpp {
|
||||
if ctx.is_null() {
|
||||
return Err(BackendError::Llamacpp("Failed to init context".to_string()));
|
||||
}
|
||||
let vocab = unsafe {
|
||||
llamacpp::model_get_vocab(model)
|
||||
};
|
||||
let vocab = unsafe { llamacpp::model_get_vocab(model) };
|
||||
if vocab.is_null() {
|
||||
return Err(BackendError::Llamacpp("Failed to get vocab".to_string()));
|
||||
}
|
||||
let n_tokens = unsafe {
|
||||
llamacpp::vocab_n_tokens(vocab)
|
||||
};
|
||||
let n_tokens = unsafe { llamacpp::vocab_n_tokens(vocab) };
|
||||
let mut logprobs = Vec::with_capacity(n_tokens as usize);
|
||||
|
||||
for token in 0..n_tokens {
|
||||
@ -280,16 +276,18 @@ impl Llamacpp {
|
||||
p: 0.0,
|
||||
});
|
||||
}
|
||||
let batch = unsafe {
|
||||
llamacpp::batch_init(conf.max_batch_total_tokens as _, 0, 1)
|
||||
};
|
||||
Ok(Llamacpp{model, ctx, vocab, logprobs, batch})
|
||||
let batch = unsafe { llamacpp::batch_init(conf.max_batch_total_tokens as _, 0, 1) };
|
||||
Ok(Llamacpp {
|
||||
model,
|
||||
ctx,
|
||||
vocab,
|
||||
logprobs,
|
||||
batch,
|
||||
})
|
||||
}
|
||||
|
||||
fn decode(&mut self) -> i32 {
|
||||
unsafe {
|
||||
llamacpp::decode(self.ctx, self.batch)
|
||||
}
|
||||
unsafe { llamacpp::decode(self.ctx, self.batch) }
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self, seq_id: llamacpp::llama_seq_id) {
|
||||
@ -344,18 +342,10 @@ impl LlamacppSampler {
|
||||
error!("Failed to init sampler");
|
||||
return None;
|
||||
}
|
||||
let top_k = unsafe {
|
||||
llamacpp::sampler_init_top_k(req.top_k)
|
||||
};
|
||||
let top_p = unsafe {
|
||||
llamacpp::sampler_init_top_p(req.top_p, req.min_keep)
|
||||
};
|
||||
let typical_p = unsafe {
|
||||
llamacpp::sampler_init_typical(req.typical_p, req.min_keep)
|
||||
};
|
||||
let temp = unsafe {
|
||||
llamacpp::sampler_init_temp(req.temp)
|
||||
};
|
||||
let top_k = unsafe { llamacpp::sampler_init_top_k(req.top_k) };
|
||||
let top_p = unsafe { llamacpp::sampler_init_top_p(req.top_p, req.min_keep) };
|
||||
let typical_p = unsafe { llamacpp::sampler_init_typical(req.typical_p, req.min_keep) };
|
||||
let temp = unsafe { llamacpp::sampler_init_temp(req.temp) };
|
||||
let penalties = unsafe {
|
||||
llamacpp::sampler_init_penalties(
|
||||
req.penalty_last_n,
|
||||
@ -364,9 +354,7 @@ impl LlamacppSampler {
|
||||
req.penalty_present,
|
||||
)
|
||||
};
|
||||
let dist = unsafe {
|
||||
llamacpp::sampler_init_dist(req.seed)
|
||||
};
|
||||
let dist = unsafe { llamacpp::sampler_init_dist(req.seed) };
|
||||
let all = &[
|
||||
("top_k", top_k),
|
||||
("top_p", top_p),
|
||||
@ -389,14 +377,12 @@ impl LlamacppSampler {
|
||||
unsafe { llamacpp::sampler_free(chain) };
|
||||
None
|
||||
} else {
|
||||
Some(LlamacppSampler{chain})
|
||||
Some(LlamacppSampler { chain })
|
||||
}
|
||||
}
|
||||
|
||||
fn sample(&self, llamacpp: &mut Llamacpp, idx: usize) -> (llamacpp::llama_token, f32) {
|
||||
let logits = unsafe {
|
||||
llamacpp::get_logits_ith(llamacpp.ctx, idx as _)
|
||||
};
|
||||
let logits = unsafe { llamacpp::get_logits_ith(llamacpp.ctx, idx as _) };
|
||||
for (token, logprob) in llamacpp.logprobs.iter_mut().enumerate() {
|
||||
*logprob = llamacpp::llama_token_data {
|
||||
id: token as _,
|
||||
@ -474,7 +460,8 @@ impl LlamacppBackend {
|
||||
|
||||
let flush = |requests: &mut Vec<_>, n_tokens: &mut usize| {
|
||||
if !requests.is_empty() {
|
||||
let _ = sync_tx.send(replace(requests, Vec::with_capacity(conf.max_batch_size)));
|
||||
let _ =
|
||||
sync_tx.send(replace(requests, Vec::with_capacity(conf.max_batch_size)));
|
||||
*n_tokens = 0;
|
||||
}
|
||||
};
|
||||
@ -559,7 +546,9 @@ impl LlamacppBackend {
|
||||
warn!("llama_decode failed, clearing kv cache");
|
||||
llamacpp.clear_kv_cache(-1);
|
||||
for seq in seqs.iter_mut() {
|
||||
let _ = requests[seq.id].tx.send(Err(InferError::IncompleteGeneration));
|
||||
let _ = requests[seq.id]
|
||||
.tx
|
||||
.send(Err(InferError::IncompleteGeneration));
|
||||
seq.running = false;
|
||||
}
|
||||
break;
|
||||
@ -576,7 +565,9 @@ impl LlamacppBackend {
|
||||
Ok(piece) => piece,
|
||||
Err(e) => {
|
||||
error!("Failed to decode token: {e}");
|
||||
let _ = requests[seq.id].tx.send(Err(InferError::IncompleteGeneration));
|
||||
let _ = requests[seq.id]
|
||||
.tx
|
||||
.send(Err(InferError::IncompleteGeneration));
|
||||
seq.running = false;
|
||||
continue;
|
||||
}
|
||||
@ -617,7 +608,9 @@ impl LlamacppBackend {
|
||||
seq.running = false;
|
||||
continue;
|
||||
}
|
||||
let _ = requests[seq.id].tx.send(Ok(InferStreamResponse::Intermediate {
|
||||
let _ = requests[seq.id]
|
||||
.tx
|
||||
.send(Ok(InferStreamResponse::Intermediate {
|
||||
token,
|
||||
top_tokens: vec![],
|
||||
}));
|
||||
@ -627,7 +620,8 @@ impl LlamacppBackend {
|
||||
|
||||
for seq in seqs.iter_mut() {
|
||||
if seq.running {
|
||||
seq.batch_pos = llamacpp.batch_push(seq.token, seq.pos, seq.id as _, true);
|
||||
seq.batch_pos =
|
||||
llamacpp.batch_push(seq.token, seq.pos, seq.id as _, true);
|
||||
seq.pos += 1;
|
||||
} else {
|
||||
llamacpp.clear_kv_cache(seq.id as _);
|
||||
@ -636,7 +630,14 @@ impl LlamacppBackend {
|
||||
}
|
||||
}
|
||||
});
|
||||
(Self{tx, status: status_rx}, ok_rx, shutdown_tx)
|
||||
(
|
||||
Self {
|
||||
tx,
|
||||
status: status_rx,
|
||||
},
|
||||
ok_rx,
|
||||
shutdown_tx,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user