diff --git a/.github/workflows/load_test.yaml b/.github/workflows/load_test.yaml index 10e248e3..fd22e395 100644 --- a/.github/workflows/load_test.yaml +++ b/.github/workflows/load_test.yaml @@ -15,11 +15,11 @@ jobs: name: Start self-hosted EC2 runner runs-on: ubuntu-latest env: - AWS_REGION: us-east-1 - EC2_AMI_ID: ami-03cfed9ea28f4b002 + AWS_REGION: eu-central-1 + EC2_AMI_ID: ami-0ab09c07cfd194259 EC2_INSTANCE_TYPE: g5.12xlarge - EC2_SUBNET_ID: subnet-931b34f5,subnet-ecb993cd,subnet-943dc2d8,subnet-45371f1a,subnet-ee93e0df,subnet-fddc3dfc - EC2_SECURITY_GROUP: sg-04d472c808f365022 + EC2_SUBNET_ID: subnet-988fd9f2,subnet-6f56db13,subnet-6a039326 + EC2_SECURITY_GROUP: sg-072f92ae3082936c6 outputs: label: ${{ steps.start-ec2-runner.outputs.label }} ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }} @@ -90,7 +90,7 @@ jobs: - load-tests runs-on: ubuntu-latest env: - AWS_REGION: us-east-1 + AWS_REGION: eu-central-1 if: ${{ always() }} # required to stop the runner even if the error happened in the previous jobs steps: - name: Configure AWS credentials @@ -105,4 +105,4 @@ jobs: mode: stop github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }} label: ${{ needs.start-runner.outputs.label }} - ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }} \ No newline at end of file + ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }} diff --git a/Cargo.lock b/Cargo.lock index b65045ff..1d030249 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2848,7 +2848,7 @@ dependencies = [ [[package]] name = "text-generation-benchmark" -version = "0.9.0" +version = "0.9.1" dependencies = [ "average", "clap", @@ -2868,7 +2868,7 @@ dependencies = [ [[package]] name = "text-generation-client" -version = "0.9.0" +version = "0.9.1" dependencies = [ "futures", "grpc-metadata", @@ -2884,7 +2884,7 @@ dependencies = [ [[package]] name = "text-generation-launcher" -version = "0.9.0" +version = "0.9.1" dependencies = [ "clap", "ctrlc", @@ -2900,7 +2900,7 @@ dependencies = [ [[package]] name = "text-generation-router" -version = "0.9.0" +version = "0.9.1" dependencies = [ "async-stream", "axum", diff --git a/Cargo.toml b/Cargo.toml index ba7a920d..fdfb274e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ members = [ ] [workspace.package] -version = "0.9.0" +version = "0.9.1" edition = "2021" authors = ["Olivier Dehaene"] homepage = "https://github.com/huggingface/text-generation-inference" diff --git a/docs/openapi.json b/docs/openapi.json index b91729d0..aaa360b0 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -10,7 +10,7 @@ "name": "Apache 2.0", "url": "https://www.apache.org/licenses/LICENSE-2.0" }, - "version": "0.9.0" + "version": "0.9.1" }, "paths": { "/": { diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 5b5cb45e..a612eb6d 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -197,6 +197,10 @@ struct Args { #[clap(default_value = "20", long, env)] max_waiting_tokens: usize, + /// The IP address to listen on + #[clap(default_value = "0.0.0.0", long, env)] + hostname: String, + /// The port to listen on. #[clap(default_value = "3000", long, short, env)] port: u16, @@ -874,6 +878,8 @@ fn spawn_webserver( args.waiting_served_ratio.to_string(), "--max-waiting-tokens".to_string(), args.max_waiting_tokens.to_string(), + "--hostname".to_string(), + args.hostname.to_string(), "--port".to_string(), args.port.to_string(), "--master-shard-uds-path".to_string(), diff --git a/router/src/main.rs b/router/src/main.rs index f782be09..e482d834 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -10,8 +10,9 @@ use opentelemetry_otlp::WithExportConfig; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::path::Path; use std::time::Duration; -use text_generation_client::ShardedClient; +use text_generation_client::{ClientError, ShardedClient}; use text_generation_router::{server, HubModelInfo}; +use thiserror::Error; use tokenizers::{FromPretrainedParameters, Tokenizer}; use tower_http::cors::AllowOrigin; use tracing_subscriber::layer::SubscriberExt; @@ -40,6 +41,8 @@ struct Args { max_batch_total_tokens: u32, #[clap(default_value = "20", long, env)] max_waiting_tokens: usize, + #[clap(default_value = "0.0.0.0", long, env)] + hostname: String, #[clap(default_value = "3000", long, short, env)] port: u16, #[clap(default_value = "/tmp/text-generation-server-0", long, env)] @@ -68,7 +71,7 @@ struct Args { ngrok_password: Option, } -fn main() -> Result<(), std::io::Error> { +fn main() -> Result<(), RouterError> { // Get args let args = Args::parse(); // Pattern match configuration @@ -82,6 +85,7 @@ fn main() -> Result<(), std::io::Error> { max_batch_prefill_tokens, max_batch_total_tokens, max_waiting_tokens, + hostname, port, master_shard_uds_path, tokenizer_name, @@ -146,8 +150,7 @@ fn main() -> Result<(), std::io::Error> { // Launch Tokio runtime tokio::runtime::Builder::new_multi_thread() .enable_all() - .build() - .unwrap() + .build()? .block_on(async { init_logging(otlp_endpoint, json_output); @@ -189,17 +192,14 @@ fn main() -> Result<(), std::io::Error> { // Instantiate sharded client from the master unix socket let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) .await - .expect("Could not connect to server"); + .map_err(RouterError::Connection)?; // Clear the cache; useful if the webserver rebooted sharded_client .clear_cache(None) .await - .expect("Unable to clear cache"); + .map_err(RouterError::Cache)?; // Get info from the shard - let shard_info = sharded_client - .info() - .await - .expect("Unable to get shard info"); + let shard_info = sharded_client.info().await.map_err(RouterError::Info)?; // Warmup model tracing::info!("Warming up model"); @@ -210,11 +210,16 @@ fn main() -> Result<(), std::io::Error> { max_batch_total_tokens, ) .await - .expect("Unable to warmup model"); + .map_err(RouterError::Warmup)?; tracing::info!("Connected"); - // Binds on localhost - let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port); + let addr = match hostname.parse() { + Ok(ip) => SocketAddr::new(ip, port), + Err(_) => { + tracing::warn!("Invalid hostname, defaulting to 0.0.0.0"); + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port) + } + }; // Run server server::run( @@ -241,7 +246,7 @@ fn main() -> Result<(), std::io::Error> { ngrok_username, ngrok_password, ) - .await; + .await?; Ok(()) }) } @@ -323,3 +328,19 @@ pub async fn get_model_info( } None } + +#[derive(Debug, Error)] +enum RouterError { + #[error("Unable to connect to the Python model shards: {0}")] + Connection(ClientError), + #[error("Unable to clear the Python model shards cache: {0}")] + Cache(ClientError), + #[error("Unable to get the Python model shards info: {0}")] + Info(ClientError), + #[error("Unable to warmup the Python model shards: {0}")] + Warmup(ClientError), + #[error("Tokio runtime failed to start: {0}")] + Tokio(#[from] std::io::Error), + #[error("Axum webserver failed: {0}")] + Axum(#[from] axum::BoxError), +} diff --git a/router/src/server.rs b/router/src/server.rs index 54418f84..8ca463c2 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -527,7 +527,7 @@ pub async fn run( ngrok_domain: Option, ngrok_username: Option, ngrok_password: Option, -) { +) -> Result<(), axum::BoxError> { // OpenAPI documentation #[derive(OpenApi)] #[openapi( @@ -726,8 +726,7 @@ pub async fn run( .serve(app.into_make_service()) //Wait until all requests are finished to shut down .with_graceful_shutdown(shutdown_signal()) - .await - .unwrap(); + .await?; } #[cfg(not(feature = "ngrok"))] { @@ -744,9 +743,9 @@ pub async fn run( .serve(app.into_make_service()) // Wait until all requests are finished to shut down .with_graceful_shutdown(shutdown_signal()) - .await - .unwrap(); + .await?; } + Ok(()) } /// Shutdown signal handler diff --git a/server/pyproject.toml b/server/pyproject.toml index fd640e7f..a696d7be 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "text-generation-server" -version = "0.9.0" +version = "0.9.1" description = "Text Generation Inference Python gRPC Server" authors = ["Olivier Dehaene "] diff --git a/server/tests/utils/test_convert.py b/server/tests/utils/test_convert.py index 7dfe6a1e..ba6c5702 100644 --- a/server/tests/utils/test_convert.py +++ b/server/tests/utils/test_convert.py @@ -14,7 +14,7 @@ def test_convert_files(): local_st_files = [ p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files ] - convert_files(local_pt_files, local_st_files) + convert_files(local_pt_files, local_st_files, discard_names=[]) found_st_files = weight_files(model_id) diff --git a/server/text_generation_server/cache.py b/server/text_generation_server/cache.py index 79fcd3aa..4504733e 100644 --- a/server/text_generation_server/cache.py +++ b/server/text_generation_server/cache.py @@ -1,3 +1,5 @@ +import torch + from typing import Dict, Optional, TypeVar from text_generation_server.models.types import Batch @@ -20,6 +22,8 @@ class Cache: batch = self.pop(batch_id) if batch is not None: del batch + if torch.cuda.is_available(): + torch.cuda.empty_cache() def clear(self): keys = list(self.cache.keys()) diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 3463049a..7a55e919 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -160,8 +160,26 @@ def download_weights( p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files ] + try: + from transformers import AutoConfig + import transformers + + config = AutoConfig.from_pretrained( + model_id, + revision=revision, + ) + architecture = config.architectures[0] + + class_ = getattr(transformers, architecture) + + # Name for this varible depends on transformers version. + discard_names = getattr(class_, "_tied_weights_keys", []) + discard_names.extend(getattr(class_, "_keys_to_ignore_on_load_missing", [])) + + except Exception as e: + discard_names = [] # Convert pytorch weights to safetensors - utils.convert_files(local_pt_files, local_st_files) + utils.convert_files(local_pt_files, local_st_files, discard_names) @app.command() diff --git a/server/text_generation_server/interceptor.py b/server/text_generation_server/interceptor.py index a3247d19..725105f3 100644 --- a/server/text_generation_server/interceptor.py +++ b/server/text_generation_server/interceptor.py @@ -1,3 +1,4 @@ +import torch import grpc from google.rpc import status_pb2, code_pb2 @@ -22,6 +23,9 @@ class ExceptionInterceptor(AsyncServerInterceptor): method_name = method_name.split("/")[-1] logger.exception(f"Method {method_name} encountered an error.") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + await context.abort_with_status( rpc_status.to_status( status_pb2.Status(code=code_pb2.INTERNAL, message=str(err)) diff --git a/server/text_generation_server/models/custom_modeling/bloom_modeling.py b/server/text_generation_server/models/custom_modeling/bloom_modeling.py index e5e87645..047a1872 100644 --- a/server/text_generation_server/models/custom_modeling/bloom_modeling.py +++ b/server/text_generation_server/models/custom_modeling/bloom_modeling.py @@ -256,6 +256,11 @@ class BloomAttention(nn.Module): self.beta = 1.0 process_group = weights.process_group + if self.num_heads % process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {process_group.size()}" + ) self.num_heads = self.num_heads // process_group.size() self.query_key_value = TensorParallelColumnLinear.load( config=config, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index d224a838..d9f3c7b8 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -112,6 +112,11 @@ class FlashLlamaAttention(torch.nn.Module): self.softmax_scale = self.head_size**-0.5 + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) self.num_heads = self.num_heads // weights.process_group.size() self.query_key_value = TensorParallelColumnLinear.load_multi( config, diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 23c5e4ff..b2dce226 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -95,6 +95,12 @@ class FlashNeoxAttention(torch.nn.Module): self.num_heads = num_heads self.hidden_size = hidden_size self.head_size = hidden_size // num_heads + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) self.num_heads = self.num_heads // weights.process_group.size() self.rotary_emb = PositionRotaryEmbedding.load( diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index cd42bfc2..acac2744 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -118,6 +118,12 @@ class FlashRWAttention(torch.nn.Module): dim=self.head_size, base=10000.0, device=weights.device ) self.softmax_scale = self.head_size ** (-0.5) + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) self.num_heads = self.num_heads // weights.process_group.size() self.query_key_value = TensorParallelColumnLinear.load( diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index d49254e1..28a25fd5 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -20,6 +20,7 @@ from text_generation_server.utils.layers import ( FastLayerNorm, get_linear, ) +from safetensors import SafetensorError def load_multi_mqa( config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size @@ -72,8 +73,17 @@ def _load_multi_mqa_gptq( qzeros = torch.cat([q_tensor, kv_tensor], dim=1) g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") - bits = weights.get_tensor("gptq_bits").item() - groupsize = weights.get_tensor("gptq_groupsize").item() + try: + bits = weights.get_tensor("gptq_bits").item() + groupsize = weights.get_tensor("gptq_groupsize").item() + except SafetensorError as e: + try: + import os + + bits = int(os.getenv("GPTQ_BITS")) + groupsize = int(os.getenv("GPTQ_GROUPSIZE")) + except Exception: + raise e qweight = qweight.to(weights.device) qzeros = qzeros.to(weights.device) @@ -102,7 +112,6 @@ def _load_multi_mqa_gptq( def _load_multi_mqa( config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size ): - if any("c_attn" in k for k in weights.routing.keys()): slice_ = weights._get_slice(f"{prefix}.c_attn.weight") shape = slice_.get_shape() @@ -211,7 +220,11 @@ class FlashMQAttention(torch.nn.Module): self.hidden_size = hidden_size self.head_size = hidden_size // num_heads - assert self.num_heads % weights.process_group.size() == 0 + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) self.num_heads = self.num_heads // weights.process_group.size() self.softmax_scale = self.head_size ** (-0.5) diff --git a/server/text_generation_server/models/custom_modeling/mpt_modeling.py b/server/text_generation_server/models/custom_modeling/mpt_modeling.py index 5ea204e1..e6057116 100644 --- a/server/text_generation_server/models/custom_modeling/mpt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mpt_modeling.py @@ -319,6 +319,12 @@ class MultiheadAttention(nn.Module): if self.softmax_scale is None: self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) self.attn_dropout_p = config.attn_config["attn_pdrop"] + + if self.n_heads % weights.process_group.size() != 0: + raise ValueError( + f"`n_heads` must be divisible by `num_shards` (got `n_heads`: {self.n_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) self.n_heads = self.n_heads // weights.process_group.size() self.Wqkv = load_col( config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias diff --git a/server/text_generation_server/models/custom_modeling/neox_modeling.py b/server/text_generation_server/models/custom_modeling/neox_modeling.py index bf2656d1..1951b171 100644 --- a/server/text_generation_server/models/custom_modeling/neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/neox_modeling.py @@ -154,7 +154,12 @@ class GPTNeoXAttention(nn.Module): torch.tensor(self.head_size, dtype=torch.float32) ).to(torch.get_default_dtype()) - assert self.num_attention_heads % weights.process_group.size() == 0 + if self.num_attention_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_attention_heads` must be divisible by `num_shards` " + f"(got `num_attention_heads`: {self.num_attention_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) self.num_attention_heads = ( self.num_attention_heads // weights.process_group.size() ) diff --git a/server/text_generation_server/models/custom_modeling/opt_modeling.py b/server/text_generation_server/models/custom_modeling/opt_modeling.py index 03fded50..aa052b08 100644 --- a/server/text_generation_server/models/custom_modeling/opt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/opt_modeling.py @@ -147,7 +147,11 @@ class OPTAttention(nn.Module): self.is_decoder = is_decoder process_group = weights.process_group - assert self.num_heads % process_group.size() == 0 + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) self.num_heads = self.num_heads // process_group.size() self.embed_dim = self.embed_dim // process_group.size() diff --git a/server/text_generation_server/models/custom_modeling/t5_modeling.py b/server/text_generation_server/models/custom_modeling/t5_modeling.py index 19deca86..793f3a66 100644 --- a/server/text_generation_server/models/custom_modeling/t5_modeling.py +++ b/server/text_generation_server/models/custom_modeling/t5_modeling.py @@ -19,6 +19,8 @@ import math import warnings from typing import Optional, Tuple, Union +from loguru import logger + import torch import torch.distributed from torch import nn @@ -246,6 +248,11 @@ class T5Attention(nn.Module): self.o = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o", weights=weights, bias=False ) + if self.n_heads % weights.process_group.size() != 0: + raise ValueError( + f"`n_heads` must be divisible by `num_shards` (got `n_heads`: {self.n_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) self.n_heads = self.n_heads // process_group.size() self.inner_dim = self.inner_dim // process_group.size() @@ -1001,12 +1008,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): super().__init__(config) self.model_dim = config.d_model - try: - self.shared = TensorParallelEmbedding(prefix="shared", weights=weights) - except RuntimeError: - self.shared = TensorParallelEmbedding( - prefix="encoder.embed_tokens", weights=weights - ) + self.shared = TensorParallelEmbedding(prefix="shared", weights=weights) encoder_config = copy.deepcopy(config) encoder_config.is_decoder = False diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index bf5f5bbe..4e5804f5 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -638,6 +638,7 @@ class FlashCausalLMBatch(Batch): # Needed to avoid dropping blocks when the batches will go out of scope for b in batches: b.block_tables = None + del b return FlashCausalLMBatch( batch_id=batches[0].batch_id, @@ -725,12 +726,11 @@ class FlashCausalLM(Model): ) _, batch = self.generate_token(batch) except Exception as e: - logger.exception( + raise RuntimeError( f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} " f"prefill tokens. " f"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`" - ) - raise e + ) from e del batch def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str: @@ -775,16 +775,20 @@ class FlashCausalLM(Model): # Allocate blocks to this batch CACHE_MANAGER.allocate(batch) - out = self.forward( - batch.input_ids, - batch.position_ids, - batch.cu_seqlen_prefill, - batch.block_tables_tensor, - batch.slots[batch.slot_indices], - batch.input_lengths_tensor, - batch.max_seqlen, - batch.prefill_head_indices, - ) + try: + out = self.forward( + batch.input_ids, + batch.position_ids, + batch.cu_seqlen_prefill, + batch.block_tables_tensor, + batch.slots[batch.slot_indices], + batch.input_lengths_tensor, + batch.max_seqlen, + batch.prefill_head_indices, + ) + except Exception as e: + del batch + raise e if prefill: next_token_logits = ( diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index 12b862d7..55d555fc 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -49,7 +49,13 @@ class FlashRWSharded(FlashCausalLM): torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) + weights = Weights( + filenames, + device, + dtype, + process_group=self.process_group, + aliases={"transformer.word_embeddings.weight": ["lm_head.weight"]}, + ) config.quantize = quantize diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 1b7073af..133aafd8 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -55,7 +55,16 @@ class T5Sharded(Seq2SeqLM): torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group + filenames, + device=device, + dtype=dtype, + process_group=self.process_group, + aliases={ + "shared.weight": [ + "encoder.embed_tokens.weight", + "decoder.embed_tokens.weight", + ] + }, ) model = T5ForConditionalGeneration(config, weights) diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index c375330a..7bc62ce6 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -51,6 +51,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): filtered_batch = batch.filter(request.request_ids) self.cache.set(filtered_batch) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) async def Warmup(self, request, context): @@ -58,6 +61,10 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): request.batch, self.model.tokenizer, self.model.dtype, self.model.device ) self.model.warmup(batch, request.max_total_tokens) + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + return generate_pb2.WarmupResponse() async def Prefill(self, request, context): @@ -89,6 +96,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): if len(batches) > 1: batch = self.model.batch_type.concatenate(batches) + if torch.cuda.is_available(): + torch.cuda.empty_cache() else: batch = batches[0] diff --git a/server/text_generation_server/utils/convert.py b/server/text_generation_server/utils/convert.py index 0e4adaba..305263ba 100644 --- a/server/text_generation_server/utils/convert.py +++ b/server/text_generation_server/utils/convert.py @@ -4,11 +4,56 @@ import os from loguru import logger from pathlib import Path -from safetensors.torch import save_file, _remove_duplicate_names, load_file -from typing import List +from safetensors.torch import save_file, load_file, _find_shared_tensors, _is_complete +from typing import List, Dict +from collections import defaultdict -def convert_file(pt_file: Path, sf_file: Path): +def _remove_duplicate_names( + state_dict: Dict[str, torch.Tensor], + *, + preferred_names: List[str] = None, + discard_names: List[str] = None, +) -> Dict[str, List[str]]: + if preferred_names is None: + preferred_names = [] + preferred_names = set(preferred_names) + if discard_names is None: + discard_names = [] + discard_names = set(discard_names) + + shareds = _find_shared_tensors(state_dict) + to_remove = defaultdict(list) + for shared in shareds: + complete_names = set( + [name for name in shared if _is_complete(state_dict[name])] + ) + if not complete_names: + raise RuntimeError( + f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue." + ) + + keep_name = sorted(list(complete_names))[0] + + # Mecanism to preferentially select keys to keep + # coming from the on-disk file to allow + # loading models saved with a different choice + # of keep_name + preferred = complete_names.difference(discard_names) + if preferred: + keep_name = sorted(list(preferred))[0] + + if preferred_names: + preferred = preferred_names.intersection(complete_names) + if preferred: + keep_name = sorted(list(preferred))[0] + for name in sorted(shared): + if name != keep_name: + to_remove[keep_name].append(name) + return to_remove + + +def convert_file(pt_file: Path, sf_file: Path, discard_names: List[str]): """ Convert a pytorch file to a safetensors file This will remove duplicate tensors from the file. @@ -20,7 +65,7 @@ def convert_file(pt_file: Path, sf_file: Path): loaded = torch.load(pt_file, map_location="cpu") if "state_dict" in loaded: loaded = loaded["state_dict"] - to_removes = _remove_duplicate_names(loaded) + to_removes = _remove_duplicate_names(loaded, discard_names=discard_names) metadata = {"format": "pt"} for kept_name, to_remove_group in to_removes.items(): @@ -42,7 +87,7 @@ def convert_file(pt_file: Path, sf_file: Path): raise RuntimeError(f"The output tensors do not match for key {k}") -def convert_files(pt_files: List[Path], sf_files: List[Path]): +def convert_files(pt_files: List[Path], sf_files: List[Path], discard_names: List[str]): assert len(pt_files) == len(sf_files) N = len(pt_files) @@ -50,6 +95,6 @@ def convert_files(pt_files: List[Path], sf_files: List[Path]): for i, (pt_file, sf_file) in enumerate(zip(pt_files, sf_files)): start = datetime.datetime.now() - convert_file(pt_file, sf_file) + convert_file(pt_file, sf_file, discard_names) elapsed = datetime.datetime.now() - start logger.info(f"Convert: [{i + 1}/{N}] -- Took: {elapsed}") diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index b3fa2abb..db392c4a 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -178,13 +178,25 @@ class SuperLayer(nn.Module): class TensorParallelHead(SuperLayer): - def __init__(self, linear, process_group): + def __init__(self, linear, process_group, should_gather: bool): super().__init__(linear) self.process_group = process_group + self.should_gather = should_gather @staticmethod def load(config, prefix: str, weights): - weight = weights.get_sharded(f"{prefix}.weight", dim=0) + if weights.process_group.size() > 1: + try: + weight = weights.get_sharded(f"{prefix}.weight", dim=0) + should_gather = True + except AssertionError: + # If the vocab size is not divisible by number of shards + # just load the entire thing. + weight = weights.get_tensor(f"{prefix}.weight") + should_gather = False + else: + weight = weights.get_tensor(f"{prefix}.weight") + should_gather = False # GPTQ doesn't quantize heads (nor embeddings) if config.quantize == "gptq": @@ -194,13 +206,14 @@ class TensorParallelHead(SuperLayer): return TensorParallelHead( get_linear(weight, bias=None, quantize=quantize), process_group=weights.process_group, + should_gather=should_gather, ) def forward(self, input: torch.Tensor) -> torch.Tensor: - world_size = self.process_group.size() - if world_size == 1: + if not self.should_gather: return super().forward(input) + world_size = self.process_group.size() if len(input.shape) == 2 and isinstance(self.linear, FastLinear): out_dim = self.linear.weight.shape[0] @@ -281,7 +294,7 @@ class TensorParallelRowLinear(SuperLayer): class TensorParallelEmbedding(nn.Module): def __init__(self, prefix: str, weights, reduce=True): super().__init__() - weight = weights.get_sharded(f"{prefix}.weight", dim=0) + weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0) num_embeddings = weights.get_shape(f"{prefix}.weight")[0] process_group = weights.process_group diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 6ad085dc..ff18d656 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -1,8 +1,9 @@ from pathlib import Path from typing import List, Dict, Optional -from safetensors import safe_open +from safetensors import safe_open, SafetensorError import torch + class Weights: def __init__( self, @@ -68,7 +69,7 @@ class Weights: tensor = tensor.to(device=self.device) return tensor - def get_sharded(self, tensor_name: str, dim: int): + def get_partial_sharded(self, tensor_name: str, dim: int): filename, tensor_name = self.get_filename(tensor_name) world_size = self.process_group.size() rank = self.process_group.rank() @@ -80,10 +81,6 @@ class Weights: start = rank * block_size stop = (rank + 1) * block_size - assert ( - size % world_size == 0 - ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" - if dim == 0: tensor = slice_[start:stop] elif dim == 1: @@ -97,29 +94,57 @@ class Weights: tensor = tensor.to(device=self.device) return tensor + def get_sharded(self, tensor_name: str, dim: int): + filename, tensor_name = self.get_filename(tensor_name) + f = self._get_handle(filename) + slice_ = f.get_slice(tensor_name) + world_size = self.process_group.size() + size = slice_.get_shape()[dim] + assert ( + size % world_size == 0 + ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" + return self.get_partial_sharded(tensor_name, dim) + def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): if quantize == "gptq": try: - qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1) + qweight = torch.cat( + [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 + ) except RuntimeError: - raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`") + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) - qzeros = torch.cat([self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1) - scales = torch.cat([self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1) + qzeros = torch.cat( + [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 + ) + scales = torch.cat( + [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 + ) w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] for w2 in w[1:]: torch.testing.assert_close(w2, w[0]) g_idx = w[0] - bits = self.get_tensor("gptq_bits").item() - groupsize = self.get_tensor("gptq_groupsize").item() + try: + bits = self.get_tensor("gptq_bits").item() + groupsize = self.get_tensor("gptq_groupsize").item() + except SafetensorError as e: + try: + import os + + bits = int(os.getenv("GPTQ_BITS")) + groupsize = int(os.getenv("GPTQ_GROUPSIZE")) + except Exception: + raise e weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) else: w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] weight = torch.cat(w, dim=dim) return weight - def get_multi_weights_row(self, prefix: str, quantize: str): + def get_multi_weights_row(self, prefix: str, quantize: str): if quantize == "gptq": use_triton_kernel = False if self.process_group.size() > 1: @@ -155,8 +180,17 @@ class Weights: else: g_idx = None - bits = self.get_tensor("gptq_bits").item() - groupsize = self.get_tensor("gptq_groupsize").item() + try: + bits = self.get_tensor("gptq_bits").item() + groupsize = self.get_tensor("gptq_groupsize").item() + except SafetensorError as e: + try: + import os + + bits = int(os.getenv("GPTQ_BITS")) + groupsize = int(os.getenv("GPTQ_GROUPSIZE")) + except Exception: + raise e weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_triton_kernel) else: