diff --git a/router/src/main.rs b/router/src/main.rs index 4cd89d68..e482d834 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -10,8 +10,9 @@ use opentelemetry_otlp::WithExportConfig; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::path::Path; use std::time::Duration; -use text_generation_client::ShardedClient; +use text_generation_client::{ClientError, ShardedClient}; use text_generation_router::{server, HubModelInfo}; +use thiserror::Error; use tokenizers::{FromPretrainedParameters, Tokenizer}; use tower_http::cors::AllowOrigin; use tracing_subscriber::layer::SubscriberExt; @@ -70,7 +71,7 @@ struct Args { ngrok_password: Option, } -fn main() -> Result<(), std::io::Error> { +fn main() -> Result<(), RouterError> { // Get args let args = Args::parse(); // Pattern match configuration @@ -149,8 +150,7 @@ fn main() -> Result<(), std::io::Error> { // Launch Tokio runtime tokio::runtime::Builder::new_multi_thread() .enable_all() - .build() - .unwrap() + .build()? .block_on(async { init_logging(otlp_endpoint, json_output); @@ -192,17 +192,14 @@ fn main() -> Result<(), std::io::Error> { // Instantiate sharded client from the master unix socket let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) .await - .expect("Could not connect to server"); + .map_err(RouterError::Connection)?; // Clear the cache; useful if the webserver rebooted sharded_client .clear_cache(None) .await - .expect("Unable to clear cache"); + .map_err(RouterError::Cache)?; // Get info from the shard - let shard_info = sharded_client - .info() - .await - .expect("Unable to get shard info"); + let shard_info = sharded_client.info().await.map_err(RouterError::Info)?; // Warmup model tracing::info!("Warming up model"); @@ -213,7 +210,7 @@ fn main() -> Result<(), std::io::Error> { max_batch_total_tokens, ) .await - .expect("Unable to warmup model"); + .map_err(RouterError::Warmup)?; tracing::info!("Connected"); let addr = match hostname.parse() { @@ -249,7 +246,7 @@ fn main() -> Result<(), std::io::Error> { ngrok_username, ngrok_password, ) - .await; + .await?; Ok(()) }) } @@ -331,3 +328,19 @@ pub async fn get_model_info( } None } + +#[derive(Debug, Error)] +enum RouterError { + #[error("Unable to connect to the Python model shards: {0}")] + Connection(ClientError), + #[error("Unable to clear the Python model shards cache: {0}")] + Cache(ClientError), + #[error("Unable to get the Python model shards info: {0}")] + Info(ClientError), + #[error("Unable to warmup the Python model shards: {0}")] + Warmup(ClientError), + #[error("Tokio runtime failed to start: {0}")] + Tokio(#[from] std::io::Error), + #[error("Axum webserver failed: {0}")] + Axum(#[from] axum::BoxError), +} diff --git a/router/src/server.rs b/router/src/server.rs index 54418f84..8ca463c2 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -527,7 +527,7 @@ pub async fn run( ngrok_domain: Option, ngrok_username: Option, ngrok_password: Option, -) { +) -> Result<(), axum::BoxError> { // OpenAPI documentation #[derive(OpenApi)] #[openapi( @@ -726,8 +726,7 @@ pub async fn run( .serve(app.into_make_service()) //Wait until all requests are finished to shut down .with_graceful_shutdown(shutdown_signal()) - .await - .unwrap(); + .await?; } #[cfg(not(feature = "ngrok"))] { @@ -744,9 +743,9 @@ pub async fn run( .serve(app.into_make_service()) // Wait until all requests are finished to shut down .with_graceful_shutdown(shutdown_signal()) - .await - .unwrap(); + .await?; } + Ok(()) } /// Shutdown signal handler diff --git a/server/text_generation_server/models/custom_modeling/bloom_modeling.py b/server/text_generation_server/models/custom_modeling/bloom_modeling.py index e5e87645..047a1872 100644 --- a/server/text_generation_server/models/custom_modeling/bloom_modeling.py +++ b/server/text_generation_server/models/custom_modeling/bloom_modeling.py @@ -256,6 +256,11 @@ class BloomAttention(nn.Module): self.beta = 1.0 process_group = weights.process_group + if self.num_heads % process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {process_group.size()}" + ) self.num_heads = self.num_heads // process_group.size() self.query_key_value = TensorParallelColumnLinear.load( config=config, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index d224a838..d9f3c7b8 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -112,6 +112,11 @@ class FlashLlamaAttention(torch.nn.Module): self.softmax_scale = self.head_size**-0.5 + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) self.num_heads = self.num_heads // weights.process_group.size() self.query_key_value = TensorParallelColumnLinear.load_multi( config, diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 23c5e4ff..b2dce226 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -95,6 +95,12 @@ class FlashNeoxAttention(torch.nn.Module): self.num_heads = num_heads self.hidden_size = hidden_size self.head_size = hidden_size // num_heads + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) self.num_heads = self.num_heads // weights.process_group.size() self.rotary_emb = PositionRotaryEmbedding.load( diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index cd42bfc2..acac2744 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -118,6 +118,12 @@ class FlashRWAttention(torch.nn.Module): dim=self.head_size, base=10000.0, device=weights.device ) self.softmax_scale = self.head_size ** (-0.5) + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) self.num_heads = self.num_heads // weights.process_group.size() self.query_key_value = TensorParallelColumnLinear.load( diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 855a6e11..5b1c6e21 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -208,7 +208,11 @@ class FlashMQAttention(torch.nn.Module): self.hidden_size = hidden_size self.head_size = hidden_size // num_heads - assert self.num_heads % weights.process_group.size() == 0 + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) self.num_heads = self.num_heads // weights.process_group.size() self.softmax_scale = self.head_size ** (-0.5) diff --git a/server/text_generation_server/models/custom_modeling/mpt_modeling.py b/server/text_generation_server/models/custom_modeling/mpt_modeling.py index 5ea204e1..e6057116 100644 --- a/server/text_generation_server/models/custom_modeling/mpt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mpt_modeling.py @@ -319,6 +319,12 @@ class MultiheadAttention(nn.Module): if self.softmax_scale is None: self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) self.attn_dropout_p = config.attn_config["attn_pdrop"] + + if self.n_heads % weights.process_group.size() != 0: + raise ValueError( + f"`n_heads` must be divisible by `num_shards` (got `n_heads`: {self.n_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) self.n_heads = self.n_heads // weights.process_group.size() self.Wqkv = load_col( config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias diff --git a/server/text_generation_server/models/custom_modeling/neox_modeling.py b/server/text_generation_server/models/custom_modeling/neox_modeling.py index bf2656d1..1951b171 100644 --- a/server/text_generation_server/models/custom_modeling/neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/neox_modeling.py @@ -154,7 +154,12 @@ class GPTNeoXAttention(nn.Module): torch.tensor(self.head_size, dtype=torch.float32) ).to(torch.get_default_dtype()) - assert self.num_attention_heads % weights.process_group.size() == 0 + if self.num_attention_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_attention_heads` must be divisible by `num_shards` " + f"(got `num_attention_heads`: {self.num_attention_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) self.num_attention_heads = ( self.num_attention_heads // weights.process_group.size() ) diff --git a/server/text_generation_server/models/custom_modeling/opt_modeling.py b/server/text_generation_server/models/custom_modeling/opt_modeling.py index 03fded50..aa052b08 100644 --- a/server/text_generation_server/models/custom_modeling/opt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/opt_modeling.py @@ -147,7 +147,11 @@ class OPTAttention(nn.Module): self.is_decoder = is_decoder process_group = weights.process_group - assert self.num_heads % process_group.size() == 0 + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) self.num_heads = self.num_heads // process_group.size() self.embed_dim = self.embed_dim // process_group.size() diff --git a/server/text_generation_server/models/custom_modeling/t5_modeling.py b/server/text_generation_server/models/custom_modeling/t5_modeling.py index 19deca86..1ea82802 100644 --- a/server/text_generation_server/models/custom_modeling/t5_modeling.py +++ b/server/text_generation_server/models/custom_modeling/t5_modeling.py @@ -246,6 +246,11 @@ class T5Attention(nn.Module): self.o = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o", weights=weights, bias=False ) + if self.n_heads % weights.process_group.size() != 0: + raise ValueError( + f"`n_heads` must be divisible by `num_shards` (got `n_heads`: {self.n_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) self.n_heads = self.n_heads // process_group.size() self.inner_dim = self.inner_dim // process_group.size() diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index bebd3df5..5420556b 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -727,12 +727,11 @@ class FlashCausalLM(Model): ) _, batch = self.generate_token(batch) except Exception as e: - logger.exception( + raise RuntimeError( f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} " f"prefill tokens. " f"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`" - ) - raise e + ) from e del batch torch.cuda.empty_cache()