Cargo fmt

This commit is contained in:
Nicolas Patry 2023-12-04 14:56:38 +00:00
parent d99f281050
commit 269792094b
3 changed files with 34 additions and 32 deletions

View File

@ -162,7 +162,6 @@ struct Args {
#[clap(long, env)]
speculate: Option<usize>,
/// The dtype to be forced upon the model. This option cannot be used with `--quantize`.
#[clap(long, env, value_enum)]
dtype: Option<Dtype>,

View File

@ -10,7 +10,7 @@ pub use pb::generate::v1::HealthResponse;
pub use pb::generate::v1::InfoResponse as ShardInfo;
pub use pb::generate::v1::{
Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters,
Tokens, Request, StoppingCriteriaParameters,
Request, StoppingCriteriaParameters, Tokens,
};
pub use sharded_client::ShardedClient;
use thiserror::Error;

View File

@ -9,7 +9,7 @@ use std::sync::{
Arc,
};
use text_generation_client::{
Batch, CachedBatch, ClientError, GeneratedText, Generation, Tokens, ShardedClient,
Batch, CachedBatch, ClientError, GeneratedText, Generation, ShardedClient, Tokens,
};
use thiserror::Error;
use tokio::sync::mpsc::error::SendError;
@ -524,8 +524,10 @@ fn send_responses(
}
// Create last Token
let tokens: Vec<Token> = if let Some(tokens_) = generation.tokens{
tokens_.ids.into_iter()
let tokens: Vec<Token> = if let Some(tokens_) = generation.tokens {
tokens_
.ids
.into_iter()
.zip(tokens_.logprobs.into_iter())
.zip(tokens_.texts.into_iter())
.zip(tokens_.is_special.into_iter())
@ -534,15 +536,16 @@ fn send_responses(
text,
logprob,
special,
}).collect()
}else{
})
.collect()
} else {
vec![]
};
// generation.top_tokens
let mut top_tokens = Vec::new();
for top_tokens_ in generation.top_tokens{
for top_tokens_ in generation.top_tokens {
let mut local_top_tokens = Vec::new();
local_top_tokens.extend(
top_tokens_
@ -562,7 +565,7 @@ fn send_responses(
}
// Force top_tokens to be the same size as tokens, both are going to be
// zipped later
if top_tokens.len() != tokens.len(){
if top_tokens.len() != tokens.len() {
top_tokens = (0..tokens.len()).map(|_| Vec::new()).collect();
}