Merge branch 'huggingface:main' into add_sealion_mpt_support

This commit is contained in:
David Ong Tat-Wee 2024-01-23 10:52:18 +08:00 committed by GitHub
commit 10fce5bffd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 35 additions and 15 deletions

View File

@ -6,6 +6,7 @@ members = [
"router/grpc-metadata", "router/grpc-metadata",
"launcher" "launcher"
] ]
resolver = "2"
[workspace.package] [workspace.package]
version = "1.3.4" version = "1.3.4"

View File

@ -466,7 +466,7 @@ fn latency_paragraph<'a>(latency: &mut Vec<f64>, name: &'static str) -> Paragrap
let latency_percentiles = crate::utils::percentiles(latency, &[50, 90, 99]); let latency_percentiles = crate::utils::percentiles(latency, &[50, 90, 99]);
// Latency p50/p90/p99 texts // Latency p50/p90/p99 texts
let colors = vec![Color::LightGreen, Color::LightYellow, Color::LightRed]; let colors = [Color::LightGreen, Color::LightYellow, Color::LightRed];
for (i, (name, value)) in latency_percentiles.iter().enumerate() { for (i, (name, value)) in latency_percentiles.iter().enumerate() {
let span = Line::from(vec![Span::styled( let span = Line::from(vec![Span::styled(
format!("{name}: {value:.2} ms"), format!("{name}: {value:.2} ms"),

View File

@ -53,6 +53,8 @@ impl std::fmt::Display for Quantization {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// To keep in track with `server`. // To keep in track with `server`.
match self { match self {
#[allow(deprecated)]
// Use `eetq` instead, which provides better latencies overall and is drop-in in most cases
Quantization::Bitsandbytes => { Quantization::Bitsandbytes => {
write!(f, "bitsandbytes") write!(f, "bitsandbytes")
} }

View File

@ -224,7 +224,7 @@ pub struct DecodeTimings {
impl DecodeTimings { impl DecodeTimings {
fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self { fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
Self { Self {
concat: concat_ns.map(|v| Duration::from_nanos(v)), concat: concat_ns.map(Duration::from_nanos),
forward: Duration::from_nanos(forward_ns), forward: Duration::from_nanos(forward_ns),
decode: Duration::from_nanos(decode_ns), decode: Duration::from_nanos(decode_ns),
total: Duration::from_nanos(total_ns), total: Duration::from_nanos(total_ns),

View File

@ -127,6 +127,7 @@ impl ShardedClient {
.iter_mut() .iter_mut()
.map(|client| Box::pin(client.prefill(batch.clone()))) .map(|client| Box::pin(client.prefill(batch.clone())))
.collect(); .collect();
#[allow(clippy::type_complexity)]
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> = let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
join_all(futures).await.into_iter().collect(); join_all(futures).await.into_iter().collect();
let mut results = results?; let mut results = results?;
@ -159,6 +160,7 @@ impl ShardedClient {
.iter_mut() .iter_mut()
.map(|client| Box::pin(client.decode(batches.clone()))) .map(|client| Box::pin(client.decode(batches.clone())))
.collect(); .collect();
#[allow(clippy::type_complexity)]
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> = let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
join_all(futures).await.into_iter().collect(); join_all(futures).await.into_iter().collect();
let mut results = results?; let mut results = results?;

View File

@ -71,6 +71,8 @@ struct Args {
ngrok_authtoken: Option<String>, ngrok_authtoken: Option<String>,
#[clap(long, env)] #[clap(long, env)]
ngrok_edge: Option<String>, ngrok_edge: Option<String>,
#[clap(long, env, default_value_t = false)]
chat_enabled_api: bool,
} }
#[tokio::main] #[tokio::main]
@ -102,6 +104,7 @@ async fn main() -> Result<(), RouterError> {
ngrok, ngrok,
ngrok_authtoken, ngrok_authtoken,
ngrok_edge, ngrok_edge,
chat_enabled_api,
} = args; } = args;
// Launch Tokio runtime // Launch Tokio runtime
@ -345,6 +348,7 @@ async fn main() -> Result<(), RouterError> {
ngrok_authtoken, ngrok_authtoken,
ngrok_edge, ngrok_edge,
tokenizer_config, tokenizer_config,
chat_enabled_api,
) )
.await?; .await?;
Ok(()) Ok(())

View File

@ -708,6 +708,7 @@ pub async fn run(
ngrok_authtoken: Option<String>, ngrok_authtoken: Option<String>,
ngrok_edge: Option<String>, ngrok_edge: Option<String>,
tokenizer_config: HubTokenizerConfig, tokenizer_config: HubTokenizerConfig,
chat_enabled_api: bool,
) -> Result<(), axum::BoxError> { ) -> Result<(), axum::BoxError> {
// OpenAPI documentation // OpenAPI documentation
#[derive(OpenApi)] #[derive(OpenApi)]
@ -856,25 +857,32 @@ pub async fn run(
docker_label: option_env!("DOCKER_LABEL"), docker_label: option_env!("DOCKER_LABEL"),
}; };
// Create router // Configure Swagger UI
let app = Router::new() let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi());
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
// Base routes // Define base and health routes
let base_routes = Router::new()
.route("/", post(compat_generate)) .route("/", post(compat_generate))
.route("/info", get(get_model_info)) .route("/info", get(get_model_info))
.route("/generate", post(generate)) .route("/generate", post(generate))
.route("/generate_stream", post(generate_stream)) .route("/generate_stream", post(generate_stream))
.route("/v1/chat/completions", post(chat_completions)) .route("/v1/chat/completions", post(chat_completions))
// AWS Sagemaker route
.route("/invocations", post(compat_generate))
// Base Health route
.route("/health", get(health)) .route("/health", get(health))
// Inference API health route
.route("/", get(health))
// AWS Sagemaker health route
.route("/ping", get(health)) .route("/ping", get(health))
// Prometheus metrics route .route("/metrics", get(metrics));
.route("/metrics", get(metrics))
// Conditional AWS Sagemaker route
let aws_sagemaker_route = if chat_enabled_api {
Router::new().route("/invocations", post(chat_completions)) // Use 'chat_completions' for OAI_ENABLED
} else {
Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise
};
// Combine routes and layers
let app = Router::new()
.merge(swagger_ui)
.merge(base_routes)
.merge(aws_sagemaker_route)
.layer(Extension(info)) .layer(Extension(info))
.layer(Extension(health_ext.clone())) .layer(Extension(health_ext.clone()))
.layer(Extension(compat_return_full_text)) .layer(Extension(compat_return_full_text))

View File

@ -1,3 +1,6 @@
[toolchain] [toolchain]
channel = "1.70.0" # Released on: 28 December, 2023
# Branched from master on: 10 November, 2023
# https://releases.rs/docs/1.75.0/
channel = "1.75.0"
components = ["rustfmt", "clippy"] components = ["rustfmt", "clippy"]