From 6f42942772100d846c71dc15a4a37cb45b648b73 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Wed, 5 Jul 2023 18:28:45 +0200 Subject: [PATCH 01/13] feat(router): add argument for hostname in router (#545) (#550) # What does this PR do? In title. Adds argument `--hostname` in router to support something like `--hostname ::`. Tested with ```commandline cargo run -- --port 8080 --hostname :: curl -I -X GET 'http://[::1]:8080/health' # failed before this commit ``` Trigger CI --------- Co-authored-by: Phil Chen --- launcher/src/main.rs | 6 ++++++ router/src/main.rs | 12 ++++++++++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 5b5cb45e..a612eb6d 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -197,6 +197,10 @@ struct Args { #[clap(default_value = "20", long, env)] max_waiting_tokens: usize, + /// The IP address to listen on + #[clap(default_value = "0.0.0.0", long, env)] + hostname: String, + /// The port to listen on. #[clap(default_value = "3000", long, short, env)] port: u16, @@ -874,6 +878,8 @@ fn spawn_webserver( args.waiting_served_ratio.to_string(), "--max-waiting-tokens".to_string(), args.max_waiting_tokens.to_string(), + "--hostname".to_string(), + args.hostname.to_string(), "--port".to_string(), args.port.to_string(), "--master-shard-uds-path".to_string(), diff --git a/router/src/main.rs b/router/src/main.rs index f782be09..4cd89d68 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -40,6 +40,8 @@ struct Args { max_batch_total_tokens: u32, #[clap(default_value = "20", long, env)] max_waiting_tokens: usize, + #[clap(default_value = "0.0.0.0", long, env)] + hostname: String, #[clap(default_value = "3000", long, short, env)] port: u16, #[clap(default_value = "/tmp/text-generation-server-0", long, env)] @@ -82,6 +84,7 @@ fn main() -> Result<(), std::io::Error> { max_batch_prefill_tokens, max_batch_total_tokens, max_waiting_tokens, + hostname, port, master_shard_uds_path, tokenizer_name, @@ -213,8 +216,13 @@ fn main() -> Result<(), std::io::Error> { .expect("Unable to warmup model"); tracing::info!("Connected"); - // Binds on localhost - let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port); + let addr = match hostname.parse() { + Ok(ip) => SocketAddr::new(ip, port), + Err(_) => { + tracing::warn!("Invalid hostname, defaulting to 0.0.0.0"); + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port) + } + }; // Run server server::run( From c4bb5264acacf55f8f1e693ccd29f7a13d3bcba0 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Thu, 6 Jul 2023 14:28:33 +0200 Subject: [PATCH 02/13] fix(server): decrease memory fragmentation (#557) --- server/text_generation_server/cache.py | 4 +++ .../models/flash_causal_lm.py | 28 ++++++++++++------- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/server/text_generation_server/cache.py b/server/text_generation_server/cache.py index 79fcd3aa..4504733e 100644 --- a/server/text_generation_server/cache.py +++ b/server/text_generation_server/cache.py @@ -1,3 +1,5 @@ +import torch + from typing import Dict, Optional, TypeVar from text_generation_server.models.types import Batch @@ -20,6 +22,8 @@ class Cache: batch = self.pop(batch_id) if batch is not None: del batch + if torch.cuda.is_available(): + torch.cuda.empty_cache() def clear(self): keys = list(self.cache.keys()) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index bf5f5bbe..bebd3df5 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -638,6 +638,8 @@ class FlashCausalLMBatch(Batch): # Needed to avoid dropping blocks when the batches will go out of scope for b in batches: b.block_tables = None + del b + torch.cuda.empty_cache() return FlashCausalLMBatch( batch_id=batches[0].batch_id, @@ -732,6 +734,7 @@ class FlashCausalLM(Model): ) raise e del batch + torch.cuda.empty_cache() def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str: return self.tokenizer.decode( @@ -775,16 +778,21 @@ class FlashCausalLM(Model): # Allocate blocks to this batch CACHE_MANAGER.allocate(batch) - out = self.forward( - batch.input_ids, - batch.position_ids, - batch.cu_seqlen_prefill, - batch.block_tables_tensor, - batch.slots[batch.slot_indices], - batch.input_lengths_tensor, - batch.max_seqlen, - batch.prefill_head_indices, - ) + try: + out = self.forward( + batch.input_ids, + batch.position_ids, + batch.cu_seqlen_prefill, + batch.block_tables_tensor, + batch.slots[batch.slot_indices], + batch.input_lengths_tensor, + batch.max_seqlen, + batch.prefill_head_indices, + ) + except Exception as e: + del batch + torch.cuda.empty_cache() + raise e if prefill: next_token_logits = ( From 31b36cca21fcd0e6b7db477a7545063e1b860156 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Thu, 6 Jul 2023 16:05:42 +0200 Subject: [PATCH 03/13] v0.9.1 (#558) --- Cargo.lock | 8 ++++---- Cargo.toml | 2 +- docs/openapi.json | 2 +- server/pyproject.toml | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b65045ff..1d030249 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2848,7 +2848,7 @@ dependencies = [ [[package]] name = "text-generation-benchmark" -version = "0.9.0" +version = "0.9.1" dependencies = [ "average", "clap", @@ -2868,7 +2868,7 @@ dependencies = [ [[package]] name = "text-generation-client" -version = "0.9.0" +version = "0.9.1" dependencies = [ "futures", "grpc-metadata", @@ -2884,7 +2884,7 @@ dependencies = [ [[package]] name = "text-generation-launcher" -version = "0.9.0" +version = "0.9.1" dependencies = [ "clap", "ctrlc", @@ -2900,7 +2900,7 @@ dependencies = [ [[package]] name = "text-generation-router" -version = "0.9.0" +version = "0.9.1" dependencies = [ "async-stream", "axum", diff --git a/Cargo.toml b/Cargo.toml index ba7a920d..fdfb274e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ members = [ ] [workspace.package] -version = "0.9.0" +version = "0.9.1" edition = "2021" authors = ["Olivier Dehaene"] homepage = "https://github.com/huggingface/text-generation-inference" diff --git a/docs/openapi.json b/docs/openapi.json index b91729d0..aaa360b0 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -10,7 +10,7 @@ "name": "Apache 2.0", "url": "https://www.apache.org/licenses/LICENSE-2.0" }, - "version": "0.9.0" + "version": "0.9.1" }, "paths": { "/": { diff --git a/server/pyproject.toml b/server/pyproject.toml index fd640e7f..a696d7be 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "text-generation-server" -version = "0.9.0" +version = "0.9.1" description = "Text Generation Inference Python gRPC Server" authors = ["Olivier Dehaene "] From e943a294bca239e26828732dd6ab5b6f95dadd0a Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 7 Jul 2023 14:50:12 +0200 Subject: [PATCH 04/13] fix(server): harden the weights choice to save on disk. (#561) - Look at `transformers` base class to check for `_key_to_ignore_on_load_missing` or `_tied_weights` which are the standard attributes to select the keys to NOT save on disk (since they are ignored) - Modified safetensors code (to be reflected in safetensors even if it's an internal function). - Will not work for trust_remote_code=True repos (like santacoder). Should help with : https://github.com/huggingface/text-generation-inference/issues/555 and : https://github.com/huggingface/text-generation-inference/pull/501 and https://github.com/huggingface/text-generation-inference/issues/556 and https://github.com/huggingface/text-generation-inference/issues/482#issuecomment-1623713593 --- server/tests/utils/test_convert.py | 2 +- server/text_generation_server/cli.py | 20 ++++++- .../text_generation_server/utils/convert.py | 57 +++++++++++++++++-- 3 files changed, 71 insertions(+), 8 deletions(-) diff --git a/server/tests/utils/test_convert.py b/server/tests/utils/test_convert.py index 7dfe6a1e..ba6c5702 100644 --- a/server/tests/utils/test_convert.py +++ b/server/tests/utils/test_convert.py @@ -14,7 +14,7 @@ def test_convert_files(): local_st_files = [ p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files ] - convert_files(local_pt_files, local_st_files) + convert_files(local_pt_files, local_st_files, discard_names=[]) found_st_files = weight_files(model_id) diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 3463049a..7a55e919 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -160,8 +160,26 @@ def download_weights( p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files ] + try: + from transformers import AutoConfig + import transformers + + config = AutoConfig.from_pretrained( + model_id, + revision=revision, + ) + architecture = config.architectures[0] + + class_ = getattr(transformers, architecture) + + # Name for this varible depends on transformers version. + discard_names = getattr(class_, "_tied_weights_keys", []) + discard_names.extend(getattr(class_, "_keys_to_ignore_on_load_missing", [])) + + except Exception as e: + discard_names = [] # Convert pytorch weights to safetensors - utils.convert_files(local_pt_files, local_st_files) + utils.convert_files(local_pt_files, local_st_files, discard_names) @app.command() diff --git a/server/text_generation_server/utils/convert.py b/server/text_generation_server/utils/convert.py index 0e4adaba..305263ba 100644 --- a/server/text_generation_server/utils/convert.py +++ b/server/text_generation_server/utils/convert.py @@ -4,11 +4,56 @@ import os from loguru import logger from pathlib import Path -from safetensors.torch import save_file, _remove_duplicate_names, load_file -from typing import List +from safetensors.torch import save_file, load_file, _find_shared_tensors, _is_complete +from typing import List, Dict +from collections import defaultdict -def convert_file(pt_file: Path, sf_file: Path): +def _remove_duplicate_names( + state_dict: Dict[str, torch.Tensor], + *, + preferred_names: List[str] = None, + discard_names: List[str] = None, +) -> Dict[str, List[str]]: + if preferred_names is None: + preferred_names = [] + preferred_names = set(preferred_names) + if discard_names is None: + discard_names = [] + discard_names = set(discard_names) + + shareds = _find_shared_tensors(state_dict) + to_remove = defaultdict(list) + for shared in shareds: + complete_names = set( + [name for name in shared if _is_complete(state_dict[name])] + ) + if not complete_names: + raise RuntimeError( + f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue." + ) + + keep_name = sorted(list(complete_names))[0] + + # Mecanism to preferentially select keys to keep + # coming from the on-disk file to allow + # loading models saved with a different choice + # of keep_name + preferred = complete_names.difference(discard_names) + if preferred: + keep_name = sorted(list(preferred))[0] + + if preferred_names: + preferred = preferred_names.intersection(complete_names) + if preferred: + keep_name = sorted(list(preferred))[0] + for name in sorted(shared): + if name != keep_name: + to_remove[keep_name].append(name) + return to_remove + + +def convert_file(pt_file: Path, sf_file: Path, discard_names: List[str]): """ Convert a pytorch file to a safetensors file This will remove duplicate tensors from the file. @@ -20,7 +65,7 @@ def convert_file(pt_file: Path, sf_file: Path): loaded = torch.load(pt_file, map_location="cpu") if "state_dict" in loaded: loaded = loaded["state_dict"] - to_removes = _remove_duplicate_names(loaded) + to_removes = _remove_duplicate_names(loaded, discard_names=discard_names) metadata = {"format": "pt"} for kept_name, to_remove_group in to_removes.items(): @@ -42,7 +87,7 @@ def convert_file(pt_file: Path, sf_file: Path): raise RuntimeError(f"The output tensors do not match for key {k}") -def convert_files(pt_files: List[Path], sf_files: List[Path]): +def convert_files(pt_files: List[Path], sf_files: List[Path], discard_names: List[str]): assert len(pt_files) == len(sf_files) N = len(pt_files) @@ -50,6 +95,6 @@ def convert_files(pt_files: List[Path], sf_files: List[Path]): for i, (pt_file, sf_file) in enumerate(zip(pt_files, sf_files)): start = datetime.datetime.now() - convert_file(pt_file, sf_file) + convert_file(pt_file, sf_file, discard_names) elapsed = datetime.datetime.now() - start logger.info(f"Convert: [{i + 1}/{N}] -- Took: {elapsed}") From b4024edd4549ab647b02b8619b4072e33e64f1f9 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Mon, 10 Jul 2023 14:47:15 +0200 Subject: [PATCH 05/13] feat: better errors for warmup and TP (#575) Close #571 --- router/src/main.rs | 37 +++++++++++++------ router/src/server.rs | 9 ++--- .../models/custom_modeling/bloom_modeling.py | 5 +++ .../custom_modeling/flash_llama_modeling.py | 5 +++ .../custom_modeling/flash_neox_modeling.py | 6 +++ .../custom_modeling/flash_rw_modeling.py | 6 +++ .../flash_santacoder_modeling.py | 6 ++- .../models/custom_modeling/mpt_modeling.py | 6 +++ .../models/custom_modeling/neox_modeling.py | 7 +++- .../models/custom_modeling/opt_modeling.py | 6 ++- .../models/custom_modeling/t5_modeling.py | 5 +++ .../models/flash_causal_lm.py | 5 +-- 12 files changed, 80 insertions(+), 23 deletions(-) diff --git a/router/src/main.rs b/router/src/main.rs index 4cd89d68..e482d834 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -10,8 +10,9 @@ use opentelemetry_otlp::WithExportConfig; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::path::Path; use std::time::Duration; -use text_generation_client::ShardedClient; +use text_generation_client::{ClientError, ShardedClient}; use text_generation_router::{server, HubModelInfo}; +use thiserror::Error; use tokenizers::{FromPretrainedParameters, Tokenizer}; use tower_http::cors::AllowOrigin; use tracing_subscriber::layer::SubscriberExt; @@ -70,7 +71,7 @@ struct Args { ngrok_password: Option, } -fn main() -> Result<(), std::io::Error> { +fn main() -> Result<(), RouterError> { // Get args let args = Args::parse(); // Pattern match configuration @@ -149,8 +150,7 @@ fn main() -> Result<(), std::io::Error> { // Launch Tokio runtime tokio::runtime::Builder::new_multi_thread() .enable_all() - .build() - .unwrap() + .build()? .block_on(async { init_logging(otlp_endpoint, json_output); @@ -192,17 +192,14 @@ fn main() -> Result<(), std::io::Error> { // Instantiate sharded client from the master unix socket let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) .await - .expect("Could not connect to server"); + .map_err(RouterError::Connection)?; // Clear the cache; useful if the webserver rebooted sharded_client .clear_cache(None) .await - .expect("Unable to clear cache"); + .map_err(RouterError::Cache)?; // Get info from the shard - let shard_info = sharded_client - .info() - .await - .expect("Unable to get shard info"); + let shard_info = sharded_client.info().await.map_err(RouterError::Info)?; // Warmup model tracing::info!("Warming up model"); @@ -213,7 +210,7 @@ fn main() -> Result<(), std::io::Error> { max_batch_total_tokens, ) .await - .expect("Unable to warmup model"); + .map_err(RouterError::Warmup)?; tracing::info!("Connected"); let addr = match hostname.parse() { @@ -249,7 +246,7 @@ fn main() -> Result<(), std::io::Error> { ngrok_username, ngrok_password, ) - .await; + .await?; Ok(()) }) } @@ -331,3 +328,19 @@ pub async fn get_model_info( } None } + +#[derive(Debug, Error)] +enum RouterError { + #[error("Unable to connect to the Python model shards: {0}")] + Connection(ClientError), + #[error("Unable to clear the Python model shards cache: {0}")] + Cache(ClientError), + #[error("Unable to get the Python model shards info: {0}")] + Info(ClientError), + #[error("Unable to warmup the Python model shards: {0}")] + Warmup(ClientError), + #[error("Tokio runtime failed to start: {0}")] + Tokio(#[from] std::io::Error), + #[error("Axum webserver failed: {0}")] + Axum(#[from] axum::BoxError), +} diff --git a/router/src/server.rs b/router/src/server.rs index 54418f84..8ca463c2 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -527,7 +527,7 @@ pub async fn run( ngrok_domain: Option, ngrok_username: Option, ngrok_password: Option, -) { +) -> Result<(), axum::BoxError> { // OpenAPI documentation #[derive(OpenApi)] #[openapi( @@ -726,8 +726,7 @@ pub async fn run( .serve(app.into_make_service()) //Wait until all requests are finished to shut down .with_graceful_shutdown(shutdown_signal()) - .await - .unwrap(); + .await?; } #[cfg(not(feature = "ngrok"))] { @@ -744,9 +743,9 @@ pub async fn run( .serve(app.into_make_service()) // Wait until all requests are finished to shut down .with_graceful_shutdown(shutdown_signal()) - .await - .unwrap(); + .await?; } + Ok(()) } /// Shutdown signal handler diff --git a/server/text_generation_server/models/custom_modeling/bloom_modeling.py b/server/text_generation_server/models/custom_modeling/bloom_modeling.py index e5e87645..047a1872 100644 --- a/server/text_generation_server/models/custom_modeling/bloom_modeling.py +++ b/server/text_generation_server/models/custom_modeling/bloom_modeling.py @@ -256,6 +256,11 @@ class BloomAttention(nn.Module): self.beta = 1.0 process_group = weights.process_group + if self.num_heads % process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {process_group.size()}" + ) self.num_heads = self.num_heads // process_group.size() self.query_key_value = TensorParallelColumnLinear.load( config=config, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index d224a838..d9f3c7b8 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -112,6 +112,11 @@ class FlashLlamaAttention(torch.nn.Module): self.softmax_scale = self.head_size**-0.5 + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) self.num_heads = self.num_heads // weights.process_group.size() self.query_key_value = TensorParallelColumnLinear.load_multi( config, diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 23c5e4ff..b2dce226 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -95,6 +95,12 @@ class FlashNeoxAttention(torch.nn.Module): self.num_heads = num_heads self.hidden_size = hidden_size self.head_size = hidden_size // num_heads + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) self.num_heads = self.num_heads // weights.process_group.size() self.rotary_emb = PositionRotaryEmbedding.load( diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index cd42bfc2..acac2744 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -118,6 +118,12 @@ class FlashRWAttention(torch.nn.Module): dim=self.head_size, base=10000.0, device=weights.device ) self.softmax_scale = self.head_size ** (-0.5) + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) self.num_heads = self.num_heads // weights.process_group.size() self.query_key_value = TensorParallelColumnLinear.load( diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 855a6e11..5b1c6e21 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -208,7 +208,11 @@ class FlashMQAttention(torch.nn.Module): self.hidden_size = hidden_size self.head_size = hidden_size // num_heads - assert self.num_heads % weights.process_group.size() == 0 + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) self.num_heads = self.num_heads // weights.process_group.size() self.softmax_scale = self.head_size ** (-0.5) diff --git a/server/text_generation_server/models/custom_modeling/mpt_modeling.py b/server/text_generation_server/models/custom_modeling/mpt_modeling.py index 5ea204e1..e6057116 100644 --- a/server/text_generation_server/models/custom_modeling/mpt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mpt_modeling.py @@ -319,6 +319,12 @@ class MultiheadAttention(nn.Module): if self.softmax_scale is None: self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) self.attn_dropout_p = config.attn_config["attn_pdrop"] + + if self.n_heads % weights.process_group.size() != 0: + raise ValueError( + f"`n_heads` must be divisible by `num_shards` (got `n_heads`: {self.n_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) self.n_heads = self.n_heads // weights.process_group.size() self.Wqkv = load_col( config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias diff --git a/server/text_generation_server/models/custom_modeling/neox_modeling.py b/server/text_generation_server/models/custom_modeling/neox_modeling.py index bf2656d1..1951b171 100644 --- a/server/text_generation_server/models/custom_modeling/neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/neox_modeling.py @@ -154,7 +154,12 @@ class GPTNeoXAttention(nn.Module): torch.tensor(self.head_size, dtype=torch.float32) ).to(torch.get_default_dtype()) - assert self.num_attention_heads % weights.process_group.size() == 0 + if self.num_attention_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_attention_heads` must be divisible by `num_shards` " + f"(got `num_attention_heads`: {self.num_attention_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) self.num_attention_heads = ( self.num_attention_heads // weights.process_group.size() ) diff --git a/server/text_generation_server/models/custom_modeling/opt_modeling.py b/server/text_generation_server/models/custom_modeling/opt_modeling.py index 03fded50..aa052b08 100644 --- a/server/text_generation_server/models/custom_modeling/opt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/opt_modeling.py @@ -147,7 +147,11 @@ class OPTAttention(nn.Module): self.is_decoder = is_decoder process_group = weights.process_group - assert self.num_heads % process_group.size() == 0 + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) self.num_heads = self.num_heads // process_group.size() self.embed_dim = self.embed_dim // process_group.size() diff --git a/server/text_generation_server/models/custom_modeling/t5_modeling.py b/server/text_generation_server/models/custom_modeling/t5_modeling.py index 19deca86..1ea82802 100644 --- a/server/text_generation_server/models/custom_modeling/t5_modeling.py +++ b/server/text_generation_server/models/custom_modeling/t5_modeling.py @@ -246,6 +246,11 @@ class T5Attention(nn.Module): self.o = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o", weights=weights, bias=False ) + if self.n_heads % weights.process_group.size() != 0: + raise ValueError( + f"`n_heads` must be divisible by `num_shards` (got `n_heads`: {self.n_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) self.n_heads = self.n_heads // process_group.size() self.inner_dim = self.inner_dim // process_group.size() diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index bebd3df5..5420556b 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -727,12 +727,11 @@ class FlashCausalLM(Model): ) _, batch = self.generate_token(batch) except Exception as e: - logger.exception( + raise RuntimeError( f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} " f"prefill tokens. " f"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`" - ) - raise e + ) from e del batch torch.cuda.empty_cache() From f0181436f41c730aad068029f54a3f86f354442d Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 12 Jul 2023 09:51:34 +0200 Subject: [PATCH 06/13] fix(server): Fixing RW code (it's remote code so the Arch checking doesn't work to see which weights to keep). (#579) Fixes #555 --- server/text_generation_server/models/flash_rw.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index 12b862d7..55d555fc 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -49,7 +49,13 @@ class FlashRWSharded(FlashCausalLM): torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) + weights = Weights( + filenames, + device, + dtype, + process_group=self.process_group, + aliases={"transformer.word_embeddings.weight": ["lm_head.weight"]}, + ) config.quantize = quantize From 5bd2ab65839af4f54ba59e60ef3dd262004de456 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 12 Jul 2023 10:00:02 +0200 Subject: [PATCH 07/13] feat(server): Support for env value for GPTQ_BITS and GPTQ_GROUPSIZE. (#580) # What does this PR do? Some models are already converted, and do not have those values in the file, this enables users to use them with less friction. Went for pure env based because adding flags would end up (imo) very tedious to maintain. There's a lot of sanitation to do: those flags would be errors if not used in conjuction with `--quantize gptq`. Then the flags need to exist in the launcher and the server passing them all throughout all function calls. This PR is intended as an easy escape hatch, not the defacto method to use gptq in TGI. Fixes #500 --- .../custom_modeling/flash_santacoder_modeling.py | 16 ++++++++++++---- server/text_generation_server/utils/weights.py | 15 ++++++++++++--- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 5b1c6e21..a19623a5 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -20,12 +20,12 @@ from text_generation_server.utils.layers import ( FastLayerNorm, get_linear, ) +from safetensors import SafetensorError def load_multi_mqa( config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size ): - if config.quantize == "gptq": return _load_multi_mqa_gptq( config, prefix, weights, bias, head_size, num_heads, hidden_size @@ -74,8 +74,17 @@ def _load_multi_mqa_gptq( qzeros = torch.cat([q_tensor, kv_tensor], dim=1) g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") - bits = weights.get_tensor("gptq_bits").item() - groupsize = weights.get_tensor("gptq_groupsize").item() + try: + bits = weights.get_tensor("gptq_bits").item() + groupsize = weights.get_tensor("gptq_groupsize").item() + except SafetensorError as e: + try: + import os + + bits = int(os.getenv("GPTQ_BITS")) + groupsize = int(os.getenv("GPTQ_GROUPSIZE")) + except Exception: + raise e weight = (qweight, qzeros, scales, g_idx, bits, groupsize) @@ -99,7 +108,6 @@ def _load_multi_mqa_gptq( def _load_multi_mqa( config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size ): - if any("c_attn" in k for k in weights.routing.keys()): slice_ = weights._get_slice(f"{prefix}.c_attn.weight") shape = slice_.get_shape() diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 83d9df68..39f66862 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -1,6 +1,6 @@ from pathlib import Path from typing import List, Dict, Optional -from safetensors import safe_open +from safetensors import safe_open, SafetensorError import torch @@ -120,8 +120,17 @@ class Weights: torch.testing.assert_close(w2, w[0]) g_idx = w[0] - bits = self.get_tensor("gptq_bits").item() - groupsize = self.get_tensor("gptq_groupsize").item() + try: + bits = self.get_tensor("gptq_bits").item() + groupsize = self.get_tensor("gptq_groupsize").item() + except SafetensorError as e: + try: + import os + + bits = int(os.getenv("GTPQ_BITS")) + groupsize = int(os.getenv("GTPQ_GROUPSIZE")) + except Exception: + raise e weight = (qweight, qzeros, scales, g_idx, bits, groupsize) else: w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] From f063ebde103a78ee1fb6eafb66cb462526c0ce44 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 12 Jul 2023 10:01:01 +0200 Subject: [PATCH 08/13] chore: migrate ci region for more availability. (#581) --- .github/workflows/load_test.yaml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/load_test.yaml b/.github/workflows/load_test.yaml index 10e248e3..fd22e395 100644 --- a/.github/workflows/load_test.yaml +++ b/.github/workflows/load_test.yaml @@ -15,11 +15,11 @@ jobs: name: Start self-hosted EC2 runner runs-on: ubuntu-latest env: - AWS_REGION: us-east-1 - EC2_AMI_ID: ami-03cfed9ea28f4b002 + AWS_REGION: eu-central-1 + EC2_AMI_ID: ami-0ab09c07cfd194259 EC2_INSTANCE_TYPE: g5.12xlarge - EC2_SUBNET_ID: subnet-931b34f5,subnet-ecb993cd,subnet-943dc2d8,subnet-45371f1a,subnet-ee93e0df,subnet-fddc3dfc - EC2_SECURITY_GROUP: sg-04d472c808f365022 + EC2_SUBNET_ID: subnet-988fd9f2,subnet-6f56db13,subnet-6a039326 + EC2_SECURITY_GROUP: sg-072f92ae3082936c6 outputs: label: ${{ steps.start-ec2-runner.outputs.label }} ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }} @@ -90,7 +90,7 @@ jobs: - load-tests runs-on: ubuntu-latest env: - AWS_REGION: us-east-1 + AWS_REGION: eu-central-1 if: ${{ always() }} # required to stop the runner even if the error happened in the previous jobs steps: - name: Configure AWS credentials @@ -105,4 +105,4 @@ jobs: mode: stop github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }} label: ${{ needs.start-runner.outputs.label }} - ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }} \ No newline at end of file + ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }} From db4efbf4bcb851a185879099ac01fdc61e34a062 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 12 Jul 2023 10:01:42 +0200 Subject: [PATCH 09/13] fix(server): T5 weights names. (#582) Fixes #541 --- .../models/custom_modeling/t5_modeling.py | 7 +------ server/text_generation_server/models/t5.py | 11 ++++++++++- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/t5_modeling.py b/server/text_generation_server/models/custom_modeling/t5_modeling.py index 1ea82802..5779a274 100644 --- a/server/text_generation_server/models/custom_modeling/t5_modeling.py +++ b/server/text_generation_server/models/custom_modeling/t5_modeling.py @@ -1006,12 +1006,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): super().__init__(config) self.model_dim = config.d_model - try: - self.shared = TensorParallelEmbedding(prefix="shared", weights=weights) - except RuntimeError: - self.shared = TensorParallelEmbedding( - prefix="encoder.embed_tokens", weights=weights - ) + self.shared = TensorParallelEmbedding(prefix="shared", weights=weights) encoder_config = copy.deepcopy(config) encoder_config.is_decoder = False diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 1b7073af..133aafd8 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -55,7 +55,16 @@ class T5Sharded(Seq2SeqLM): torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group + filenames, + device=device, + dtype=dtype, + process_group=self.process_group, + aliases={ + "shared.weight": [ + "encoder.embed_tokens.weight", + "decoder.embed_tokens.weight", + ] + }, ) model = T5ForConditionalGeneration(config, weights) From 7f9072228afa5c6dd9a827d72bba818c3ec088c0 Mon Sep 17 00:00:00 2001 From: Adam Kowalski <49454006+akowalsk@users.noreply.github.com> Date: Wed, 12 Jul 2023 03:40:32 -0500 Subject: [PATCH 10/13] fix(server): Adding logger import to t5_modeling.py (#585) Logger is referenced during the apex importing but is not imported, causing a NameError --- .../models/custom_modeling/t5_modeling.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/server/text_generation_server/models/custom_modeling/t5_modeling.py b/server/text_generation_server/models/custom_modeling/t5_modeling.py index 5779a274..793f3a66 100644 --- a/server/text_generation_server/models/custom_modeling/t5_modeling.py +++ b/server/text_generation_server/models/custom_modeling/t5_modeling.py @@ -19,6 +19,8 @@ import math import warnings from typing import Optional, Tuple, Union +from loguru import logger + import torch import torch.distributed from torch import nn From 2c4bf882689fda969c84dfc86c4158032de10e1b Mon Sep 17 00:00:00 2001 From: ssmi153 <129111316+ssmi153@users.noreply.github.com> Date: Wed, 12 Jul 2023 20:17:35 +0800 Subject: [PATCH 11/13] fix(server): Bug fixes for GPTQ_BITS environment variable passthrough (#590) # What does this PR do? This fixes a typo and extends the GPTP_BITS environment variables through to the second method which requires the same logic. Please let me know if there's anything I've misunderstood in this change. Thanks @Narsil for the original fix. --- server/text_generation_server/utils/weights.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 39f66862..4f300fe7 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -127,8 +127,8 @@ class Weights: try: import os - bits = int(os.getenv("GTPQ_BITS")) - groupsize = int(os.getenv("GTPQ_GROUPSIZE")) + bits = int(os.getenv("GPTQ_BITS")) + groupsize = int(os.getenv("GPTQ_GROUPSIZE")) except Exception: raise e weight = (qweight, qzeros, scales, g_idx, bits, groupsize) @@ -149,8 +149,17 @@ class Weights: scales = self.get_tensor(f"{prefix}.scales") g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) - bits = self.get_tensor("gptq_bits").item() - groupsize = self.get_tensor("gptq_groupsize").item() + try: + bits = self.get_tensor("gptq_bits").item() + groupsize = self.get_tensor("gptq_groupsize").item() + except SafetensorError as e: + try: + import os + + bits = int(os.getenv("GPTQ_BITS")) + groupsize = int(os.getenv("GPTQ_GROUPSIZE")) + except Exception: + raise e weight = (qweight, qzeros, scales, g_idx, bits, groupsize) else: From 67347950b7518efeb64c7f99ee360af685b53934 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 12 Jul 2023 16:43:31 +0200 Subject: [PATCH 12/13] feat(server): Implements sharding for non divisible `vocab_size`. (#583) - The code is relatively easy (just disable the checks on Embedding and Head) This cannot be done in the same easy fashion for hidden_dim/head_dim. It's relatively easy on some models (classic MHA) but it would make the other models (MQA) much more complex, and GPTQ quantization another quite hairy piece of code. --- server/text_generation_server/utils/layers.py | 23 +++++++++++++++---- .../text_generation_server/utils/weights.py | 17 ++++++++++---- 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 8e0362b8..4f65446e 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -174,13 +174,25 @@ class SuperLayer(nn.Module): class TensorParallelHead(SuperLayer): - def __init__(self, linear, process_group): + def __init__(self, linear, process_group, should_gather: bool): super().__init__(linear) self.process_group = process_group + self.should_gather = should_gather @staticmethod def load(config, prefix: str, weights): - weight = weights.get_sharded(f"{prefix}.weight", dim=0) + if weights.process_group.size() > 1: + try: + weight = weights.get_sharded(f"{prefix}.weight", dim=0) + should_gather = True + except AssertionError: + # If the vocab size is not divisible by number of shards + # just load the entire thing. + weight = weights.get_tensor(f"{prefix}.weight") + should_gather = False + else: + weight = weights.get_tensor(f"{prefix}.weight") + should_gather = False # GPTQ doesn't quantize heads (nor embeddings) if config.quantize == "gptq": @@ -190,13 +202,14 @@ class TensorParallelHead(SuperLayer): return TensorParallelHead( get_linear(weight, bias=None, quantize=quantize), process_group=weights.process_group, + should_gather=should_gather, ) def forward(self, input: torch.Tensor) -> torch.Tensor: - world_size = self.process_group.size() - if world_size == 1: + if not self.should_gather: return super().forward(input) + world_size = self.process_group.size() if len(input.shape) == 2 and isinstance(self.linear, FastLinear): out_dim = self.linear.weight.shape[0] @@ -277,7 +290,7 @@ class TensorParallelRowLinear(SuperLayer): class TensorParallelEmbedding(nn.Module): def __init__(self, prefix: str, weights, reduce=True): super().__init__() - weight = weights.get_sharded(f"{prefix}.weight", dim=0) + weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0) num_embeddings = weights.get_shape(f"{prefix}.weight")[0] process_group = weights.process_group diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 4f300fe7..afcbb9c3 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -69,7 +69,7 @@ class Weights: tensor = tensor.to(device=self.device) return tensor - def get_sharded(self, tensor_name: str, dim: int): + def get_partial_sharded(self, tensor_name: str, dim: int): filename, tensor_name = self.get_filename(tensor_name) world_size = self.process_group.size() rank = self.process_group.rank() @@ -81,10 +81,6 @@ class Weights: start = rank * block_size stop = (rank + 1) * block_size - assert ( - size % world_size == 0 - ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" - if dim == 0: tensor = slice_[start:stop] elif dim == 1: @@ -98,6 +94,17 @@ class Weights: tensor = tensor.to(device=self.device) return tensor + def get_sharded(self, tensor_name: str, dim: int): + filename, tensor_name = self.get_filename(tensor_name) + f = self._get_handle(filename) + slice_ = f.get_slice(tensor_name) + world_size = self.process_group.size() + size = slice_.get_shape()[dim] + assert ( + size % world_size == 0 + ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" + return self.get_partial_sharded(tensor_name, dim) + def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): if quantize == "gptq": try: From f2f0289fb99c7caab0c3749fdf211e4d5ab2938b Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 12 Jul 2023 17:05:50 +0200 Subject: [PATCH 13/13] feat(server): empty cache on errors --- server/text_generation_server/interceptor.py | 4 ++++ server/text_generation_server/models/flash_causal_lm.py | 3 --- server/text_generation_server/server.py | 9 +++++++++ 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/server/text_generation_server/interceptor.py b/server/text_generation_server/interceptor.py index a3247d19..725105f3 100644 --- a/server/text_generation_server/interceptor.py +++ b/server/text_generation_server/interceptor.py @@ -1,3 +1,4 @@ +import torch import grpc from google.rpc import status_pb2, code_pb2 @@ -22,6 +23,9 @@ class ExceptionInterceptor(AsyncServerInterceptor): method_name = method_name.split("/")[-1] logger.exception(f"Method {method_name} encountered an error.") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + await context.abort_with_status( rpc_status.to_status( status_pb2.Status(code=code_pb2.INTERNAL, message=str(err)) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 5420556b..4e5804f5 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -639,7 +639,6 @@ class FlashCausalLMBatch(Batch): for b in batches: b.block_tables = None del b - torch.cuda.empty_cache() return FlashCausalLMBatch( batch_id=batches[0].batch_id, @@ -733,7 +732,6 @@ class FlashCausalLM(Model): f"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`" ) from e del batch - torch.cuda.empty_cache() def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str: return self.tokenizer.decode( @@ -790,7 +788,6 @@ class FlashCausalLM(Model): ) except Exception as e: del batch - torch.cuda.empty_cache() raise e if prefill: diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index c375330a..7bc62ce6 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -51,6 +51,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): filtered_batch = batch.filter(request.request_ids) self.cache.set(filtered_batch) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) async def Warmup(self, request, context): @@ -58,6 +61,10 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): request.batch, self.model.tokenizer, self.model.dtype, self.model.device ) self.model.warmup(batch, request.max_total_tokens) + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + return generate_pb2.WarmupResponse() async def Prefill(self, request, context): @@ -89,6 +96,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): if len(batches) > 1: batch = self.model.batch_type.concatenate(batches) + if torch.cuda.is_available(): + torch.cuda.empty_cache() else: batch = batches[0]