mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Merge branch 'huggingface:main' into add_sealion_mpt_support
This commit is contained in:
commit
10fce5bffd
@ -6,6 +6,7 @@ members = [
|
||||
"router/grpc-metadata",
|
||||
"launcher"
|
||||
]
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "1.3.4"
|
||||
|
@ -466,7 +466,7 @@ fn latency_paragraph<'a>(latency: &mut Vec<f64>, name: &'static str) -> Paragrap
|
||||
let latency_percentiles = crate::utils::percentiles(latency, &[50, 90, 99]);
|
||||
|
||||
// Latency p50/p90/p99 texts
|
||||
let colors = vec![Color::LightGreen, Color::LightYellow, Color::LightRed];
|
||||
let colors = [Color::LightGreen, Color::LightYellow, Color::LightRed];
|
||||
for (i, (name, value)) in latency_percentiles.iter().enumerate() {
|
||||
let span = Line::from(vec![Span::styled(
|
||||
format!("{name}: {value:.2} ms"),
|
||||
|
@ -53,6 +53,8 @@ impl std::fmt::Display for Quantization {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
// To keep in track with `server`.
|
||||
match self {
|
||||
#[allow(deprecated)]
|
||||
// Use `eetq` instead, which provides better latencies overall and is drop-in in most cases
|
||||
Quantization::Bitsandbytes => {
|
||||
write!(f, "bitsandbytes")
|
||||
}
|
||||
|
@ -224,7 +224,7 @@ pub struct DecodeTimings {
|
||||
impl DecodeTimings {
|
||||
fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
||||
Self {
|
||||
concat: concat_ns.map(|v| Duration::from_nanos(v)),
|
||||
concat: concat_ns.map(Duration::from_nanos),
|
||||
forward: Duration::from_nanos(forward_ns),
|
||||
decode: Duration::from_nanos(decode_ns),
|
||||
total: Duration::from_nanos(total_ns),
|
||||
|
@ -127,6 +127,7 @@ impl ShardedClient {
|
||||
.iter_mut()
|
||||
.map(|client| Box::pin(client.prefill(batch.clone())))
|
||||
.collect();
|
||||
#[allow(clippy::type_complexity)]
|
||||
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
|
||||
join_all(futures).await.into_iter().collect();
|
||||
let mut results = results?;
|
||||
@ -159,6 +160,7 @@ impl ShardedClient {
|
||||
.iter_mut()
|
||||
.map(|client| Box::pin(client.decode(batches.clone())))
|
||||
.collect();
|
||||
#[allow(clippy::type_complexity)]
|
||||
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
|
||||
join_all(futures).await.into_iter().collect();
|
||||
let mut results = results?;
|
||||
|
@ -71,6 +71,8 @@ struct Args {
|
||||
ngrok_authtoken: Option<String>,
|
||||
#[clap(long, env)]
|
||||
ngrok_edge: Option<String>,
|
||||
#[clap(long, env, default_value_t = false)]
|
||||
chat_enabled_api: bool,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
@ -102,6 +104,7 @@ async fn main() -> Result<(), RouterError> {
|
||||
ngrok,
|
||||
ngrok_authtoken,
|
||||
ngrok_edge,
|
||||
chat_enabled_api,
|
||||
} = args;
|
||||
|
||||
// Launch Tokio runtime
|
||||
@ -345,6 +348,7 @@ async fn main() -> Result<(), RouterError> {
|
||||
ngrok_authtoken,
|
||||
ngrok_edge,
|
||||
tokenizer_config,
|
||||
chat_enabled_api,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
|
@ -708,6 +708,7 @@ pub async fn run(
|
||||
ngrok_authtoken: Option<String>,
|
||||
ngrok_edge: Option<String>,
|
||||
tokenizer_config: HubTokenizerConfig,
|
||||
chat_enabled_api: bool,
|
||||
) -> Result<(), axum::BoxError> {
|
||||
// OpenAPI documentation
|
||||
#[derive(OpenApi)]
|
||||
@ -856,25 +857,32 @@ pub async fn run(
|
||||
docker_label: option_env!("DOCKER_LABEL"),
|
||||
};
|
||||
|
||||
// Create router
|
||||
let app = Router::new()
|
||||
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
|
||||
// Base routes
|
||||
// Configure Swagger UI
|
||||
let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi());
|
||||
|
||||
// Define base and health routes
|
||||
let base_routes = Router::new()
|
||||
.route("/", post(compat_generate))
|
||||
.route("/info", get(get_model_info))
|
||||
.route("/generate", post(generate))
|
||||
.route("/generate_stream", post(generate_stream))
|
||||
.route("/v1/chat/completions", post(chat_completions))
|
||||
// AWS Sagemaker route
|
||||
.route("/invocations", post(compat_generate))
|
||||
// Base Health route
|
||||
.route("/health", get(health))
|
||||
// Inference API health route
|
||||
.route("/", get(health))
|
||||
// AWS Sagemaker health route
|
||||
.route("/ping", get(health))
|
||||
// Prometheus metrics route
|
||||
.route("/metrics", get(metrics))
|
||||
.route("/metrics", get(metrics));
|
||||
|
||||
// Conditional AWS Sagemaker route
|
||||
let aws_sagemaker_route = if chat_enabled_api {
|
||||
Router::new().route("/invocations", post(chat_completions)) // Use 'chat_completions' for OAI_ENABLED
|
||||
} else {
|
||||
Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise
|
||||
};
|
||||
|
||||
// Combine routes and layers
|
||||
let app = Router::new()
|
||||
.merge(swagger_ui)
|
||||
.merge(base_routes)
|
||||
.merge(aws_sagemaker_route)
|
||||
.layer(Extension(info))
|
||||
.layer(Extension(health_ext.clone()))
|
||||
.layer(Extension(compat_return_full_text))
|
||||
|
@ -1,3 +1,6 @@
|
||||
[toolchain]
|
||||
channel = "1.70.0"
|
||||
# Released on: 28 December, 2023
|
||||
# Branched from master on: 10 November, 2023
|
||||
# https://releases.rs/docs/1.75.0/
|
||||
channel = "1.75.0"
|
||||
components = ["rustfmt", "clippy"]
|
Loading…
Reference in New Issue
Block a user