Cargo fmt

This commit is contained in:
Nicolas Patry 2023-12-04 14:56:38 +00:00
parent d99f281050
commit 269792094b
3 changed files with 34 additions and 32 deletions

View File

@ -162,7 +162,6 @@ struct Args {
#[clap(long, env)]
speculate: Option<usize>,
/// The dtype to be forced upon the model. This option cannot be used with `--quantize`.
#[clap(long, env, value_enum)]
dtype: Option<Dtype>,

View File

@ -10,7 +10,7 @@ pub use pb::generate::v1::HealthResponse;
pub use pb::generate::v1::InfoResponse as ShardInfo;
pub use pb::generate::v1::{
Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters,
Tokens, Request, StoppingCriteriaParameters,
Request, StoppingCriteriaParameters, Tokens,
};
pub use sharded_client::ShardedClient;
use thiserror::Error;

View File

@ -9,7 +9,7 @@ use std::sync::{
Arc,
};
use text_generation_client::{
Batch, CachedBatch, ClientError, GeneratedText, Generation, Tokens, ShardedClient,
Batch, CachedBatch, ClientError, GeneratedText, Generation, ShardedClient, Tokens,
};
use thiserror::Error;
use tokio::sync::mpsc::error::SendError;
@ -524,45 +524,48 @@ fn send_responses(
}
// Create last Token
let tokens: Vec<Token> = if let Some(tokens_) = generation.tokens{
tokens_.ids.into_iter()
.zip(tokens_.logprobs.into_iter())
.zip(tokens_.texts.into_iter())
.zip(tokens_.is_special.into_iter())
.map(|(((id, logprob), text), special)| Token {
id,
text,
logprob,
special,
}).collect()
}else{
let tokens: Vec<Token> = if let Some(tokens_) = generation.tokens {
tokens_
.ids
.into_iter()
.zip(tokens_.logprobs.into_iter())
.zip(tokens_.texts.into_iter())
.zip(tokens_.is_special.into_iter())
.map(|(((id, logprob), text), special)| Token {
id,
text,
logprob,
special,
})
.collect()
} else {
vec![]
};
// generation.top_tokens
let mut top_tokens = Vec::new();
for top_tokens_ in generation.top_tokens{
for top_tokens_ in generation.top_tokens {
let mut local_top_tokens = Vec::new();
local_top_tokens.extend(
top_tokens_
.ids
.into_iter()
.zip(top_tokens_.logprobs.into_iter())
.zip(top_tokens_.texts.into_iter())
.zip(top_tokens_.is_special.into_iter())
.map(|(((id, logprob), text), special)| Token {
id,
text,
logprob,
special,
}),
);
local_top_tokens.extend(
top_tokens_
.ids
.into_iter()
.zip(top_tokens_.logprobs.into_iter())
.zip(top_tokens_.texts.into_iter())
.zip(top_tokens_.is_special.into_iter())
.map(|(((id, logprob), text), special)| Token {
id,
text,
logprob,
special,
}),
);
top_tokens.push(local_top_tokens);
}
// Force top_tokens to be the same size as tokens, both are going to be
// Force top_tokens to be the same size as tokens, both are going to be
// zipped later
if top_tokens.len() != tokens.len(){
if top_tokens.len() != tokens.len() {
top_tokens = (0..tokens.len()).map(|_| Vec::new()).collect();
}