From d69a0633bee6f8a665a1f7d258fceaa4475c102f Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Thu, 1 Jun 2023 11:41:35 +0200 Subject: [PATCH 01/22] fix(server): fix has_position_ids (#395) Fix #389 --- server/text_generation_server/models/causal_lm.py | 5 ----- server/text_generation_server/models/model.py | 7 +++++++ 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index a20a6143..92622350 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -496,11 +496,6 @@ class CausalLM(Model): else: tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - self.has_position_ids = ( - inspect.signature(model.forward).parameters.get("position_ids", None) - is not None - ) - super(CausalLM, self).__init__( model=model, tokenizer=tokenizer, diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 29bad321..6b8472a5 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -1,3 +1,4 @@ +import inspect import torch from abc import ABC, abstractmethod @@ -29,6 +30,12 @@ class Model(ABC): self.device = device self.rank = rank self.world_size = world_size + + self.has_position_ids = ( + inspect.signature(model.forward).parameters.get("position_ids", None) + is not None + ) + self.check_initialized() @property From c0928e6f26a537add41a237303fcabb8ac317a88 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Thu, 1 Jun 2023 12:07:41 +0200 Subject: [PATCH 02/22] feat(server): remove trust_remote_code requirement for falcon models (#396) --- .../text_generation_server/models/__init__.py | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 4adf1381..78b68721 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -1,7 +1,7 @@ import torch from loguru import logger -from transformers import AutoConfig +from transformers.configuration_utils import PretrainedConfig from transformers.models.auto import modeling_auto from typing import Optional @@ -138,10 +138,8 @@ def get_model( trust_remote_code=trust_remote_code, ) - config = AutoConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - model_type = config.model_type + config_dict, _ = PretrainedConfig.get_config_dict(model_id, revision=revision, trust_remote_code=trust_remote_code) + model_type = config_dict["model_type"] if model_type == "gpt_bigcode": if sharded: @@ -201,9 +199,9 @@ def get_model( if model_type in ["RefinedWeb", "RefinedWebModel"]: if sharded: if FLASH_ATTENTION: - if config.alibi or ( - config.model_type == "RefinedWebModel" - and config.n_head_kv != config.n_head + if config_dict.get("alibi", False) or ( + model_type == "RefinedWebModel" + and config_dict.get("multi_query", True) ): raise NotImplementedError("sharded is not supported for this model") return FlashRWSharded( @@ -216,7 +214,7 @@ def get_model( FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb") ) else: - if FLASH_ATTENTION and not config.alibi: + if FLASH_ATTENTION and not config_dict.get("alibi", False): return FlashRW( model_id, revision, @@ -250,7 +248,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - if config.model_type == "opt": + if model_type == "opt": if sharded: return OPTSharded( model_id, @@ -294,7 +292,7 @@ def get_model( model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code ) - auto_map = getattr(config, "auto_map", None) + auto_map = config_dict.get("auto_map", None) if trust_remote_code and auto_map is not None: if "AutoModelForCausalLM" in auto_map.keys(): return CausalLM( From 95d3546976e811ba462047ef8b555d9840efbfe5 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Thu, 1 Jun 2023 12:10:35 +0200 Subject: [PATCH 03/22] feat(server): load santacoder/starcoder models with safetensors (#393) Fix #366 --- launcher/src/main.rs | 16 +- .../models/flash_santacoder.py | 166 ++++++++++-------- 2 files changed, 91 insertions(+), 91 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 0810d979..7ee8bf1b 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -546,11 +546,7 @@ enum LauncherError { WebserverCannotStart, } -fn download_convert_model( - args: &Args, - auto_convert: bool, - running: Arc, -) -> Result<(), LauncherError> { +fn download_convert_model(args: &Args, running: Arc) -> Result<(), LauncherError> { let mut download_argv = vec![ "text-generation-server".to_string(), "download-weights".to_string(), @@ -562,11 +558,6 @@ fn download_convert_model( "--json-output".to_string(), ]; - // Auto convert weights to safetensors - if auto_convert { - download_argv.push("--auto-convert".to_string()); - } - // Model optional revision if let Some(revision) = &args.revision { download_argv.push("--revision".to_string()); @@ -932,11 +923,8 @@ fn main() -> Result<(), LauncherError> { }) .expect("Error setting Ctrl-C handler"); - // auto_convert is only needed for sharded models as we do not require safetensors in - // single shard mode - let auto_convert = num_shard > 1; // Download and convert model weights - download_convert_model(&args, auto_convert, running.clone())?; + download_convert_model(&args, running.clone())?; // Shared shutdown bool let shutdown = Arc::new(Mutex::new(false)); diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 482e0f54..7907e2cc 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -54,12 +54,7 @@ class FlashSantacoder(FlashCausalLM): ) # We do not use from_pretrained as we modified the model internal module layout - try: - filenames = weight_files(model_id, revision, ".bin") - # Local files not found - except LocalEntryNotFoundError: - hub_files = weight_hub_files(model_id, revision, ".bin") - filenames = download_weights(hub_files, model_id, revision) + filenames = weight_files(model_id, revision, ".safetensors") with init_empty_weights(): model = FlashSantacoderForCausalLM(config) @@ -91,85 +86,100 @@ class FlashSantacoder(FlashCausalLM): transpose: bool, ): for filename in filenames: - state_dict = torch.load(filename, map_location="cpu") - for key, value in state_dict.items(): - value = value.to(device if quantize is None else "cpu").to(dtype) + with safe_open( + filename, framework="pt", device=str(device) if quantize is None else "cpu" + ) as f: + for key in f.keys(): + value = f.get_tensor(key) + value = value.to(device if quantize is None else "cpu").to(dtype) - layer_name = ".".join(key.split(".")[:4]) + layer_name = ".".join(key.split(".")[:4]) - # Fused qkv - if "q_attn.weight" in key or "kv_attn.weight" in key: - final_key = layer_name + ".c_attn.weight" - elif "q_attn.bias" in key or "kv_attn.bias" in key: - final_key = layer_name + ".c_attn.bias" + # Fused qkv + if "q_attn.weight" in key or "kv_attn.weight" in key: + final_key = layer_name + ".c_attn.weight" + elif "q_attn.bias" in key or "kv_attn.bias" in key: + final_key = layer_name + ".c_attn.bias" - else: - final_key = key - - module_name, param_name = final_key.rsplit(".", 1) - module = model.get_submodule(module_name) - - try: - current_parameter_tensor = module._parameters[param_name] - except KeyError: - current_parameter_tensor = None - - if current_parameter_tensor is not None: - if transpose and ( - "c_fc.weight" in key - or "c_proj.weight" in key - or "q_attn.weight" in key - or "kv_attn.weight" in key - or "c_attn.weight" in key - ): - # Tranpose as we use nn.Linear instead of Conv1D - value = value.T - - if current_parameter_tensor.device == torch.device("meta"): - # Init qkv - if "c_attn.weight" in final_key: - module._parameters[param_name] = value.new_empty( - ( - model.transformer.head_size - * (model.transformer.num_heads + 2), - value.shape[1], - ) - ) - elif "c_attn.bias" in final_key: - module._parameters[param_name] = value.new_empty( - ( - model.transformer.head_size - * (model.transformer.num_heads + 2) - ) - ) - - # Copy to correct slice - if "q_attn.weight" in key: - module._parameters[param_name][: value.shape[0]] = value - elif "q_attn.bias" in key: - module._parameters[param_name][: value.shape[0]] = value - elif "kv_attn.weight" in key: - module._parameters[param_name][ - model.transformer.head_size * model.transformer.num_heads : - ] = value - elif "kv_attn.bias" in key: - module._parameters[param_name][ - model.transformer.head_size * model.transformer.num_heads : - ] = value else: - if current_parameter_tensor.shape != value.shape: - raise ValueError( - f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}" - ) - module._parameters[param_name] = value - else: - module._buffers[param_name] = value + final_key = key - del value + module_name, param_name = final_key.rsplit(".", 1) + module = model.get_submodule(module_name) + + try: + current_parameter_tensor = module._parameters[param_name] + except KeyError: + current_parameter_tensor = None + + if current_parameter_tensor is not None: + if transpose and ( + "c_fc.weight" in key + or "c_proj.weight" in key + or "q_attn.weight" in key + or "kv_attn.weight" in key + or "c_attn.weight" in key + ): + # Tranpose as we use nn.Linear instead of Conv1D + value = value.T + + if current_parameter_tensor.device == torch.device("meta"): + # Init qkv + if "c_attn.weight" in final_key: + module._parameters[param_name] = value.new_empty( + ( + model.transformer.head_size + * (model.transformer.num_heads + 2), + value.shape[1], + ) + ) + elif "c_attn.bias" in final_key: + module._parameters[param_name] = value.new_empty( + ( + model.transformer.head_size + * (model.transformer.num_heads + 2) + ) + ) + + # Copy to correct slice + if "q_attn.weight" in key: + module._parameters[param_name][: value.shape[0]] = value + elif "q_attn.bias" in key: + module._parameters[param_name][: value.shape[0]] = value + elif "kv_attn.weight" in key: + module._parameters[param_name][ + model.transformer.head_size * model.transformer.num_heads : + ] = value + elif "kv_attn.bias" in key: + module._parameters[param_name][ + model.transformer.head_size * model.transformer.num_heads : + ] = value + else: + if current_parameter_tensor.shape != value.shape: + raise ValueError( + f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}" + ) + module._parameters[param_name] = value + else: + module._buffers[param_name] = value + + del value + + if model.lm_head.weight.device == torch.device("meta"): + model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight) torch.cuda.empty_cache() model.post_load_weights(quantize) + uninitialized_parameters = [] + for n, p in model.named_parameters(): + if p.data.device == torch.device("meta"): + uninitialized_parameters.append(n) + if uninitialized_parameters: + raise RuntimeError( + f"found uninitialized parameters in model : {uninitialized_parameters}" + ) + def decode(self, generated_ids: List[int]) -> str: # Do not skip special tokens as they are used for custom parsing rules of the generated text return self.tokenizer.decode( @@ -389,6 +399,8 @@ class FlashSantacoderSharded(FlashSantacoder): else: module._buffers[param_name] = tensor - model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight) + if model.lm_head.weight.device == torch.device("meta"): + model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight) + torch.cuda.empty_cache() model.post_load_weights(quantize) From e7248fe90e27c7c8e39dd4cac5874eb9f96ab182 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 1 Jun 2023 19:49:13 +0200 Subject: [PATCH 04/22] v0.8.2 --- Cargo.lock | 8 ++++---- Cargo.toml | 2 +- docs/openapi.json | 2 +- server/pyproject.toml | 2 +- server/text_generation_server/models/__init__.py | 4 +++- .../text_generation_server/models/flash_santacoder.py | 10 +++++++--- 6 files changed, 17 insertions(+), 11 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 67ad8674..bd5994a8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2557,7 +2557,7 @@ dependencies = [ [[package]] name = "text-generation-benchmark" -version = "0.8.1" +version = "0.8.2" dependencies = [ "average", "clap", @@ -2577,7 +2577,7 @@ dependencies = [ [[package]] name = "text-generation-client" -version = "0.8.1" +version = "0.8.2" dependencies = [ "futures", "grpc-metadata", @@ -2593,7 +2593,7 @@ dependencies = [ [[package]] name = "text-generation-launcher" -version = "0.8.1" +version = "0.8.2" dependencies = [ "clap", "ctrlc", @@ -2609,7 +2609,7 @@ dependencies = [ [[package]] name = "text-generation-router" -version = "0.8.1" +version = "0.8.2" dependencies = [ "async-stream", "axum", diff --git a/Cargo.toml b/Cargo.toml index 3190b64c..b28286fa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ members = [ ] [workspace.package] -version = "0.8.1" +version = "0.8.2" edition = "2021" authors = ["Olivier Dehaene"] homepage = "https://github.com/huggingface/text-generation-inference" diff --git a/docs/openapi.json b/docs/openapi.json index 040c8e8b..e5ef0e3c 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -10,7 +10,7 @@ "name": "Apache 2.0", "url": "https://www.apache.org/licenses/LICENSE-2.0" }, - "version": "0.8.1" + "version": "0.8.2" }, "paths": { "/": { diff --git a/server/pyproject.toml b/server/pyproject.toml index 7400a055..d381eac4 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "text-generation-server" -version = "0.8.1" +version = "0.8.2" description = "Text Generation Inference Python gRPC Server" authors = ["Olivier Dehaene "] diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 78b68721..fc92d03d 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -138,7 +138,9 @@ def get_model( trust_remote_code=trust_remote_code, ) - config_dict, _ = PretrainedConfig.get_config_dict(model_id, revision=revision, trust_remote_code=trust_remote_code) + config_dict, _ = PretrainedConfig.get_config_dict( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) model_type = config_dict["model_type"] if model_type == "gpt_bigcode": diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 7907e2cc..e1c893d0 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -87,7 +87,9 @@ class FlashSantacoder(FlashCausalLM): ): for filename in filenames: with safe_open( - filename, framework="pt", device=str(device) if quantize is None else "cpu" + filename, + framework="pt", + device=str(device) if quantize is None else "cpu", ) as f: for key in f.keys(): value = f.get_tensor(key) @@ -148,11 +150,13 @@ class FlashSantacoder(FlashCausalLM): module._parameters[param_name][: value.shape[0]] = value elif "kv_attn.weight" in key: module._parameters[param_name][ - model.transformer.head_size * model.transformer.num_heads : + model.transformer.head_size + * model.transformer.num_heads : ] = value elif "kv_attn.bias" in key: module._parameters[param_name][ - model.transformer.head_size * model.transformer.num_heads : + model.transformer.head_size + * model.transformer.num_heads : ] = value else: if current_parameter_tensor.shape != value.shape: From 62fc40103079bc27e97194ef69e9e34a180b0a85 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Fri, 2 Jun 2023 09:51:06 +0200 Subject: [PATCH 05/22] feat(sagemaker): add trust remote code to entrypoint (#394) --- sagemaker-entrypoint.sh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sagemaker-entrypoint.sh b/sagemaker-entrypoint.sh index 1340f55a..9ac47010 100755 --- a/sagemaker-entrypoint.sh +++ b/sagemaker-entrypoint.sh @@ -18,4 +18,8 @@ if [[ -n "${HF_MODEL_QUANTIZE}" ]]; then export QUANTIZE="${HF_MODEL_QUANTIZE}" fi +if [[ -n "${HF_MODEL_TRUST_REMOTE_CODE}" ]]; then + export TRUST_REMOTE_CODE="${HF_MODEL_TRUST_REMOTE_CODE}" +fi + text-generation-launcher --port 8080 From 83b84486ad9e35bff1d5f2d166db296b767542c7 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Fri, 2 Jun 2023 14:17:27 +0200 Subject: [PATCH 06/22] feat(launcher): parse oom signal (#404) --- launcher/src/main.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 7ee8bf1b..f59ff685 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -410,9 +410,14 @@ fn shard_manager( let mut wait_time = Instant::now(); loop { // Process exited - if p.poll().is_some() { + if let Some(exit_status) = p.poll() { let mut err = String::new(); p.stderr.take().unwrap().read_to_string(&mut err).unwrap(); + + if let ExitStatus::Signaled(signal) = exit_status { + tracing::error!("Shard process was signaled to shutdown with signal {signal}"); + } + status_sender .send(ShardStatus::Failed((rank, err))) .unwrap(); From 895c5f15628df870f7a2ced7151dedb84231a996 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Fri, 2 Jun 2023 17:12:30 +0200 Subject: [PATCH 07/22] feat(server): only compute prefill logprobs when asked (#406) Close #288 --- Makefile | 1 + benchmark/src/generation.rs | 1 + clients/python/README.md | 46 +++++- clients/python/pyproject.toml | 2 +- clients/python/tests/test_client.py | 22 +-- clients/python/text_generation/client.py | 10 ++ clients/python/text_generation/types.py | 14 +- integration-tests/conftest.py | 9 +- integration-tests/models/test_bloom_560m.py | 2 + .../models/test_bloom_560m_sharded.py | 1 + integration-tests/models/test_flash_falcon.py | 2 + integration-tests/models/test_flash_llama.py | 5 +- integration-tests/models/test_flash_neox.py | 1 + .../models/test_flash_neox_sharded.py | 1 + .../models/test_flash_santacoder.py | 4 +- .../models/test_flash_starcoder.py | 11 +- integration-tests/models/test_mt0_base.py | 2 + integration-tests/models/test_t5_sharded.py | 1 + integration-tests/requirements.txt | 2 +- proto/generate.proto | 2 + router/src/health.rs | 1 + router/src/lib.rs | 4 + router/src/queue.rs | 2 + router/src/server.rs | 19 ++- router/src/validation.rs | 5 + server/tests/models/test_bloom.py | 1 + server/tests/models/test_causal_lm.py | 1 + server/tests/models/test_santacoder.py | 2 + server/tests/models/test_seq2seq_lm.py | 1 + .../models/causal_lm.py | 4 +- .../custom_modeling/flash_llama_modeling.py | 3 + .../custom_modeling/flash_neox_modeling.py | 3 + .../custom_modeling/flash_rw_modeling.py | 3 + .../flash_santacoder_modeling.py | 3 + .../models/flash_causal_lm.py | 132 +++++++++++++----- .../models/seq2seq_lm.py | 2 +- 36 files changed, 252 insertions(+), 73 deletions(-) diff --git a/Makefile b/Makefile index 7309aaee..a33aba17 100644 --- a/Makefile +++ b/Makefile @@ -3,6 +3,7 @@ install-server: install-integration-tests: cd integration-tests && pip install -r requirements.txt + cd clients/python && pip install . install-router: cd router && cargo install --path . diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index 17c72d26..b57c652b 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -136,6 +136,7 @@ async fn prefill( let requests = (0..batch_size) .map(|id| Request { id: id.into(), + prefill_logprobs: false, inputs: sequence.clone(), truncate: sequence_length, parameters: Some(parameters.clone()), diff --git a/clients/python/README.md b/clients/python/README.md index 99ff185a..4e0e564c 100644 --- a/clients/python/README.md +++ b/clients/python/README.md @@ -107,8 +107,42 @@ print(text) ### Types ```python -# Prompt tokens -class PrefillToken: +# Request Parameters +class Parameters: + # Activate logits sampling + do_sample: bool + # Maximum number of generated tokens + max_new_tokens: int + # The parameter for repetition penalty. 1.0 means no penalty. + # See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + repetition_penalty: Optional[float] + # Whether to prepend the prompt to the generated text + return_full_text: bool + # Stop generating tokens if a member of `stop_sequences` is generated + stop: List[str] + # Random sampling seed + seed: Optional[int] + # The value used to module the logits distribution. + temperature: Optional[float] + # The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_k: Optional[int] + # If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + # higher are kept for generation. + top_p: Optional[float] + # truncate inputs tokens to the given size + truncate: Optional[int] + # Typical Decoding mass + # See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information + typical_p: Optional[float] + # Generate best_of sequences and return the one if the highest token logprobs + best_of: Optional[int] + # Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) + watermark: bool + # Get decoder input token logprobs and ids + decoder_input_details: bool + +# Decoder input tokens +class InputToken: # Token ID from the model tokenizer id: int # Token text @@ -151,8 +185,8 @@ class BestOfSequence: generated_tokens: int # Sampling seed if sampling was activated seed: Optional[int] - # Prompt tokens - prefill: List[PrefillToken] + # Decoder input tokens, empty if decoder_input_details is False + prefill: List[InputToken] # Generated tokens tokens: List[Token] @@ -165,8 +199,8 @@ class Details: generated_tokens: int # Sampling seed if sampling was activated seed: Optional[int] - # Prompt tokens - prefill: List[PrefillToken] + # Decoder input tokens, empty if decoder_input_details is False + prefill: List[InputToken] # Generated tokens tokens: List[Token] # Additional sequences when using the `best_of` parameter diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 06d5f9cb..a52bdd81 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "text-generation" -version = "0.5.2" +version = "0.6.0" description = "Hugging Face Text Generation Python Client" license = "Apache-2.0" authors = ["Olivier Dehaene "] diff --git a/clients/python/tests/test_client.py b/clients/python/tests/test_client.py index 32462f14..1e25e1b1 100644 --- a/clients/python/tests/test_client.py +++ b/clients/python/tests/test_client.py @@ -2,28 +2,30 @@ import pytest from text_generation import Client, AsyncClient from text_generation.errors import NotFoundError, ValidationError -from text_generation.types import FinishReason, PrefillToken, Token +from text_generation.types import FinishReason, InputToken def test_generate(flan_t5_xxl_url, hf_headers): client = Client(flan_t5_xxl_url, hf_headers) - response = client.generate("test", max_new_tokens=1) + response = client.generate("test", max_new_tokens=1, decoder_input_details=True) assert response.generated_text == "" assert response.details.finish_reason == FinishReason.Length assert response.details.generated_tokens == 1 assert response.details.seed is None assert len(response.details.prefill) == 1 - assert response.details.prefill[0] == PrefillToken(id=0, text="", logprob=None) + assert response.details.prefill[0] == InputToken(id=0, text="", logprob=None) assert len(response.details.tokens) == 1 assert response.details.tokens[0].id == 3 - assert response.details.tokens[0].text == "" + assert response.details.tokens[0].text == " " assert not response.details.tokens[0].special def test_generate_best_of(flan_t5_xxl_url, hf_headers): client = Client(flan_t5_xxl_url, hf_headers) - response = client.generate("test", max_new_tokens=1, best_of=2, do_sample=True) + response = client.generate( + "test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True + ) assert response.details.seed is not None assert response.details.best_of_sequences is not None @@ -73,17 +75,19 @@ def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers): @pytest.mark.asyncio async def test_generate_async(flan_t5_xxl_url, hf_headers): client = AsyncClient(flan_t5_xxl_url, hf_headers) - response = await client.generate("test", max_new_tokens=1) + response = await client.generate( + "test", max_new_tokens=1, decoder_input_details=True + ) assert response.generated_text == "" assert response.details.finish_reason == FinishReason.Length assert response.details.generated_tokens == 1 assert response.details.seed is None assert len(response.details.prefill) == 1 - assert response.details.prefill[0] == PrefillToken(id=0, text="", logprob=None) + assert response.details.prefill[0] == InputToken(id=0, text="", logprob=None) assert len(response.details.tokens) == 1 assert response.details.tokens[0].id == 3 - assert response.details.tokens[0].text == "" + assert response.details.tokens[0].text == " " assert not response.details.tokens[0].special @@ -91,7 +95,7 @@ async def test_generate_async(flan_t5_xxl_url, hf_headers): async def test_generate_async_best_of(flan_t5_xxl_url, hf_headers): client = AsyncClient(flan_t5_xxl_url, hf_headers) response = await client.generate( - "test", max_new_tokens=1, best_of=2, do_sample=True + "test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True ) assert response.details.seed is not None diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index 8b8742fc..bf045d47 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -74,6 +74,7 @@ class Client: truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: bool = False, + decoder_input_details: bool = False, ) -> Response: """ Given a prompt, generate the following text @@ -110,6 +111,8 @@ class Client: See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) + decoder_input_details (`bool`): + Return the decoder input token logprobs and ids Returns: Response: generated response @@ -130,6 +133,7 @@ class Client: truncate=truncate, typical_p=typical_p, watermark=watermark, + decoder_input_details=decoder_input_details, ) request = Request(inputs=prompt, stream=False, parameters=parameters) @@ -202,6 +206,7 @@ class Client: parameters = Parameters( best_of=None, details=True, + decoder_input_details=False, do_sample=do_sample, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, @@ -311,6 +316,7 @@ class AsyncClient: truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: bool = False, + decoder_input_details: bool = False, ) -> Response: """ Given a prompt, generate the following text asynchronously @@ -347,6 +353,8 @@ class AsyncClient: See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) + decoder_input_details (`bool`): + Return the decoder input token logprobs and ids Returns: Response: generated response @@ -355,6 +363,7 @@ class AsyncClient: parameters = Parameters( best_of=best_of, details=True, + decoder_input_details=decoder_input_details, do_sample=do_sample, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, @@ -437,6 +446,7 @@ class AsyncClient: parameters = Parameters( best_of=None, details=True, + decoder_input_details=False, do_sample=do_sample, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index ad3cd09b..548f0b63 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -37,6 +37,8 @@ class Parameters(BaseModel): watermark: bool = False # Get generation details details: bool = False + # Get decoder input token logprobs and ids + decoder_input_details: bool = False @validator("best_of") def valid_best_of(cls, field_value, values): @@ -129,8 +131,8 @@ class Request(BaseModel): return field_value -# Prompt tokens -class PrefillToken(BaseModel): +# Decoder input tokens +class InputToken(BaseModel): # Token ID from the model tokenizer id: int # Token text @@ -173,8 +175,8 @@ class BestOfSequence(BaseModel): generated_tokens: int # Sampling seed if sampling was activated seed: Optional[int] - # Prompt tokens - prefill: List[PrefillToken] + # Decoder input tokens, empty if decoder_input_details is False + prefill: List[InputToken] # Generated tokens tokens: List[Token] @@ -187,8 +189,8 @@ class Details(BaseModel): generated_tokens: int # Sampling seed if sampling was activated seed: Optional[int] - # Prompt tokens - prefill: List[PrefillToken] + # Decoder input tokens, empty if decoder_input_details is False + prefill: List[InputToken] # Generated tokens tokens: List[Token] # Additional sequences when using the `best_of` parameter diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 902a7158..82f1b719 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -16,7 +16,7 @@ from syrupy.extensions.json import JSONSnapshotExtension from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError from text_generation import AsyncClient -from text_generation.types import Response, Details, PrefillToken, Token, BestOfSequence +from text_generation.types import Response, Details, InputToken, Token, BestOfSequence DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None) HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None) @@ -62,7 +62,7 @@ class ResponseComparator(JSONSnapshotExtension): and token.special == other.special ) - def eq_prefill_token(prefill_token: PrefillToken, other: PrefillToken) -> bool: + def eq_prefill_token(prefill_token: InputToken, other: InputToken) -> bool: try: return ( prefill_token.id == other.id @@ -332,7 +332,10 @@ def generate_load(): client: AsyncClient, prompt: str, max_new_tokens: int, n: int ) -> List[Response]: futures = [ - client.generate(prompt, max_new_tokens=max_new_tokens) for _ in range(n) + client.generate( + prompt, max_new_tokens=max_new_tokens, decoder_input_details=True + ) + for _ in range(n) ] return await asyncio.gather(*futures) diff --git a/integration-tests/models/test_bloom_560m.py b/integration-tests/models/test_bloom_560m.py index 809250cb..bdcbdc78 100644 --- a/integration-tests/models/test_bloom_560m.py +++ b/integration-tests/models/test_bloom_560m.py @@ -19,6 +19,7 @@ async def test_bloom_560m(bloom_560, response_snapshot): "Pour déguster un ortolan, il faut tout d'abord", max_new_tokens=10, top_p=0.9, + decoder_input_details=True, seed=0, ) @@ -40,6 +41,7 @@ async def test_bloom_560m_all_params(bloom_560, response_snapshot): truncate=5, typical_p=0.9, watermark=True, + decoder_input_details=True, seed=0, ) diff --git a/integration-tests/models/test_bloom_560m_sharded.py b/integration-tests/models/test_bloom_560m_sharded.py index ee67250a..3995f9e5 100644 --- a/integration-tests/models/test_bloom_560m_sharded.py +++ b/integration-tests/models/test_bloom_560m_sharded.py @@ -19,6 +19,7 @@ async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot): "Pour déguster un ortolan, il faut tout d'abord", max_new_tokens=10, top_p=0.9, + decoder_input_details=True, seed=0, ) diff --git a/integration-tests/models/test_flash_falcon.py b/integration-tests/models/test_flash_falcon.py index e36a6a28..eac91984 100644 --- a/integration-tests/models/test_flash_falcon.py +++ b/integration-tests/models/test_flash_falcon.py @@ -19,6 +19,7 @@ async def test_flash_falcon(flash_falcon, response_snapshot): response = await flash_falcon.generate( "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:", max_new_tokens=10, + decoder_input_details=True, ) assert response.details.generated_tokens == 10 @@ -40,6 +41,7 @@ async def test_flash_falcon_all_params(flash_falcon, response_snapshot): truncate=5, typical_p=0.9, watermark=True, + decoder_input_details=True, seed=0, ) diff --git a/integration-tests/models/test_flash_llama.py b/integration-tests/models/test_flash_llama.py index edc847c1..c69314ff 100644 --- a/integration-tests/models/test_flash_llama.py +++ b/integration-tests/models/test_flash_llama.py @@ -16,7 +16,9 @@ async def flash_llama(flash_llama_handle): @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama(flash_llama, response_snapshot): - response = await flash_llama.generate("Test request", max_new_tokens=10) + response = await flash_llama.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) assert response.details.generated_tokens == 10 assert response == response_snapshot @@ -37,6 +39,7 @@ async def test_flash_llama_all_params(flash_llama, response_snapshot): truncate=5, typical_p=0.9, watermark=True, + decoder_input_details=True, seed=0, ) diff --git a/integration-tests/models/test_flash_neox.py b/integration-tests/models/test_flash_neox.py index daff7f0a..ff9b9763 100644 --- a/integration-tests/models/test_flash_neox.py +++ b/integration-tests/models/test_flash_neox.py @@ -18,6 +18,7 @@ async def test_flash_neox(flash_neox, response_snapshot): response = await flash_neox.generate( "<|USER|>What's your mood today?<|ASSISTANT|>", max_new_tokens=10, + decoder_input_details=True, ) assert response.details.generated_tokens == 10 diff --git a/integration-tests/models/test_flash_neox_sharded.py b/integration-tests/models/test_flash_neox_sharded.py index a1aa0f07..8a491915 100644 --- a/integration-tests/models/test_flash_neox_sharded.py +++ b/integration-tests/models/test_flash_neox_sharded.py @@ -18,6 +18,7 @@ async def test_flash_neox(flash_neox_sharded, response_snapshot): response = await flash_neox_sharded.generate( "<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>", max_new_tokens=10, + decoder_input_details=True, ) assert response.details.generated_tokens == 10 diff --git a/integration-tests/models/test_flash_santacoder.py b/integration-tests/models/test_flash_santacoder.py index a15a6439..0f005f15 100644 --- a/integration-tests/models/test_flash_santacoder.py +++ b/integration-tests/models/test_flash_santacoder.py @@ -15,7 +15,9 @@ async def flash_santacoder(flash_santacoder_handle): @pytest.mark.asyncio async def test_flash_santacoder(flash_santacoder, response_snapshot): - response = await flash_santacoder.generate("def print_hello", max_new_tokens=10) + response = await flash_santacoder.generate( + "def print_hello", max_new_tokens=10, decoder_input_details=True + ) assert response.details.generated_tokens == 10 assert response == response_snapshot diff --git a/integration-tests/models/test_flash_starcoder.py b/integration-tests/models/test_flash_starcoder.py index 72b298c9..64e8b27c 100644 --- a/integration-tests/models/test_flash_starcoder.py +++ b/integration-tests/models/test_flash_starcoder.py @@ -16,7 +16,9 @@ async def flash_starcoder(flash_starcoder_handle): @pytest.mark.asyncio @pytest.mark.private async def test_flash_starcoder(flash_starcoder, response_snapshot): - response = await flash_starcoder.generate("def print_hello", max_new_tokens=10) + response = await flash_starcoder.generate( + "def print_hello", max_new_tokens=10, decoder_input_details=True + ) assert response.details.generated_tokens == 10 assert response == response_snapshot @@ -26,7 +28,12 @@ async def test_flash_starcoder(flash_starcoder, response_snapshot): @pytest.mark.private async def test_flash_starcoder_default_params(flash_starcoder, response_snapshot): response = await flash_starcoder.generate( - "def print_hello", max_new_tokens=60, temperature=0.2, top_p=0.95, seed=0 + "def print_hello", + max_new_tokens=60, + temperature=0.2, + top_p=0.95, + decoder_input_details=True, + seed=0, ) assert response.details.generated_tokens == 60 diff --git a/integration-tests/models/test_mt0_base.py b/integration-tests/models/test_mt0_base.py index 4ed95aad..12f23e4c 100644 --- a/integration-tests/models/test_mt0_base.py +++ b/integration-tests/models/test_mt0_base.py @@ -19,6 +19,7 @@ async def test_mt0_base(mt0_base, response_snapshot): "Why is the sky blue?", max_new_tokens=10, top_p=0.9, + decoder_input_details=True, seed=0, ) @@ -40,6 +41,7 @@ async def test_mt0_base_all_params(mt0_base, response_snapshot): truncate=5, typical_p=0.9, watermark=True, + decoder_input_details=True, seed=0, ) diff --git a/integration-tests/models/test_t5_sharded.py b/integration-tests/models/test_t5_sharded.py index a2d84330..7c288b23 100644 --- a/integration-tests/models/test_t5_sharded.py +++ b/integration-tests/models/test_t5_sharded.py @@ -18,6 +18,7 @@ async def test_t5_sharded(t5_sharded, response_snapshot): response = await t5_sharded.generate( "Please answer the following question. What is the boiling point of Nitrogen?", max_new_tokens=10, + decoder_input_details=True, ) assert response == response_snapshot diff --git a/integration-tests/requirements.txt b/integration-tests/requirements.txt index 051730ff..2f36d5d6 100644 --- a/integration-tests/requirements.txt +++ b/integration-tests/requirements.txt @@ -1,5 +1,5 @@ syrupy -text-generation==0.5.2 +text-generation pytest pytest-asyncio==0.17.2 docker \ No newline at end of file diff --git a/proto/generate.proto b/proto/generate.proto index 0c40e5bb..a0f5a75e 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -87,6 +87,8 @@ message Request { NextTokenChooserParameters parameters = 4; /// Stopping Criteria Parameters StoppingCriteriaParameters stopping_parameters = 5; + /// Return prefill logprobs + bool prefill_logprobs = 6; } message Batch { diff --git a/router/src/health.rs b/router/src/health.rs index 45f50e9d..a3cacdcd 100644 --- a/router/src/health.rs +++ b/router/src/health.rs @@ -34,6 +34,7 @@ impl Health { id: LIVENESS_ID, inputs: "liveness".to_string(), truncate: 10, + prefill_logprobs: false, parameters: Some(NextTokenChooserParameters { temperature: 1.0, top_k: 0, diff --git a/router/src/lib.rs b/router/src/lib.rs index 080dc4f4..67fff017 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -125,6 +125,9 @@ pub(crate) struct GenerateParameters { #[schema(default = "true")] pub details: bool, #[serde(default)] + #[schema(default = "true")] + pub decoder_input_details: bool, + #[serde(default)] #[schema( exclusive_minimum = 0, nullable = true, @@ -153,6 +156,7 @@ fn default_parameters() -> GenerateParameters { truncate: None, watermark: false, details: false, + decoder_input_details: false, seed: None, } } diff --git a/router/src/queue.rs b/router/src/queue.rs index 94851e1c..03807933 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -201,6 +201,7 @@ impl State { batch_requests.push(Request { id, + prefill_logprobs: entry.request.decoder_input_details, inputs: entry.request.inputs.clone(), truncate: entry.request.truncate, parameters: Some(entry.request.parameters.clone()), @@ -281,6 +282,7 @@ mod tests { inputs: "".to_string(), input_length: 0, truncate: 0, + decoder_input_details: false, parameters: NextTokenChooserParameters { temperature: 0.0, top_k: 0, diff --git a/router/src/server.rs b/router/src/server.rs index fd6a66bb..10c0ba3c 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -160,7 +160,7 @@ async fn generate( add_prompt = Some(req.0.inputs.clone()); } - let details = req.0.parameters.details; + let details = req.0.parameters.details || req.0.parameters.decoder_input_details; // Inference let (response, best_of_responses) = match req.0.parameters.best_of { @@ -364,7 +364,17 @@ async fn generate_stream( let details = req.0.parameters.details; let best_of = req.0.parameters.best_of.unwrap_or(1); - if best_of == 1 { + if best_of != 1 { + let err = InferError::from(ValidationError::BestOfStream); + metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + tracing::error!("{err}"); + yield Ok(Event::from(err)); + } else if req.0.parameters.decoder_input_details { + let err = InferError::from(ValidationError::PrefillDetailsStream); + metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + tracing::error!("{err}"); + yield Ok(Event::from(err)); + } else { match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await { // Keep permit as long as generate_stream lives Ok((_permit, mut response_stream)) => { @@ -474,11 +484,6 @@ async fn generate_stream( tracing::error!("{err}"); yield Ok(Event::from(err)); } - } else { - let err = InferError::from(ValidationError::BestOfStream); - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); - tracing::error!("{err}"); - yield Ok(Event::from(err)); } }; diff --git a/router/src/validation.rs b/router/src/validation.rs index cbb0d9cd..8843c6a8 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -145,6 +145,7 @@ impl Validation { truncate, seed, watermark, + decoder_input_details, .. } = request.parameters; @@ -261,6 +262,7 @@ impl Validation { Ok(ValidGenerateRequest { inputs, + decoder_input_details, input_length: input_length as u32, truncate: truncate.unwrap_or(self.max_input_length) as u32, parameters, @@ -335,6 +337,7 @@ pub(crate) struct ValidGenerateRequest { pub inputs: String, pub input_length: u32, pub truncate: u32, + pub decoder_input_details: bool, pub parameters: NextTokenChooserParameters, pub stopping_parameters: StoppingCriteriaParameters, } @@ -351,6 +354,8 @@ pub enum ValidationError { BestOfSeed, #[error("`best_of` != 1 is not supported when streaming tokens")] BestOfStream, + #[error("`decoder_input_details` == true is not supported when streaming tokens")] + PrefillDetailsStream, #[error("`temperature` must be strictly positive")] Temperature, #[error("`repetition_penalty` must be strictly positive")] diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 590ba557..338fe053 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -24,6 +24,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="Test", + prefill_logprobs=True, truncate=100, parameters=default_pb_parameters, stopping_parameters=default_pb_stop_parameters, diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 3f28f5b3..0f9dab2c 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -25,6 +25,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="Test", + prefill_logprobs=True, truncate=100, parameters=default_pb_parameters, stopping_parameters=default_pb_stop_parameters, diff --git a/server/tests/models/test_santacoder.py b/server/tests/models/test_santacoder.py index bef8db38..fceec560 100644 --- a/server/tests/models/test_santacoder.py +++ b/server/tests/models/test_santacoder.py @@ -15,6 +15,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="def", + prefill_logprobs=True, truncate=100, parameters=default_pb_parameters, stopping_parameters=default_pb_stop_parameters, @@ -31,6 +32,7 @@ def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="defworld", + prefill_logprobs=True, truncate=100, parameters=default_pb_parameters, stopping_parameters=default_pb_stop_parameters, diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index a3199d02..299340f8 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -28,6 +28,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="Test", + prefill_logprobs=True, truncate=100, parameters=default_pb_parameters, stopping_parameters=default_pb_stop_parameters, diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 92622350..ba0853f5 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -104,7 +104,7 @@ class CausalLMBatch(Batch): ).to(device) for _ in pb.requests: input_len = tokenized_inputs["input_ids"].shape[1] - prefix_offsets.append(0) + prefix_offsets.append(input_len - 5) read_offsets.append(input_len) input_lengths = tokenized_inputs["attention_mask"].sum(1) @@ -617,7 +617,7 @@ class CausalLM(Model): generated_text = None # Prefill - if stopping_criteria.current_tokens == 1: + if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: # Remove generated token to only have prefill and add nan for first prompt token prefill_logprobs = [float("nan")] + torch.log_softmax( logits, -1 diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 2dcb6ed8..f4116937 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -443,6 +443,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): max_s, past_key_values: Optional[torch.Tensor] = None, pre_allocate_past_size: Optional[int] = None, + lm_head_indices: Optional[torch.Tensor] = None, ): hidden_states, present = self.model( input_ids, @@ -453,6 +454,8 @@ class FlashLlamaForCausalLM(torch.nn.Module): past_key_values, pre_allocate_past_size, ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] logits = self.lm_head(hidden_states) if self.model.tp_embeddings: diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 26e21753..b798750a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -481,6 +481,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): max_s, past_key_values: Optional[torch.Tensor] = None, pre_allocate_past_size: Optional[int] = None, + lm_head_indices: Optional[torch.Tensor] = None, ): hidden_states, present = self.gpt_neox( input_ids, @@ -491,6 +492,8 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): past_key_values, pre_allocate_past_size, ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] logits = self.embed_out(hidden_states) if self.gpt_neox.tp_embeddings: diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 545da26a..03487703 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -752,6 +752,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): max_s, past_key_values: Optional[torch.Tensor] = None, pre_allocate_past_size: Optional[int] = None, + lm_head_indices: Optional[torch.Tensor] = None, ): hidden_states, present = self.transformer( input_ids, @@ -762,6 +763,8 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): past_key_values, pre_allocate_past_size, ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] logits = self.lm_head(hidden_states) if self.transformer.tp_embeddings: diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 9bded805..b61ec873 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -358,6 +358,7 @@ class FlashSantacoderForCausalLM(nn.Module): max_s, past_key_values: Optional[torch.Tensor] = None, pre_allocate_past_size: Optional[int] = None, + lm_head_indices: Optional[torch.Tensor] = None, ): hidden_states, present = self.transformer( input_ids, @@ -368,6 +369,8 @@ class FlashSantacoderForCausalLM(nn.Module): past_key_values, pre_allocate_past_size, ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] logits = self.lm_head(hidden_states) if self.transformer.tp_embeddings: diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 35cbe174..5ff951b3 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -42,6 +42,11 @@ class FlashCausalLMBatch(Batch): past_key_values: Optional[torch.Tensor] max_seqlen: int + # Prefill metadata tensors to efficiently compute logprobs + prefill_head_indices: Optional[torch.Tensor] + prefill_next_token_indices: Optional[torch.tensor] + prefill_cu_outlens: Optional[List[int]] + # All tokens all_input_ids: List[List[int]] all_input_ids_tensor: torch.Tensor @@ -84,11 +89,18 @@ class FlashCausalLMBatch(Batch): all_input_ids = [] requests_idx_mapping = {} + all_prefill_logprobs = True + no_prefill_logprobs = True + prefill_head_indices = [] + prefill_next_token_indices = [] + prefill_cu_outlens = [0] + next_token_chooser_parameters = [] stopping_criterias = [] # Cumulative length cumulative_length = 0 + prefill_out_cumulative_length = 0 max_tokens = 0 max_length = 0 @@ -106,13 +118,14 @@ class FlashCausalLMBatch(Batch): max_seqlen = max(max_seqlen, input_length) input_lengths.append(input_length) - prefix_offsets.append(0) + prefix_offsets.append(input_length - 5) read_offsets.append(input_length) all_input_ids.append(tokenized_input) # Position ids - position_ids.append(np.arange(0, input_length)) + request_position_ids = torch.arange(0, input_length, dtype=torch.int32) + position_ids.append(request_position_ids) # Add cumulative lengths of all previous inputs cu_seqlens.append(cumulative_length + input_length) @@ -125,6 +138,26 @@ class FlashCausalLMBatch(Batch): max_new_tokens = stopping_criteria.max_new_tokens stopping_criterias.append(stopping_criteria) + all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs + no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs + + if r.prefill_logprobs: + prefill_head_indices.append(request_position_ids + cumulative_length) + prefill_next_token_indices.append( + prefill_out_cumulative_length + input_length - 1 + ) + prefill_cu_outlens.append(prefill_out_cumulative_length + input_length) + prefill_out_cumulative_length += input_length + else: + prefill_head_indices.append( + torch.tensor( + [cumulative_length + input_length - 1], dtype=torch.int32 + ) + ) + prefill_next_token_indices.append(prefill_out_cumulative_length) + prefill_cu_outlens.append(prefill_out_cumulative_length + 1) + prefill_out_cumulative_length += 1 + # Update cumulative_length += input_length max_tokens += input_length + max_new_tokens @@ -141,18 +174,35 @@ class FlashCausalLMBatch(Batch): for i, input_ids in enumerate(all_input_ids): all_input_ids_tensor[i, : len(input_ids)] = input_ids + if len(pb.requests) > 1: + input_ids = np.concatenate(all_input_ids, dtype=np.int64) + position_ids = torch.cat(position_ids) + else: + input_ids = all_input_ids[0] + position_ids = position_ids[0] + # Create tensors on device - input_ids = torch.tensor( - np.concatenate(all_input_ids), dtype=torch.int64, device=device - ) + input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) all_input_ids_tensor = torch.tensor( all_input_ids_tensor, dtype=torch.int64, device=device ) - position_ids = torch.tensor( - np.concatenate(position_ids), dtype=torch.int32, device=device - ) + position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device) cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32) + if all_prefill_logprobs: + prefill_head_indices = None + prefill_next_token_indices = cu_seqlens[1:] - 1 + elif no_prefill_logprobs: + prefill_head_indices = cu_seqlens[1:] - 1 + prefill_next_token_indices = None + else: + prefill_head_indices = torch.tensor( + torch.cat(prefill_head_indices), dtype=torch.int64, device=device + ) + prefill_next_token_indices = torch.tensor( + prefill_next_token_indices, dtype=torch.int64, device=device + ) + return cls( batch_id=pb.id, requests=pb.requests, @@ -162,6 +212,9 @@ class FlashCausalLMBatch(Batch): cu_seqlens=cu_seqlens, cu_seqlens_q=None, max_seqlen=max_seqlen, + prefill_head_indices=prefill_head_indices, + prefill_next_token_indices=prefill_next_token_indices, + prefill_cu_outlens=prefill_cu_outlens, past_key_values=None, input_lengths=input_lengths, prefix_offsets=prefix_offsets, @@ -280,6 +333,9 @@ class FlashCausalLMBatch(Batch): cu_seqlens=cu_seqlens, cu_seqlens_q=cu_seqlens_q, max_seqlen=max_seqlen, + prefill_head_indices=None, + prefill_next_token_indices=None, + prefill_cu_outlens=None, past_key_values=past_key_values, input_lengths=input_lengths, prefix_offsets=prefix_offsets, @@ -415,6 +471,9 @@ class FlashCausalLMBatch(Batch): cu_seqlens=cu_seqlens, cu_seqlens_q=cu_seqlens_q, max_seqlen=max_seqlen, + prefill_head_indices=None, + prefill_next_token_indices=None, + prefill_cu_outlens=None, past_key_values=past_key_values, input_lengths=input_lengths, prefix_offsets=prefix_offsets, @@ -486,6 +545,7 @@ class FlashCausalLM(Model): max_s: int, past_key_values: Optional = None, pre_allocate_past_size: Optional[int] = None, + lm_head_indices: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # Model Forward return self.model.forward( @@ -496,6 +556,7 @@ class FlashCausalLM(Model): max_s=max_s, past_key_values=past_key_values, pre_allocate_past_size=pre_allocate_past_size, + lm_head_indices=lm_head_indices, ) @tracer.start_as_current_span("generate_token") @@ -503,9 +564,10 @@ class FlashCausalLM(Model): self, batch: FlashCausalLMBatch ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: prefill = batch.past_key_values is None + prefill_logprobs = batch.prefill_next_token_indices is not None single_request = len(batch) == 1 - if prefill and len(batch) == 1: + if prefill and single_request: # Ask to pre-allocate kv to its max size # == number of tokens + max_new_tokens pre_allocate_past_size = ( @@ -522,11 +584,12 @@ class FlashCausalLM(Model): batch.max_seqlen, batch.past_key_values, pre_allocate_past_size, + batch.prefill_head_indices, ) if prefill: next_token_logits = ( - out[-1:] if single_request else out[batch.cu_seqlens[1:] - 1] + out[batch.prefill_next_token_indices] if prefill_logprobs else out ) else: next_token_logits = out @@ -536,10 +599,10 @@ class FlashCausalLM(Model): ) if prefill: - if len(batch) > 1: + if len(batch) > 1 and prefill_logprobs: # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs # When batch == 1, we will just use the batch.input_ids values directly - prefill_tokens_indices = batch.input_ids.new_zeros(len(batch.input_ids)) + prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) # Create batch.cu_seqlens_q for decode batch.cu_seqlens_q = torch.arange( @@ -600,7 +663,6 @@ class FlashCausalLM(Model): # Zipped iterator iterator = zip( batch.input_lengths, - batch.stopping_criterias, batch.all_input_ids, ) @@ -611,29 +673,33 @@ class FlashCausalLM(Model): # For each member of the batch for i, ( input_length, - stopping_criteria, all_input_ids, ) in enumerate(iterator): - # Indexing metadata start_index = cumulative_length end_index = cumulative_length + input_length if prefill: + # Indexing metadata + out_start_index = batch.prefill_cu_outlens[i] + out_end_index = batch.prefill_cu_outlens[i + 1] + out_length = out_end_index - out_start_index + # Initialize position_ids # In decode, we do not need this as we can just increment position ids next_position_ids[i] = batch.position_ids[end_index - 1] # Used to gather prefill logprobs # Copy batch.input_ids to prefill_token_indices - if len(batch) > 1: - prefill_tokens_indices[ - start_index : end_index - 1 - ] = batch.input_ids[start_index + 1 : end_index] - else: - # Set prefill_tokens_indices to the correct slice - prefill_tokens_indices = batch.input_ids[ - start_index + 1 : end_index - ] + if prefill_logprobs: + if len(batch) > 1: + prefill_tokens_indices[ + out_start_index : out_end_index - 1 + ] = batch.input_ids[start_index + 1 : start_index + out_length] + else: + # Set prefill_tokens_indices to the correct slice + prefill_tokens_indices = batch.input_ids[ + start_index + 1 : start_index + out_length + ] batch.all_input_ids_tensor[i, input_length] = next_input_ids[i] @@ -644,7 +710,7 @@ class FlashCausalLM(Model): batch.position_ids = next_position_ids + 1 batch.cu_seqlens = batch.cu_seqlens + batch.cu_seqlens_q - if prefill: + if prefill and prefill_logprobs: # Get prefill logprobs prefill_logprobs_tensor = torch.log_softmax(out, -1) prefill_logprobs = torch.gather( @@ -657,8 +723,6 @@ class FlashCausalLM(Model): next_token_logprobs = next_token_logprobs.tolist() next_token_ids = batch.input_ids.tolist() - cumulative_length = 0 - # Zipped iterator iterator = zip( batch.requests, @@ -688,9 +752,6 @@ class FlashCausalLM(Model): next_token_id, next_token_logprob, ) in enumerate(iterator): - start_index = cumulative_length - end_index = cumulative_length + input_length - # Append next token to all tokens all_input_ids.append(next_token_id) @@ -728,10 +789,13 @@ class FlashCausalLM(Model): generated_text = None # Prefill - if prefill: + if prefill and request.prefill_logprobs: + out_start_index = batch.prefill_cu_outlens[i] + out_end_index = batch.prefill_cu_outlens[i + 1] + # Remove generated token to only have prefill and add nan for first prompt token request_prefill_logprobs = [float("nan")] + prefill_logprobs[ - start_index : end_index - 1 + out_start_index : out_end_index - 1 ] prefill_token_ids = all_input_ids[:-1] prefill_texts = self.tokenizer.batch_decode( @@ -764,8 +828,10 @@ class FlashCausalLM(Model): batch.prefix_offsets[i] = prefix_offset batch.read_offsets[i] = read_offset batch.all_input_ids[i] = all_input_ids - cumulative_length += input_length + batch.prefill_cu_outlens = None + batch.prefill_head_indices = None + batch.prefill_next_token_indices = None batch.max_seqlen = batch.max_seqlen + 1 # No need to return a batch if we know that all requests stopped diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 68e59dc3..3ad5698c 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -688,7 +688,7 @@ class Seq2SeqLM(Model): generated_text = None # Prefill - if stopping_criteria.current_tokens == 1: + if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: prefill_tokens = PrefillTokens( [self.tokenizer.bos_token_id], [float("nan")], From 6abec14a7eeb6e29a394557d64e2b527af1a89fb Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Mon, 5 Jun 2023 16:09:41 +0200 Subject: [PATCH 08/22] feat(server): batch tokenization for flash causal lm (#411) --- .../models/flash_causal_lm.py | 18 ++++++++++++++---- server/text_generation_server/utils/hub.py | 4 +++- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 5ff951b3..a2ad2d5e 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -79,6 +79,16 @@ class FlashCausalLMBatch(Batch): dtype: torch.dtype, device: torch.device, ) -> "FlashCausalLMBatch": + batch_inputs = [] + max_truncation = 0 + for r in pb.requests: + batch_inputs.append(r.inputs) + max_truncation = max(max_truncation, r.truncate) + + batch_tokenized_inputs = tokenizer( + batch_inputs, truncation=True, max_length=max_truncation + )["input_ids"] + position_ids = [] cu_seqlens = [0] max_seqlen = 0 @@ -106,13 +116,13 @@ class FlashCausalLMBatch(Batch): max_length = 0 # Parse batch - for i, r in enumerate(pb.requests): + for i, (r, tokenized_input) in enumerate( + zip(pb.requests, batch_tokenized_inputs) + ): # request id -> idx in list mapping requests_idx_mapping[r.id] = i - tokenized_input = tokenizer( - r.inputs, truncation=True, max_length=r.truncate - )["input_ids"] + tokenized_input = tokenized_input[-r.truncate :] input_length = len(tokenized_input) max_seqlen = max(max_seqlen, input_length) diff --git a/server/text_generation_server/utils/hub.py b/server/text_generation_server/utils/hub.py index 134ac7cd..2ed7673c 100644 --- a/server/text_generation_server/utils/hub.py +++ b/server/text_generation_server/utils/hub.py @@ -134,7 +134,7 @@ def download_weights( ) -> List[Path]: """Download the safetensors files from the hub""" - def download_file(filename, tries=5): + def download_file(filename, tries=5, backoff: int = 5): local_file = try_to_load_from_cache(model_id, revision, filename) if local_file is not None: logger.info(f"File {filename} already present in cache.") @@ -158,6 +158,8 @@ def download_weights( if i + 1 == tries: raise e logger.error(e) + logger.info(f"Retrying in {backoff} seconds") + time.sleep(backoff) logger.info(f"Retry {i + 1}/{tries - 1}") # We do this instead of using tqdm because we want to parse the logs with the launcher From 19c41824cb11ba1a3b60a2a65274d8c074383de3 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Mon, 5 Jun 2023 18:16:08 +0200 Subject: [PATCH 09/22] chore: update openapi schema --- docs/openapi.json | 7 ++++++- router/src/lib.rs | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index e5ef0e3c..8c652946 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -375,7 +375,8 @@ "$ref": "#/components/schemas/GenerateParameters" }, "stream": { - "type": "boolean" + "type": "boolean", + "default": "false" } } }, @@ -459,6 +460,10 @@ "minimum": 0.0, "exclusiveMinimum": 0.0 }, + "decoder_input_details": { + "type": "boolean", + "default": "true" + }, "details": { "type": "boolean", "default": "true" diff --git a/router/src/lib.rs b/router/src/lib.rs index 67fff017..7dff7a11 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -176,7 +176,7 @@ pub(crate) struct CompatGenerateRequest { #[serde(default = "default_parameters")] pub parameters: GenerateParameters, #[serde(default)] - #[allow(dead_code)] + #[schema(default = "false")] pub stream: bool, } From abd58ff82c37d5e4f131abdac3d298927a815604 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 8 Jun 2023 14:51:52 +0200 Subject: [PATCH 10/22] feat(server): Rework model loading (#344) # What does this PR do? Reworked the loading logic. Idea is to use cleaner loading code: - Remove need for `no_init_weights` - Remove all weird `bnb_linear` and `load_weights` and `post_load_weights`. New code layout: - New class `Weights` in charge of handling loading the weights from multiple files into appropiate tensors (potentially sharded) - TP layers now are "shells", they contain the code to know what kind of sharding we need + eventual `all_reduce`. They do not inherit from linear, but they contain some kind of Linear instead - the contained linear can be either FastLinear, BnbLinear or GPTq Linear next. - All modeling code is explictly made for sharding, process group is just no-ops for non sharded code (removes a lot of test cases) ![Screenshot from 2023-05-19 23-19-59](https://github.com/huggingface/text-generation-inference/assets/204321/9a802654-74a3-488c-87a8-073743a6143f) --------- Co-authored-by: Ubuntu Co-authored-by: Ubuntu Co-authored-by: OlivierDehaene Co-authored-by: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> --- .gitignore | 1 + Dockerfile | 17 +- Makefile | 7 +- integration-tests/conftest.py | 14 + .../test_flash_starcoder_default_params.json | 14 +- .../__snapshots__/test_neox/test_neox.json | 113 ++ .../test_neox/test_neox_load.json | 454 +++++++ .../test_neox_sharded/test_neox.json | 163 +++ .../test_neox_sharded/test_neox_load.json | 654 +++++++++ integration-tests/models/test_flash_neox.py | 4 +- integration-tests/models/test_neox.py | 48 + integration-tests/models/test_neox_sharded.py | 44 + integration-tests/pytest.ini | 1 + server/Makefile | 5 +- server/Makefile-transformers | 13 - .../custom_kernels/fused_attention_cuda.cu | 250 ++++ .../fused_bloom_attention_cuda.cu | 250 ++++ server/custom_kernels/setup.py | 19 + server/pyproject.toml | 3 +- server/requirements.txt | 5 +- server/tests/models/test_bloom.py | 9 +- .../text_generation_server/models/__init__.py | 143 +- server/text_generation_server/models/bloom.py | 203 +-- .../models/custom_modeling/bloom_modeling.py | 912 +++++++++++++ .../custom_modeling/flash_llama_modeling.py | 215 ++- .../custom_modeling/flash_neox_modeling.py | 281 ++-- .../custom_modeling/flash_rw_modeling.py | 363 ++--- .../flash_santacoder_modeling.py | 308 +++-- .../models/custom_modeling/neox_modeling.py | 794 +++++++++++ .../models/custom_modeling/opt_modeling.py | 837 ++++++++++++ .../models/custom_modeling/t5_modeling.py | 1200 +++++++++++++++++ .../models/flash_llama.py | 262 +--- .../models/flash_neox.py | 122 +- .../text_generation_server/models/flash_rw.py | 194 +-- .../models/flash_santacoder.py | 369 +---- .../models/galactica.py | 207 +-- .../text_generation_server/models/gpt_neox.py | 207 +-- server/text_generation_server/models/opt.py | 189 +-- server/text_generation_server/models/t5.py | 205 +-- .../text_generation_server/utils/__init__.py | 2 + server/text_generation_server/utils/dist.py | 54 +- server/text_generation_server/utils/layers.py | 367 +++-- .../text_generation_server/utils/weights.py | 77 ++ 43 files changed, 6806 insertions(+), 2793 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_neox/test_neox.json create mode 100644 integration-tests/models/__snapshots__/test_neox/test_neox_load.json create mode 100644 integration-tests/models/__snapshots__/test_neox_sharded/test_neox.json create mode 100644 integration-tests/models/__snapshots__/test_neox_sharded/test_neox_load.json create mode 100644 integration-tests/models/test_neox.py create mode 100644 integration-tests/models/test_neox_sharded.py delete mode 100644 server/Makefile-transformers create mode 100644 server/custom_kernels/custom_kernels/fused_attention_cuda.cu create mode 100644 server/custom_kernels/custom_kernels/fused_bloom_attention_cuda.cu create mode 100644 server/custom_kernels/setup.py create mode 100644 server/text_generation_server/models/custom_modeling/bloom_modeling.py create mode 100644 server/text_generation_server/models/custom_modeling/neox_modeling.py create mode 100644 server/text_generation_server/models/custom_modeling/opt_modeling.py create mode 100644 server/text_generation_server/models/custom_modeling/t5_modeling.py create mode 100644 server/text_generation_server/utils/weights.py diff --git a/.gitignore b/.gitignore index 19604d42..20c9baee 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ .idea target router/tokenizer.json +*__pycache__* diff --git a/Dockerfile b/Dockerfile index 483270a8..576dab8d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,6 +2,8 @@ FROM lukemathwalker/cargo-chef:latest-rust-1.69 AS chef WORKDIR /usr/src +ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse + FROM chef as planner COPY Cargo.toml Cargo.toml COPY rust-toolchain.toml rust-toolchain.toml @@ -98,14 +100,14 @@ COPY server/Makefile-flash-att Makefile RUN make build-flash-attention # Build Transformers CUDA kernels -FROM kernel-builder as transformers-builder +FROM kernel-builder as custom-kernels-builder WORKDIR /usr/src -COPY server/Makefile-transformers Makefile +COPY server/custom_kernels/ . # Build specific version of transformers -RUN BUILD_EXTENSIONS="True" make build-transformers +RUN python setup.py build # Text Generation Inference base image FROM nvidia/cuda:11.8.0-base-ubuntu20.04 as base @@ -136,11 +138,10 @@ COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages # Copy build artifacts from transformers builder -COPY --from=transformers-builder /usr/src/transformers /usr/src/transformers -COPY --from=transformers-builder /usr/src/transformers/build/lib.linux-x86_64-cpython-39/transformers /usr/src/transformers/src/transformers +COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels -# Install transformers dependencies -RUN cd /usr/src/transformers && pip install -e . --no-cache-dir && pip install einops --no-cache-dir +# Install flash-attention dependencies +RUN pip install einops --no-cache-dir # Install server COPY proto proto @@ -170,4 +171,4 @@ ENTRYPOINT ["./entrypoint.sh"] FROM base ENTRYPOINT ["text-generation-launcher"] -CMD ["--json-output"] \ No newline at end of file +CMD ["--json-output"] diff --git a/Makefile b/Makefile index a33aba17..c7f649ec 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,9 @@ install-server: cd server && make install +install-custom-kernels: + if [ "$$BUILD_EXTENSIONS" == "True" ]; then cd server/custom_kernels && python setup.py install; else echo "Custom kernels are disabled, you need set to BUILD_EXTENSION environment variable to 'True' in order to build them. (Please read the docs, kernels might not work on all hardware)"; fi + install-integration-tests: cd integration-tests && pip install -r requirements.txt cd clients/python && pip install . @@ -14,7 +17,7 @@ install-launcher: install-benchmark: cd benchmark && cargo install --path . -install: install-server install-router install-launcher +install: install-server install-router install-launcher install-custom-kernels server-dev: cd server && make run-dev @@ -52,4 +55,4 @@ run-bloom: text-generation-launcher --model-id bigscience/bloom --num-shard 8 --port 8080 run-bloom-quantize: - text-generation-launcher --model-id bigscience/bloom --num-shard 8 --quantize --port 8080 \ No newline at end of file + text-generation-launcher --model-id bigscience/bloom --num-shard 8 --quantize --port 8080 diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 82f1b719..8f59d75a 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -209,6 +209,7 @@ def launcher(event_loop): num_shard: Optional[int] = None, quantize: Optional[str] = None, trust_remote_code: bool = False, + use_flash_attention: bool = True, ): port = random.randint(8000, 10_000) master_port = random.randint(10_000, 20_000) @@ -240,6 +241,9 @@ def launcher(event_loop): env = os.environ env["LOG_LEVEL"] = "info,text_generation_router=debug" + if not use_flash_attention: + env["USE_FLASH_ATTENTION"] = "false" + with subprocess.Popen( args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env ) as process: @@ -254,12 +258,16 @@ def launcher(event_loop): process.stdout.close() process.stderr.close() + if not use_flash_attention: + del env["USE_FLASH_ATTENTION"] + @contextlib.contextmanager def docker_launcher( model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None, trust_remote_code: bool = False, + use_flash_attention: bool = True, ): port = random.randint(8000, 10_000) @@ -287,6 +295,9 @@ def launcher(event_loop): gpu_count = num_shard if num_shard is not None else 1 env = {"LOG_LEVEL": "info,text_generation_router=debug"} + if not use_flash_attention: + env["USE_FLASH_ATTENTION"] = "false" + if HUGGING_FACE_HUB_TOKEN is not None: env["HUGGING_FACE_HUB_TOKEN"] = HUGGING_FACE_HUB_TOKEN @@ -310,6 +321,9 @@ def launcher(event_loop): yield ContainerLauncherHandle(client, container.name, port) + if not use_flash_attention: + del env["USE_FLASH_ATTENTION"] + try: container.stop() container.wait() diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder_default_params.json b/integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder_default_params.json index afd0b662..89e02c07 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder_default_params.json +++ b/integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder_default_params.json @@ -11,17 +11,17 @@ }, { "id": 1459, - "logprob": -5.6289062, + "logprob": -5.6328125, "text": " print" }, { "id": 81, - "logprob": -1.6005859, + "logprob": -1.6035156, "text": "_" }, { "id": 7656, - "logprob": -5.9921875, + "logprob": -5.9882812, "text": "hello" } ], @@ -59,19 +59,19 @@ }, { "id": 10896, - "logprob": -0.3659668, + "logprob": -0.38549805, "special": false, "text": " World" }, { "id": 657, - "logprob": -0.49804688, + "logprob": -0.5229492, "special": false, "text": "\")" }, { "id": 203, - "logprob": -0.11279297, + "logprob": -0.10632324, "special": false, "text": "\n" }, @@ -113,7 +113,7 @@ }, { "id": 426, - "logprob": -0.051635742, + "logprob": 0.0, "special": false, "text": "name" }, diff --git a/integration-tests/models/__snapshots__/test_neox/test_neox.json b/integration-tests/models/__snapshots__/test_neox/test_neox.json new file mode 100644 index 00000000..2abc27e1 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_neox/test_neox.json @@ -0,0 +1,113 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|USER|>" + }, + { + "id": 1276, + "logprob": -4.5546875, + "text": "What" + }, + { + "id": 434, + "logprob": -4.1992188, + "text": "'s" + }, + { + "id": 634, + "logprob": -5.125, + "text": " your" + }, + { + "id": 12315, + "logprob": -9.8984375, + "text": " mood" + }, + { + "id": 3063, + "logprob": -4.0976562, + "text": " today" + }, + { + "id": 32, + "logprob": -0.14562988, + "text": "?" + }, + { + "id": 50279, + "logprob": -0.26733398, + "text": "<|ASSISTANT|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 42, + "logprob": -0.86279297, + "special": false, + "text": "I" + }, + { + "id": 1353, + "logprob": -0.94921875, + "special": false, + "text": "'m" + }, + { + "id": 7016, + "logprob": -2.1835938, + "special": false, + "text": " sorry" + }, + { + "id": 13, + "logprob": -0.074035645, + "special": false, + "text": "," + }, + { + "id": 1394, + "logprob": -0.86376953, + "special": false, + "text": "You" + }, + { + "id": 452, + "logprob": -1.2070312, + "special": false, + "text": " have" + }, + { + "id": 247, + "logprob": -1.4365234, + "special": false, + "text": " a" + }, + { + "id": 4327, + "logprob": -1.109375, + "special": false, + "text": " choice" + }, + { + "id": 273, + "logprob": -0.93408203, + "special": false, + "text": " of" + }, + { + "id": 752, + "logprob": -1.8808594, + "special": false, + "text": " what" + } + ] + }, + "generated_text": "I'm sorry,You have a choice of what" +} diff --git a/integration-tests/models/__snapshots__/test_neox/test_neox_load.json b/integration-tests/models/__snapshots__/test_neox/test_neox_load.json new file mode 100644 index 00000000..f37f0d8e --- /dev/null +++ b/integration-tests/models/__snapshots__/test_neox/test_neox_load.json @@ -0,0 +1,454 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|USER|>" + }, + { + "id": 1276, + "logprob": -4.5546875, + "text": "What" + }, + { + "id": 434, + "logprob": -4.1953125, + "text": "'s" + }, + { + "id": 634, + "logprob": -5.125, + "text": " your" + }, + { + "id": 12315, + "logprob": -9.8828125, + "text": " mood" + }, + { + "id": 3063, + "logprob": -3.9980469, + "text": " today" + }, + { + "id": 32, + "logprob": -0.14672852, + "text": "?" + }, + { + "id": 50279, + "logprob": -0.26489258, + "text": "<|ASSISTANT|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 42, + "logprob": -0.8618164, + "special": false, + "text": "I" + }, + { + "id": 1353, + "logprob": -0.9506836, + "special": false, + "text": "'m" + }, + { + "id": 7016, + "logprob": -2.1738281, + "special": false, + "text": " sorry" + }, + { + "id": 13, + "logprob": -0.0758667, + "special": false, + "text": "," + }, + { + "id": 1394, + "logprob": -0.9135742, + "special": false, + "text": "You" + }, + { + "id": 452, + "logprob": -1.1445312, + "special": false, + "text": " have" + }, + { + "id": 247, + "logprob": -1.4375, + "special": false, + "text": " a" + }, + { + "id": 4327, + "logprob": -1.1103516, + "special": false, + "text": " choice" + }, + { + "id": 273, + "logprob": -1.0058594, + "special": false, + "text": " of" + }, + { + "id": 752, + "logprob": -1.921875, + "special": false, + "text": " what" + } + ] + }, + "generated_text": "I'm sorry,You have a choice of what" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|USER|>" + }, + { + "id": 1276, + "logprob": -4.5546875, + "text": "What" + }, + { + "id": 434, + "logprob": -4.1953125, + "text": "'s" + }, + { + "id": 634, + "logprob": -5.125, + "text": " your" + }, + { + "id": 12315, + "logprob": -9.8828125, + "text": " mood" + }, + { + "id": 3063, + "logprob": -3.9980469, + "text": " today" + }, + { + "id": 32, + "logprob": -0.14672852, + "text": "?" + }, + { + "id": 50279, + "logprob": -0.26489258, + "text": "<|ASSISTANT|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 42, + "logprob": -0.8618164, + "special": false, + "text": "I" + }, + { + "id": 1353, + "logprob": -0.9506836, + "special": false, + "text": "'m" + }, + { + "id": 7016, + "logprob": -2.1738281, + "special": false, + "text": " sorry" + }, + { + "id": 13, + "logprob": -0.0758667, + "special": false, + "text": "," + }, + { + "id": 1394, + "logprob": -0.9135742, + "special": false, + "text": "You" + }, + { + "id": 452, + "logprob": -1.1445312, + "special": false, + "text": " have" + }, + { + "id": 247, + "logprob": -1.4375, + "special": false, + "text": " a" + }, + { + "id": 4327, + "logprob": -1.1103516, + "special": false, + "text": " choice" + }, + { + "id": 273, + "logprob": -1.0058594, + "special": false, + "text": " of" + }, + { + "id": 752, + "logprob": -1.921875, + "special": false, + "text": " what" + } + ] + }, + "generated_text": "I'm sorry,You have a choice of what" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|USER|>" + }, + { + "id": 1276, + "logprob": -4.5546875, + "text": "What" + }, + { + "id": 434, + "logprob": -4.1953125, + "text": "'s" + }, + { + "id": 634, + "logprob": -5.125, + "text": " your" + }, + { + "id": 12315, + "logprob": -9.8828125, + "text": " mood" + }, + { + "id": 3063, + "logprob": -3.9980469, + "text": " today" + }, + { + "id": 32, + "logprob": -0.14672852, + "text": "?" + }, + { + "id": 50279, + "logprob": -0.26489258, + "text": "<|ASSISTANT|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 42, + "logprob": -0.8618164, + "special": false, + "text": "I" + }, + { + "id": 1353, + "logprob": -0.9506836, + "special": false, + "text": "'m" + }, + { + "id": 7016, + "logprob": -2.1738281, + "special": false, + "text": " sorry" + }, + { + "id": 13, + "logprob": -0.0758667, + "special": false, + "text": "," + }, + { + "id": 1394, + "logprob": -0.9135742, + "special": false, + "text": "You" + }, + { + "id": 452, + "logprob": -1.1445312, + "special": false, + "text": " have" + }, + { + "id": 247, + "logprob": -1.4375, + "special": false, + "text": " a" + }, + { + "id": 4327, + "logprob": -1.1103516, + "special": false, + "text": " choice" + }, + { + "id": 273, + "logprob": -1.0058594, + "special": false, + "text": " of" + }, + { + "id": 752, + "logprob": -1.921875, + "special": false, + "text": " what" + } + ] + }, + "generated_text": "I'm sorry,You have a choice of what" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|USER|>" + }, + { + "id": 1276, + "logprob": -4.5546875, + "text": "What" + }, + { + "id": 434, + "logprob": -4.1953125, + "text": "'s" + }, + { + "id": 634, + "logprob": -5.125, + "text": " your" + }, + { + "id": 12315, + "logprob": -9.8828125, + "text": " mood" + }, + { + "id": 3063, + "logprob": -3.9980469, + "text": " today" + }, + { + "id": 32, + "logprob": -0.14672852, + "text": "?" + }, + { + "id": 50279, + "logprob": -0.26489258, + "text": "<|ASSISTANT|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 42, + "logprob": -0.8618164, + "special": false, + "text": "I" + }, + { + "id": 1353, + "logprob": -0.9506836, + "special": false, + "text": "'m" + }, + { + "id": 7016, + "logprob": -2.1738281, + "special": false, + "text": " sorry" + }, + { + "id": 13, + "logprob": -0.0758667, + "special": false, + "text": "," + }, + { + "id": 1394, + "logprob": -0.9135742, + "special": false, + "text": "You" + }, + { + "id": 452, + "logprob": -1.1445312, + "special": false, + "text": " have" + }, + { + "id": 247, + "logprob": -1.4375, + "special": false, + "text": " a" + }, + { + "id": 4327, + "logprob": -1.1103516, + "special": false, + "text": " choice" + }, + { + "id": 273, + "logprob": -1.0058594, + "special": false, + "text": " of" + }, + { + "id": 752, + "logprob": -1.921875, + "special": false, + "text": " what" + } + ] + }, + "generated_text": "I'm sorry,You have a choice of what" + } +] diff --git a/integration-tests/models/__snapshots__/test_neox_sharded/test_neox.json b/integration-tests/models/__snapshots__/test_neox_sharded/test_neox.json new file mode 100644 index 00000000..25cdf6d7 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_neox_sharded/test_neox.json @@ -0,0 +1,163 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|prompter|>" + }, + { + "id": 1276, + "logprob": -8.0234375, + "text": "What" + }, + { + "id": 310, + "logprob": -5.4179688, + "text": " is" + }, + { + "id": 247, + "logprob": -2.1542969, + "text": " a" + }, + { + "id": 1167, + "logprob": -5.359375, + "text": " mem" + }, + { + "id": 70, + "logprob": -0.006038666, + "text": "e" + }, + { + "id": 13, + "logprob": -7.328125, + "text": "," + }, + { + "id": 285, + "logprob": -0.3173828, + "text": " and" + }, + { + "id": 752, + "logprob": -2.0625, + "text": " what" + }, + { + "id": 434, + "logprob": -5.7734375, + "text": "'s" + }, + { + "id": 253, + "logprob": -0.74072266, + "text": " the" + }, + { + "id": 2892, + "logprob": -6.5898438, + "text": " history" + }, + { + "id": 3212, + "logprob": -2.2949219, + "text": " behind" + }, + { + "id": 436, + "logprob": -11.40625, + "text": " this" + }, + { + "id": 3159, + "logprob": -2.1113281, + "text": " word" + }, + { + "id": 32, + "logprob": -0.008056641, + "text": "?" + }, + { + "id": 0, + "logprob": -2.3300781, + "text": "<|endoftext|>" + }, + { + "id": 50281, + "logprob": -18.28125, + "text": "<|assistant|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 510, + "logprob": -0.5878906, + "special": false, + "text": "The" + }, + { + "id": 3159, + "logprob": -0.5449219, + "special": false, + "text": " word" + }, + { + "id": 346, + "logprob": -0.05038452, + "special": false, + "text": " \"" + }, + { + "id": 6441, + "logprob": -0.002292633, + "special": false, + "text": "mem" + }, + { + "id": 70, + "logprob": -1.3828278e-05, + "special": false, + "text": "e" + }, + { + "id": 3, + "logprob": -0.0010242462, + "special": false, + "text": "\"" + }, + { + "id": 369, + "logprob": -0.090270996, + "special": false, + "text": " was" + }, + { + "id": 806, + "logprob": -0.12719727, + "special": false, + "text": " first" + }, + { + "id": 908, + "logprob": -0.016571045, + "special": false, + "text": " used" + }, + { + "id": 275, + "logprob": -0.43432617, + "special": false, + "text": " in" + } + ] + }, + "generated_text": "The word \"meme\" was first used in" +} diff --git a/integration-tests/models/__snapshots__/test_neox_sharded/test_neox_load.json b/integration-tests/models/__snapshots__/test_neox_sharded/test_neox_load.json new file mode 100644 index 00000000..0b38e701 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_neox_sharded/test_neox_load.json @@ -0,0 +1,654 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|prompter|>" + }, + { + "id": 1276, + "logprob": -8.0234375, + "text": "What" + }, + { + "id": 310, + "logprob": -5.4179688, + "text": " is" + }, + { + "id": 247, + "logprob": -2.1542969, + "text": " a" + }, + { + "id": 1167, + "logprob": -5.359375, + "text": " mem" + }, + { + "id": 70, + "logprob": -0.006038666, + "text": "e" + }, + { + "id": 13, + "logprob": -7.328125, + "text": "," + }, + { + "id": 285, + "logprob": -0.3173828, + "text": " and" + }, + { + "id": 752, + "logprob": -2.0625, + "text": " what" + }, + { + "id": 434, + "logprob": -5.7734375, + "text": "'s" + }, + { + "id": 253, + "logprob": -0.74072266, + "text": " the" + }, + { + "id": 2892, + "logprob": -6.5898438, + "text": " history" + }, + { + "id": 3212, + "logprob": -2.2949219, + "text": " behind" + }, + { + "id": 436, + "logprob": -11.40625, + "text": " this" + }, + { + "id": 3159, + "logprob": -2.1113281, + "text": " word" + }, + { + "id": 32, + "logprob": -0.008056641, + "text": "?" + }, + { + "id": 0, + "logprob": -2.3300781, + "text": "<|endoftext|>" + }, + { + "id": 50281, + "logprob": -18.28125, + "text": "<|assistant|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 510, + "logprob": -0.5878906, + "special": false, + "text": "The" + }, + { + "id": 3159, + "logprob": -0.5498047, + "special": false, + "text": " word" + }, + { + "id": 346, + "logprob": -0.04815674, + "special": false, + "text": " \"" + }, + { + "id": 6441, + "logprob": -0.002313614, + "special": false, + "text": "mem" + }, + { + "id": 70, + "logprob": -1.2636185e-05, + "special": false, + "text": "e" + }, + { + "id": 3, + "logprob": -0.0010147095, + "special": false, + "text": "\"" + }, + { + "id": 369, + "logprob": -0.0859375, + "special": false, + "text": " was" + }, + { + "id": 806, + "logprob": -0.12609863, + "special": false, + "text": " first" + }, + { + "id": 908, + "logprob": -0.016601562, + "special": false, + "text": " used" + }, + { + "id": 275, + "logprob": -0.38256836, + "special": false, + "text": " in" + } + ] + }, + "generated_text": "The word \"meme\" was first used in" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|prompter|>" + }, + { + "id": 1276, + "logprob": -8.0234375, + "text": "What" + }, + { + "id": 310, + "logprob": -5.421875, + "text": " is" + }, + { + "id": 247, + "logprob": -2.1640625, + "text": " a" + }, + { + "id": 1167, + "logprob": -5.40625, + "text": " mem" + }, + { + "id": 70, + "logprob": -0.005420685, + "text": "e" + }, + { + "id": 13, + "logprob": -7.2226562, + "text": "," + }, + { + "id": 285, + "logprob": -0.26879883, + "text": " and" + }, + { + "id": 752, + "logprob": -2.1992188, + "text": " what" + }, + { + "id": 434, + "logprob": -5.46875, + "text": "'s" + }, + { + "id": 253, + "logprob": -0.8017578, + "text": " the" + }, + { + "id": 2892, + "logprob": -6.6796875, + "text": " history" + }, + { + "id": 3212, + "logprob": -2.1972656, + "text": " behind" + }, + { + "id": 436, + "logprob": -11.4453125, + "text": " this" + }, + { + "id": 3159, + "logprob": -2.1933594, + "text": " word" + }, + { + "id": 32, + "logprob": -0.007858276, + "text": "?" + }, + { + "id": 0, + "logprob": -2.328125, + "text": "<|endoftext|>" + }, + { + "id": 50281, + "logprob": -18.21875, + "text": "<|assistant|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 510, + "logprob": -0.6201172, + "special": false, + "text": "The" + }, + { + "id": 3159, + "logprob": -0.546875, + "special": false, + "text": " word" + }, + { + "id": 346, + "logprob": -0.051879883, + "special": false, + "text": " \"" + }, + { + "id": 6441, + "logprob": -0.0020179749, + "special": false, + "text": "mem" + }, + { + "id": 70, + "logprob": -9.059906e-06, + "special": false, + "text": "e" + }, + { + "id": 3, + "logprob": -0.00096797943, + "special": false, + "text": "\"" + }, + { + "id": 369, + "logprob": -0.07940674, + "special": false, + "text": " was" + }, + { + "id": 806, + "logprob": -0.12182617, + "special": false, + "text": " first" + }, + { + "id": 908, + "logprob": -0.017227173, + "special": false, + "text": " used" + }, + { + "id": 275, + "logprob": -0.44482422, + "special": false, + "text": " in" + } + ] + }, + "generated_text": "The word \"meme\" was first used in" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|prompter|>" + }, + { + "id": 1276, + "logprob": -8.0234375, + "text": "What" + }, + { + "id": 310, + "logprob": -5.421875, + "text": " is" + }, + { + "id": 247, + "logprob": -2.1640625, + "text": " a" + }, + { + "id": 1167, + "logprob": -5.40625, + "text": " mem" + }, + { + "id": 70, + "logprob": -0.005420685, + "text": "e" + }, + { + "id": 13, + "logprob": -7.2226562, + "text": "," + }, + { + "id": 285, + "logprob": -0.26879883, + "text": " and" + }, + { + "id": 752, + "logprob": -2.1992188, + "text": " what" + }, + { + "id": 434, + "logprob": -5.46875, + "text": "'s" + }, + { + "id": 253, + "logprob": -0.8017578, + "text": " the" + }, + { + "id": 2892, + "logprob": -6.6796875, + "text": " history" + }, + { + "id": 3212, + "logprob": -2.1972656, + "text": " behind" + }, + { + "id": 436, + "logprob": -11.4453125, + "text": " this" + }, + { + "id": 3159, + "logprob": -2.1933594, + "text": " word" + }, + { + "id": 32, + "logprob": -0.007858276, + "text": "?" + }, + { + "id": 0, + "logprob": -2.328125, + "text": "<|endoftext|>" + }, + { + "id": 50281, + "logprob": -18.21875, + "text": "<|assistant|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 510, + "logprob": -0.6201172, + "special": false, + "text": "The" + }, + { + "id": 3159, + "logprob": -0.546875, + "special": false, + "text": " word" + }, + { + "id": 346, + "logprob": -0.051879883, + "special": false, + "text": " \"" + }, + { + "id": 6441, + "logprob": -0.0020179749, + "special": false, + "text": "mem" + }, + { + "id": 70, + "logprob": -9.059906e-06, + "special": false, + "text": "e" + }, + { + "id": 3, + "logprob": -0.00096797943, + "special": false, + "text": "\"" + }, + { + "id": 369, + "logprob": -0.07940674, + "special": false, + "text": " was" + }, + { + "id": 806, + "logprob": -0.12182617, + "special": false, + "text": " first" + }, + { + "id": 908, + "logprob": -0.017227173, + "special": false, + "text": " used" + }, + { + "id": 275, + "logprob": -0.44482422, + "special": false, + "text": " in" + } + ] + }, + "generated_text": "The word \"meme\" was first used in" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|prompter|>" + }, + { + "id": 1276, + "logprob": -8.0234375, + "text": "What" + }, + { + "id": 310, + "logprob": -5.421875, + "text": " is" + }, + { + "id": 247, + "logprob": -2.1640625, + "text": " a" + }, + { + "id": 1167, + "logprob": -5.40625, + "text": " mem" + }, + { + "id": 70, + "logprob": -0.005420685, + "text": "e" + }, + { + "id": 13, + "logprob": -7.2226562, + "text": "," + }, + { + "id": 285, + "logprob": -0.26879883, + "text": " and" + }, + { + "id": 752, + "logprob": -2.1992188, + "text": " what" + }, + { + "id": 434, + "logprob": -5.46875, + "text": "'s" + }, + { + "id": 253, + "logprob": -0.8017578, + "text": " the" + }, + { + "id": 2892, + "logprob": -6.6796875, + "text": " history" + }, + { + "id": 3212, + "logprob": -2.1972656, + "text": " behind" + }, + { + "id": 436, + "logprob": -11.4453125, + "text": " this" + }, + { + "id": 3159, + "logprob": -2.1933594, + "text": " word" + }, + { + "id": 32, + "logprob": -0.007858276, + "text": "?" + }, + { + "id": 0, + "logprob": -2.328125, + "text": "<|endoftext|>" + }, + { + "id": 50281, + "logprob": -18.21875, + "text": "<|assistant|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 510, + "logprob": -0.6201172, + "special": false, + "text": "The" + }, + { + "id": 3159, + "logprob": -0.546875, + "special": false, + "text": " word" + }, + { + "id": 346, + "logprob": -0.051879883, + "special": false, + "text": " \"" + }, + { + "id": 6441, + "logprob": -0.0020179749, + "special": false, + "text": "mem" + }, + { + "id": 70, + "logprob": -1.04904175e-05, + "special": false, + "text": "e" + }, + { + "id": 3, + "logprob": -0.0009560585, + "special": false, + "text": "\"" + }, + { + "id": 369, + "logprob": -0.08557129, + "special": false, + "text": " was" + }, + { + "id": 806, + "logprob": -0.12084961, + "special": false, + "text": " first" + }, + { + "id": 908, + "logprob": -0.01737976, + "special": false, + "text": " used" + }, + { + "id": 275, + "logprob": -0.4025879, + "special": false, + "text": " in" + } + ] + }, + "generated_text": "The word \"meme\" was first used in" + } +] diff --git a/integration-tests/models/test_flash_neox.py b/integration-tests/models/test_flash_neox.py index ff9b9763..1076126b 100644 --- a/integration-tests/models/test_flash_neox.py +++ b/integration-tests/models/test_flash_neox.py @@ -37,8 +37,8 @@ async def test_flash_neox_load(flash_neox, generate_load, response_snapshot): generated_texts = [r.generated_text for r in responses] assert len(generated_texts) == 4 - assert generated_texts, all( + assert all( [text == generated_texts[0] for text in generated_texts] - ) + ), generated_texts assert responses == response_snapshot diff --git a/integration-tests/models/test_neox.py b/integration-tests/models/test_neox.py new file mode 100644 index 00000000..7b88f86a --- /dev/null +++ b/integration-tests/models/test_neox.py @@ -0,0 +1,48 @@ +import pytest + + +@pytest.fixture(scope="module") +def neox_handle(launcher): + with launcher( + "stabilityai/stablelm-tuned-alpha-3b", num_shard=1, use_flash_attention=False + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def neox(neox_handle): + await neox_handle.health(300) + return neox_handle.client + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_neox(neox, response_snapshot): + response = await neox.generate( + "<|USER|>What's your mood today?<|ASSISTANT|>", + max_new_tokens=10, + decoder_input_details=True, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_neox_load(neox, generate_load, response_snapshot): + responses = await generate_load( + neox, + "<|USER|>What's your mood today?<|ASSISTANT|>", + max_new_tokens=10, + n=4, + ) + + generated_texts = [r.generated_text for r in responses] + + assert len(generated_texts) == 4 + assert generated_texts, all( + [text == generated_texts[0] for text in generated_texts] + ) + + assert responses == response_snapshot diff --git a/integration-tests/models/test_neox_sharded.py b/integration-tests/models/test_neox_sharded.py new file mode 100644 index 00000000..8cee8765 --- /dev/null +++ b/integration-tests/models/test_neox_sharded.py @@ -0,0 +1,44 @@ +import pytest + + +@pytest.fixture(scope="module") +def neox_sharded_handle(launcher): + with launcher( + "OpenAssistant/oasst-sft-1-pythia-12b", num_shard=2, use_flash_attention=False + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def neox_sharded(neox_sharded_handle): + await neox_sharded_handle.health(300) + return neox_sharded_handle.client + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_neox(neox_sharded, response_snapshot): + response = await neox_sharded.generate( + "<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>", + max_new_tokens=10, + decoder_input_details=True, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_neox_load(neox_sharded, generate_load, response_snapshot): + responses = await generate_load( + neox_sharded, + "<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>", + max_new_tokens=10, + n=4, + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/integration-tests/pytest.ini b/integration-tests/pytest.ini index 485e6017..7dcae663 100644 --- a/integration-tests/pytest.ini +++ b/integration-tests/pytest.ini @@ -1,4 +1,5 @@ [pytest] +addopts = --snapshot-warn-unused asyncio_mode = auto markers = private: marks tests as requiring an admin hf token (deselect with '-m "not private"') \ No newline at end of file diff --git a/server/Makefile b/server/Makefile index 6eb56c75..17020c97 100644 --- a/server/Makefile +++ b/server/Makefile @@ -1,4 +1,3 @@ -include Makefile-transformers include Makefile-flash-att unit-tests: @@ -17,7 +16,7 @@ install-torch: # Install specific version of torch pip install torch --extra-index-url https://download.pytorch.org/whl/cu118 --no-cache-dir -install: gen-server install-torch install-transformers +install: gen-server install-torch pip install pip --upgrade pip install -r requirements.txt pip install -e ".[bnb, accelerate]" @@ -26,4 +25,4 @@ run-dev: SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded export-requirements: - poetry export -o requirements.txt -E bnb --without-hashes \ No newline at end of file + poetry export -o requirements.txt -E bnb --without-hashes diff --git a/server/Makefile-transformers b/server/Makefile-transformers deleted file mode 100644 index 64d01672..00000000 --- a/server/Makefile-transformers +++ /dev/null @@ -1,13 +0,0 @@ -transformers_commit := 69009822aa7897ffab97afb814e38126b83f639e - -transformers: - # Clone fork of transformers with custom CUDA kernels and sharding logic - pip install --upgrade setuptools - git clone https://github.com/OlivierDehaene/transformers.git - -build-transformers: transformers - cd transformers && git fetch && git checkout $(transformers_commit) && python setup.py build - -install-transformers: build-transformers - pip uninstall transformers -y || true - cd transformers && python setup.py install \ No newline at end of file diff --git a/server/custom_kernels/custom_kernels/fused_attention_cuda.cu b/server/custom_kernels/custom_kernels/fused_attention_cuda.cu new file mode 100644 index 00000000..60f9f028 --- /dev/null +++ b/server/custom_kernels/custom_kernels/fused_attention_cuda.cu @@ -0,0 +1,250 @@ +#include +#include +#include +#include +#include + +#include + +/** +* Friendly reminder of how multithreading works in CUDA: https://developer.nvidia.com/blog/even-easier-introduction-cuda +* Check example at https://github.com/thomasw21/LinearTransformers/blob/main/model/attention/fast_weight/fast_weight_cuda.cu +**/ + +// Available in pytorch main +//#define DISPATCH_CASE_FLOATING_TYPES(...) \ +// at::AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \ +// at::AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ +// at::AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ +// at::AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + +/* +* Forward passes +*/ + +/** +* cast to fp32 if in fp16 + mask + softmax computation in fp32 + cast back to original dtype +**/ +template +__global__ void forward_masked_softmax_kernel( + const torch::PackedTensorAccessor32 attention_scores, // [B, KV] + const torch::PackedTensorAccessor32 mask, // [B, KV] + torch::PackedTensorAccessor32 result, // [B, KV] + const int64_t effective_kv_length, + const dim3 blockDim, + const int64_t rows_per_block, + const int64_t kv_length, + const int64_t batch_size +) { + const auto row_id = threadIdx.x / effective_kv_length; + const auto effective_kv_length_id = threadIdx.x % effective_kv_length; + const auto kv_length_start = effective_kv_length_id * min_kv_length_shard_size_per_thread; + auto kv_length_end_ = (effective_kv_length_id + 1) * min_kv_length_shard_size_per_thread; + kv_length_end_ = (kv_length_end_ > kv_length) ? kv_length : kv_length_end_; + const auto kv_length_end = kv_length_end_; + + const auto batch_id = blockIdx.x * rows_per_block + row_id; + + // We need 2 float storage for each row, one for max computation, the other for normalizing exponential + extern __shared__ float temp_storage[]; + const auto row_id_mem_offset = row_id * 2; + if (effective_kv_length_id == 0) { + temp_storage[row_id_mem_offset] = -std::numeric_limits::infinity(); + temp_storage[row_id_mem_offset + 1] = 0; + } + __syncthreads(); + + // Compute mask and max + if (batch_id < batch_size) { + float thread_max = -std::numeric_limits::infinity(); + for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { + if (mask[batch_id][kv_length_id] == 0) { + const float candidate = attention_scores[batch_id][kv_length_id]; + thread_max = (thread_max < candidate) ? candidate : thread_max; + } + } + if (thread_max != -std::numeric_limits::infinity()) { + // TODO @thomasw21 with more memory we can probably compute a much faster `max-reduce` in parallel O(ln(n)) operations in each memory slot + gpuAtomicMax(&temp_storage[row_id_mem_offset], thread_max); + } + } + + __syncthreads(); + + // Compute exp(elt - max) masked + float exponential[min_kv_length_shard_size_per_thread]; + if (batch_id < batch_size) { + float thread_add = 0; + for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { + if (mask[batch_id][kv_length_id] == 0) { + exponential[kv_length_id - kv_length_start] = std::exp(static_cast(attention_scores[batch_id][kv_length_id]) - temp_storage[row_id_mem_offset]); + thread_add = thread_add + exponential[kv_length_id - kv_length_start]; + } else { + exponential[kv_length_id - kv_length_start] = 0.; + } + } + if (thread_add > 0) { + // TODO @thomasw21 with more memory we can probably compute a much faster `sum-reduce` in parallel O(ln(n)) operations in each memory slot + gpuAtomicAdd(&temp_storage[row_id_mem_offset + 1], thread_add); + } + } + + __syncthreads(); + + // Compute softmax + if (batch_id < batch_size) { + // If sum of all exponential is 0, we set the softmax values to 0 + if (temp_storage[row_id_mem_offset + 1] == 0.) { + for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { + result[batch_id][kv_length_id] = 0.; + } + } else { + for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { + result[batch_id][kv_length_id] = static_cast(exponential[kv_length_id - kv_length_start] / temp_storage[row_id_mem_offset + 1]); + } + } + } +} + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +std::tuple>, at::Tensor> forward( + const at::Tensor query, + const at::Tensor key, + const at::Tensor value, + const std::optional> layer_past, + const at::Tensor attention_mask, + const std::optional head_mask, + const float inv_norm_factor, + const int num_heads, + const bool use_cache +) { + auto query_layer = query; + auto key_layer = key; + auto value_layer = value; + + if (layer_past) { + const auto past_key = (*layer_past).at(0); + const auto past_value = (*layer_past).at(1); + key_layer = at::cat({past_key, key_layer}, 2); + value_layer = at::cat({past_value, value_layer}, 2); + } + + std::optional> present; + if (use_cache) { + present = {key_layer, value_layer}; + } else { + present = {}; + } + + const auto batch_size = query_layer.size(0); + const auto q_length = query_layer.size(2); + const auto attn_head_size = query_layer.size(3); + const auto batch_size_times_num_heads = batch_size * num_heads; + const auto kv_length = key_layer.size(2); + + const auto query_view = query_layer.reshape({batch_size_times_num_heads, q_length, attn_head_size}); + auto key_view = key_layer.reshape({batch_size_times_num_heads, kv_length, attn_head_size}).transpose(1, 2); + auto value_view = value_layer.reshape({batch_size_times_num_heads, kv_length, attn_head_size}); + + auto query_scaled = query_view * inv_norm_factor; + auto attention_scores = at::bmm(query_scaled, key_view); + + // Computing `optionally_cast_fp16_to_fp32 + masked_fill + softmax + cast_to_intial_dtype` + at::Tensor attention_probs; + if (true) { + // TODO @thomasw21: it's easier to think of attention_scores as 2D tensors + const auto attention_scores_2d = attention_scores.view({batch_size_times_num_heads * q_length, kv_length}); + const auto attention_mask_2d = attention_mask.view({batch_size_times_num_heads * q_length, kv_length}); + + // Custom kernel + attention_probs = at::empty_like(attention_scores_2d); + + // Check that inputs and contiguous + cuda tensors + CHECK_INPUT(attention_scores_2d); + CHECK_INPUT(attention_mask_2d); + + // TODO @thomas21: change by to this as it's cleaner when pytorch 1.13 comes out + // DISPATCH_CASE_FLOATING_TYPES(attention_scores.scalar_type(), "masked_softmax", [&] { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, attention_scores.scalar_type(), "masked_softmax", [&] { + /* + * Understanding how GPUs work: https://developer.nvidia.com/blog/cuda-refresher-cuda-programming-model/ + * A100 specifications: https://images.nvidia.com/aem-dam/en-zz/Solutions/data-center/nvidia-ampere-architecture-whitepaper.pdf + * - SMs: 108 + * - TPCs: 56 (What's that?) + * - Memory size: 40 GB + * - L2 Cache size: 40960 KB (shared across all SMs) + * - L1/Shared memory size: 192 KB (shared across all threads within a SM) + * - Max Threads / SM: 2048 + * - Max Thread Blocks / SM: 32 + */ + + /* + * We should split [batch_size_times_num_heads_block, q_length] in seperate blocks and [batch_size_times_num_heads_block_size, kv_length] a single block + * with multiple threads as we need to `sync_threads` to run exponential sum. + * We maximise the usage of threads within a single block + */ + // TODO @thomasw21 figure out everything warp related: + // - why do they have to be power of 2 + // TODO @thomas21 check why everyone is setting 1024 when officially it's 2048 + const auto MAX_THREADS_PER_SM = 1024; + // TODO @thomasw21 figure out how to have longer sequences, currently the maximum is `max_kv_length = MAX_THREADS_PER_SM * MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD` + const auto MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD = 4; + // `effective_kv_length = ceil(kv_length / MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD)` + const auto effective_kv_length = (kv_length - 1)/ MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD + 1; + const auto rows_per_block = MAX_THREADS_PER_SM / effective_kv_length; + const auto num_blocks = (batch_size_times_num_heads * q_length - 1) / rows_per_block + 1; + + const dim3 gridDim(num_blocks); // Number of blocks that run + const dim3 blockDim(MAX_THREADS_PER_SM); // Number of threads that run per block + const int shared_mem_forward = rows_per_block * 2 * sizeof(float); + + // 192 * 2 ** 10 + // const auto MAX_L1_MEMORY = 196608; + // const auto MAX_SMs = 108; + // TORCH_CHECK(batch_size_times_num_heads * q_length <= MAX_L1_MEMORY, "Shared memory exceeds 192KB limitation."); + // TORCH_CHECK(gridDim.x * gridDim.y * gridDim.z <= MAX_SMs, "A100s only have 108 SMs. Raising as require blocks is bigger."); + // TORCH_CHECK(blockDim.x * blockDim.y * blockDim.z <= MAX_THREADS_PER_SM, "A100s only have 2048 threads per block. Raising as require requested threads is higher."); + + forward_masked_softmax_kernel<<>>( + attention_scores_2d.packed_accessor32(), + attention_mask_2d.packed_accessor32(), + attention_probs.packed_accessor32(), + effective_kv_length, + blockDim, + rows_per_block, + kv_length, + batch_size_times_num_heads * q_length + ); + }); + attention_probs = attention_probs.view({batch_size_times_num_heads, q_length, kv_length}); + } else { + // Pytorch C++ API + auto input_dtype = attention_scores.scalar_type(); + if (input_dtype == at::ScalarType::Float) { + attention_scores = attention_scores.to(at::ScalarType::Float); + }; + // TODO @thomasw21 Figure out how to get minimum value + auto attn_weights = attention_scores.masked_fill_(attention_mask, -1e34); + attention_probs = attn_weights.softmax(-1, at::ScalarType::Float).to(input_dtype); + } + + auto context_layer = attention_probs.bmm(value_view); + + // `_merge_heads` + context_layer = context_layer.view({batch_size, num_heads, q_length, attn_head_size}); + context_layer = context_layer.permute({0, 2, 1, 3}); + context_layer = context_layer.reshape({batch_size, q_length, attn_head_size * num_heads}); + + return std::make_tuple(context_layer, present, attention_probs); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "forward", + &forward, + "GPT-Neox attention mechanism forward (CUDA)" + ); +} diff --git a/server/custom_kernels/custom_kernels/fused_bloom_attention_cuda.cu b/server/custom_kernels/custom_kernels/fused_bloom_attention_cuda.cu new file mode 100644 index 00000000..4be547b1 --- /dev/null +++ b/server/custom_kernels/custom_kernels/fused_bloom_attention_cuda.cu @@ -0,0 +1,250 @@ +#include +#include +#include +#include +#include + +#include + +/** +* Friendly reminder of how multithreading works in CUDA: https://developer.nvidia.com/blog/even-easier-introduction-cuda +* Check example at https://github.com/thomasw21/LinearTransformers/blob/main/model/attention/fast_weight/fast_weight_cuda.cu +**/ + +// Available in pytorch main +//#define DISPATCH_CASE_FLOATING_TYPES(...) \ +// at::AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \ +// at::AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ +// at::AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ +// at::AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + +/* +* Forward passes +*/ + +/** +* cast to fp32 if in fp16 + mask + softmax computation in fp32 + cast back to original dtype +**/ +template +__global__ void forward_masked_softmax_kernel( + const torch::PackedTensorAccessor32 attention_scores, // [B, KV] + const torch::PackedTensorAccessor32 mask, // [B, KV] + torch::PackedTensorAccessor32 result, // [B, KV] + const int64_t effective_kv_length, + const dim3 blockDim, + const int64_t rows_per_block, + const int64_t kv_length, + const int64_t batch_size +) { + const auto row_id = threadIdx.x / effective_kv_length; + const auto effective_kv_length_id = threadIdx.x % effective_kv_length; + const auto kv_length_start = effective_kv_length_id * min_kv_length_shard_size_per_thread; + auto kv_length_end_ = (effective_kv_length_id + 1) * min_kv_length_shard_size_per_thread; + kv_length_end_ = (kv_length_end_ > kv_length) ? kv_length : kv_length_end_; + const auto kv_length_end = kv_length_end_; + + const auto batch_id = blockIdx.x * rows_per_block + row_id; + + // We need 2 float storage for each row, one for max computation, the other for normalizing exponential + extern __shared__ float temp_storage[]; + const auto row_id_mem_offset = row_id * 2; + if (effective_kv_length_id == 0) { + temp_storage[row_id_mem_offset] = -std::numeric_limits::infinity(); + temp_storage[row_id_mem_offset + 1] = 0; + } + __syncthreads(); + + // Compute mask and max + if (batch_id < batch_size) { + float thread_max = -std::numeric_limits::infinity(); + for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { + if (mask[batch_id][kv_length_id] == 0) { + const float candidate = attention_scores[batch_id][kv_length_id]; + thread_max = (thread_max < candidate) ? candidate : thread_max; + } + } + if (thread_max != -std::numeric_limits::infinity()) { + // TODO @thomasw21 with more memory we can probably compute a much faster `max-reduce` in parallel O(ln(n)) operations in each memory slot + gpuAtomicMax(&temp_storage[row_id_mem_offset], thread_max); + } + } + + __syncthreads(); + + // Compute exp(elt - max) masked + float exponential[min_kv_length_shard_size_per_thread]; + if (batch_id < batch_size) { + float thread_add = 0; + for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { + if (mask[batch_id][kv_length_id] == 0) { + exponential[kv_length_id - kv_length_start] = std::exp(static_cast(attention_scores[batch_id][kv_length_id]) - temp_storage[row_id_mem_offset]); + thread_add = thread_add + exponential[kv_length_id - kv_length_start]; + } else { + exponential[kv_length_id - kv_length_start] = 0.; + } + } + if (thread_add > 0) { + // TODO @thomasw21 with more memory we can probably compute a much faster `sum-reduce` in parallel O(ln(n)) operations in each memory slot + gpuAtomicAdd(&temp_storage[row_id_mem_offset + 1], thread_add); + } + } + + __syncthreads(); + + // Compute softmax + if (batch_id < batch_size) { + // If sum of all exponential is 0, we set the softmax values to 0 + if (temp_storage[row_id_mem_offset + 1] == 0.) { + for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { + result[batch_id][kv_length_id] = 0.; + } + } else { + for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { + result[batch_id][kv_length_id] = static_cast(exponential[kv_length_id - kv_length_start] / temp_storage[row_id_mem_offset + 1]); + } + } + } +} + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +std::tuple>, at::Tensor> forward( + const at::Tensor fused_qkv, + const std::optional> layer_past, + const at::Tensor alibi, + const at::Tensor attention_mask, + const std::optional head_mask, + const float beta, + const float inv_norm_factor, + const int num_heads, + const bool use_cache +) { + const auto batch_size = fused_qkv.size(0); + const auto q_length = fused_qkv.size(1); + const auto three_times_hidden_size = fused_qkv.size(2); + const auto head_dim = three_times_hidden_size / (3 * num_heads); + const auto batch_size_times_num_heads = batch_size * num_heads; + + // `split_heads` + const auto fused_qkv_view = fused_qkv.view({batch_size, q_length, num_heads, 3 * head_dim}); + const auto tensor_list = fused_qkv_view.split(head_dim, -1); + const auto query_layer = tensor_list[0].transpose(1, 2).reshape({batch_size_times_num_heads, q_length, head_dim}); + auto key_layer = tensor_list[1].permute({0, 2, 3, 1}).reshape({batch_size_times_num_heads, head_dim, q_length}); + auto value_layer = tensor_list[2].transpose(1, 2).reshape({batch_size_times_num_heads, q_length, head_dim}); + + if (layer_past) { + const auto past_key = (*layer_past).at(0); + const auto past_value = (*layer_past).at(1); + key_layer = at::cat({past_key, key_layer}, 2); + value_layer = at::cat({past_value, value_layer}, 1); + } + + std::optional> present; + if (use_cache) { + present = {key_layer, value_layer}; + } else { + present = {}; + } + + auto attention_scores = alibi.baddbmm(query_layer, key_layer, beta, inv_norm_factor); + + // Computing `optionally_cast_fp16_to_fp32 + masked_fill + softmax + cast_to_intial_dtype` + at::Tensor attention_probs; + if (true) { + const auto kv_length = key_layer.size(2); + + // TODO @thomasw21: it's easier to think of attention_scores as 2D tensors + const auto attention_scores_2d = attention_scores.view({batch_size_times_num_heads * q_length, kv_length}); + const auto attention_mask_2d = attention_mask.view({batch_size_times_num_heads * q_length, kv_length}); + + // Custom kernel + attention_probs = at::empty_like(attention_scores_2d); + + // Check that inputs and contiguous + cuda tensors + CHECK_INPUT(attention_scores_2d); + CHECK_INPUT(attention_mask_2d); + + // TODO @thomas21: change by to this as it's cleaner when pytorch 1.13 comes out + // DISPATCH_CASE_FLOATING_TYPES(attention_scores.scalar_type(), "masked_softmax", [&] { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, attention_scores.scalar_type(), "masked_softmax", [&] { + /* + * Understanding how GPUs work: https://developer.nvidia.com/blog/cuda-refresher-cuda-programming-model/ + * A100 specifications: https://images.nvidia.com/aem-dam/en-zz/Solutions/data-center/nvidia-ampere-architecture-whitepaper.pdf + * - SMs: 108 + * - TPCs: 56 (What's that?) + * - Memory size: 40 GB + * - L2 Cache size: 40960 KB (shared across all SMs) + * - L1/Shared memory size: 192 KB (shared across all threads within a SM) + * - Max Threads / SM: 2048 + * - Max Thread Blocks / SM: 32 + */ + + /* + * We should split [batch_size_times_num_heads_block, q_length] in seperate blocks and [batch_size_times_num_heads_block_size, kv_length] a single block + * with multiple threads as we need to `sync_threads` to run exponential sum. + * We maximise the usage of threads within a single block + */ + // TODO @thomasw21 figure out everything warp related: + // - why do they have to be power of 2 + // TODO @thomas21 check why everyone is setting 1024 when officially it's 2048 + const auto MAX_THREADS_PER_SM = 1024; + // TODO @thomasw21 figure out how to have longer sequences, currently the maximum is `max_kv_length = MAX_THREADS_PER_SM * MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD` + const auto MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD = 4; + // `effective_kv_length = ceil(kv_length / MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD)` + const auto effective_kv_length = (kv_length - 1)/ MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD + 1; + const auto rows_per_block = MAX_THREADS_PER_SM / effective_kv_length; + const auto num_blocks = (batch_size_times_num_heads * q_length - 1) / rows_per_block + 1; + + const dim3 gridDim(num_blocks); // Number of blocks that run + const dim3 blockDim(MAX_THREADS_PER_SM); // Number of threads that run per block + const int shared_mem_forward = rows_per_block * 2 * sizeof(float); + + // 192 * 2 ** 10 + // const auto MAX_L1_MEMORY = 196608; + // const auto MAX_SMs = 108; + // TORCH_CHECK(batch_size_times_num_heads * q_length <= MAX_L1_MEMORY, "Shared memory exceeds 192KB limitation."); + // TORCH_CHECK(gridDim.x * gridDim.y * gridDim.z <= MAX_SMs, "A100s only have 108 SMs. Raising as require blocks is bigger."); + // TORCH_CHECK(blockDim.x * blockDim.y * blockDim.z <= MAX_THREADS_PER_SM, "A100s only have 2048 threads per block. Raising as require requested threads is higher."); + + forward_masked_softmax_kernel<<>>( + attention_scores_2d.packed_accessor32(), + attention_mask_2d.packed_accessor32(), + attention_probs.packed_accessor32(), + effective_kv_length, + blockDim, + rows_per_block, + kv_length, + batch_size_times_num_heads * q_length + ); + }); + attention_probs = attention_probs.view({batch_size_times_num_heads, q_length, kv_length}); + } else { + // Pytorch C++ API + auto input_dtype = attention_scores.scalar_type(); + if (input_dtype == at::ScalarType::Float) { + attention_scores = attention_scores.to(at::ScalarType::Float); + }; + // TODO @thomasw21 Figure out how to get minimum value + auto attn_weights = attention_scores.masked_fill_(attention_mask, -1e34); + attention_probs = attn_weights.softmax(-1, at::ScalarType::Float).to(input_dtype); + } + + auto context_layer = attention_probs.bmm(value_layer); + + // `_merge_heads` + context_layer = context_layer.view({batch_size, num_heads, q_length, head_dim}); + context_layer = context_layer.permute({0, 2, 1, 3}); + context_layer = context_layer.reshape({batch_size, q_length, three_times_hidden_size / 3}); + + return std::make_tuple(context_layer, present, attention_probs); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "forward", + &forward, + "Bloom attention mechanism forward (CUDA)" + ); +} \ No newline at end of file diff --git a/server/custom_kernels/setup.py b/server/custom_kernels/setup.py new file mode 100644 index 00000000..43b8ee4e --- /dev/null +++ b/server/custom_kernels/setup.py @@ -0,0 +1,19 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +setup( + name="custom_kernels", + ext_modules=[ + CUDAExtension( + name="custom_kernels.fused_bloom_attention_cuda", + sources=["custom_kernels/fused_bloom_attention_cuda.cu"], + extra_compile_args=["-arch=compute_80", "-std=c++17"], + ), + CUDAExtension( + name="custom_kernels.fused_attention_cuda", + sources=["custom_kernels/fused_attention_cuda.cu"], + extra_compile_args=["-arch=compute_80", "-std=c++17"], + ), + ], + cmdclass={"build_ext": BuildExtension}, +) diff --git a/server/pyproject.toml b/server/pyproject.toml index d381eac4..f0ec25eb 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -25,7 +25,8 @@ opentelemetry-instrumentation-grpc = "^0.36b0" hf-transfer = "^0.1.2" sentencepiece = "^0.1.97" tokenizers = "0.13.3" -huggingface-hub = "0.14.0" +huggingface-hub = "^0.14.1" +transformers = "^4.29.2" [tool.poetry.extras] accelerate = ["accelerate"] diff --git a/server/requirements.txt b/server/requirements.txt index 50ba4e43..e8cee52b 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -13,8 +13,8 @@ grpcio-reflection==1.55.0 ; python_version >= "3.9" and python_version < "4.0" grpcio-status==1.55.0 ; python_version >= "3.9" and python_version < "4.0" grpcio==1.55.0 ; python_version >= "3.9" and python_version < "4.0" hf-transfer==0.1.3 ; python_version >= "3.9" and python_version < "4.0" -huggingface-hub==0.14.0 ; python_version >= "3.9" and python_version < "4.0" -idna==3.4 ; python_version >= "3.9" and python_version < "4.0" +huggingface-hub==0.14.1 ; python_version >= "3.9" and python_version < "4.0" +idna==3.4 ; python_version >= "3.9" and python_version < "4" loguru==0.6.0 ; python_version >= "3.9" and python_version < "4.0" opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "4.0" opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "4.0" @@ -33,6 +33,7 @@ safetensors==0.3.1 ; python_version >= "3.9" and python_version < "4.0" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "4.0" setuptools==67.8.0 ; python_version >= "3.9" and python_version < "4.0" tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "4.0" +transformers==4.29.2 ; python_version >= "3.9" and python_version < "4.0" tqdm==4.65.0 ; python_version >= "3.9" and python_version < "4.0" typer==0.6.1 ; python_version >= "3.9" and python_version < "4.0" typing-extensions==4.6.0 ; python_version >= "3.9" and python_version < "4.0" diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 338fe053..71013cb6 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -6,12 +6,17 @@ from transformers import AutoTokenizer from text_generation_server.pb import generate_pb2 from text_generation_server.models.causal_lm import CausalLMBatch -from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOM +from text_generation_server.utils import weight_hub_files, download_weights +from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOMSharded @pytest.fixture(scope="session") def default_bloom(): - return BLOOM("bigscience/bloom-560m") + model_id = "bigscience/bloom-560m" + revision = "main" + filenames = weight_hub_files(model_id, revision, ".safetensors") + download_weights(filenames, model_id, revision) + return BLOOMSharded(model_id) @pytest.fixture(scope="session") diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index fc92d03d..f1b84a53 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -1,3 +1,4 @@ +import os import torch from loguru import logger @@ -8,17 +9,20 @@ from typing import Optional from text_generation_server.models.model import Model from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.flash_causal_lm import FlashCausalLM -from text_generation_server.models.bloom import BLOOM, BLOOMSharded +from text_generation_server.models.bloom import BLOOMSharded from text_generation_server.models.seq2seq_lm import Seq2SeqLM from text_generation_server.models.rw import RW -from text_generation_server.models.opt import OPT, OPTSharded -from text_generation_server.models.galactica import Galactica, GalacticaSharded +from text_generation_server.models.opt import OPTSharded +from text_generation_server.models.galactica import GalacticaSharded from text_generation_server.models.santacoder import SantaCoder -from text_generation_server.models.gpt_neox import GPTNeoxSharded from text_generation_server.models.t5 import T5Sharded +from text_generation_server.models.gpt_neox import GPTNeoxSharded try: - if torch.cuda.is_available(): + if ( + torch.cuda.is_available() + and not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false" + ): major, minor = torch.cuda.get_device_capability() is_sm75 = major == 7 and minor == 5 is_sm8x = major == 8 and minor >= 0 @@ -30,14 +34,12 @@ try: f"GPU with CUDA capability {major} {minor} is not supported" ) - from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded - from text_generation_server.models.flash_rw import FlashRW, FlashRWSharded + from text_generation_server.models.flash_rw import FlashRWSharded + from text_generation_server.models.flash_neox import FlashNeoXSharded from text_generation_server.models.flash_llama import ( FlashLlama, - FlashLlamaSharded, ) from text_generation_server.models.flash_santacoder import ( - FlashSantacoder, FlashSantacoderSharded, ) @@ -52,30 +54,22 @@ except ImportError: __all__ = [ "Model", - "BLOOM", "BLOOMSharded", "CausalLM", "FlashCausalLM", - "Galactica", "GalacticaSharded", - "GPTNeoxSharded", "Seq2SeqLM", "SantaCoder", - "OPT", "OPTSharded", "T5Sharded", "get_model", ] if FLASH_ATTENTION: - __all__.append(FlashNeoX) __all__.append(FlashNeoXSharded) - __all__.append(FlashRW) __all__.append(FlashRWSharded) - __all__.append(FlashSantacoder) __all__.append(FlashSantacoderSharded) __all__.append(FlashLlama) - __all__.append(FlashLlamaSharded) FLASH_ATT_ERROR_MESSAGE = ( "{} requires Flash Attention CUDA kernels to be installed.\n" @@ -102,36 +96,24 @@ def get_model( trust_remote_code: bool, ) -> Model: if "facebook/galactica" in model_id: - if sharded: - return GalacticaSharded( - model_id, - revision, - quantize=quantize, - trust_remote_code=trust_remote_code, - ) - else: - return Galactica( - model_id, - revision, - quantize=quantize, - trust_remote_code=trust_remote_code, - ) + return GalacticaSharded( + model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code + ) if model_id.startswith("bigcode/"): - if sharded: - if not FLASH_ATTENTION: - raise NotImplementedError( - FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder") - ) + if FLASH_ATTENTION: return FlashSantacoderSharded( model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code, ) + elif sharded: + raise NotImplementedError( + FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder") + ) else: - santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder - return santacoder_cls( + return SantaCoder( model_id, revision, quantize=quantize, @@ -144,20 +126,19 @@ def get_model( model_type = config_dict["model_type"] if model_type == "gpt_bigcode": - if sharded: - if not FLASH_ATTENTION: - raise NotImplementedError( - FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder") - ) + if FLASH_ATTENTION: return FlashSantacoderSharded( model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code, ) + elif sharded: + raise NotImplementedError( + FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder") + ) else: - santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder - return santacoder_cls( + return SantaCoder( model_id, revision, quantize=quantize, @@ -165,33 +146,45 @@ def get_model( ) if model_type == "bloom": - if sharded: - return BLOOMSharded( + return BLOOMSharded( + model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code + ) + + elif model_type == "gpt_neox": + if FLASH_ATTENTION: + return FlashNeoXSharded( + model_id, + revision, + quantize=quantize, + trust_remote_code=trust_remote_code, + ) + elif sharded: + return GPTNeoxSharded( model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code, ) else: - return BLOOM( + return CausalLM( model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code, ) - if model_type == "gpt_neox": - if sharded: - neox_cls = FlashNeoXSharded if FLASH_ATTENTION else GPTNeoxSharded - return neox_cls( + elif model_type == "llama": + if FLASH_ATTENTION: + return FlashLlama( model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code, ) + elif sharded: + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama")) else: - neox_cls = FlashNeoX if FLASH_ATTENTION else CausalLM - return neox_cls( + return CausalLM( model_id, revision, quantize=quantize, @@ -217,7 +210,7 @@ def get_model( ) else: if FLASH_ATTENTION and not config_dict.get("alibi", False): - return FlashRW( + return FlashRWSharded( model_id, revision, quantize=quantize, @@ -231,42 +224,12 @@ def get_model( trust_remote_code=trust_remote_code, ) - if model_type == "llama": - if sharded: - if FLASH_ATTENTION: - return FlashLlamaSharded( - model_id, - revision, - quantize=quantize, - trust_remote_code=trust_remote_code, - ) - raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Llama")) - else: - llama_cls = FlashLlama if FLASH_ATTENTION else CausalLM - return llama_cls( - model_id, - revision, - quantize=quantize, - trust_remote_code=trust_remote_code, - ) + elif model_type == "opt": + return OPTSharded( + model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code + ) - if model_type == "opt": - if sharded: - return OPTSharded( - model_id, - revision, - quantize=quantize, - trust_remote_code=trust_remote_code, - ) - else: - return OPT( - model_id, - revision, - quantize=quantize, - trust_remote_code=trust_remote_code, - ) - - if model_type == "t5": + elif model_type == "t5": if sharded: return T5Sharded( model_id, diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 45d7cd4c..50b3b76a 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -1,37 +1,26 @@ import torch import torch.distributed -from typing import List, Optional, Type +from typing import Optional, Type -from accelerate import init_empty_weights -from safetensors import safe_open from transformers import ( AutoTokenizer, - AutoModelForCausalLM, AutoConfig, PreTrainedTokenizerBase, ) -from transformers.models.bloom.parallel_layers import ( - TensorParallelColumnLinear, - TensorParallelEmbedding, - TensorParallelRowLinear, -) +from text_generation_server.models.custom_modeling.bloom_modeling import ( + BloomForCausalLM, +) from text_generation_server.models import CausalLM from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.pb import generate_pb2 from text_generation_server.utils import ( initialize_torch_distributed, weight_files, + Weights, ) -HAS_BITS_AND_BYTES = True -try: - import bitsandbytes as bnb - from bitsandbytes.nn import Int8Params -except Exception as e: - HAS_BITS_AND_BYTES = False - class BloomCausalLMBatch(CausalLMBatch): @classmethod @@ -42,34 +31,12 @@ class BloomCausalLMBatch(CausalLMBatch): dtype: torch.dtype, device: torch.device, ) -> "CausalLMBatch": - batch = super(BloomCausalLMBatch, cls).from_pb( - pb=pb, tokenizer=tokenizer, dtype=dtype, device=device - ) + batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device) batch.keys_head_dim_last = False return batch -class BLOOM(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - trust_remote_code: bool = False, - ): - super(BLOOM, self).__init__( - model_id=model_id, - revision=revision, - quantize=quantize, - trust_remote_code=trust_remote_code, - ) - - @property - def batch_type(self) -> Type[CausalLMBatch]: - return BloomCausalLMBatch - - -class BLOOMSharded(BLOOM): +class BLOOMSharded(CausalLM): def __init__( self, model_id: str, @@ -101,25 +68,16 @@ class BLOOMSharded(BLOOM): trust_remote_code=trust_remote_code, ) config.pad_token_id = 3 + config.quantize = quantize torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - - with init_empty_weights(): - model = AutoModelForCausalLM.from_config( - config, trust_remote_code=trust_remote_code - ) - - torch.distributed.barrier(group=self.process_group) - self.load_weights( - model, - filenames, - quantize=quantize, - device=device, - dtype=dtype, - rank=rank, - world_size=world_size, + weights = Weights( + filenames, device=device, dtype=dtype, process_group=self.process_group ) + + model = BloomForCausalLM(config, weights) + torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( model=model, @@ -131,132 +89,9 @@ class BLOOMSharded(BLOOM): world_size=world_size, ) - @staticmethod - def load_weights( - model, - filenames: List[str], - quantize: Optional[str], - device: torch.device, - dtype: torch.dtype, - rank: int, - world_size: int, - ): - parameters = dict(model.named_parameters()) - for file in filenames: - with safe_open( - file, framework="pt", device=str(device) if quantize is None else "cpu" - ) as f: - for name in f.keys(): - if name.startswith("transformer.") or name.startswith("lm_head."): - full_name = name - else: - full_name = f"transformer.{name}" - - module_name, param_name = full_name.rsplit(".", 1) - module = model.get_submodule(module_name) - current_tensor = parameters[full_name] - - slice_ = f.get_slice(name) - - if isinstance(module, TensorParallelColumnLinear): - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - elif isinstance(module, TensorParallelRowLinear): - if param_name == "weight": - size = slice_.get_shape()[1] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[:, start:stop] - else: - tensor = slice_[:] - # XXX: Hack for Rowlinear to add the bias only once. - if rank != 0: - tensor = torch.zeros_like(tensor) - elif ( - isinstance(module, TensorParallelEmbedding) - or name == "lm_head.weight" - ): - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - else: - tensor = slice_[:] - - if current_tensor.shape != tensor.shape: - raise ValueError( - f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}" - ) - - tensor = tensor.contiguous().to(dtype) - - if quantize == "bitsandbytes": - if not HAS_BITS_AND_BYTES: - raise ImportError( - "bitsandbytes is not available on your machine either because it is not installed " - "or you don't have a GPU.\n" - "You can install it with `pip install bitsandbytes`." - ) - - if ( - type(module) - in [TensorParallelRowLinear, TensorParallelColumnLinear] - and param_name == "weight" - ): - tensor = Int8Params( - tensor, - has_fp16_weights=False, - requires_grad=False, - ).to(device) - state = bnb.MatmulLtState() - state.threshold = 6.0 - state.has_fp16_weights = False - state.memory_efficient_backward = False - state.use_pool = True - state.CB = tensor.CB - state.SCB = tensor.SCB - tensor.CB = None - tensor.SCB = None - - def replace_linear(state): - def linear(input, weight, bias): - out = bnb.matmul( - input, - weight, - state=state, - threshold=state.threshold, - bias=bias, - ) - - if state.CB is not None: - # we converted 8-bit row major to turing/ampere format - # in the first inference pass - # we no longer need the row-major weight - del state.CB - weight.data = state.CxB - - return out - - return linear - - module.linear = replace_linear(state) - else: - tensor = tensor.to(device) - elif quantize == "gptq": - raise NotImplementedError("`gptq` is not implemented for now") - elif quantize is None: - tensor = tensor.to(device) - else: - raise ValueError(f"Unexpected quantize `{quantize}`") - - module._parameters[param_name] = tensor - if name == "word_embeddings.weight": - model.lm_head._parameters["weight"] = tensor + @property + def batch_type(self) -> Type[CausalLMBatch]: + return BloomCausalLMBatch def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None @@ -269,9 +104,5 @@ class BLOOMSharded(BLOOM): use_cache=True, ) - # Logits are sharded, so we need to gather them - logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)] - torch.distributed.all_gather(logits, outputs.logits, group=self.process_group) - logits = torch.cat(logits, dim=2) - + logits = outputs.logits return logits, outputs.past_key_values diff --git a/server/text_generation_server/models/custom_modeling/bloom_modeling.py b/server/text_generation_server/models/custom_modeling/bloom_modeling.py new file mode 100644 index 00000000..e5e87645 --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/bloom_modeling.py @@ -0,0 +1,912 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. team and BigScience workshop. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BLOOM model.""" + +import math +import os +import warnings +from typing import Optional, Tuple, Union + +import torch +import torch.distributed +import torch.utils.checkpoint +from torch import nn +from torch.nn import LayerNorm +from torch.nn import functional as F + +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, +) +from transformers import BloomConfig, PreTrainedModel + +from text_generation_server.utils.layers import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, + TensorParallelHead, +) + +CUSTOM_KERNELS_ENABLED = False +if not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True": + try: + from custom_kernels import fused_bloom_attention_cuda + + CUSTOM_KERNELS_ENABLED = True + except ImportError: + pass + +_CHECKPOINT_FOR_DOC = "bigscience/bloom-560m" +_CONFIG_FOR_DOC = "BloomConfig" + +BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "bigscience/bigscience-small-testing", + "bigscience/bloom-560m", + "bigscience/bloom-1b1", + "bigscience/bloom-1b7", + "bigscience/bloom-3b", + "bigscience/bloom-7b1", + "bigscience/bloom", +] + + +def _make_causal_mask( + input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int +) -> torch.BoolTensor: + """ + Make causal mask used for self-attention. + """ + batch_size, target_length = input_ids_shape + mask = torch.ones( + (target_length, target_length + past_key_values_length), + dtype=torch.bool, + device=device, + ) + mask = mask.triu(1 + past_key_values_length) + + expanded_mask = mask.unsqueeze(0).expand( + batch_size, target_length, target_length + past_key_values_length + ) + return expanded_mask + + +def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor: + """ + Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`. + """ + batch_size, src_length = mask.shape + tgt_length = tgt_length if tgt_length is not None else src_length + + expanded_mask = ~(mask[:, None, :].to(torch.bool)) + return expanded_mask.expand(batch_size, tgt_length, src_length) + + +def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int) -> torch.Tensor: + """ + Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it + relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value + `softmax(l+a) = softmax(l)`. Based on + https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 + TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly. + + Args: + Returns tensor shaped (batch_size * num_heads, 1, max_seq_len) + attention_mask (`torch.Tensor`): + Token-wise attention mask, this should be of shape (batch_size, max_seq_len). + num_heads (`int`, *required*): + number of heads + dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`): + dtype of the output tensor + """ + batch_size, seq_length = attention_mask.shape + closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) + base = torch.tensor( + 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), + device=attention_mask.device, + dtype=torch.float32, + ) + powers = torch.arange( + 1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32 + ) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != num_heads: + extra_base = torch.tensor( + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), + device=attention_mask.device, + dtype=torch.float32, + ) + num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) + extra_powers = torch.arange( + 1, + 1 + 2 * num_remaining_heads, + 2, + device=attention_mask.device, + dtype=torch.int32, + ) + slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) + + # Note: alibi will added to the attention bias that will be applied to the query, key product of attention + # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length) + # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length) + # => the query_length dimension will then be broadcasted correctly + # This is more or less identical to T5's relative position bias: + # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527 + arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :] + alibi = slopes[..., None] * arange_tensor + return alibi + + +# @torch.jit.script +def dropout_add( + x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool +) -> torch.Tensor: + """ + Dropout add function + + Args: + x (`torch.tensor`, *required*): + input tensor + residual (`torch.tensor`, *required*): + esidual tensor + prob (`float`, *required*): + dropout probability + training (`bool`, *required*): + training mode + """ + out = F.dropout(x, p=prob, training=training) + out = residual + out + return out + + +# @torch.jit.script # this is shit for unknow reasons. +def _split_heads( + fused_qkv: torch.Tensor, num_heads: int, head_dim: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory + storage as `fused_qkv` + + Args: + fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim] + + Returns: + query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim] + value: [batch_size, seq_length, num_heads, head_dim] + """ + batch_size, seq_length, three_times_hidden_size = fused_qkv.shape + fused_qkv = fused_qkv.view(batch_size, seq_length, num_heads, 3 * head_dim) + query_layer, key_layer, value_layer = fused_qkv.split(head_dim, dim=-1) + + query_layer = query_layer.transpose(1, 2).reshape( + batch_size * num_heads, seq_length, head_dim + ) + key_layer = key_layer.permute(0, 2, 3, 1).reshape( + batch_size * num_heads, head_dim, seq_length + ) + value_layer = value_layer.transpose(1, 2).reshape( + batch_size * num_heads, seq_length, head_dim + ) + + return query_layer, key_layer, value_layer + + +# @torch.jit.script +def _merge_heads(x: torch.Tensor, num_heads: int, head_dim: int) -> torch.Tensor: + """ + Merge heads together over the last dimenstion + + Args: + x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim] + + Returns: + torch.tensor: [batch_size, seq_length, num_heads * head_dim] + """ + # What we want to achieve is: + # batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim + batch_size_and_num_heads, seq_length, _ = x.shape + batch_size = batch_size_and_num_heads // num_heads + + # First view to decompose the batch size + # batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim + x = x.view(batch_size, num_heads, seq_length, head_dim) + + # batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim + x = x.permute(0, 2, 1, 3) + + # batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim + return x.reshape(batch_size, seq_length, num_heads * head_dim) + + +class BloomAttention(nn.Module): + def __init__(self, prefix, config: BloomConfig, weights): + super().__init__() + + self.pretraining_tp = config.pretraining_tp + self.slow_but_exact = config.slow_but_exact + + self.process_group = weights.process_group + + self.hidden_size = config.hidden_size + self.num_heads = config.n_head + self.head_dim = self.hidden_size // self.num_heads + self.split_size = self.hidden_size + self.hidden_dropout = config.hidden_dropout + + if self.head_dim * self.num_heads != self.hidden_size: + raise ValueError( + f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:" + f" {self.num_heads})." + ) + + # Layer-wise attention scaling + self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) + self.beta = 1.0 + + process_group = weights.process_group + self.num_heads = self.num_heads // process_group.size() + self.query_key_value = TensorParallelColumnLinear.load( + config=config, + prefix=f"{prefix}.query_key_value", + weights=weights, + bias=True, + ) + self.dense = TensorParallelRowLinear.load( + config=config, prefix=f"{prefix}.dense", weights=weights, bias=True + ) + self.attention_dropout = nn.Dropout(config.attention_dropout) + + @staticmethod + def compute_attention( + fused_qkv: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]], + alibi: torch.Tensor, + attention_mask: torch.Tensor, + head_mask: Optional[torch.Tensor], + beta: float, + inv_norm_factor: float, + num_heads: int, + use_cache: bool, + ): + batch_size, q_length, three_times_hidden_size = fused_qkv.shape + head_dim = three_times_hidden_size // (3 * num_heads) + batch_size * num_heads + + ### TODO @thomasw21: this takes quite a bit of time, how do I accelerate that? + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_layer, key_layer, value_layer) = _split_heads( + fused_qkv, num_heads=num_heads, head_dim=head_dim + ) + + if layer_past is not None: + past_key, past_value = layer_past + # concatenate along seq_length dimension: + # - key: [batch_size * self.num_heads, head_dim, kv_length] + # - value: [batch_size * self.num_heads, kv_length, head_dim] + past_key = past_key.view(-1, *past_key.shape[-2:]) + key_layer = torch.cat((past_key, key_layer), dim=2) + past_value = past_value.view(-1, *past_value.shape[-2:]) + value_layer = torch.cat((past_value, value_layer), dim=1) + + _, _, kv_length = key_layer.shape + + if use_cache is True: + present = (key_layer, value_layer) + else: + present = None + ### + + # [batch_size * num_heads, q_length, kv_length] + # we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11 + attention_scores = alibi.baddbmm( + batch1=query_layer, + batch2=key_layer, + beta=beta, + alpha=inv_norm_factor, + ) + + # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] + input_dtype = attention_scores.dtype + # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` + if input_dtype == torch.float16: + attention_scores = attention_scores.to(torch.float) + # torch.finfo not supported by torch.jit, we temporarily remplace with `-1e34` + attn_weights = attention_scores.masked_fill_( + attention_mask, torch.finfo(attention_scores.dtype).min + ) + attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + input_dtype + ) + + # # [batch_size, num_heads, q_length, kv_length] + # attention_probs = self.attention_dropout(attention_probs) + + if head_mask is not None: + attention_probs = attention_probs * head_mask + + # matmul: [batch_size * num_heads, q_length, head_dim] + context_layer = torch.bmm(attention_probs, value_layer, out=query_layer) + + # change view [batch_size, num_heads, q_length, head_dim] + context_layer = _merge_heads( + context_layer, num_heads=num_heads, head_dim=head_dim + ) + + return context_layer, present, attention_probs + + def forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + alibi: torch.Tensor, + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + fused_qkv = self.query_key_value( + hidden_states + ) # [batch_size, seq_length, 3 x hidden_size] + batch_size, q_length, _ = fused_qkv.shape + + if layer_past is not None: + past_key, past_value = layer_past + layer_past = ( + past_key.view(-1, *past_key.shape[-2:]), + past_value.view(-1, *past_value.shape[-2:]), + ) + + if CUSTOM_KERNELS_ENABLED: + assert self.training is False, "Only foward pass was implemented" + assert ( + attention_mask.shape[-1] < 4096 + ), "Custom kernel support only up to 4096 tokens" + ( + context_layer, + present, + attention_probs, + ) = fused_bloom_attention_cuda.forward( + fused_qkv, + layer_past, + alibi, + attention_mask, + head_mask, + self.beta, + self.inv_norm_factor, + self.num_heads, + use_cache, + ) + else: + context_layer, present, attention_probs = self.compute_attention( + fused_qkv=fused_qkv, + layer_past=layer_past, + alibi=alibi, + attention_mask=attention_mask, + head_mask=head_mask, + beta=self.beta, + inv_norm_factor=self.inv_norm_factor, + num_heads=self.num_heads, + use_cache=use_cache, + ) + + # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 + if self.pretraining_tp > 1 and self.slow_but_exact: + slices = self.hidden_size / self.pretraining_tp + output_tensor = torch.zeros_like(context_layer) + for i in range(self.pretraining_tp): + output_tensor = output_tensor + F.linear( + context_layer[:, :, int(i * slices) : int((i + 1) * slices)], + self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], + ) + else: + output_tensor = self.dense(context_layer) + + # output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training) + output_tensor += residual + + outputs = (output_tensor, present) + if output_attentions: + outputs += (attention_probs,) + + return outputs + + +class BloomMLP(nn.Module): + def __init__(self, prefix, config: BloomConfig, weights): + super().__init__() + + self.pretraining_tp = config.pretraining_tp + self.slow_but_exact = config.slow_but_exact + self.dense_h_to_4h = TensorParallelColumnLinear.load( + config=config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True + ) + self.dense_4h_to_h = TensorParallelRowLinear.load( + config=config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True + ) + self.gelu_impl = torch.nn.GELU(approximate="tanh") + self.hidden_dropout = config.hidden_dropout + + def forward( + self, hidden_states: torch.Tensor, residual: torch.Tensor + ) -> torch.Tensor: + hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states)) + + if self.pretraining_tp > 1 and self.slow_but_exact: + intermediate_output = torch.zeros_like(residual) + slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp + for i in range(self.pretraining_tp): + intermediate_output = intermediate_output + F.linear( + hidden_states[:, :, int(i * slices) : int((i + 1) * slices)], + self.dense_4h_to_h.weight[ + :, int(i * slices) : int((i + 1) * slices) + ], + ) + else: + intermediate_output = self.dense_4h_to_h(hidden_states) + + # output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training) + intermediate_output += residual + + return intermediate_output + + +class BloomBlock(nn.Module): + def __init__(self, layer_id: int, config: BloomConfig, weights): + super().__init__() + + prefix = f"h.{layer_id}" + self.input_layernorm = LayerNorm.load( + prefix=f"{prefix}.input_layernorm", + weights=weights, + eps=config.layer_norm_epsilon, + ) + self.num_heads = config.n_head + self.self_attention = BloomAttention( + prefix=f"{prefix}.self_attention", config=config, weights=weights + ) + self.post_attention_layernorm = LayerNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.layer_norm_epsilon, + ) + + self.mlp = BloomMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + self.apply_residual_connection_post_layernorm = ( + config.apply_residual_connection_post_layernorm + ) + self.hidden_dropout = config.hidden_dropout + + def forward( + self, + hidden_states: torch.Tensor, + alibi: torch.Tensor, + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + # hidden_states: [batch_size, seq_length, hidden_size] + + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + + # Layer norm post the self attention. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + # Self attention. + attn_outputs = self.self_attention( + layernorm_output, + residual, + layer_past=layer_past, + attention_mask=attention_mask, + alibi=alibi, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + attention_output = attn_outputs[0] + + outputs = attn_outputs[1:] + + layernorm_output = self.post_attention_layernorm(attention_output) + + # Get residual + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = attention_output + + # MLP. + output = self.mlp(layernorm_output, residual) + + if use_cache: + outputs = (output,) + outputs + else: + outputs = (output,) + outputs[1:] + + return outputs # hidden_states, present, attentions + + +class BloomPreTrainedModel(PreTrainedModel): + config_class = BloomConfig + base_model_prefix = "transformer" + _no_split_modules = ["BloomBlock"] + + @staticmethod + def _convert_to_standard_cache( + past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]: + """ + Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size, + num_heads, ...])) + """ + batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape + num_heads = batch_size_times_num_heads // batch_size + # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length] + # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim] + return tuple( + ( + layer_past[0].view(batch_size, num_heads, head_dim, seq_length), + layer_past[1].view(batch_size, num_heads, seq_length, head_dim), + ) + for layer_past in past_key_value + ) + + @staticmethod + def _convert_to_bloom_cache( + past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]] + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]: + """ + Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...])) + """ + batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape + batch_size_times_num_heads = batch_size * num_heads + # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length] + # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim] + return tuple( + ( + layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length), + layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim), + ) + for layer_past in past_key_value + ) + + +class BloomModel(BloomPreTrainedModel): + def __init__(self, config: BloomConfig, weights): + super().__init__(config) + + self.embed_dim = config.hidden_size + self.num_heads = config.n_head + + process_group = weights.process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + + self.word_embeddings = TensorParallelEmbedding( + prefix="word_embeddings", weights=weights + ) + + self.word_embeddings_layernorm = LayerNorm.load( + prefix="word_embeddings_layernorm", + weights=weights, + eps=config.layer_norm_epsilon, + ) + + # Transformer blocks + self.h = nn.ModuleList( + [ + BloomBlock(layer_id=layer_id, config=config, weights=weights) + for layer_id in range(config.num_hidden_layers) + ] + ) + + # Final Layer Norm + self.ln_f = LayerNorm.load( + prefix="ln_f", weights=weights, eps=config.layer_norm_epsilon + ) + + def _prepare_attn_mask( + self, + attention_mask: torch.Tensor, + input_shape: Tuple[int, int], + past_key_values_length: int, + ) -> torch.BoolTensor: + # create causal mask + # [batch_size, seq_length] -> [batch_size, tgt_length, src_length] + combined_attention_mask = None + device = attention_mask.device + _, src_length = input_shape + + if src_length > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + device=device, + past_key_values_length=past_key_values_length, + ) + + # [batch_size, seq_length] -> [batch_size, tgt_length, src_length] + expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length) + combined_attention_mask = ( + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask | combined_attention_mask + ) + + return combined_attention_mask + + def set_input_embeddings(self, new_embeddings: torch.Tensor): + self.word_embeddings = new_embeddings + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **deprecated_arguments, + ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: + if deprecated_arguments.pop("position_ids", False) is not False: + # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` + warnings.warn( + "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" + " passing `position_ids`.", + FutureWarning, + ) + if len(deprecated_arguments) > 0: + raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if past_key_values is None: + past_key_values = tuple([None] * len(self.h)) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape batch_size x num_heads x N x N + # head_mask has shape n_layer x batch x num_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + hidden_states = self.word_embeddings_layernorm(inputs_embeds) + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + # Compute alibi tensor: check build_alibi_tensor documentation + seq_length_with_past = seq_length + past_key_values_length = 0 + if past_key_values[0] is not None: + past_key_values_length = past_key_values[0][0].shape[-1] + seq_length_with_past = seq_length_with_past + past_key_values_length + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), device=hidden_states.device + ) + else: + attention_mask = attention_mask.to(hidden_states.device) + + alibi = build_alibi_tensor(attention_mask, self.num_heads) + + causal_mask = self._prepare_attn_mask( + attention_mask, + input_shape=(batch_size, seq_length), + past_key_values_length=past_key_values_length, + ) + + if hasattr(self, "tp_rank"): + assert self.num_heads % self.tp_world_size == 0 + block_size = self.num_heads // self.tp_world_size + alibi = alibi[ + :, self.tp_rank * block_size : (self.tp_rank + 1) * block_size + ] + alibi = alibi.reshape(batch_size * block_size, 1, seq_length_with_past) + causal_mask = torch.repeat_interleave(causal_mask, block_size, dim=0) + else: + alibi = alibi.reshape(batch_size * self.num_heads, 1, seq_length_with_past) + causal_mask = torch.repeat_interleave(causal_mask, self.num_heads, dim=0) + + alibi = alibi.to(hidden_states.dtype) + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=causal_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + alibi=alibi, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + ( + outputs[2 if use_cache else 1], + ) + + # Add last hidden state + hidden_states = self.ln_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + ] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class BloomForCausalLM(BloomPreTrainedModel): + def __init__(self, config, weights): + super().__init__(config) + self.transformer = BloomModel(config, weights) + + self.lm_head = TensorParallelHead.load( + config, + prefix="word_embeddings", + weights=weights, + ) + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> dict: + # only last token for input_ids if past is not None + if past_key_values: + input_ids = input_ids[:, -1].unsqueeze(-1) + + # the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed + if past_key_values[0][0].shape[0] == input_ids.shape[0]: + past_key_values = self._convert_to_bloom_cache(past_key_values) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **deprecated_arguments, + ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + if deprecated_arguments.pop("position_ids", False) is not False: + # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` + warnings.warn( + "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" + " passing `position_ids`.", + FutureWarning, + ) + if len(deprecated_arguments) > 0: + raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + lm_logits = self.lm_head(hidden_states) + loss = None + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index f4116937..8a35ffa8 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -30,21 +30,23 @@ import flash_attn_cuda import dropout_layer_norm from text_generation_server.utils.layers import ( - FastLinear, TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, PositionRotaryEmbedding, + TensorParallelHead, ) class LlamaRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): + def __init__(self, prefix, weights, eps=1e-6): """ LlamaRMSNorm is equivalent to T5LayerNorm """ super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) + + weight = weights.get_tensor(f"{prefix}.weight") + self.weight = nn.Parameter(weight) self.variance_epsilon = eps def forward(self, hidden_states, residual=None): @@ -91,35 +93,35 @@ class LlamaRMSNorm(nn.Module): class FlashLlamaAttention(torch.nn.Module): def __init__( self, - num_heads, - hidden_size, - process_group=None, + prefix: str, + config, + weights, ): super().__init__() - self.num_heads = num_heads - self.hidden_size = hidden_size - self.head_size = hidden_size // num_heads + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_heads + + self.rotary_emb = PositionRotaryEmbedding.load( + prefix=f"{prefix}.rotary_emb", weights=weights + ) - self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000) self.softmax_scale = self.head_size ** (-0.5) - if process_group is None: - self.query_key_value = FastLinear(hidden_size, 3 * hidden_size, bias=False) - self.o_proj = FastLinear(hidden_size, hidden_size, bias=False) - else: - self.num_heads = self.num_heads // process_group.size() - self.query_key_value = TensorParallelColumnLinear( - hidden_size, - 3 * hidden_size, - bias=False, - process_group=process_group, - ) - self.o_proj = TensorParallelRowLinear( - hidden_size, - hidden_size, - bias=False, - process_group=process_group, - ) + self.num_heads = self.num_heads // weights.process_group.size() + self.query_key_value = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) def forward( self, @@ -195,8 +197,9 @@ class FlashLlamaAttention(torch.nn.Module): class LlamaMLP(nn.Module): - def __init__(self, act, hidden_size, intermediate_size, process_group=None): + def __init__(self, prefix, config, weights): super().__init__() + act = config.hidden_act self.act = ( ACT2FN[act] if "gelu" not in act @@ -207,32 +210,23 @@ class LlamaMLP(nn.Module): else "none", ) ) - - if process_group is None: - # Fuse gate and up proj - self.gate_up_proj = FastLinear( - hidden_size, 2 * intermediate_size, bias=False - ) - self.down_proj = FastLinear(intermediate_size, hidden_size, bias=False) - self.intermediate_size = intermediate_size - else: - # Fuse gate and up proj - self.gate_up_proj = TensorParallelColumnLinear( - hidden_size, - 2 * intermediate_size, - bias=False, - process_group=process_group, - ) - self.down_proj = TensorParallelRowLinear( - intermediate_size, - hidden_size, - bias=False, - process_group=process_group, - reduce=True, - ) - self.intermediate_size = self.down_proj.in_features - - self.process_group = process_group + # Fuse gate and up proj + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + self.intermediate_size = ( + config.intermediate_size // weights.process_group.size() + ) def forward(self, hidden_states): gate_up_states = self.gate_up_proj(hidden_states) @@ -241,22 +235,22 @@ class LlamaMLP(nn.Module): class FlashLlamaLayer(nn.Module): - def __init__( - self, - num_heads, - act, - hidden_size, - intermediate_size, - rms_norm_eps, - process_group=None, - ): + def __init__(self, layer_id, config, weights): super().__init__() + prefix = f"model.layers.{layer_id}" + self.self_attn = FlashLlamaAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) - self.self_attn = FlashLlamaAttention(num_heads, hidden_size, process_group) - self.mlp = LlamaMLP(act, hidden_size, intermediate_size, process_group) - - self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps) - self.post_attention_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps) + self.input_layernorm = LlamaRMSNorm( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = LlamaRMSNorm( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) def forward( self, @@ -295,54 +289,35 @@ class FlashLlamaLayer(nn.Module): class FlashLlamaModel(torch.nn.Module): - def __init__(self, config, process_group=None): - super(FlashLlamaModel, self).__init__() + def __init__(self, config, weights): + super().__init__() self.config = config - self.tp_embeddings = False - if process_group is not None: - self.tp_rank = process_group.rank() - self.tp_world_size = process_group.size() - if config.vocab_size % self.tp_world_size == 0: - self.tp_embeddings = True - - if self.tp_embeddings: - self.embed_tokens = TensorParallelEmbedding( - config.vocab_size, config.hidden_size, process_group=process_group - ) - else: - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) - + process_group = weights.process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + self.embed_tokens = TensorParallelEmbedding( + prefix="model.embed_tokens", weights=weights + ) self.layers = nn.ModuleList( [ FlashLlamaLayer( - config.num_attention_heads, - config.hidden_act, - config.hidden_size, - config.intermediate_size, - config.rms_norm_eps, - process_group, + layer_id, + config, + weights, ) - for _ in range(config.num_hidden_layers) + for layer_id in range(config.num_hidden_layers) ] ) - self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = LlamaRMSNorm( + prefix="model.norm", weights=weights, eps=config.rms_norm_eps + ) self.gradient_checkpointing = False self.head_size = self.layers[0].self_attn.head_size self.num_heads = self.layers[0].self_attn.num_heads - def post_load_weights(self, quantize: Optional[str] = None): - if isinstance(self.embed_tokens, TensorParallelEmbedding): - self.embed_tokens.add_null_idx() - for layer in self.layers: - layer: FlashLlamaLayer - layer.self_attn.query_key_value.prepare_weights(quantize) - layer.self_attn.o_proj.prepare_weights(quantize) - layer.mlp.gate_up_proj.prepare_weights(quantize) - layer.mlp.down_proj.prepare_weights(quantize) - def forward( self, input_ids, @@ -410,29 +385,15 @@ class FlashLlamaModel(torch.nn.Module): class FlashLlamaForCausalLM(torch.nn.Module): - def __init__(self, config, process_group=None): + def __init__(self, config, weights): super().__init__() - self.process_group = process_group - if self.process_group is not None: - self.world_size = self.process_group.size() - else: - self.world_size = 1 - - self.model = FlashLlamaModel(config, process_group) - - if self.model.tp_embeddings: - self.lm_head = FastLinear( - config.hidden_size, - config.vocab_size // process_group.size(), - bias=False, - ) - else: - self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False) - - def post_load_weights(self, quantize: Optional[str] = None): - self.model.post_load_weights(quantize) - self.lm_head.prepare_weights() + self.model = FlashLlamaModel(config, weights) + self.lm_head = TensorParallelHead.load( + config, + prefix="lm_head", + weights=weights, + ) def forward( self, @@ -457,12 +418,4 @@ class FlashLlamaForCausalLM(torch.nn.Module): if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits = self.lm_head(hidden_states) - - if self.model.tp_embeddings: - # Logits are sharded, so we need to gather them - world_logits = [torch.empty_like(logits) for _ in range(self.world_size)] - torch.distributed.all_gather(world_logits, logits, group=self.process_group) - world_logits = torch.cat(world_logits, dim=1) - - return world_logits, present return logits, present diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index b798750a..0fe43bcb 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -31,61 +31,81 @@ from typing import Optional import flash_attn_cuda from text_generation_server.utils.layers import ( - FastLinear, TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, + TensorParallelHead, FastLayerNorm, PositionRotaryEmbedding, + get_linear, ) +def load_row(config, prefix: str, weights, bias: bool): + weight = weights.get_sharded(f"{prefix}.weight", dim=1) + if bias and weights.process_group.rank() == 0: + # Rank is only on the first rank process + bias = weights.get_tensor(f"{prefix}.bias") + else: + bias = None + + linear = get_linear(weight, bias, config.quantize) + if config.use_parallel_residual: + return linear + else: + return TensorParallelRowLinear(linear, process_group=weights.process_group) + + +def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size): + weight = weights.get_sharded(f"{prefix}.weight", dim=0) + bias = weights.get_sharded(f"{prefix}.bias", dim=0) + + weight = ( + weight.view( + num_heads, + 3, + head_size, + hidden_size, + ) + .permute(1, 0, 2, 3) + .reshape(-1, hidden_size) + ) + bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1) + + linear = get_linear(weight, bias, config.quantize) + if config.use_parallel_residual: + return linear + else: + return TensorParallelColumnLinear(linear) + + class FlashNeoxAttention(torch.nn.Module): - def __init__( - self, - num_heads, - hidden_size, - rotary_pct, - rotary_emb_base, - process_group=None, - reduce=True, - ): + def __init__(self, config, prefix, weights): super().__init__() + num_heads = config.num_attention_heads + hidden_size = config.hidden_size + self.num_heads = num_heads self.hidden_size = hidden_size self.head_size = hidden_size // num_heads + self.num_heads = self.num_heads // weights.process_group.size() + + self.rotary_emb = PositionRotaryEmbedding.load( + prefix=f"{prefix}.rotary_emb", weights=weights + ) - rotary_ndims = int(self.head_size * rotary_pct) - self.rotary_emb = PositionRotaryEmbedding(rotary_ndims, base=rotary_emb_base) self.softmax_scale = self.head_size ** (-0.5) - if process_group is None: - self.query_key_value = FastLinear(hidden_size, 3 * hidden_size) - self.dense = FastLinear(hidden_size, hidden_size) - else: - self.num_heads = self.num_heads // process_group.size() - self.query_key_value = TensorParallelColumnLinear( - hidden_size, - 3 * hidden_size, - process_group=process_group, - ) - self.dense = TensorParallelRowLinear( - hidden_size, hidden_size, process_group=process_group, reduce=reduce - ) - - def shuffle_qkv_dims(self): - """Swap dims to avoid an additional permute""" - self.query_key_value.weight = torch.nn.Parameter( - self.query_key_value.weight.view( - self.num_heads, 3, self.head_size, self.hidden_size - ) - .permute(1, 0, 2, 3) - .reshape(-1, self.hidden_size) + self.query_key_value = load_qkv( + config, + prefix=f"{prefix}.query_key_value", + weights=weights, + num_heads=self.num_heads, + head_size=self.head_size, + hidden_size=self.hidden_size, ) - self.query_key_value.bias = torch.nn.Parameter( - self.query_key_value.bias.view(self.num_heads, 3, self.head_size) - .permute(1, 0, 2) - .reshape(-1) + self.dense = load_row( + config, prefix=f"{prefix}.dense", weights=weights, bias=True ) def forward( @@ -162,10 +182,9 @@ class FlashNeoxAttention(torch.nn.Module): class FlashMLP(nn.Module): - def __init__( - self, act, hidden_size, intermediate_size, process_group=None, reduce=True - ): + def __init__(self, config, prefix, weights): super().__init__() + act = config.hidden_act self.act = ( ACT2FN[act] if "gelu" not in act @@ -177,22 +196,12 @@ class FlashMLP(nn.Module): ) ) - if process_group is None: - self.dense_h_to_4h = FastLinear(hidden_size, intermediate_size) - self.dense_4h_to_h = FastLinear(intermediate_size, hidden_size) - else: - self.dense_h_to_4h = TensorParallelColumnLinear( - hidden_size, - intermediate_size, - process_group=process_group, - ) - self.dense_4h_to_h = TensorParallelRowLinear( - intermediate_size, - hidden_size, - process_group=process_group, - reduce=reduce, - ) - self.process_group = process_group + self.dense_h_to_4h = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True + ) + self.dense_4h_to_h = load_row( + config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True + ) def forward(self, hidden_states): hidden_states = self.dense_h_to_4h(hidden_states) @@ -202,38 +211,28 @@ class FlashMLP(nn.Module): class FlashNeoXLayer(nn.Module): - def __init__( - self, - num_heads, - act, - hidden_size, - intermediate_size, - rotary_pct, - rotary_emb_base, - layer_norm_eps, - use_parallel_residual, - process_group=None, - ): + def __init__(self, layer_id, config, weights): super().__init__() - self.use_parallel_residual = use_parallel_residual - self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps) - self.post_attention_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps) + + layer_norm_eps = config.layer_norm_eps + + prefix = f"gpt_neox.layers.{layer_id}" + + self.use_parallel_residual = config.use_parallel_residual + self.input_layernorm = FastLayerNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=layer_norm_eps + ) + self.post_attention_layernorm = FastLayerNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=layer_norm_eps, + ) self.attention = FlashNeoxAttention( - num_heads, - hidden_size, - rotary_pct, - rotary_emb_base, - process_group, - reduce=not use_parallel_residual, + config, prefix=f"{prefix}.attention", weights=weights ) - self.mlp = FlashMLP( - act, - hidden_size, - intermediate_size, - process_group, - reduce=not use_parallel_residual, - ) - self.process_group = process_group + + self.mlp = FlashMLP(config, prefix=f"{prefix}.mlp", weights=weights) + self.process_group = weights.process_group def forward( self, @@ -266,9 +265,7 @@ class FlashNeoXLayer(nn.Module): mlp_output = self.mlp(ln2_hidden_states) intermediate = mlp_output + attn_output - # Only reduce once and after the addition instead of once per layer - if self.process_group is not None: - torch.distributed.all_reduce(intermediate, group=self.process_group) + torch.distributed.all_reduce(intermediate, group=self.process_group) return intermediate + hidden_states, None else: @@ -302,42 +299,24 @@ class FlashGPTNeoXPreTrainedModel(PreTrainedModel): class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): - def __init__(self, config, process_group=None): + def __init__(self, config, weights): super().__init__(config) self.config = config - self.tp_embeddings = False - if process_group is not None: - self.tp_rank = process_group.rank() - self.tp_world_size = process_group.size() - if config.vocab_size % self.tp_world_size == 0: - self.tp_embeddings = True - - if self.tp_embeddings: - self.embed_in = TensorParallelEmbedding( - config.vocab_size, config.hidden_size, process_group=process_group - ) - else: - self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size) + self.embed_in = TensorParallelEmbedding( + prefix="gpt_neox.embed_in", weights=weights + ) self.layers = nn.ModuleList( [ - FlashNeoXLayer( - config.num_attention_heads, - config.hidden_act, - config.hidden_size, - config.intermediate_size, - config.rotary_pct, - config.rotary_emb_base, - config.layer_norm_eps, - config.use_parallel_residual, - process_group, - ) - for _ in range(config.num_hidden_layers) + FlashNeoXLayer(layer_id, config, weights) + for layer_id in range(config.num_hidden_layers) ] ) - self.final_layer_norm = FastLayerNorm( - config.hidden_size, eps=config.layer_norm_eps + self.final_layer_norm = FastLayerNorm.load( + prefix="gpt_neox.final_layer_norm", + weights=weights, + eps=config.layer_norm_eps, ) self.gradient_checkpointing = False @@ -345,29 +324,6 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): self.head_size = self.layers[0].attention.head_size self.num_heads = self.layers[0].attention.num_heads - def post_load_weights(self, quantize: Optional[str] = None): - if isinstance(self.embed_in, TensorParallelEmbedding): - self.embed_in.add_null_idx() - for layer in self.layers: - layer: FlashNeoXLayer - layer.attention.shuffle_qkv_dims() - layer.attention.query_key_value.prepare_weights(quantize) - layer.attention.dense.prepare_weights(quantize) - layer.mlp.dense_h_to_4h.prepare_weights(quantize) - layer.mlp.dense_4h_to_h.prepare_weights(quantize) - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - # Pop here as we will replace the layer in our own logic and don't want from_pretrained - # to do it for us - load_in_8bit = kwargs.pop("load_in_8bit", False) - model = super(FlashGPTNeoXModel, cls).from_pretrained( - pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs - ) - - model.post_load_weights("bitsandbytes" if load_in_8bit else None) - return model - def forward( self, input_ids, @@ -435,42 +391,13 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): - def __init__(self, config, process_group=None): + def __init__(self, config, weights): super().__init__(config) + self.gpt_neox = FlashGPTNeoXModel(config, weights) - self.process_group = process_group - if self.process_group is not None: - self.world_size = self.process_group.size() - else: - self.world_size = 1 - - self.gpt_neox = FlashGPTNeoXModel(config, process_group) - - if self.gpt_neox.tp_embeddings: - self.embed_out = FastLinear( - config.hidden_size, - config.vocab_size // process_group.size(), - bias=False, - ) - else: - self.embed_out = FastLinear( - config.hidden_size, config.vocab_size, bias=False - ) - - def post_load_weights(self, quantize: Optional[str] = None): - self.gpt_neox.post_load_weights(quantize) - self.embed_out.prepare_weights() - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - # Pop here as we will replace the layer in our own logic and don't want from_pretrained - # to do it for us - load_in_8bit = kwargs.pop("load_in_8bit", False) - model = super(FlashGPTNeoXForCausalLM, cls).from_pretrained( - pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs + self.embed_out = TensorParallelHead.load( + config, prefix="embed_out", weights=weights ) - model.post_load_weights("bitsandbytes" if load_in_8bit else None) - return model def forward( self, @@ -495,12 +422,4 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits = self.embed_out(hidden_states) - - if self.gpt_neox.tp_embeddings: - # Logits are sharded, so we need to gather them - world_logits = [torch.empty_like(logits) for _ in range(self.world_size)] - torch.distributed.all_gather(world_logits, logits, group=self.process_group) - world_logits = torch.cat(world_logits, dim=1) - - return world_logits, present return logits, present diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 03487703..55195162 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -1,5 +1,3 @@ -import os - import torch import torch.distributed @@ -12,15 +10,31 @@ from typing import Optional import flash_attn_cuda from text_generation_server.utils.layers import ( - FastLinear, TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, + TensorParallelHead, FastLayerNorm, PositionRotaryEmbedding, + get_linear, ) +def load_row(config, prefix: str, weights, bias: bool): + weight = weights.get_sharded(f"{prefix}.weight", dim=1) + if bias and weights.process_group.rank() == 0: + # Rank is only on the first rank process + bias = weights.get_tensor(f"{prefix}.bias") + else: + bias = None + + linear = get_linear(weight, bias, config.quantize) + if config.parallel_attn: + return linear + else: + return TensorParallelRowLinear(linear, process_group=weights.process_group) + + class RWConfig(PretrainedConfig): attribute_map = { "num_hidden_layers": "n_layer", @@ -85,44 +99,31 @@ class RWConfig(PretrainedConfig): class FlashRWAttention(torch.nn.Module): def __init__( self, - num_heads, - num_heads_kv, - hidden_size, - bias, - process_group=None, - reduce=True, + config, + prefix, + weights, ): super().__init__() - self.num_heads = num_heads - self.num_heads_kv = num_heads_kv - self.hidden_size = hidden_size - self.head_size = hidden_size // num_heads + self.num_heads = config.n_head + self.num_heads_kv = config.n_head_kv + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_heads - self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000) + self.rotary_emb = PositionRotaryEmbedding.static( + dim=self.head_size, base=10000.0, device=weights.device + ) self.softmax_scale = self.head_size ** (-0.5) + self.num_heads = self.num_heads // weights.process_group.size() - if process_group is None: - self.query_key_value = FastLinear( - hidden_size, - self.head_size * (self.num_heads + 2 * self.num_heads_kv), - bias=bias, - ) - self.dense = FastLinear(hidden_size, hidden_size, bias=bias) - else: - self.query_key_value = TensorParallelColumnLinear( - hidden_size, - self.head_size * (self.num_heads + 2 * self.num_heads_kv), - bias=bias, - process_group=process_group, - ) - self.dense = TensorParallelRowLinear( - hidden_size, - hidden_size, - bias=bias, - process_group=process_group, - reduce=reduce, - ) - self.num_heads = self.num_heads // process_group.size() + self.query_key_value = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.query_key_value", + weights=weights, + bias=config.bias, + ) + self.dense = load_row( + config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias + ) def forward( self, @@ -212,57 +213,48 @@ class FlashRWAttention(torch.nn.Module): class FlashRWLargeAttention(torch.nn.Module): def __init__( self, - num_heads, - num_heads_kv, - hidden_size, - bias, - process_group=None, - reduce=True, + config, + prefix, + weights, ): super().__init__() + hidden_size = config.hidden_size + num_heads = config.n_head + num_heads_kv = config.n_head_kv + self.hidden_size = hidden_size self.head_size = hidden_size // num_heads - self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000) + self.rotary_emb = PositionRotaryEmbedding.static( + self.head_size, base=10000.0, device=weights.device + ) self.softmax_scale = self.head_size ** (-0.5) self.num_groups = num_heads // (num_heads_kv * 2) self.num_heads = num_heads // self.num_groups self.num_heads_kv = num_heads_kv // self.num_groups + process_group = weights.process_group - if process_group is None: - self.query_key_value = FastLinear( - hidden_size, - self.num_groups - * self.head_size - * (self.num_heads + 2 * self.num_heads_kv), - bias=bias, + if process_group.size() > self.num_groups: + raise NotImplementedError( + f"Tensor Parallelism is not implemented for world_size > n groups" ) - self.dense = FastLinear(hidden_size, hidden_size, bias=bias) - else: - if process_group.size() > self.num_groups: - raise NotImplementedError( - f"Tensor Parallelism is not implemented for world_size > n groups" - ) + if self.num_groups % process_group.size() != 0: + raise NotImplementedError( + f"Tensor Parallelism is not implemented for {self.num_groups} not divisible by {process_group.size()}" + ) + self.num_groups = self.num_groups // process_group.size() - self.query_key_value = TensorParallelColumnLinear( - hidden_size, - self.num_groups - * self.head_size - * (self.num_heads + 2 * self.num_heads_kv), - bias=bias, - process_group=process_group, - ) - self.dense = TensorParallelRowLinear( - hidden_size, - hidden_size, - bias=bias, - process_group=process_group, - reduce=reduce, - ) - - self.num_groups = self.num_groups // process_group.size() + self.query_key_value = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.query_key_value", + weights=weights, + bias=config.bias, + ) + self.dense = load_row( + config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias + ) def forward( self, @@ -359,28 +351,16 @@ class FlashRWLargeAttention(torch.nn.Module): class FlashMLP(nn.Module): - def __init__(self, hidden_size, bias, process_group=None, reduce=True): + def __init__(self, config, prefix, weights): super().__init__() self.act = torch.nn.functional.gelu - if process_group is None: - self.dense_h_to_4h = FastLinear(hidden_size, 4 * hidden_size, bias=bias) - self.dense_4h_to_h = FastLinear(4 * hidden_size, hidden_size, bias=bias) - else: - self.dense_h_to_4h = TensorParallelColumnLinear( - hidden_size, - 4 * hidden_size, - bias=bias, - process_group=process_group, - ) - self.dense_4h_to_h = TensorParallelRowLinear( - 4 * hidden_size, - hidden_size, - bias=bias, - process_group=process_group, - reduce=reduce, - ) - self.process_group = process_group + self.dense_h_to_4h = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=config.bias + ) + self.dense_4h_to_h = load_row( + config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=config.bias + ) def forward(self, hidden_states): hidden_states = self.dense_h_to_4h(hidden_states) @@ -392,38 +372,44 @@ class FlashMLP(nn.Module): class FlashRWLayer(nn.Module): def __init__( self, - num_heads, - num_heads_kv, - hidden_size, - bias, - layer_norm_eps, - parallel_attn, - process_group=None, + layer_id, + config, + weights, ): super().__init__() + parallel_attn = config.parallel_attn self.parallel_attn = parallel_attn - self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps) + prefix = f"transformer.h.{layer_id}" + + self.input_layernorm = FastLayerNorm.load( + prefix=f"{prefix}.input_layernorm", + weights=weights, + eps=config.layer_norm_epsilon, + ) self.self_attention = FlashRWAttention( - num_heads, - num_heads_kv, - hidden_size, - bias, - process_group=process_group, - reduce=False, + config, + prefix=f"{prefix}.self_attention", + weights=weights, ) self.post_attention_layernorm = ( - FastLayerNorm(hidden_size, eps=layer_norm_eps) + FastLayerNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.layer_norm_epsilon, + ) if not parallel_attn else None ) self.mlp = FlashMLP( - hidden_size, bias, process_group=process_group, reduce=False + config, + prefix=f"{prefix}.mlp", + weights=weights, ) - self.process_group = process_group + self.process_group = weights.process_group def forward( self, @@ -454,9 +440,7 @@ class FlashRWLayer(nn.Module): mlp_output = self.mlp(ln_hidden_states) intermediate = mlp_output + attn_output - # Only reduce once and after the addition instead of once per layer - if self.process_group is not None: - torch.distributed.all_reduce(intermediate, group=self.process_group) + torch.distributed.all_reduce(intermediate, group=self.process_group) return intermediate, residual else: @@ -483,33 +467,30 @@ class FlashRWLayer(nn.Module): class FlashRWLargeLayer(nn.Module): - def __init__( - self, - num_heads, - num_heads_kv, - hidden_size, - bias, - layer_norm_eps, - process_group=None, - ): + def __init__(self, layer_id, config, weights): super().__init__() - self.ln_attn = FastLayerNorm(hidden_size, eps=layer_norm_eps) - self.ln_mlp = FastLayerNorm(hidden_size, eps=layer_norm_eps) + prefix = f"transformer.h.{layer_id}" + self.ln_attn = FastLayerNorm.load( + prefix=f"{prefix}.ln_attn", + weights=weights, + eps=config.layer_norm_epsilon, + ) + self.ln_mlp = FastLayerNorm.load( + prefix=f"{prefix}.ln_mlp", + weights=weights, + eps=config.layer_norm_epsilon, + ) self.self_attention = FlashRWLargeAttention( - num_heads, - num_heads_kv, - hidden_size, - bias, - process_group=process_group, - reduce=False, + config, + prefix=f"{prefix}.self_attention", + weights=weights, ) + assert config.parallel_attn, "This version doesn't support non parallel_attn" - self.mlp = FlashMLP( - hidden_size, bias, process_group=process_group, reduce=False - ) + self.mlp = FlashMLP(config, prefix=f"{prefix}.mlp", weights=weights) - self.process_group = process_group + self.process_group = weights.process_group def forward( self, @@ -543,9 +524,7 @@ class FlashRWLargeLayer(nn.Module): intermediate = attn_output + mlp_output - # Only reduce once and after the addition instead of once per layer - if self.process_group is not None: - torch.distributed.all_reduce(intermediate, group=self.process_group) + torch.distributed.all_reduce(intermediate, group=self.process_group) return intermediate, residual @@ -555,37 +534,18 @@ class FlashRWPreTrainedModel(PreTrainedModel): class FlashRWModel(FlashRWPreTrainedModel): - def __init__(self, config, process_group=None): + def __init__(self, config, weights): super().__init__(config) self.config = config - self.tp_embeddings = False - if process_group is not None: - self.tp_rank = process_group.rank() - self.tp_world_size = process_group.size() - if config.vocab_size % self.tp_world_size == 0: - self.tp_embeddings = True - - if self.tp_embeddings: - self.word_embeddings = TensorParallelEmbedding( - config.vocab_size, config.hidden_size, process_group=process_group - ) - else: - self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) - + self.word_embeddings = TensorParallelEmbedding( + prefix="transformer.word_embeddings", weights=weights + ) if config.model_type == "RefinedWebModel": self.h = nn.ModuleList( [ - FlashRWLayer( - config.n_head, - config.n_head_kv, - config.hidden_size, - config.bias, - config.layer_norm_epsilon, - config.parallel_attn, - process_group, - ) - for _ in range(config.num_hidden_layers) + FlashRWLayer(layer_id, config, weights) + for layer_id in range(config.num_hidden_layers) ] ) self.cache_size = ( @@ -596,15 +556,8 @@ class FlashRWModel(FlashRWPreTrainedModel): elif config.model_type == "RefinedWeb": self.h = nn.ModuleList( [ - FlashRWLargeLayer( - config.n_head, - config.n_head_kv, - config.hidden_size, - config.bias, - config.layer_norm_epsilon, - process_group, - ) - for _ in range(config.num_hidden_layers) + FlashRWLargeLayer(layer_id, config, weights) + for layer_id in range(config.num_hidden_layers) ] ) self.cache_size = ( @@ -617,31 +570,13 @@ class FlashRWModel(FlashRWPreTrainedModel): f"model_type {config.model_type} is not supported." ) - self.ln_f = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) - - self.head_size = self.h[0].self_attention.head_size - - def post_load_weights(self, quantize: Optional[str] = None): - if isinstance(self.word_embeddings, TensorParallelEmbedding): - self.word_embeddings.add_null_idx() - for layer in self.h: - layer: FlashRWLayer - layer.self_attention.query_key_value.prepare_weights(quantize) - layer.self_attention.dense.prepare_weights(quantize) - layer.mlp.dense_h_to_4h.prepare_weights(quantize) - layer.mlp.dense_4h_to_h.prepare_weights(quantize) - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - # Pop here as we will replace the layer in our own logic and don't want from_pretrained - # to do it for us - load_in_8bit = kwargs.pop("load_in_8bit", False) - model = super(FlashRWModel, cls).from_pretrained( - pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs + self.ln_f = FastLayerNorm.load( + prefix="transformer.ln_f", + weights=weights, + eps=config.layer_norm_epsilon, ) - model.post_load_weights("bitsandbytes" if load_in_8bit else None) - return model + self.head_size = self.h[0].self_attention.head_size def forward( self, @@ -708,40 +643,14 @@ class FlashRWModel(FlashRWPreTrainedModel): class FlashRWForCausalLM(FlashRWPreTrainedModel): - def __init__(self, config, process_group=None): + def __init__(self, config, weights): super().__init__(config) - self.process_group = process_group - if self.process_group is not None: - self.world_size = self.process_group.size() - else: - self.world_size = 1 + self.transformer = FlashRWModel(config, weights) - self.transformer = FlashRWModel(config, process_group) - - if self.transformer.tp_embeddings: - self.lm_head = FastLinear( - config.hidden_size, - config.vocab_size // process_group.size(), - bias=False, - ) - else: - self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False) - - def post_load_weights(self, quantize: Optional[str] = None): - self.transformer.post_load_weights(quantize) - self.lm_head.prepare_weights() - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - # Pop here as we will replace the layer in our own logic and don't want from_pretrained - # to do it for us - load_in_8bit = kwargs.pop("load_in_8bit", False) - model = super(FlashRWForCausalLM, cls).from_pretrained( - pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs + self.lm_head = TensorParallelHead.load( + config, prefix="lm_head", weights=weights ) - model.post_load_weights("bitsandbytes" if load_in_8bit else None) - return model def forward( self, @@ -766,12 +675,4 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits = self.lm_head(hidden_states) - - if self.transformer.tp_embeddings: - # Logits are sharded, so we need to gather them - world_logits = [torch.empty_like(logits) for _ in range(self.world_size)] - torch.distributed.all_gather(world_logits, logits, group=self.process_group) - world_logits = torch.cat(world_logits, dim=1) - - return world_logits, present return logits, present diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index b61ec873..888a6066 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -8,39 +8,142 @@ from typing import Optional # Flash attention imports import flash_attn_cuda from text_generation_server.utils.layers import ( - FastLinear, TensorParallelRowLinear, TensorParallelColumnLinear, + TensorParallelHead, TensorParallelEmbedding, FastLayerNorm, + get_linear, ) -class FlashMQAttention(torch.nn.Module): - def __init__( - self, - num_heads, +def load_multi_mqa( + config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size +): + if any("c_attn" in k for k in weights.routing.keys()): + slice_ = weights._get_slice(f"{prefix}.c_attn.weight") + shape = slice_.get_shape() + world_size = weights.process_group.size() + rank = weights.process_group.rank() + if config.transpose: + block_size = (shape[1] - 2 * head_size) // world_size + start = rank * block_size + stop = (rank + 1) * block_size + assert (shape[1] - 2 * head_size) % world_size == 0 + q_tensor = slice_[:, start:stop] + kv_tensor = slice_[:, -2 * head_size :] + weight = torch.cat([q_tensor, kv_tensor], dim=1).T + else: + block_size = (shape[0] - 2 * head_size) // world_size + start = rank * block_size + stop = (rank + 1) * block_size + assert (shape[0] - 2 * head_size) % world_size == 0 + q_tensor = slice_[start:stop] + kv_tensor = slice_[-2 * head_size :] + weight = torch.cat([q_tensor, kv_tensor], dim=0) + if bias: + slice_ = weights._get_slice(f"{prefix}.c_attn.bias") + shape = slice_.get_shape() + block_size = (shape[0] - 2 * head_size) // world_size + assert (shape[0] - 2 * head_size) % world_size == 0 + q_tensor = slice_[start:stop] + start = rank * block_size + stop = (rank + 1) * block_size + q_tensor = slice_[start:stop] + kv_tensor = slice_[-2 * head_size :] + bias = torch.cat([q_tensor, kv_tensor], dim=0) + else: + if config.transpose: + w = [ + weights.get_sharded(f"{prefix}.q_attn.weight", dim=1).T, + weights.get_tensor(f"{prefix}.kv_attn.weight").T, + ] + weight = torch.cat(w, dim=0) + else: + w = [ + weights.get_sharded(f"{prefix}.q_attn.weight", dim=0), + weights.get_tensor(f"{prefix}.kv_attn.weight"), + ] + weight = torch.cat(w, dim=1) + + if bias: + b = [ + weights.get_sharded(f"{prefix}.q_attn.bias", dim=0), + weights.get_tensor(f"{prefix}.kv_attn.bias"), + ] + bias = torch.cat(b, dim=0) + else: + bias = None + + weight = weight.to(dtype=weights.dtype).to(device=weights.device) + assert list(weight.shape) == [ + (num_heads + 2) * head_size, hidden_size, - process_group=None, - ): + ], f"{weight.shape} != {[(num_heads + 2) * head_size, hidden_size]}" + if bias is not None: + bias = bias.to(dtype=weights.dtype).to(device=weights.device) + assert list(bias.shape) == [ + (num_heads + 2) * head_size + ], f"{weight.shape} != {[(num_heads + 2) * head_size]}" + return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + + +def load_col(config, prefix: str, weights, bias: bool): + if config.transpose: + weight = weights.get_sharded(f"{prefix}.weight", dim=1).T + else: + weight = weights.get_sharded(f"{prefix}.weight", dim=0) + + if bias: + bias = weights.get_sharded(f"{prefix}.bias", dim=0) + else: + bias = None + return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + + +def load_row(config, prefix: str, weights, bias: bool): + if config.transpose: + weight = weights.get_sharded(f"{prefix}.weight", dim=0).T + else: + weight = weights.get_sharded(f"{prefix}.weight", dim=1) + + if bias and weights.process_group.rank() == 0: + # Rank is only on the first rank process + bias = weights.get_tensor(f"{prefix}.bias") + else: + bias = None + return TensorParallelRowLinear( + get_linear(weight, bias, config.quantize), process_group=weights.process_group + ) + + +class FlashMQAttention(torch.nn.Module): + def __init__(self, prefix, config, weights): super().__init__() + num_heads = config.num_attention_heads + hidden_size = config.hidden_size + self.num_heads = num_heads self.hidden_size = hidden_size self.head_size = hidden_size // num_heads + assert self.num_heads % weights.process_group.size() == 0 + self.num_heads = self.num_heads // weights.process_group.size() + self.softmax_scale = self.head_size ** (-0.5) - if process_group is None: - self.c_attn = FastLinear(hidden_size, hidden_size + 2 * self.head_size) - self.c_proj = FastLinear(hidden_size, hidden_size) - else: - self.num_heads = self.num_heads // process_group.size() - self.c_attn = FastLinear(hidden_size, self.head_size * (self.num_heads + 2)) - self.c_proj = TensorParallelRowLinear( - hidden_size, - hidden_size, - process_group=process_group, - ) + self.c_attn = load_multi_mqa( + config, + prefix=prefix, + weights=weights, + bias=True, + head_size=self.head_size, + hidden_size=hidden_size, + num_heads=self.num_heads, + ) + self.c_proj = load_row( + config, prefix=f"{prefix}.c_proj", weights=weights, bias=True + ) def forward( self, @@ -121,8 +224,9 @@ class FlashMQAttention(torch.nn.Module): class MLP(nn.Module): - def __init__(self, act, hidden_size, intermediate_size, process_group=None): + def __init__(self, prefix, config, weights): super().__init__() + act = config.activation_function self.act = ( ACT2FN[act] if "gelu" not in act @@ -134,20 +238,12 @@ class MLP(nn.Module): ) ) - if process_group is None: - self.c_fc = FastLinear(hidden_size, intermediate_size) - self.c_proj = FastLinear(intermediate_size, hidden_size) - else: - self.c_fc = TensorParallelColumnLinear( - hidden_size, - intermediate_size, - process_group=process_group, - ) - self.c_proj = TensorParallelRowLinear( - intermediate_size, - hidden_size, - process_group=process_group, - ) + self.c_fc = load_col( + config, prefix=f"{prefix}.c_fc", weights=weights, bias=True + ) + self.c_proj = load_row( + config, prefix=f"{prefix}.c_proj", weights=weights, bias=True + ) def forward(self, hidden_states): hidden_states = self.c_fc(hidden_states) @@ -157,28 +253,24 @@ class MLP(nn.Module): class Block(nn.Module): - def __init__( - self, - num_heads, - act, - hidden_size, - intermediate_size, - layer_norm_eps, - process_group=None, - ): + def __init__(self, layer_id, config, weights): super().__init__() - self.ln_1 = FastLayerNorm(hidden_size, eps=layer_norm_eps) - self.ln_2 = FastLayerNorm(hidden_size, eps=layer_norm_eps) + prefix = f"transformer.h.{layer_id}" + self.ln_1 = FastLayerNorm.load( + prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon + ) + self.ln_2 = FastLayerNorm.load( + prefix=f"{prefix}.ln_2", weights=weights, eps=config.layer_norm_epsilon + ) self.attn = FlashMQAttention( - num_heads, - hidden_size, - process_group, + prefix=f"{prefix}.attn", + config=config, + weights=weights, ) self.mlp = MLP( - act, - hidden_size, - intermediate_size, - process_group, + prefix=f"{prefix}.mlp", + config=config, + weights=weights, ) def forward( @@ -210,66 +302,39 @@ class Block(nn.Module): class FlashSantacoderModel(nn.Module): - def __init__(self, config, process_group=None): + def __init__(self, config, weights): super().__init__() self.config = config - self.process_group = process_group - self.tp_embeddings = False - if process_group is not None: - self.tp_rank = process_group.rank() - self.tp_world_size = process_group.size() - if config.vocab_size % self.tp_world_size == 0: - self.tp_embeddings = True - - if self.tp_embeddings: - self.wte = TensorParallelEmbedding( - config.vocab_size, - config.hidden_size, - reduce=False, - process_group=process_group, - ) - self.wpe = TensorParallelEmbedding( - config.max_position_embeddings, - config.hidden_size, - reduce=False, - process_group=process_group, - ) - else: - self.wte = nn.Embedding(config.vocab_size, config.hidden_size) - self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.process_group = weights.process_group + self.wte = TensorParallelEmbedding( + prefix="transformer.wte", + weights=weights, + reduce=False, + ) + self.wpe = TensorParallelEmbedding( + prefix="transformer.wpe", + weights=weights, + reduce=False, + ) self.h = nn.ModuleList( [ Block( - config.num_attention_heads, - config.activation_function, - config.hidden_size, - config.n_inner - if config.n_inner is not None - else 4 * config.hidden_size, - config.layer_norm_epsilon, - process_group, + layer_id, + config, + weights, ) - for _ in range(config.num_hidden_layers) + for layer_id in range(config.num_hidden_layers) ] ) - self.ln_f = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.ln_f = FastLayerNorm.load( + prefix="transformer.ln_f", weights=weights, eps=config.layer_norm_epsilon + ) self.head_size = self.h[0].attn.head_size self.num_heads = self.h[0].attn.num_heads - def post_load_weights(self, quantize: Optional[str] = None): - if self.tp_embeddings: - self.wte.add_null_idx() - self.wpe.add_null_idx() - for layer in self.h: - layer: Block - layer.attn.c_attn.prepare_weights(quantize) - layer.attn.c_proj.prepare_weights(quantize) - layer.mlp.c_fc.prepare_weights(quantize) - layer.mlp.c_proj.prepare_weights(quantize) - def forward( self, input_ids, @@ -281,8 +346,7 @@ class FlashSantacoderModel(nn.Module): pre_allocate_past_size: Optional[int] = None, ): hidden_states = self.wte(input_ids) + self.wpe(position_ids) - if self.tp_embeddings: - torch.distributed.all_reduce(hidden_states, group=self.process_group) + torch.distributed.all_reduce(hidden_states, group=self.process_group) # Prefill if past_key_values is None: @@ -331,23 +395,12 @@ class FlashSantacoderModel(nn.Module): class FlashSantacoderForCausalLM(nn.Module): - def __init__(self, config, process_group=None): + def __init__(self, config, weights): super().__init__() - - self.transformer = FlashSantacoderModel(config, process_group) - - if self.transformer.tp_embeddings: - self.lm_head = FastLinear( - config.hidden_size, - config.vocab_size // process_group.size(), - bias=False, - ) - else: - self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False) - - def post_load_weights(self, quantize: Optional[str] = None): - self.transformer.post_load_weights(quantize) - self.lm_head.prepare_weights() + self.transformer = FlashSantacoderModel(config, weights) + self.lm_head = TensorParallelHead.load( + config, prefix="transformer.wte", weights=weights + ) def forward( self, @@ -372,29 +425,4 @@ class FlashSantacoderForCausalLM(nn.Module): if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits = self.lm_head(hidden_states) - - if self.transformer.tp_embeddings: - # Logits are sharded, so we need to gather them - if logits.shape[0] == 1: - # Fast path when batch size is 1 - world_logits = logits.new_empty( - (logits.shape[1] * self.transformer.tp_world_size) - ) - torch.distributed.all_gather_into_tensor( - world_logits, logits.view(-1), group=self.transformer.process_group - ) - world_logits = world_logits.view(1, -1) - else: - # We cannot use all_gather_into_tensor as it only support concatenating on the first dim - world_logits = [ - torch.empty_like(logits) - for _ in range(self.transformer.tp_world_size) - ] - torch.distributed.all_gather( - world_logits, logits, group=self.transformer.process_group - ) - world_logits = torch.cat(world_logits, dim=1) - - return world_logits, present - return logits, present diff --git a/server/text_generation_server/models/custom_modeling/neox_modeling.py b/server/text_generation_server/models/custom_modeling/neox_modeling.py new file mode 100644 index 00000000..bf2656d1 --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/neox_modeling.py @@ -0,0 +1,794 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch GPTNeoX model.""" + +from typing import Optional, Tuple, Union + +import os +import torch +import torch.distributed +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.file_utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers import GPTNeoXConfig +from loguru import logger +from text_generation_server.utils.layers import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, + TensorParallelHead, +) + + +CUSTOM_KERNELS_ENABLED = False +if not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True": + try: + from custom_kernels import fused_attention_cuda + + CUSTOM_KERNELS_ENABLED = True + except ImportError: + pass + +if not CUSTOM_KERNELS_ENABLED: + logger.warning("We're not using custom kernels.") + + +def make_causal_mask( + input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int +) -> torch.BoolTensor: + """ + Make causal mask used for self-attention. + """ + batch_size, target_length = input_ids_shape + mask = torch.ones( + (target_length, target_length + past_key_values_length), + dtype=torch.bool, + device=device, + ) + mask = mask.triu(1 + past_key_values_length) + + expanded_mask = mask.unsqueeze(0).expand( + batch_size, target_length, target_length + past_key_values_length + ) + return expanded_mask + + +def expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor: + """ + Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`. + """ + batch_size, src_length = mask.shape + tgt_length = tgt_length if tgt_length is not None else src_length + + expanded_mask = ~(mask[:, None, :].to(torch.bool)) + return expanded_mask.expand(batch_size, tgt_length, src_length) + + +def prepare_attn_mask( + attention_mask: torch.Tensor, + input_shape: Tuple[int, int], + past_key_values_length: int, +) -> torch.BoolTensor: + # create causal mask + # [batch_size, seq_length] -> [batch_size, tgt_length, src_length] + combined_attention_mask = None + device = attention_mask.device + _, src_length = input_shape + + if src_length > 1: + combined_attention_mask = make_causal_mask( + input_shape, device=device, past_key_values_length=past_key_values_length + ) + + # [batch_size, seq_length] -> [batch_size, tgt_length, src_length] + expanded_attn_mask = expand_mask(attention_mask, tgt_length=src_length) + combined_attention_mask = ( + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask | combined_attention_mask + ) + + return combined_attention_mask + + +class GPTNeoXPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + +class GPTNeoXAttention(nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_attention_heads + self.rotary_ndims = int(self.head_size * config.rotary_pct) + max_positions = config.max_position_embeddings + # ??? TODO + # self.register_buffer( + # "bias", + # torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( + # 1, 1, max_positions, max_positions + # ), + # ) + # self.register_buffer("masked_bias", torch.tensor(-1e9)) + self.rotary_emb = RotaryEmbedding( + self.rotary_ndims, + config.max_position_embeddings, + base=config.rotary_emb_base, + ) + self.rotary_emb.inv_freq = nn.Parameter( + weights.get_tensor(f"{prefix}.rotary_emb.inv_freq") + ) + self.inv_norm_factor = 1.0 / torch.sqrt( + torch.tensor(self.head_size, dtype=torch.float32) + ).to(torch.get_default_dtype()) + + assert self.num_attention_heads % weights.process_group.size() == 0 + self.num_attention_heads = ( + self.num_attention_heads // weights.process_group.size() + ) + self.query_key_value = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.query_key_value", weights=weights, bias=True + ) + self.dense = TensorParallelRowLinear.load( + config, prefix=f"{prefix}.dense", weights=weights, bias=True + ) + + def forward( + self, + hidden_states, + position_ids, + attention_mask, + head_mask=None, + layer_past=None, + use_cache=False, + output_attentions=False, + ): + has_layer_past = layer_past is not None + + # Compute QKV + # Attention heads [batch, seq_len, hidden_size] + # --> [batch, seq_len, (np * 3 * head_size)] + qkv = self.query_key_value(hidden_states) + + # [batch, seq_len, (num_heads * 3 * head_size)] + # --> [batch, seq_len, num_heads, 3 * head_size] + new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size) + qkv = qkv.view(*new_qkv_shape).permute(0, 2, 1, 3) + # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size] + query, key, value = qkv.split(self.head_size, -1) + + # Compute token offset for rotary embeddings (when decoding) + seq_len = key.shape[-2] + if has_layer_past: + seq_len += layer_past[0].shape[-2] + + # Compute rotary embeddings on rotary_ndims + query_rot = query[..., : self.rotary_ndims] + key_rot = key[..., : self.rotary_ndims] + + query_rot, key_rot = self.rotary_emb(query_rot, key_rot, position_ids, seq_len) + + query[..., : self.rotary_ndims] = query_rot + key[..., : self.rotary_ndims] = key_rot + + if CUSTOM_KERNELS_ENABLED: + attn_output, present, attn_weights = fused_attention_cuda.forward( + query, + key, + value, + layer_past, + attention_mask, + head_mask, + self.inv_norm_factor, + self.num_attention_heads, + use_cache, + ) + else: + # Cache QKV values + if has_layer_past: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + present = (key, value) if use_cache else None + + # Compute attention + attn_output, attn_weights = self._attn( + query, key, value, attention_mask, head_mask + ) + + # Reshape outputs + attn_output = self._merge_heads( + attn_output, self.num_attention_heads, self.head_size + ) + + attn_output = self.dense(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs + + @classmethod + def _split_heads(cls, tensor, num_attention_heads, attn_head_size): + """ + Splits hidden dim into attn_head_size and num_attention_heads + """ + # tensor: [bs, seq_len, hidden_size] + new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size) + # -> [bs, seq_len, num_attention_heads, attn_head_size] + tensor = tensor.view(new_shape) + # -> [bs, num_attention_heads, seq_len, attn_head_size] + tensor = tensor.permute(0, 2, 1, 3) + return tensor + + @classmethod + def _merge_heads(cls, tensor, num_attention_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden dim + """ + # tensor [bs, num_attention_heads, seq_len, attn_head_size] + tensor = tensor.permute(0, 2, 1, 3).contiguous() + # -> [bs, seq_len, num_attention_heads, attn_head_size] + tensor = tensor.view( + tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size + ) + # -> [bs, seq_len, hidden_size] + return tensor + + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size] + # compute causal mask from causal mask buffer + batch_size, num_attention_heads, query_length, attn_head_size = query.size() + key_length = key.size(-2) + + query = query.view( + batch_size * num_attention_heads, query_length, attn_head_size + ) + key = key.view(batch_size * num_attention_heads, key_length, attn_head_size) + attn_scores = torch.zeros( + 1, + dtype=query.dtype, + device=key.device, + ).expand(batch_size * num_attention_heads, query_length, key_length) + attn_scores = torch.baddbmm( + attn_scores, + query, + key.transpose(1, 2), + beta=1.0, + alpha=self.inv_norm_factor, + ) + + # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] + input_dtype = attn_scores.dtype + if input_dtype in [torch.float16, torch.bfloat16]: + attn_scores = attn_scores.to(torch.float) + attn_scores = torch.where( + attention_mask, torch.finfo(attn_scores.dtype).min, attn_scores + ) + attn_scores = attn_scores.view( + batch_size, num_attention_heads, query_length, key_length + ) + + attn_weights = nn.functional.softmax(attn_scores, dim=-1) + attn_weights = attn_weights.to(value.dtype) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + return attn_output, attn_weights + + +class RotaryEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings, base=10000, device=None): + super().__init__() + self.true_inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2).float().to(device) / dim) + ) + self.register_buffer("inv_freq", self.true_inv_freq) + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + self.cos_cached = None + self.sin_cached = None + + @staticmethod + def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + @staticmethod + def _create_cos_sin(inv_freq, max_position_embeddings, dtype, device): + t = torch.arange( + max_position_embeddings, device=inv_freq.device, dtype=inv_freq.dtype + ) + freqs = torch.einsum("i,j->ij", t, inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + return emb.cos().to(device).to(dtype), emb.sin().to(device).to(dtype) + + def forward(self, q, k, position_ids, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if ( + seq_len > self.max_seq_len_cached + or self.cos_cached is None + or self.sin_cached is None + ): + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + self.cos_cached, self.sin_cached = self._create_cos_sin( + self.true_inv_freq, self.max_seq_len_cached, q.dtype, q.device + ) + return rotary_forward(q, k, self.cos_cached, self.sin_cached, position_ids) + + +@torch.jit.script +def rotary_forward(q, k, cos, sin, position_ids): + cos = cos[position_ids].unsqueeze(1) + sin = sin[position_ids].unsqueeze(1) + + chunk_size = q.shape[-1] // 2 + q1, q2 = q.split(chunk_size, -1) + q_rotated = torch.cat((-q2, q1), dim=-1) + k1, k2 = k.split(chunk_size, -1) + k_rotated = torch.cat((-k2, k1), dim=-1) + + q_embed = (q * cos) + (q_rotated * sin) + k_embed = (k * cos) + (k_rotated * sin) + return q_embed, k_embed + + +class GPTNeoXMLP(nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + self.act = ( + ACT2FN[config.hidden_act] + if "gelu_fast" not in config.hidden_act + else lambda x: torch.nn.functional.gelu(x, approximate="tanh") + ) + + self.dense_h_to_4h = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True + ) + self.dense_4h_to_h = TensorParallelRowLinear.load( + config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True + ) + + def forward(self, hidden_states): + hidden_states = self.dense_h_to_4h(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dense_4h_to_h(hidden_states) + return hidden_states + + +class GPTNeoXLayer(nn.Module): + def __init__(self, layer_id, config, weights): + super().__init__() + self.use_parallel_residual = config.use_parallel_residual + self.input_layernorm = nn.LayerNorm.load( + prefix=f"gpt_neox.layers.{layer_id}.input_layernorm", + weights=weights, + eps=config.layer_norm_eps, + ) + self.post_attention_layernorm = nn.LayerNorm.load( + prefix=f"gpt_neox.layers.{layer_id}.post_attention_layernorm", + weights=weights, + eps=config.layer_norm_eps, + ) + self.attention = GPTNeoXAttention( + config, prefix=f"gpt_neox.layers.{layer_id}.attention", weights=weights + ) + self.mlp = GPTNeoXMLP( + config, prefix=f"gpt_neox.layers.{layer_id}.mlp", weights=weights + ) + + def forward( + self, + hidden_states, + position_ids, + attention_mask=None, + head_mask=None, + use_cache=False, + layer_past=None, + output_attentions=False, + ): + attention_layer_outputs = self.attention( + self.input_layernorm(hidden_states), + attention_mask=attention_mask, + position_ids=position_ids, + layer_past=layer_past, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attention_layer_outputs[ + 0 + ] # output_attn: attn_output, present, (attn_weights) + outputs = attention_layer_outputs[1:] + + if self.use_parallel_residual: + # pseudocode: + # x = x + attn(ln1(x)) + mlp(ln2(x)) + mlp_output = self.mlp(self.post_attention_layernorm(hidden_states)) + hidden_states = mlp_output + attn_output + hidden_states + else: + # pseudocode: + # x = x + attn(ln1(x)) + # x = x + mlp(ln2(x)) + attn_output = attn_output + hidden_states + mlp_output = self.mlp(self.post_attention_layernorm(attn_output)) + hidden_states = mlp_output + attn_output + + if use_cache: + outputs = ( + hidden_states, + ) + outputs # hidden_states, present, (attn_weights) + else: + outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights) + + return outputs + + +class GPTNeoXModel(GPTNeoXPreTrainedModel): + def __init__(self, config, weights): + super().__init__(config) + self.config = config + + self.num_attention_heads = config.num_attention_heads + + self.embed_in = TensorParallelEmbedding( + prefix="gpt_neox.embed_in", weights=weights + ) + self.layers = nn.ModuleList( + [ + GPTNeoXLayer(layer_id, config, weights) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.final_layer_norm = nn.LayerNorm.load( + prefix="gpt_neox.final_layer_norm", + weights=weights, + eps=config.layer_norm_eps, + ) + self.tp_world_size = weights.process_group.size() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids=None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + r""" + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * self.config.num_hidden_layers) + else: + past_length = past_key_values[0][0].size(-2) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_length, seq_length + past_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_in(input_ids) + + hidden_states = inputs_embeds + + # Attention mask. + seq_length_with_past = seq_length + past_key_values_length = 0 + if past_key_values[0] is not None: + past_key_values_length = past_key_values[0][0].shape[-1] + seq_length_with_past = seq_length_with_past + past_key_values_length + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), device=hidden_states.device + ) + else: + attention_mask = attention_mask.to(hidden_states.device) + + causal_mask = prepare_attn_mask( + attention_mask, + input_shape=(batch_size, seq_length), + past_key_values_length=past_key_values_length, + ) + + assert self.num_attention_heads % self.tp_world_size == 0 + block_size = self.num_attention_heads // self.tp_world_size + causal_mask = torch.repeat_interleave(causal_mask, block_size, dim=0) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + presents = () if use_cache else None + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + outputs = layer( + hidden_states, + position_ids=position_ids, + attention_mask=causal_mask, + head_mask=head_mask[i], + layer_past=layer_past, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + if output_attentions: + all_attentions = all_attentions + (outputs[2 if use_cache else 1],) + + hidden_states = self.final_layer_norm(hidden_states) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, presents, all_hidden_states, all_attentions] + if v is not None + ) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_attentions, + ) + + +class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config, weights): + super().__init__(config) + self.gpt_neox = GPTNeoXModel(config, weights) + self.embed_out = TensorParallelHead.load( + config, prefix="embed_out", weights=weights + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional tensors are + only required when the model is used as a decoder in a Sequence to Sequence model. + + Contains pre-computed hidden-states (key and values in the self-attention blocks that can be used (see + `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, GPTNeoXForCausalLM, GPTNeoXConfig + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") + >>> config = GPTNeoXConfig.from_pretrained("EleutherAI/gpt-neox-20b") + >>> config.is_decoder = True + >>> model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", config=config) + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.logits + ```""" + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.gpt_neox( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + lm_logits = self.embed_out(hidden_states) + + lm_loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # we are doing next-token prediction; shift prediction scores and input ids by one + shift_logits = lm_logits[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1) + ) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithPast( + loss=lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs, + ): + input_shape = input_ids.shape + + # cut decoder_input_ids if past is used + if past_key_values and past_key_values[0] is not None: + input_ids = input_ids[:, -1:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + ) + + return model_inputs + + def _reorder_cache(self, past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx) + for past_state in layer_past[:2] + ) + + layer_past[2:], + ) + return reordered_past diff --git a/server/text_generation_server/models/custom_modeling/opt_modeling.py b/server/text_generation_server/models/custom_modeling/opt_modeling.py new file mode 100644 index 00000000..03fded50 --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/opt_modeling.py @@ -0,0 +1,837 @@ +# coding=utf-8 +# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch OPT model.""" +import random +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers import OPTConfig +from text_generation_server.utils.layers import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, + TensorParallelHead, +) + +EPS = 1e-5 + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full( + (tgt_len, tgt_len), + torch.tensor(torch.finfo(dtype).min, device=device), + device=device, + ) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat( + [ + torch.zeros( + tgt_len, past_key_values_length, dtype=dtype, device=device + ), + mask, + ], + dim=-1, + ) + return mask[None, None, :, :].expand( + bsz, 1, tgt_len, tgt_len + past_key_values_length + ) + + +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(dtype).min + ) + + +class OPTLearnedPositionalEmbedding(nn.Module): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, weights): + super().__init__() + self.offset = 2 + self.weight = nn.Parameter( + weights.get_tensor("model.decoder.embed_positions.weight") + ) + + def forward( + self, attention_mask: torch.LongTensor, past_key_values_length: int = 0 + ): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + attention_mask = attention_mask.long() + + # create positions depending on attention_mask + positions = ( + torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask + ).long() - 1 + + # cut positions if `past_key_values_length` is > 0 + positions = positions[:, past_key_values_length:] + + return torch.nn.functional.embedding(positions + self.offset, self.weight) + + +class OPTAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config, + prefix, + weights, + is_decoder: bool = False, + bias: bool = True, + process_group=None, + ): + super().__init__() + embed_dim = config.embed_dim + num_heads = config.num_attention_heads + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = config.dropout + self.head_dim = embed_dim // num_heads + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + process_group = weights.process_group + assert self.num_heads % process_group.size() == 0 + self.num_heads = self.num_heads // process_group.size() + self.embed_dim = self.embed_dim // process_group.size() + + self.q_proj = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.q_proj", weights=weights, bias=bias + ) + self.k_proj = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.k_proj", weights=weights, bias=bias + ) + self.v_proj = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.v_proj", weights=weights, bias=bias + ) + self.out_proj = TensorParallelRowLinear.load( + config, prefix=f"{prefix}.out_proj", weights=weights, bias=bias + ) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = ( + attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + + attention_mask + ) + attn_weights = torch.max( + attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437 + if attn_weights.dtype == torch.float16: + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(torch.float16) + else: + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view( + bsz, self.num_heads, tgt_len, src_len + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view( + bsz, self.num_heads, tgt_len, src_len + ) + attn_weights = attn_weights_reshaped.view( + bsz * self.num_heads, tgt_len, src_len + ) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class OPTDecoderLayer(nn.Module): + def __init__(self, layer_id: int, config: OPTConfig, weights): + super().__init__() + self.process_group = weights.process_group + self.embed_dim = config.hidden_size + prefix = f"model.decoder.layers.{layer_id}" + self.self_attn = OPTAttention( + config, + prefix=f"{prefix}.self_attn", + weights=weights, + is_decoder=True, + bias=config.enable_bias, + ) + self.do_layer_norm_before = config.do_layer_norm_before + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + + self.self_attn_layer_norm = nn.LayerNorm.load( + prefix=f"{prefix}.self_attn_layer_norm", weights=weights, eps=EPS + ) + self.fc1 = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.fc1", weights=weights, bias=config.enable_bias + ) + self.fc2 = TensorParallelRowLinear.load( + config, prefix=f"{prefix}.fc2", weights=weights, bias=config.enable_bias + ) + self.final_layer_norm = nn.LayerNorm.load( + prefix=f"{prefix}.final_layer_norm", weights=weights, eps=EPS + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training + ) + hidden_states = residual + hidden_states + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Fully Connected + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training + ) + + hidden_states = (residual + hidden_states).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class OPTPreTrainedModel(PreTrainedModel): + config_class = OPTConfig + + +class OPTDecoder(OPTPreTrainedModel): + def __init__(self, config: OPTConfig, weights): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + self.vocab_size = config.vocab_size + + self.embed_tokens = TensorParallelEmbedding( + prefix="model.decoder.embed_tokens", weights=weights + ) + self.embed_positions = OPTLearnedPositionalEmbedding(weights) + + if config.word_embed_proj_dim != config.hidden_size: + self.project_out = FastLinear.load( + config, prefix="model.decoder.project_out", bias=False + ) + else: + self.project_out = None + + if config.word_embed_proj_dim != config.hidden_size: + self.project_in = FastLinear.load( + config, prefix="model.decoder.project_in", bias=False + ) + else: + self.project_in = None + + # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility + # with checkpoints that have been fine-tuned before transformers v4.20.1 + # see https://github.com/facebookresearch/metaseq/pull/164 + if config.do_layer_norm_before and not config._remove_final_layer_norm: + self.final_layer_norm = nn.LayerNorm.load( + prefix="model.decoder.final_layer_norm", weights=weights, eps=EPS + ) + else: + self.final_layer_norm = None + + self.layers = nn.ModuleList( + [ + OPTDecoderLayer(layer_id, config, weights) + for layer_id in range(config.num_hidden_layers) + ] + ) + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask( + self, attention_mask, input_shape, inputs_embeds, past_key_values_length + ): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ).to(inputs_embeds.device) + combined_attention_mask = ( + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError( + "You have to specify either decoder_input_ids or decoder_inputs_embeds" + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + past_key_values_length = ( + past_key_values[0][0].shape[2] if past_key_values is not None else 0 + ) + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values_length + seq_length + + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + batch_size, mask_seq_length, device=inputs_embeds.device + ) + causal_attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + pos_embeds = self.embed_positions(attention_mask, past_key_values_length) + + if self.project_in is not None: + inputs_embeds = self.project_in(inputs_embeds) + + hidden_states = inputs_embeds + pos_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + # check if head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask], ["head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): + continue + + past_key_value = ( + past_key_values[idx] if past_key_values is not None else None + ) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if self.final_layer_norm is not None: + hidden_states = self.final_layer_norm(hidden_states) + + if self.project_out is not None: + hidden_states = self.project_out(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class OPTModel(OPTPreTrainedModel): + def __init__(self, config: OPTConfig, weights): + super().__init__(config) + self.decoder = OPTDecoder(config, weights) + # Initialize weights and apply final processing + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + + return BaseModelOutputWithPast( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + ) + + +class OPTForCausalLM(OPTPreTrainedModel): + def __init__(self, config, weights): + super().__init__(config) + + self.model = OPTModel(config, weights) + + self.lm_head = TensorParallelHead.load( + config, prefix="model.decoder.embed_tokens", weights=weights + ) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = self.lm_head(outputs[0]).contiguous() + + loss = None + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs, + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx) for past_state in layer_past + ), + ) + return reordered_past diff --git a/server/text_generation_server/models/custom_modeling/t5_modeling.py b/server/text_generation_server/models/custom_modeling/t5_modeling.py new file mode 100644 index 00000000..51862e3c --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/t5_modeling.py @@ -0,0 +1,1200 @@ +# coding=utf-8 +# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch T5 model.""" + +import copy +import math +import warnings +from typing import Optional, Tuple, Union + +import torch +import torch.distributed +from torch import nn +from torch.nn import CrossEntropyLoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import ( + is_torch_fx_proxy, +) +from transformers import T5Config +from text_generation_server.utils.layers import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, + TensorParallelHead, +) + + +class PartialTPEmbedding(nn.Module): + def __init__(self, prefix: str, weights): + super().__init__() + weight = weights.get_sharded(f"{prefix}.weight", dim=1) + self.weight = nn.Parameter(weight) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.embedding(input, self.weight) + + +@torch.jit.script +def layer_norm(hidden_states, weight, epsilon): + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + epsilon) + + # convert into half-precision if necessary + if weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(weight.dtype) + + return weight * hidden_states + + +class T5LayerNorm(nn.Module): + def __init__(self, prefix, weights, eps=1e-6): + """ + Construct a layernorm module in the T5 style. No bias and no subtraction of mean. + """ + super().__init__() + weight = weights.get_tensor(f"{prefix}.weight") + self.weight = nn.Parameter(weight) + self.variance_epsilon = torch.tensor(eps) + + def forward(self, hidden_states): + return layer_norm(hidden_states, self.weight, self.variance_epsilon) + + +try: + from apex.normalization import FusedRMSNorm + + T5LayerNorm = FusedRMSNorm # noqa + + logger.info( + "Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm" + ) +except ImportError: + # using the normal T5LayerNorm + pass +except Exception: + logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm") + pass + +ALL_LAYERNORM_LAYERS.append(T5LayerNorm) + + +class T5DenseActDense(nn.Module): + def __init__(self, config: T5Config, prefix, weights): + super().__init__() + self.wi = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.wi", weights=weights, bias=False + ) + + ### XXX: T5 models do not handle well both f16 and quantization. + ### Overidding specifically this layer for that reason. + ### https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L316 + ### https://github.com/huggingface/transformers/issues/20287 + _q = config.quantize + _dtype = weights.dtype + weights.dtype = torch.float32 + config.quantize = None + self.wo_cast = (torch.float32, _dtype) + self.wo = TensorParallelRowLinear.load( + config, prefix=f"{prefix}.wo", weights=weights, bias=False + ) + weights.dtype = _dtype + config.quantize = _q + + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ( + ACT2FN[config.dense_act_fn] + if "gelu" not in config.dense_act_fn + else lambda x: torch.nn.functional.gelu(x, approximate="tanh") + ) + + def forward(self, hidden_states): + hidden_states = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states) + + hidden_states = hidden_states.to(dtype=self.wo_cast[0]) + hidden_states = self.wo(hidden_states) + # XXX: Recasting is already done within the layer norm. + # Casting back to float16 here modifies results + # hidden_states = hidden_states.to(dtype=self.wo_cast[1]) + return hidden_states + + +class T5DenseGatedActDense(nn.Module): + def __init__(self, config: T5Config, prefix, weights): + super().__init__() + self.wi_0 = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.wi_0", weights=weights, bias=False + ) + self.wi_1 = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.wi_1", weights=weights, bias=False + ) + ### XXX: T5 models do not handle well both f16 and quantization. + ### Overidding specifically this layer for that reason. + ### https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L316 + ### https://github.com/huggingface/transformers/issues/20287 + _q = config.quantize + _dtype = weights.dtype + weights.dtype = torch.float32 + config.quantize = None + self.wo_cast = (torch.float32, _dtype) + self.wo = TensorParallelRowLinear.load( + config, prefix=f"{prefix}.wo", weights=weights, bias=False + ) + weights.dtype = _dtype + config.quantize = _q + + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ( + ACT2FN[config.dense_act_fn] + if "gelu" not in config.dense_act_fn + else lambda x: torch.nn.functional.gelu(x, approximate="tanh") + ) + + def forward(self, hidden_states): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + + hidden_states = hidden_states.to(dtype=self.wo_cast[0]) + hidden_states = self.wo(hidden_states) + # XXX: Recasting is already done within the layer norm. + # Casting back to float16 here modifies results + # hidden_states = hidden_states.to(dtype=self.wo_cast[1]) + return hidden_states + + +class T5LayerFF(nn.Module): + def __init__(self, config: T5Config, prefix, weights): + super().__init__() + if config.is_gated_act: + self.DenseReluDense = T5DenseGatedActDense( + config, prefix=f"{prefix}.DenseReluDense", weights=weights + ) + else: + self.DenseReluDense = T5DenseActDense( + config, prefix=f"{prefix}.DenseReluDense", weights=weights + ) + + self.layer_norm = T5LayerNorm( + prefix=f"{prefix}.layer_norm", + weights=weights, + eps=config.layer_norm_epsilon, + ) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states + + +class T5Attention(nn.Module): + def __init__( + self, config: T5Config, prefix, weights, has_relative_attention_bias=False + ): + super().__init__() + self.is_decoder = config.is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + + process_group = weights.process_group + # Mesh TensorFlow initialization to avoid scaling before softmax + assert self.n_heads % process_group.size() == 0 + self.q = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.q", weights=weights, bias=False + ) + self.k = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.k", weights=weights, bias=False + ) + self.v = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.v", weights=weights, bias=False + ) + self.o = TensorParallelRowLinear.load( + config, prefix=f"{prefix}.o", weights=weights, bias=False + ) + self.n_heads = self.n_heads // process_group.size() + self.inner_dim = self.inner_dim // process_group.size() + + if self.has_relative_attention_bias: + self.relative_attention_bias = PartialTPEmbedding( + prefix=f"{prefix}.relative_attention_bias", weights=weights + ) + + @staticmethod + def _relative_position_bucket( + relative_position, bidirectional=True, num_buckets=32, max_distance=128 + ): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min( + relative_position, torch.zeros_like(relative_position) + ) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, + torch.full_like(relative_position_if_large, num_buckets - 1), + ) + + relative_buckets += torch.where( + is_small, relative_position, relative_position_if_large + ) + return relative_buckets + + def compute_bias(self, query_length, key_length, device=None): + """Compute binned relative position bias""" + if device is None: + device = self.relative_attention_bias.weight.device + context_position = torch.arange(query_length, dtype=torch.long, device=device)[ + :, None + ] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[ + None, : + ] + relative_position = ( + memory_position - context_position + ) # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias( + relative_position_bucket + ) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze( + 0 + ) # shape (1, num_heads, query_length, key_length) + return values + + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + + batch_size, seq_length = hidden_states.shape[:2] + + real_seq_length = seq_length + + if past_key_value is not None: + assert ( + len(past_key_value) == 2 + ), f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states" + real_seq_length += ( + past_key_value[0].shape[2] if query_length is None else query_length + ) + + key_length = ( + real_seq_length if key_value_states is None else key_value_states.shape[1] + ) + + def shape(states): + """projection""" + return states.view( + batch_size, -1, self.n_heads, self.key_value_proj_dim + ).transpose(1, 2) + + def unshape(states): + """reshape""" + return ( + states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + ) + + def project(hidden_states, proj_layer, key_value_states, past_key_value): + """projects hidden states correctly to key/query states""" + if key_value_states is None: + # self-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(hidden_states)) + elif past_key_value is None: + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + + if past_key_value is not None: + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = torch.cat([past_key_value, hidden_states], dim=2) + elif past_key_value.shape[2] != key_value_states.shape[1]: + # checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + else: + # cross-attn + hidden_states = past_key_value + return hidden_states + + # get query states + query_states = shape( + self.q(hidden_states) + ) # (batch_size, n_heads, seq_length, dim_per_head) + + # get key/value states + key_states = project( + hidden_states, + self.k, + key_value_states, + past_key_value[0] if past_key_value is not None else None, + ) + value_states = project( + hidden_states, + self.v, + key_value_states, + past_key_value[1] if past_key_value is not None else None, + ) + + # compute scores + scores = torch.matmul( + query_states, key_states.transpose(3, 2) + ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + + if position_bias is None: + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, real_seq_length, key_length), + device=scores.device, + dtype=scores.dtype, + ) + else: + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device + ) + + # if key and values are already calculated + # we want only the last query position bias + if past_key_value is not None: + position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + + if mask is not None: + position_bias = ( + position_bias + mask + ) # (batch_size, n_heads, seq_length, key_length) + + position_bias_masked = position_bias + + scores += position_bias_masked + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( + scores + ) # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) # (batch_size, n_heads, seq_length, key_length) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = unshape( + torch.matmul(attn_weights, value_states) + ) # (batch_size, seq_length, dim) + attn_output = self.o(attn_output) + + present_key_value_state = ( + (key_states, value_states) if (self.is_decoder and use_cache) else None + ) + outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + +class T5LayerSelfAttention(nn.Module): + def __init__(self, config, prefix, weights, has_relative_attention_bias=False): + super().__init__() + self.SelfAttention = T5Attention( + config, + prefix=f"{prefix}.SelfAttention", + weights=weights, + has_relative_attention_bias=has_relative_attention_bias, + ) + self.layer_norm = T5LayerNorm( + prefix=f"{prefix}.layer_norm", + weights=weights, + eps=config.layer_norm_epsilon, + ) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + self.dropout(attention_output[0]) + outputs = (hidden_states,) + attention_output[ + 1: + ] # add attentions if we output them + return outputs + + +class T5LayerCrossAttention(nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + self.EncDecAttention = T5Attention( + config, + prefix=f"{prefix}.EncDecAttention", + weights=weights, + has_relative_attention_bias=False, + ) + self.layer_norm = T5LayerNorm( + prefix=f"{prefix}.layer_norm", + weights=weights, + eps=config.layer_norm_epsilon, + ) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + query_length=None, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, + ) + layer_output = hidden_states + self.dropout(attention_output[0]) + outputs = (layer_output,) + attention_output[ + 1: + ] # add attentions if we output them + return outputs + + +class T5Block(nn.Module): + def __init__(self, config, prefix, weights, has_relative_attention_bias: bool): + super().__init__() + self.is_decoder = config.is_decoder + self.layer = nn.ModuleList() + self.layer.append( + T5LayerSelfAttention( + config, + prefix=f"{prefix}.layer.0", + weights=weights, + has_relative_attention_bias=has_relative_attention_bias, + ) + ) + if self.is_decoder: + i = 2 + self.layer.append( + T5LayerCrossAttention( + config, prefix=f"{prefix}.layer.1", weights=weights + ) + ) + else: + i = 1 + + self.layer.append( + T5LayerFF(config, prefix=f"{prefix}.layer.{i}", weights=weights) + ) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + return_dict=True, + ): + if past_key_value is not None: + if not self.is_decoder: + logger.warning( + "`past_key_values` is passed to the encoder. Please make sure this is intended." + ) + expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 + + if len(past_key_value) != expected_num_past_key_values: + raise ValueError( + f"There should be {expected_num_past_key_values} past states. " + f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}" + f"Got {len(past_key_value)} past key / value states" + ) + + self_attn_past_key_value = past_key_value[:2] + cross_attn_past_key_value = past_key_value[2:] + else: + self_attn_past_key_value, cross_attn_past_key_value = None, None + + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=self_attn_past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states, present_key_value_state = self_attention_outputs[:2] + attention_outputs = self_attention_outputs[ + 2: + ] # Keep self-attention outputs and relative position weights + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp( + hidden_states, min=-clamp_value, max=clamp_value + ) + + do_cross_attention = self.is_decoder and encoder_hidden_states is not None + if do_cross_attention: + # the actual query length is unknown for cross attention + # if using past key value states. Need to inject it here + if present_key_value_state is not None: + query_length = present_key_value_state[0].shape[2] + else: + query_length = None + + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + query_length=query_length, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = cross_attention_outputs[0] + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp( + hidden_states, min=-clamp_value, max=clamp_value + ) + + # Combine self attn and cross attn key value states + if present_key_value_state is not None: + present_key_value_state = ( + present_key_value_state + cross_attention_outputs[1] + ) + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[2:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp( + hidden_states, min=-clamp_value, max=clamp_value + ) + + outputs = (hidden_states,) + + if use_cache: + outputs = outputs + (present_key_value_state,) + attention_outputs + else: + outputs = outputs + attention_outputs + + return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + + +class T5PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = T5Config + + def _shift_right(self, input_ids): + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + assert decoder_start_token_id is not None, ( + "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id." + " See T5 docs for more information" + ) + + # shift inputs to the right + if is_torch_fx_proxy(input_ids): + # Item assignment is not supported natively for proxies. + shifted_input_ids = torch.full( + input_ids.shape[:-1] + (1,), decoder_start_token_id + ) + shifted_input_ids = torch.cat( + [shifted_input_ids, input_ids[..., :-1]], dim=-1 + ) + else: + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + assert ( + pad_token_id is not None + ), "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +class T5Stack(T5PreTrainedModel): + def __init__(self, config, prefix, weights, embed_tokens): + super().__init__(config) + + self.is_decoder = config.is_decoder + + self.embed_tokens = embed_tokens + self.block = nn.ModuleList( + [ + T5Block( + config, + prefix=f"{prefix}.block.{layer_id}", + weights=weights, + has_relative_attention_bias=(layer_id == 0), + ) + for layer_id in range(config.num_layers) + ] + ) + self.final_layer_norm = T5LayerNorm( + prefix=f"{prefix}.final_layer_norm", + weights=weights, + eps=config.layer_norm_epsilon, + ) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + # Model parallel + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds" + ) + + if inputs_embeds is None: + assert ( + self.embed_tokens is not None + ), "You have to initialize the model with valid token embeddings" + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + + # required mask seq length can be calculated via length of past + mask_seq_length = ( + past_key_values[0][0].shape[2] + seq_length + if past_key_values is not None + else seq_length + ) + + if use_cache is True: + assert ( + self.is_decoder + ), f"`use_cache` can only be set to `True` if {self} is used as a decoder" + + if attention_mask is None: + attention_mask = torch.ones( + batch_size, mask_seq_length, device=inputs_embeds.device + ) + if ( + self.is_decoder + and encoder_attention_mask is None + and encoder_hidden_states is not None + ): + encoder_seq_length = encoder_hidden_states.shape[1] + encoder_attention_mask = torch.ones( + batch_size, + encoder_seq_length, + device=inputs_embeds.device, + dtype=torch.long, + ) + + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.is_decoder and encoder_hidden_states is not None: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=inputs_embeds.device + ) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask( + cross_attn_head_mask, self.config.num_layers + ) + present_key_value_states = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout(inputs_embeds) + + for i, (layer_module, past_key_value) in enumerate( + zip(self.block, past_key_values) + ): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + # Model parallel + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + if use_cache is False: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + + hidden_states, present_key_value_state = layer_outputs[:2] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[2] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[ + 4 if output_attentions else 3 + ] + # append next layer key value states + if use_cache: + present_key_value_states = present_key_value_states + ( + present_key_value_state, + ) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + if self.is_decoder: + all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + present_key_value_states, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +class T5ForConditionalGeneration(T5PreTrainedModel): + def __init__(self, config: T5Config, weights): + super().__init__(config) + self.model_dim = config.d_model + + self.shared = TensorParallelEmbedding(prefix="shared", weights=weights) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack( + config=encoder_config, + prefix="encoder", + weights=weights, + embed_tokens=self.shared, + ) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = T5Stack( + config=decoder_config, + prefix="decoder", + weights=weights, + embed_tokens=self.shared, + ) + + self.lm_head = TensorParallelHead.load( + config, prefix="lm_head", weights=weights + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + # Convert encoder inputs in embeddings if needed + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + if ( + labels is not None + and decoder_input_ids is None + and decoder_inputs_embeds is None + ): + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = decoder_outputs[0] + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim**-0.5) + + lm_logits = self.lm_head(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100) + # move labels to correct device to enable PP + labels = labels.to(lm_logits.device) + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 + + if not return_dict: + output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs + return ((loss,) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + decoder_attention_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, + ): + # cut decoder_input_ids if past is used + if past_key_values is not None: + input_ids = input_ids[:, -1:] + + return { + "decoder_input_ids": input_ids, + "past_key_values": past_key_values, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "decoder_attention_mask": decoder_attention_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + } + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + def _reorder_cache(self, past_key_values, beam_idx): + # if decoder past is not included in output + # speedy decoding is disabled and no need to reorder + if past_key_values is None: + logger.warning( + "You might want to consider setting `use_cache=True` to speed up decoding" + ) + return past_key_values + + reordered_decoder_past = () + for layer_past_states in past_key_values: + # get the correct batch idx from layer past batch dim + # batch dim of `past` is at 2nd position + reordered_layer_past_states = () + for layer_past_state in layer_past_states: + # need to set correct `past` for each of the four key / value states + reordered_layer_past_states = reordered_layer_past_states + ( + layer_past_state.index_select( + 0, beam_idx.to(layer_past_state.device) + ), + ) + + assert reordered_layer_past_states[0].shape == layer_past_states[0].shape + assert len(reordered_layer_past_states) == len(layer_past_states) + + reordered_decoder_past = reordered_decoder_past + ( + reordered_layer_past_states, + ) + return reordered_decoder_past diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index fe28580d..eb216a20 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -1,154 +1,25 @@ import torch import torch.distributed -from accelerate import init_empty_weights from opentelemetry import trace -from pathlib import Path -from safetensors import safe_open from transformers import AutoConfig from transformers.models.llama import LlamaTokenizer -from typing import Optional, List +from typing import Optional from text_generation_server.models import FlashCausalLM from text_generation_server.models.custom_modeling.flash_llama_modeling import ( FlashLlamaForCausalLM, - TensorParallelEmbedding, - TensorParallelRowLinear, - TensorParallelColumnLinear, ) from text_generation_server.utils import ( initialize_torch_distributed, weight_files, - download_weights, - weight_hub_files, - LocalEntryNotFoundError, + Weights, ) tracer = trace.get_tracer(__name__) class FlashLlama(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - trust_remote_code: bool = False, - ): - if torch.cuda.is_available(): - device = torch.device("cuda") - dtype = torch.float16 - else: - raise NotImplementedError("FlashLlama is only available on GPU") - - tokenizer = LlamaTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - - # We do not use from_pretrained as we modified the model internal module layout - try: - filenames = weight_files(model_id, revision, ".bin") - # Local files not found - except LocalEntryNotFoundError: - hub_files = weight_hub_files(model_id, revision, ".bin") - filenames = download_weights(hub_files, model_id, revision) - - with init_empty_weights(): - model = FlashLlamaForCausalLM(config) - - self.load_weights(model, filenames, quantize, device, dtype) - - super(FlashCausalLM, self).__init__( - model=model.to(device), - tokenizer=tokenizer, - requires_padding=False, - dtype=dtype, - device=device, - ) - - @staticmethod - def load_weights( - model, - filenames: List[Path], - quantize: Optional[str], - device: torch.device, - dtype: torch.dtype, - ): - for filename in filenames: - state_dict = torch.load(filename, map_location="cpu") - for key, value in state_dict.items(): - value = value.to(device if quantize is None else "cpu").to(dtype) - - layer_name = ".".join(key.split(".")[:4]) - - # Fused qkv - if "q_proj" in key or "k_proj" in key or "v_proj" in key: - final_key = layer_name + ".query_key_value.weight" - - # Fused gate and up projs - elif "gate_proj" in key or "up_proj" in key: - final_key = layer_name + ".gate_up_proj.weight" - else: - final_key = key - - module_name, param_name = final_key.rsplit(".", 1) - module = model.get_submodule(module_name) - - try: - current_parameter_tensor = module._parameters[param_name] - except KeyError: - current_parameter_tensor = None - - if current_parameter_tensor is not None: - if current_parameter_tensor.device == torch.device("meta"): - # Init qkv - if "query_key_value" in final_key: - module._parameters[param_name] = value.new_empty( - (value.shape[0] * 3, value.shape[1]) - ) - # Init gate and up proj - elif "gate_up_proj" in final_key: - module._parameters[param_name] = value.new_empty( - (value.shape[0] * 2, value.shape[1]) - ) - - # Copy to correct slice - if "q_proj" in key: - module._parameters[param_name][: value.shape[0]] = value - elif "k_proj" in key: - module._parameters[param_name][ - value.shape[0] : value.shape[0] * 2 - ] = value - elif "v_proj" in key: - module._parameters[param_name][value.shape[0] * 2 :] = value - elif "gate_proj" in key: - module._parameters[param_name][: value.shape[0]] = value - elif "up_proj" in key: - module._parameters[param_name][value.shape[0] :] = value - else: - if current_parameter_tensor.shape != value.shape: - raise ValueError( - f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}" - ) - module._parameters[param_name] = value - else: - module._buffers[param_name] = value - - del value - - torch.cuda.empty_cache() - model.post_load_weights(quantize) - - -class FlashLlamaSharded(FlashLlama): def __init__( self, model_id: str, @@ -176,24 +47,16 @@ class FlashLlamaSharded(FlashLlama): ) torch.distributed.barrier(group=self.process_group) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights(filenames, device, dtype, process_group=self.process_group) - with init_empty_weights(): - model = FlashLlamaForCausalLM(config, process_group=self.process_group) + config.quantize = quantize + model = FlashLlamaForCausalLM(config, weights) - torch.distributed.barrier(group=self.process_group) - self.load_weights( - model, - filenames, - quantize=quantize, - device=device, - dtype=dtype, - rank=rank, - world_size=world_size, - ) torch.distributed.barrier(group=self.process_group) super(FlashCausalLM, self).__init__( - model=model.to(device), + model=model, tokenizer=tokenizer, requires_padding=False, dtype=dtype, @@ -201,114 +64,3 @@ class FlashLlamaSharded(FlashLlama): rank=rank, world_size=world_size, ) - - @staticmethod - def load_weights( - model, - filenames: List[str], - quantize: Optional[str], - device: torch.device, - dtype: torch.dtype, - rank: int, - world_size: int, - ): - for file in filenames: - with safe_open( - file, framework="pt", device=str(device) if quantize is None else "cpu" - ) as f: - for name in f.keys(): - slice_ = f.get_slice(name) - - layer_name = ".".join(name.split(".")[:4]) - - # Fused qkv - if "q_proj" in name or "k_proj" in name or "v_proj" in name: - final_name = layer_name + ".query_key_value.weight" - - # Fused gate and up projs - elif "gate_proj" in name or "up_proj" in name: - final_name = layer_name + ".gate_up_proj.weight" - else: - final_name = name - - module_name, param_name = final_name.rsplit(".", 1) - module = model.get_submodule(module_name) - - if isinstance(module, TensorParallelColumnLinear): - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - elif isinstance(module, TensorParallelRowLinear): - size = slice_.get_shape()[1] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[:, start:stop] - elif isinstance(module, TensorParallelEmbedding): - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - elif name == "lm_head.weight" and model.model.tp_embeddings: - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - else: - try: - tensor = slice_[:] - except: - tensor = f.get_tensor(name) - - tensor = tensor.contiguous().to(dtype) - - try: - current_parameter_tensor = module._parameters[param_name] - except KeyError: - current_parameter_tensor = None - - if current_parameter_tensor is not None: - if current_parameter_tensor.device == torch.device("meta"): - # Init qkv - if "query_key_value" in final_name: - module._parameters[param_name] = tensor.new_empty( - (tensor.shape[0] * 3, tensor.shape[1]) - ) - # Init gate and up proj - elif "gate_up_proj" in final_name: - module._parameters[param_name] = tensor.new_empty( - (tensor.shape[0] * 2, tensor.shape[1]) - ) - - # Init gate and up proj - if "q_proj" in name: - module._parameters[param_name][: tensor.shape[0]] = tensor - elif "k_proj" in name: - module._parameters[param_name][ - tensor.shape[0] : tensor.shape[0] * 2 - ] = tensor - elif "v_proj" in name: - module._parameters[param_name][ - tensor.shape[0] * 2 : - ] = tensor - elif "gate_proj" in name: - module._parameters[param_name][: tensor.shape[0]] = tensor - elif "up_proj" in name: - module._parameters[param_name][tensor.shape[0] :] = tensor - else: - if current_parameter_tensor.shape != tensor.shape: - raise ValueError( - f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" - ) - - module._parameters[param_name] = tensor - - else: - module._buffers[param_name] = tensor - - torch.cuda.empty_cache() - model.post_load_weights(quantize) diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 31ae7914..4847571d 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -1,45 +1,24 @@ import torch import torch.distributed -from accelerate import init_empty_weights from opentelemetry import trace -from safetensors import safe_open from transformers import AutoTokenizer, AutoConfig -from typing import Optional, List +from typing import Optional from text_generation_server.models import FlashCausalLM from text_generation_server.models.custom_modeling.flash_neox_modeling import ( FlashGPTNeoXForCausalLM, - TensorParallelEmbedding, - TensorParallelRowLinear, - TensorParallelColumnLinear, ) from text_generation_server.utils import ( initialize_torch_distributed, weight_files, + Weights, ) tracer = trace.get_tracer(__name__) -class FlashNeoX(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - trust_remote_code: bool = False, - ): - super(FlashNeoX, self).__init__( - FlashGPTNeoXForCausalLM, - model_id, - revision, - quantize, - trust_remote_code=trust_remote_code, - ) - - -class FlashNeoXSharded(FlashNeoX): +class FlashNeoXSharded(FlashCausalLM): def __init__( self, model_id: str, @@ -65,23 +44,16 @@ class FlashNeoXSharded(FlashNeoX): config = AutoConfig.from_pretrained( model_id, revision=revision, trust_remote_code=trust_remote_code ) + config.quantize = quantize torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - - with init_empty_weights(): - model = FlashGPTNeoXForCausalLM(config, self.process_group) - - torch.distributed.barrier(group=self.process_group) - self.load_weights( - model, - filenames, - quantize=quantize, - device=device, - dtype=dtype, - rank=rank, - world_size=world_size, + weights = Weights( + filenames, device=device, dtype=dtype, process_group=self.process_group ) + + model = FlashGPTNeoXForCausalLM(config, weights) + torch.distributed.barrier(group=self.process_group) super(FlashCausalLM, self).__init__( model=model.to(device), @@ -92,79 +64,3 @@ class FlashNeoXSharded(FlashNeoX): rank=rank, world_size=world_size, ) - - @staticmethod - def load_weights( - model, - filenames: List[str], - quantize: Optional[str], - device: torch.device, - dtype: torch.dtype, - rank: int, - world_size: int, - ): - parameters = dict(model.named_parameters()) - for file in filenames: - with safe_open( - file, framework="pt", device=str(device) if quantize is None else "cpu" - ) as f: - for name in f.keys(): - module_name, param_name = name.rsplit(".", 1) - module = model.get_submodule(module_name) - - current_parameter_tensor = parameters.get(name, None) - - slice_ = f.get_slice(name) - - if isinstance(module, TensorParallelColumnLinear): - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - elif isinstance(module, TensorParallelRowLinear): - if param_name == "weight": - size = slice_.get_shape()[1] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[:, start:stop] - else: - tensor = slice_[:] - # XXX: Hack for Rowlinear to add the bias only once. - if rank != 0: - tensor = torch.zeros_like(tensor) - elif isinstance(module, TensorParallelEmbedding): - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings: - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - else: - try: - tensor = slice_[:] - except: - tensor = f.get_tensor(name) - - if ( - current_parameter_tensor is not None - and current_parameter_tensor.shape != tensor.shape - ): - raise ValueError( - f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" - ) - - tensor = tensor.contiguous().to(dtype) - - if current_parameter_tensor is not None: - module._parameters[param_name] = tensor - else: - module._buffers[param_name] = tensor - - model.post_load_weights(quantize) diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index 4fc4c389..5f963bfb 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -1,119 +1,25 @@ import torch import torch.distributed -from pathlib import Path -from accelerate import init_empty_weights from opentelemetry import trace -from safetensors import safe_open -from transformers import AutoTokenizer, AutoConfig -from typing import Optional, List +from transformers import AutoTokenizer +from typing import Optional from text_generation_server.models import FlashCausalLM from text_generation_server.models.custom_modeling.flash_rw_modeling import ( RWConfig, FlashRWForCausalLM, - TensorParallelEmbedding, - TensorParallelRowLinear, - TensorParallelColumnLinear, ) from text_generation_server.utils import ( initialize_torch_distributed, weight_files, - download_weights, - weight_hub_files, - LocalEntryNotFoundError, + Weights, ) tracer = trace.get_tracer(__name__) -class FlashRW(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - trust_remote_code: bool = False, - ): - if torch.cuda.is_available(): - device = torch.device("cuda") - dtype = torch.float16 - else: - raise NotImplementedError("RW is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = RWConfig.from_pretrained( - model_id, - revision=revision, - ) - - # We do not use from_pretrained as it is too slow - try: - filenames = weight_files(model_id, revision, ".bin") - # Local files not found - except LocalEntryNotFoundError: - hub_files = weight_hub_files(model_id, revision, ".bin") - filenames = download_weights(hub_files, model_id, revision) - - with init_empty_weights(): - model = FlashRWForCausalLM(config) - - self.load_weights( - model, - filenames, - quantize, - device, - dtype, - ) - - super(FlashCausalLM, self).__init__( - model=model.to(device), - tokenizer=tokenizer, - requires_padding=False, - dtype=dtype, - device=device, - ) - - @staticmethod - def load_weights( - model: FlashRWForCausalLM, - filenames: List[Path], - quantize: Optional[str], - device: torch.device, - dtype: torch.dtype, - ): - for filename in filenames: - state_dict = torch.load(filename, map_location="cpu") - for key, value in state_dict.items(): - value = value.to(device if quantize is None else "cpu").to(dtype) - - module_name, param_name = key.rsplit(".", 1) - module = model.get_submodule(module_name) - - try: - current_parameter_tensor = module._parameters[param_name] - if current_parameter_tensor.shape != value.shape: - raise ValueError( - f"Name {key} -- Current {current_parameter_tensor.shape} and got {value.shape}" - ) - module._parameters[param_name] = value - except KeyError: - module._buffers[param_name] = value - - del value - - torch.cuda.empty_cache() - model.post_load_weights(quantize) - - -class FlashRWSharded(FlashRW): +class FlashRWSharded(FlashCausalLM): def __init__( self, model_id: str, @@ -142,20 +48,12 @@ class FlashRWSharded(FlashRW): torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights(filenames, device, dtype, process_group=self.process_group) - with init_empty_weights(): - model = FlashRWForCausalLM(config, self.process_group) + config.quantize = quantize + + model = FlashRWForCausalLM(config, weights) - torch.distributed.barrier(group=self.process_group) - self.load_weights( - model, - filenames, - quantize=quantize, - device=device, - dtype=dtype, - rank=rank, - world_size=world_size, - ) torch.distributed.barrier(group=self.process_group) super(FlashCausalLM, self).__init__( model=model.to(device), @@ -166,79 +64,3 @@ class FlashRWSharded(FlashRW): rank=rank, world_size=world_size, ) - - @staticmethod - def load_weights( - model, - filenames: List[str], - quantize: Optional[str], - device: torch.device, - dtype: torch.dtype, - rank: int, - world_size: int, - ): - parameters = dict(model.named_parameters()) - for file in filenames: - with safe_open( - file, framework="pt", device=str(device) if quantize is None else "cpu" - ) as f: - for name in f.keys(): - module_name, param_name = name.rsplit(".", 1) - module = model.get_submodule(module_name) - - current_parameter_tensor = parameters.get(name, None) - - slice_ = f.get_slice(name) - - if isinstance(module, TensorParallelColumnLinear): - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - elif isinstance(module, TensorParallelRowLinear): - if param_name == "weight": - size = slice_.get_shape()[1] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[:, start:stop] - else: - tensor = slice_[:] - # XXX: Hack for Rowlinear to add the bias only once. - if rank != 0: - tensor = torch.zeros_like(tensor) - elif isinstance(module, TensorParallelEmbedding): - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - elif name == "lm_head.weight" and model.transformer.tp_embeddings: - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - else: - try: - tensor = slice_[:] - except: - tensor = f.get_tensor(name) - - if ( - current_parameter_tensor is not None - and current_parameter_tensor.shape != tensor.shape - ): - raise ValueError( - f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" - ) - - tensor = tensor.contiguous().to(dtype) - - if current_parameter_tensor is not None: - module._parameters[param_name] = tensor - else: - module._buffers[param_name] = tensor - - model.post_load_weights(quantize) diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index e1c893d0..54634e4a 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -1,197 +1,24 @@ import torch import torch.distributed -from accelerate import init_empty_weights from opentelemetry import trace -from safetensors import safe_open -from pathlib import Path -from transformers import AutoTokenizer, GPT2Config +from transformers import AutoTokenizer, AutoConfig from typing import Optional, List from text_generation_server.models import FlashCausalLM from text_generation_server.models.custom_modeling.flash_santacoder_modeling import ( FlashSantacoderForCausalLM, - TensorParallelRowLinear, - TensorParallelColumnLinear, - TensorParallelEmbedding, ) from text_generation_server.utils import ( initialize_torch_distributed, weight_files, - download_weights, - weight_hub_files, - LocalEntryNotFoundError, + Weights, ) tracer = trace.get_tracer(__name__) -class FlashSantacoder(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - trust_remote_code: bool = False, - ): - if torch.cuda.is_available(): - device = torch.device("cuda") - dtype = torch.float16 - else: - raise NotImplementedError("FlashSantacoder is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = GPT2Config.from_pretrained( - model_id, - revision=revision, - ) - - # We do not use from_pretrained as we modified the model internal module layout - filenames = weight_files(model_id, revision, ".safetensors") - - with init_empty_weights(): - model = FlashSantacoderForCausalLM(config) - - self.load_weights( - model, - filenames, - quantize, - device, - dtype, - config.architectures[0].startswith("GPT2"), - ) - - super(FlashCausalLM, self).__init__( - model=model.to(device), - tokenizer=tokenizer, - requires_padding=False, - dtype=dtype, - device=device, - ) - - @staticmethod - def load_weights( - model: FlashSantacoderForCausalLM, - filenames: List[Path], - quantize: Optional[str], - device: torch.device, - dtype: torch.dtype, - transpose: bool, - ): - for filename in filenames: - with safe_open( - filename, - framework="pt", - device=str(device) if quantize is None else "cpu", - ) as f: - for key in f.keys(): - value = f.get_tensor(key) - value = value.to(device if quantize is None else "cpu").to(dtype) - - layer_name = ".".join(key.split(".")[:4]) - - # Fused qkv - if "q_attn.weight" in key or "kv_attn.weight" in key: - final_key = layer_name + ".c_attn.weight" - elif "q_attn.bias" in key or "kv_attn.bias" in key: - final_key = layer_name + ".c_attn.bias" - - else: - final_key = key - - module_name, param_name = final_key.rsplit(".", 1) - module = model.get_submodule(module_name) - - try: - current_parameter_tensor = module._parameters[param_name] - except KeyError: - current_parameter_tensor = None - - if current_parameter_tensor is not None: - if transpose and ( - "c_fc.weight" in key - or "c_proj.weight" in key - or "q_attn.weight" in key - or "kv_attn.weight" in key - or "c_attn.weight" in key - ): - # Tranpose as we use nn.Linear instead of Conv1D - value = value.T - - if current_parameter_tensor.device == torch.device("meta"): - # Init qkv - if "c_attn.weight" in final_key: - module._parameters[param_name] = value.new_empty( - ( - model.transformer.head_size - * (model.transformer.num_heads + 2), - value.shape[1], - ) - ) - elif "c_attn.bias" in final_key: - module._parameters[param_name] = value.new_empty( - ( - model.transformer.head_size - * (model.transformer.num_heads + 2) - ) - ) - - # Copy to correct slice - if "q_attn.weight" in key: - module._parameters[param_name][: value.shape[0]] = value - elif "q_attn.bias" in key: - module._parameters[param_name][: value.shape[0]] = value - elif "kv_attn.weight" in key: - module._parameters[param_name][ - model.transformer.head_size - * model.transformer.num_heads : - ] = value - elif "kv_attn.bias" in key: - module._parameters[param_name][ - model.transformer.head_size - * model.transformer.num_heads : - ] = value - else: - if current_parameter_tensor.shape != value.shape: - raise ValueError( - f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}" - ) - module._parameters[param_name] = value - else: - module._buffers[param_name] = value - - del value - - if model.lm_head.weight.device == torch.device("meta"): - model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight) - - torch.cuda.empty_cache() - model.post_load_weights(quantize) - - uninitialized_parameters = [] - for n, p in model.named_parameters(): - if p.data.device == torch.device("meta"): - uninitialized_parameters.append(n) - if uninitialized_parameters: - raise RuntimeError( - f"found uninitialized parameters in model : {uninitialized_parameters}" - ) - - def decode(self, generated_ids: List[int]) -> str: - # Do not skip special tokens as they are used for custom parsing rules of the generated text - return self.tokenizer.decode( - generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False - ) - - -class FlashSantacoderSharded(FlashSantacoder): +class FlashSantacoderSharded(FlashCausalLM): def __init__( self, model_id: str, @@ -214,28 +41,22 @@ class FlashSantacoderSharded(FlashSantacoder): trust_remote_code=trust_remote_code, ) - config = GPT2Config.from_pretrained( + config = AutoConfig.from_pretrained( model_id, revision=revision, + trust_remote_code=True, ) + config.quantize = quantize + config.transpose = config.architectures[0].startswith("GPT2") torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - - with init_empty_weights(): - model = FlashSantacoderForCausalLM(config, self.process_group) - - torch.distributed.barrier(group=self.process_group) - self.load_weights( - model, - filenames, - quantize=quantize, - device=device, - dtype=dtype, - rank=rank, - world_size=world_size, - transpose=config.architectures[0].startswith("GPT2"), + weights = Weights( + filenames, device=device, dtype=dtype, process_group=self.process_group ) + + model = FlashSantacoderForCausalLM(config, weights) + torch.distributed.barrier(group=self.process_group) super(FlashCausalLM, self).__init__( model=model.to(device), @@ -247,164 +68,8 @@ class FlashSantacoderSharded(FlashSantacoder): world_size=world_size, ) - @staticmethod - def load_weights( - model, - filenames: List[str], - quantize: Optional[str], - device: torch.device, - dtype: torch.dtype, - rank: int, - world_size: int, - transpose: bool, - ): - for file in filenames: - with safe_open( - file, framework="pt", device=str(device) if quantize is None else "cpu" - ) as f: - for key in f.keys(): - slice_ = f.get_slice(key) - - layer_name = ".".join(key.split(".")[:4]) - - # Fused qkv - if "q_attn.weight" in key or "kv_attn.weight" in key: - final_key = layer_name + ".c_attn.weight" - elif "q_attn.bias" in key or "kv_attn.bias" in key: - final_key = layer_name + ".c_attn.bias" - else: - final_key = key - - module_name, param_name = final_key.rsplit(".", 1) - module = model.get_submodule(module_name) - - if isinstance(module, TensorParallelColumnLinear): - dim = 1 if transpose and "weight" in param_name else 0 - size = slice_.get_shape()[dim] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = ( - slice_[start:stop] if dim == 0 else slice_[:, start:stop] - ) - elif isinstance(module, TensorParallelRowLinear): - if param_name == "weight": - dim = 0 if transpose else 1 - size = slice_.get_shape()[dim] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = ( - slice_[start:stop] - if dim == 0 - else slice_[:, start:stop] - ) - else: - tensor = slice_[:] - # XXX: Hack for Rowlinear to add the bias only once. - if rank != 0: - tensor = torch.zeros_like(tensor) - elif isinstance(module, TensorParallelEmbedding): - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - elif key == "lm_head.weight" and model.transformer.tp_embeddings: - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - else: - try: - tensor = slice_[:] - except: - tensor = f.get_tensor(key) - - tensor = tensor.contiguous().to(dtype) - - try: - current_parameter_tensor = module._parameters[param_name] - except KeyError: - current_parameter_tensor = None - - if current_parameter_tensor is not None: - if transpose and ( - "c_fc.weight" in key - or "c_proj.weight" in key - or "q_attn.weight" in key - or "kv_attn.weight" in key - or "c_attn.weight" in key - ): - # Tranpose as we use nn.Linear instead of Conv1D - tensor = tensor.T - - if current_parameter_tensor.device == torch.device("meta"): - # Init qkv - if "c_attn.weight" in final_key: - module._parameters[param_name] = tensor.new_empty( - ( - model.transformer.head_size - * (model.transformer.num_heads + 2), - tensor.shape[1], - ) - ) - elif "c_attn.bias" in final_key: - module._parameters[param_name] = tensor.new_empty( - ( - model.transformer.head_size - * (model.transformer.num_heads + 2) - ) - ) - - # Copy to correct slice - if "q_attn" in key: - size = tensor.shape[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = tensor[start:stop] - module._parameters[param_name][: tensor.shape[0]] = tensor - elif "kv_attn.weight" in key: - module._parameters[param_name][ - model.transformer.head_size - * model.transformer.num_heads : - ] = tensor - elif "kv_attn.bias" in key: - module._parameters[param_name][ - model.transformer.head_size - * model.transformer.num_heads : - ] = tensor - elif "c_attn" in key: - # Slice q_tensor by shard - q_tensor = tensor[: -2 * model.transformer.head_size] - block_size = q_tensor.shape[0] // world_size - start = rank * block_size - stop = (rank + 1) * block_size - q_tensor = q_tensor[start:stop] - - module._parameters[param_name][ - : q_tensor.shape[0] - ] = q_tensor - - # Kv tensor is copied for every shard - kv_tensor = tensor[-2 * model.transformer.head_size :] - module._parameters[param_name][ - q_tensor.shape[0] : - ] = kv_tensor - else: - if current_parameter_tensor.shape != tensor.shape: - raise ValueError( - f"Name {key} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" - ) - - module._parameters[param_name] = tensor - else: - module._buffers[param_name] = tensor - - if model.lm_head.weight.device == torch.device("meta"): - model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight) - - torch.cuda.empty_cache() - model.post_load_weights(quantize) + def decode(self, generated_ids: List[int]) -> str: + # Do not skip special tokens as they are used for custom parsing rules of the generated text + return self.tokenizer.decode( + generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False + ) diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 37ccc398..01e1c773 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -2,41 +2,25 @@ import re import torch import torch.distributed -from typing import List, Optional, Type, Tuple +from typing import List, Optional, Type -from accelerate import init_empty_weights -from safetensors import safe_open from transformers import ( AutoTokenizer, - AutoModelForCausalLM, AutoConfig, PreTrainedTokenizerBase, ) -from transformers.models.opt.parallel_layers import ( - TensorParallelColumnLinear, - TensorParallelEmbedding, - TensorParallelRowLinear, -) - from text_generation_server.models import CausalLM from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.pb import generate_pb2 -from text_generation_server.models.opt import OPT +from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM from text_generation_server.utils import ( NextTokenChooser, StoppingCriteria, initialize_torch_distributed, weight_files, + Weights, ) -HAS_BITS_AND_BYTES = True -try: - import bitsandbytes as bnb - from bitsandbytes.nn import Int8Params -except Exception as e: - HAS_BITS_AND_BYTES = False - - # CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py # we split individual characters inside special tokens like [START_DNA] @@ -168,33 +152,7 @@ class GalacticaCausalLMBatch(CausalLMBatch): ) -class Galactica(OPT): - @property - def batch_type(self) -> Type[CausalLMBatch]: - return GalacticaCausalLMBatch - - def decode(self, generated_ids: List[int]) -> str: - # Do not skip special tokens as they are used for custom parsing rules of the generated text - return self.tokenizer.decode( - generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False - ) - - def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None - ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: - """Overwrite forward to ignore position_ids""" - - # Model Forward - outputs = self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - use_cache=True, - ) - return outputs.logits, outputs.past_key_values - - -class GalacticaSharded(Galactica): +class GalacticaSharded(CausalLM): def __init__( self, model_id: str, @@ -224,26 +182,17 @@ class GalacticaSharded(Galactica): tp_parallel=True, trust_remote_code=trust_remote_code, ) + config.quantize = quantize tokenizer.pad_token_id = config.pad_token_id torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - - with init_empty_weights(): - model = AutoModelForCausalLM.from_config( - config, trust_remote_code=trust_remote_code - ) - - torch.distributed.barrier(group=self.process_group) - self.load_weights( - model, - filenames, - quantize=quantize, - device=device, - dtype=dtype, - rank=rank, - world_size=world_size, + weights = Weights( + filenames, device=device, dtype=dtype, process_group=self.process_group ) + + model = OPTForCausalLM(config, weights) + torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( model=model, @@ -255,127 +204,15 @@ class GalacticaSharded(Galactica): world_size=world_size, ) - @staticmethod - def load_weights( - model, - filenames: List[str], - quantize: Optional[str], - device: torch.device, - dtype: torch.dtype, - rank: int, - world_size: int, - ): - parameters = dict(model.named_parameters()) - for file in filenames: - with safe_open( - file, framework="pt", device=str(device) if quantize is None else "cpu" - ) as f: - for name in f.keys(): - if name == "lm_head.weight": - continue + @property + def batch_type(self) -> Type[CausalLMBatch]: + return GalacticaCausalLMBatch - module_name, param_name = name.rsplit(".", 1) - module = model.get_submodule(module_name) - current_tensor = parameters[name] - - slice_ = f.get_slice(name) - - if isinstance(module, TensorParallelColumnLinear): - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - elif isinstance(module, TensorParallelRowLinear): - if param_name == "weight": - size = slice_.get_shape()[1] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[:, start:stop] - else: - tensor = slice_[:] - # XXX: Hack for Rowlinear to add the bias only once. - if rank != 0: - tensor = torch.zeros_like(tensor) - elif isinstance(module, TensorParallelEmbedding): - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - else: - tensor = slice_[:] - - if current_tensor.shape != tensor.shape: - raise ValueError( - f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}" - ) - - tensor = tensor.contiguous().to(dtype) - - if quantize == "bitsandbytes": - if not HAS_BITS_AND_BYTES: - raise ImportError( - "bitsandbytes is not available on your machine either because it is not installed " - "or you don't have a GPU.\n" - "You can install it with `pip install bitsandbytes`." - ) - - if ( - type(module) - in [TensorParallelRowLinear, TensorParallelColumnLinear] - and param_name == "weight" - ): - tensor = Int8Params( - tensor, - has_fp16_weights=False, - requires_grad=False, - ).to(device) - state = bnb.MatmulLtState() - state.threshold = 6.0 - state.has_fp16_weights = False - state.memory_efficient_backward = False - state.use_pool = True - state.CB = tensor.CB - state.SCB = tensor.SCB - tensor.CB = None - tensor.SCB = None - - def replace_linear(state): - def linear(input, weight, bias): - out = bnb.matmul( - input, - weight, - state=state, - threshold=state.threshold, - bias=bias, - ) - - if state.CB is not None: - # we converted 8-bit row major to turing/ampere format - # in the first inference pass - # we no longer need the row-major weight - del state.CB - weight.data = state.CxB - - return out - - return linear - - module.linear = replace_linear(state) - else: - tensor = tensor.to(device) - elif quantize == "gptq": - raise NotImplementedError("`gptq` is not implemented for now") - elif quantize is None: - tensor = tensor.to(device) - else: - raise ValueError(f"Unexpected quantize `{quantize}`") - - module._parameters[param_name] = tensor - if name == "model.decoder.embed_tokens.weight": - model.lm_head._parameters["weight"] = tensor + def decode(self, generated_ids: List[int]) -> str: + # Do not skip special tokens as they are used for custom parsing rules of the generated text + return self.tokenizer.decode( + generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False + ) def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None @@ -386,10 +223,4 @@ class GalacticaSharded(Galactica): past_key_values=past_key_values, use_cache=True, ) - - # Logits are sharded, so we need to gather them - logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)] - torch.distributed.all_gather(logits, outputs.logits, group=self.process_group) - logits = torch.cat(logits, dim=2) - - return logits, outputs.past_key_values + return outputs.logits, outputs.past_key_values diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index 5ab8a624..0abf0239 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -1,34 +1,22 @@ import torch import torch.distributed -from typing import List, Optional +from typing import Optional -from accelerate import init_empty_weights -from safetensors import safe_open from transformers import ( AutoTokenizer, - AutoModelForCausalLM, AutoConfig, ) -from transformers.models.gpt_neox.parallel_layers import ( - TensorParallelColumnLinear, - TensorParallelEmbedding, - TensorParallelRowLinear, -) - from text_generation_server.models import CausalLM +from text_generation_server.models.custom_modeling.neox_modeling import ( + GPTNeoxForCausalLM, +) from text_generation_server.utils import ( initialize_torch_distributed, weight_files, + Weights, ) -HAS_BITS_AND_BYTES = True -try: - import bitsandbytes as bnb - from bitsandbytes.nn import Int8Params -except Exception as e: - HAS_BITS_AND_BYTES = False - class GPTNeoxSharded(CausalLM): def __init__( @@ -58,28 +46,18 @@ class GPTNeoxSharded(CausalLM): config = AutoConfig.from_pretrained( model_id, revision=revision, - tp_parallel=True, trust_remote_code=trust_remote_code, ) + config.quantize = quantize torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - - with init_empty_weights(): - model = AutoModelForCausalLM.from_config( - config, trust_remote_code=trust_remote_code - ) - - torch.distributed.barrier(group=self.process_group) - self.load_weights( - model, - filenames, - quantize=quantize, - device=device, - dtype=dtype, - rank=rank, - world_size=world_size, + weights = Weights( + filenames, device=device, dtype=dtype, process_group=self.process_group ) + + model = GPTNeoxForCausalLM(config, weights) + torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( model=model, @@ -91,161 +69,16 @@ class GPTNeoxSharded(CausalLM): world_size=world_size, ) - @staticmethod - def load_weights( - model, - filenames: List[str], - quantize: Optional[str], - device: torch.device, - dtype: torch.dtype, - rank: int, - world_size: int, - ): - parameters = dict(model.named_parameters()) - for file in filenames: - with safe_open( - file, framework="pt", device=str(device) if quantize is None else "cpu" - ) as f: - for name in f.keys(): - module_name, param_name = name.rsplit(".", 1) - module = model.get_submodule(module_name) - - current_parameter_tensor = parameters.get(name, None) - - slice_ = f.get_slice(name) - - if isinstance(module, TensorParallelColumnLinear): - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - elif isinstance(module, TensorParallelRowLinear): - if param_name == "weight": - size = slice_.get_shape()[1] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[:, start:stop] - else: - tensor = slice_[:] - # XXX: Hack for Rowlinear to add the bias only once. - if rank != 0: - tensor = torch.zeros_like(tensor) - elif isinstance(module, TensorParallelEmbedding): - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings: - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - else: - try: - tensor = slice_[:] - except: - tensor = f.get_tensor(name) - - if ( - current_parameter_tensor is not None - and current_parameter_tensor.shape != tensor.shape - ): - raise ValueError( - f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" - ) - - tensor = tensor.contiguous().to(dtype) - - if quantize == "bitsandbytes": - if not HAS_BITS_AND_BYTES: - raise ImportError( - "bitsandbytes is not available on your machine either because it is not installed " - "or you don't have a GPU.\n" - "You can install it with `pip install bitsandbytes`." - ) - - if ( - type(module) - in [TensorParallelRowLinear, TensorParallelColumnLinear] - and param_name == "weight" - ): - tensor = Int8Params( - tensor, - has_fp16_weights=False, - requires_grad=False, - ).to(device) - state = bnb.MatmulLtState() - state.threshold = 6.0 - state.has_fp16_weights = False - state.memory_efficient_backward = False - state.use_pool = True - state.CB = tensor.CB - state.SCB = tensor.SCB - tensor.CB = None - tensor.SCB = None - - def replace_linear(state): - def linear(input, weight, bias): - out = bnb.matmul( - input, - weight, - state=state, - threshold=state.threshold, - bias=bias, - ) - - if state.CB is not None: - # we converted 8-bit row major to turing/ampere format - # in the first inference pass - # we no longer need the row-major weight - del state.CB - weight.data = state.CxB - - return out - - return linear - - module.linear = replace_linear(state) - else: - tensor = tensor.to(device) - elif quantize == "gptq": - raise NotImplementedError("`gptq` is not implemented for now") - elif quantize is None: - tensor = tensor.to(device) - else: - raise ValueError(f"Unexpected quantize `{quantize}`") - - if current_parameter_tensor is not None: - module._parameters[param_name] = tensor - else: - module._buffers[param_name] = tensor - def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ): - if self.model.gpt_neox.tp_embeddings: - outputs = self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - use_cache=True, - ) + outputs = self.model.forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=True, + ) - # Logits are sharded, so we need to gather them - logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)] - torch.distributed.all_gather( - logits, outputs.logits, group=self.process_group - ) - logits = torch.cat(logits, dim=2) - - return logits, outputs.past_key_values - # While the model itself is sharded, the embeddings might not as they might not be dividable by num-shard - else: - return super(GPTNeoxSharded, self).forward( - input_ids, attention_mask, position_ids, past_key_values - ) + logits = outputs.logits + return logits, outputs.past_key_values diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index 9cc4d5e1..16cb48b7 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -1,52 +1,22 @@ import torch import torch.distributed -from typing import List, Optional, Tuple +from typing import Optional -from accelerate import init_empty_weights -from safetensors import safe_open from transformers import ( AutoTokenizer, - AutoModelForCausalLM, AutoConfig, ) -from transformers.models.opt.parallel_layers import ( - TensorParallelColumnLinear, - TensorParallelEmbedding, - TensorParallelRowLinear, -) - +from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM from text_generation_server.models import CausalLM from text_generation_server.utils import ( initialize_torch_distributed, weight_files, + Weights, ) -HAS_BITS_AND_BYTES = True -try: - import bitsandbytes as bnb - from bitsandbytes.nn import Int8Params -except Exception as e: - HAS_BITS_AND_BYTES = False - -class OPT(CausalLM): - def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None - ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: - """Overwrite forward to ignore position_ids""" - - # Model Forward - outputs = self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - use_cache=True, - ) - return outputs.logits, outputs.past_key_values - - -class OPTSharded(OPT): +class OPTSharded(CausalLM): def __init__( self, model_id: str, @@ -73,29 +43,19 @@ class OPTSharded(OPT): config = AutoConfig.from_pretrained( model_id, revision=revision, - tp_parallel=True, trust_remote_code=trust_remote_code, ) + config.quantize = quantize tokenizer.pad_token_id = config.pad_token_id torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - - with init_empty_weights(): - model = AutoModelForCausalLM.from_config( - config, trust_remote_code=trust_remote_code - ) - - torch.distributed.barrier(group=self.process_group) - self.load_weights( - model, - filenames, - quantize=quantize, - device=device, - dtype=dtype, - rank=rank, - world_size=world_size, + weights = Weights( + filenames, device=device, dtype=dtype, process_group=self.process_group ) + + model = OPTForCausalLM(config, weights) + torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( model=model, @@ -107,128 +67,6 @@ class OPTSharded(OPT): world_size=world_size, ) - @staticmethod - def load_weights( - model, - filenames: List[str], - quantize: Optional[str], - device: torch.device, - dtype: torch.dtype, - rank: int, - world_size: int, - ): - parameters = dict(model.named_parameters()) - for file in filenames: - with safe_open( - file, framework="pt", device=str(device) if quantize is None else "cpu" - ) as f: - for name in f.keys(): - if name == "lm_head.weight": - continue - - module_name, param_name = name.rsplit(".", 1) - module = model.get_submodule(module_name) - current_tensor = parameters[name] - - slice_ = f.get_slice(name) - - if isinstance(module, TensorParallelColumnLinear): - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - elif isinstance(module, TensorParallelRowLinear): - if param_name == "weight": - size = slice_.get_shape()[1] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[:, start:stop] - else: - tensor = slice_[:] - # XXX: Hack for Rowlinear to add the bias only once. - if rank != 0: - tensor = torch.zeros_like(tensor) - elif isinstance(module, TensorParallelEmbedding): - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - else: - tensor = slice_[:] - - if current_tensor.shape != tensor.shape: - raise ValueError( - f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}" - ) - - tensor = tensor.contiguous().to(dtype) - - if quantize == "bitsandbytes": - if not HAS_BITS_AND_BYTES: - raise ImportError( - "bitsandbytes is not available on your machine either because it is not installed " - "or you don't have a GPU.\n" - "You can install it with `pip install bitsandbytes`." - ) - - if ( - type(module) - in [TensorParallelRowLinear, TensorParallelColumnLinear] - and param_name == "weight" - ): - tensor = Int8Params( - tensor, - has_fp16_weights=False, - requires_grad=False, - ).to(device) - state = bnb.MatmulLtState() - state.threshold = 6.0 - state.has_fp16_weights = False - state.memory_efficient_backward = False - state.use_pool = True - state.CB = tensor.CB - state.SCB = tensor.SCB - tensor.CB = None - tensor.SCB = None - - def replace_linear(state): - def linear(input, weight, bias): - out = bnb.matmul( - input, - weight, - state=state, - threshold=state.threshold, - bias=bias, - ) - - if state.CB is not None: - # we converted 8-bit row major to turing/ampere format - # in the first inference pass - # we no longer need the row-major weight - del state.CB - weight.data = state.CxB - - return out - - return linear - - module.linear = replace_linear(state) - else: - tensor = tensor.to(device) - elif quantize == "gptq": - raise NotImplementedError("`gptq` is not implemented for now") - elif quantize is None: - tensor = tensor.to(device) - else: - raise ValueError(f"Unexpected quantize `{quantize}`") - - module._parameters[param_name] = tensor - if name == "model.decoder.embed_tokens.weight": - model.lm_head._parameters["weight"] = tensor - def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ): @@ -239,9 +77,4 @@ class OPTSharded(OPT): use_cache=True, ) - # Logits are sharded, so we need to gather them - logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)] - torch.distributed.all_gather(logits, outputs.logits, group=self.process_group) - logits = torch.cat(logits, dim=2) - - return logits, outputs.past_key_values + return outputs.logits, outputs.past_key_values diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index d12b89d2..c89462fc 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -3,31 +3,20 @@ import torch.distributed from typing import List, Optional, Tuple -from accelerate import init_empty_weights -from safetensors import safe_open from transformers import ( AutoTokenizer, - AutoModelForSeq2SeqLM, AutoConfig, ) from text_generation_server.models import Seq2SeqLM +from text_generation_server.models.custom_modeling.t5_modeling import ( + T5ForConditionalGeneration, +) from text_generation_server.utils import ( initialize_torch_distributed, weight_files, + Weights, ) -from transformers.models.t5.parallel_layers import ( - TensorParallelRowLinear, - TensorParallelColumnLinear, - TensorParallelEmbedding, -) - -HAS_BITS_AND_BYTES = True -try: - import bitsandbytes as bnb - from bitsandbytes.nn import Int8Params -except ImportError as e: - HAS_BITS_AND_BYTES = False class T5Sharded(Seq2SeqLM): @@ -46,6 +35,13 @@ class T5Sharded(Seq2SeqLM): device = torch.device("cpu") dtype = torch.float32 + config = AutoConfig.from_pretrained( + model_id, + revision=revision, + trust_remote_code=trust_remote_code, + ) + config.quantize = quantize + tokenizer = AutoTokenizer.from_pretrained( model_id, revision=revision, @@ -53,33 +49,16 @@ class T5Sharded(Seq2SeqLM): truncation_side="left", trust_remote_code=trust_remote_code, ) - - config = AutoConfig.from_pretrained( - model_id, - revision=revision, - tp_parallel=True, - trust_remote_code=trust_remote_code, - ) tokenizer.bos_token_id = config.decoder_start_token_id torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - - with init_empty_weights(): - model = AutoModelForSeq2SeqLM.from_config( - config, trust_remote_code=trust_remote_code - ) - - torch.distributed.barrier(group=self.process_group) - self.load_weights( - model, - filenames, - quantize=quantize, - device=device, - dtype=dtype, - rank=rank, - world_size=world_size, + weights = Weights( + filenames, device=device, dtype=dtype, process_group=self.process_group ) + + model = T5ForConditionalGeneration(config, weights) + torch.distributed.barrier(group=self.process_group) super(Seq2SeqLM, self).__init__( model=model, @@ -91,151 +70,6 @@ class T5Sharded(Seq2SeqLM): world_size=world_size, ) - @staticmethod - def load_weights( - model, - filenames: List[str], - quantize: Optional[str], - device: torch.device, - dtype: torch.dtype, - rank: int, - world_size: int, - ): - parameters = dict(model.named_parameters()) - for file in filenames: - with safe_open( - file, framework="pt", device=str(device) if quantize is None else "cpu" - ) as f: - for name in f.keys(): - module_name, param_name = name.rsplit(".", 1) - module = model.get_submodule(module_name) - - current_parameter_tensor = parameters.get(name, None) - - slice_ = f.get_slice(name) - - if isinstance(module, TensorParallelColumnLinear): - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - elif isinstance(module, TensorParallelRowLinear): - if param_name == "weight": - size = slice_.get_shape()[1] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[:, start:stop] - else: - tensor = slice_[:] - # XXX: Hack for Rowlinear to add the bias only once. - if rank != 0: - tensor = torch.zeros_like(tensor) - elif isinstance(module, TensorParallelEmbedding): - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - elif name == "lm_head.weight": - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - elif "relative_attention_bias.weight" in name: - size = slice_.get_shape()[1] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[:, start:stop] - else: - try: - tensor = slice_[:] - except: - tensor = f.get_tensor(name) - - if ( - current_parameter_tensor is not None - and current_parameter_tensor.shape != tensor.shape - ): - raise ValueError( - f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" - ) - - tensor = tensor.contiguous() - - # See: https://github.com/huggingface/transformers/blob/1fe1e3caa44617047f149bcc0c0b566343b714a7/src/transformers/models/t5/modeling_t5.py#LL316C15-L316C71 - if module_name.endswith("wo"): - tensor = tensor.to(torch.float32) - else: - tensor = tensor.to(dtype) - - if quantize == "bitsandbytes" and not module_name.endswith("wo"): - if not HAS_BITS_AND_BYTES: - raise ImportError( - "bitsandbytes is not available on your machine either because it is not installed " - "or you don't have a GPU.\n" - "You can install it with `pip install bitsandbytes`." - ) - - if ( - type(module) - in [TensorParallelRowLinear, TensorParallelColumnLinear] - and param_name == "weight" - ): - tensor = Int8Params( - tensor, - has_fp16_weights=False, - requires_grad=False, - ).to(device) - state = bnb.MatmulLtState() - state.threshold = 6.0 - state.has_fp16_weights = False - state.memory_efficient_backward = False - state.use_pool = True - state.CB = tensor.CB - state.SCB = tensor.SCB - tensor.CB = None - tensor.SCB = None - - def replace_linear(state): - def linear(input, weight, bias): - out = bnb.matmul( - input, - weight, - state=state, - threshold=state.threshold, - bias=bias, - ) - - if state.CB is not None: - # we converted 8-bit row major to turing/ampere format - # in the first inference pass - # we no longer need the row-major weight - del state.CB - weight.data = state.CxB - - return out - - return linear - - module.linear = replace_linear(state) - else: - tensor = tensor.to(device) - elif quantize == "gptq" and not module_name.endswith("wo"): - raise NotImplementedError("`gptq` is not implemented for now") - elif quantize is None or module_name.endswith("wo"): - tensor = tensor.to(device) - else: - raise ValueError(f"Unexpected quantize `{quantize}`") - - if current_parameter_tensor is not None: - module._parameters[param_name] = tensor - else: - module._buffers[param_name] = tensor - def forward( self, input_ids, @@ -260,13 +94,8 @@ class T5Sharded(Seq2SeqLM): use_cache=True, ) - # Logits are sharded, so we need to gather them - logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)] - torch.distributed.all_gather(logits, outputs.logits, group=self.process_group) - logits = torch.cat(logits, dim=2) - return ( - logits, + outputs.logits, outputs.encoder_last_hidden_state, outputs.past_key_values, ) diff --git a/server/text_generation_server/utils/__init__.py b/server/text_generation_server/utils/__init__.py index 6a351d66..befedcf0 100644 --- a/server/text_generation_server/utils/__init__.py +++ b/server/text_generation_server/utils/__init__.py @@ -1,5 +1,6 @@ from text_generation_server.utils.convert import convert_file, convert_files from text_generation_server.utils.dist import initialize_torch_distributed +from text_generation_server.utils.weights import Weights from text_generation_server.utils.hub import ( weight_files, weight_hub_files, @@ -35,4 +36,5 @@ __all__ = [ "StoppingCriteria", "StopSequenceCriteria", "FinishReason", + "Weights", ] diff --git a/server/text_generation_server/utils/dist.py b/server/text_generation_server/utils/dist.py index 9785493e..fe9c3b7b 100644 --- a/server/text_generation_server/utils/dist.py +++ b/server/text_generation_server/utils/dist.py @@ -4,6 +4,37 @@ import torch from datetime import timedelta +class FakeBarrier: + def wait(self): + pass + + +class FakeGroup: + def __init__(self, rank, size): + self._rank = rank + self._size = size + + def allreduce(self, *args, **kwargs): + return FakeBarrier() + + def allgather(self, inputs, local_tensor, **kwargs): + assert ( + len(inputs[0]) == len(local_tensor) == 1 + ), f"{len(inputs[0])} != {len(local_tensor)} != 1, and the FakeGroup is supposed to join on simple tensors" + for input_ in inputs: + input_[0].data = local_tensor[0].data + return FakeBarrier() + + def barrier(self, *args, **kwargs): + return FakeBarrier() + + def size(self): + return self._size + + def rank(self): + return self._rank + + def initialize_torch_distributed(): rank = int(os.getenv("RANK", "0")) world_size = int(os.getenv("WORLD_SIZE", "1")) @@ -23,13 +54,18 @@ def initialize_torch_distributed(): backend = "gloo" options = None - # Call the init process. - torch.distributed.init_process_group( - backend=backend, - world_size=world_size, - rank=rank, - timeout=timedelta(seconds=60), - pg_options=options, - ) + if world_size == 1: + return FakeGroup(rank, world_size), rank, world_size + else: + if os.getenv("DEBUG", None) == "1": + return FakeGroup(rank, world_size), rank, world_size + # Call the init process. + torch.distributed.init_process_group( + backend=backend, + world_size=world_size, + rank=rank, + timeout=timedelta(seconds=60), + pg_options=options, + ) - return torch.distributed.group.WORLD, rank, world_size + return torch.distributed.group.WORLD, rank, world_size diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 127f9ba4..ee32a0dc 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -1,176 +1,240 @@ import torch +import torch.distributed from torch import nn from torch.nn import functional as F -from typing import Optional +from typing import List HAS_BITS_AND_BYTES = True try: - from bitsandbytes.nn import Linear8bitLt -except ImportError as e: + import bitsandbytes as bnb + from bitsandbytes.nn import Int8Params + +except ImportError: HAS_BITS_AND_BYTES = False +from accelerate import init_empty_weights -class FastLinear(nn.Linear): + +# Monkey patching +@classmethod +def load_layer_norm(cls, prefix, weights, eps): + weight = weights.get_tensor(f"{prefix}.weight") + bias = weights.get_tensor(f"{prefix}.bias") + with init_empty_weights(): + ln = cls(weight.shape, eps=eps) + + ln.weight = nn.Parameter(weight) + ln.bias = nn.Parameter(bias) + return ln + + +torch.nn.LayerNorm.load = load_layer_norm + + +class FastLinear(nn.Module): def __init__( self, - in_features: int, - out_features: int, - bias: bool = True, - device=None, - dtype=None, + weight, + bias, ) -> None: - super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype) - self.quantized = False - self.bnb_linear = None - - def prepare_weights(self, quantize: Optional[str] = None): - if quantize == "bitsandbytes": - if not HAS_BITS_AND_BYTES: - raise ImportError( - "bitsandbytes is not available on your machine either because it is not installed " - "or you don't have a GPU.\n" - "You can install it with `pip install bitsandbytes`." - ) - - self.quantized = True - self.bnb_linear = Linear8bitLt( - self.in_features, - self.out_features, - has_fp16_weights=False, - threshold=6.0, - bias=False, - ) - # Copy data to bnb_linear - self.bnb_linear.weight.data = self.weight.data - if self.bias is not None: - self.bnb_linear.bias = nn.Parameter(self.bias) - - # Delete reference to data - self.weight = None + super().__init__() + self.weight = nn.Parameter(weight) + if bias is not None: + self.bias = nn.Parameter(bias) + else: self.bias = None - elif quantize == "gptq": - raise NotImplementedError("`gptq` is not implemented for now") - elif quantize is None: - self.weight = nn.Parameter(self.weight.T) + + @classmethod + def load(cls, config, prefix: str, weights, bias: bool): + weight = weights.get_tensor(f"{prefix}.weight") + if bias: + bias = weights.get_tensor(f"{prefix}.bias") else: - raise ValueError(f"Unexpected quantize `{quantize}`") + bias = None + return cls(weight, bias) def forward(self, input: torch.Tensor) -> torch.Tensor: - if self.quantized: - return self.bnb_linear(input) - else: - if self.bias is not None: - return torch.addmm(self.bias, input, self.weight) - return torch.matmul(input, self.weight) + return F.linear(input, self.weight, self.bias) -class TensorParallelColumnLinear(FastLinear): +class Linear8bitLt(nn.Module): def __init__( self, - in_features, - out_features, - process_group: torch.distributed.ProcessGroup, - bias=True, - device=None, - dtype=None, + weight, + bias, + has_fp16_weights=True, + memory_efficient_backward=False, + threshold=0.0, + index=None, ): - self.process_group = process_group - self.tp_world_size = process_group.size() - assert out_features % self.tp_world_size == 0 - out_features = out_features // self.tp_world_size + super().__init__() + assert ( + not memory_efficient_backward + ), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0" + self.state = bnb.MatmulLtState() + self.index = index - super().__init__( - in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - dtype=dtype, + # Necessary for stacked layers + self.state.threshold = threshold + self.state.has_fp16_weights = has_fp16_weights + self.state.memory_efficient_backward = memory_efficient_backward + if threshold > 0.0 and not has_fp16_weights: + self.state.use_pool = True + + self.weight = Int8Params( + weight.data, + has_fp16_weights=has_fp16_weights, + requires_grad=has_fp16_weights, ) + self.weight.cuda(weight.device) + self.bias = bias + def init_8bit_state(self): + self.state.CB = self.weight.CB + self.state.SCB = self.weight.SCB + self.weight.CB = None + self.weight.SCB = None -class TensorParallelRowLinear(FastLinear): - def __init__( - self, - in_features, - out_features, - process_group: torch.distributed.ProcessGroup, - reduce=True, - bias=True, - device=None, - dtype=None, - ): - self.process_group = process_group - self.tp_world_size = process_group.size() - self.reduce = reduce - assert in_features % self.tp_world_size == 0 - in_features = in_features // self.tp_world_size + def forward(self, x: torch.Tensor): + self.state.is_training = self.training + if self.weight.CB is not None: + self.init_8bit_state() - super().__init__( - in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - dtype=dtype, - ) + # weights are cast automatically as Int8Params, but the bias has to be cast manually + if self.bias is not None and self.bias.dtype != x.dtype: + self.bias.data = self.bias.data.to(x.dtype) - def forward(self, input: torch.Tensor) -> torch.Tensor: - out = super(TensorParallelRowLinear, self).forward(input) - if self.reduce: - torch.distributed.all_reduce(out, group=self.process_group) + out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) + if not self.state.has_fp16_weights: + if self.state.CB is not None and self.state.CxB is not None: + # we converted 8-bit row major to turing/ampere format in the first inference pass + # we no longer need the row-major weight + del self.state.CB + self.weight.data = self.state.CxB return out -class TensorParallelEmbedding(nn.Embedding): - def __init__( - self, - num_embeddings, - embedding_dim, - process_group: torch.distributed.ProcessGroup, - reduce=True, - padding_idx=None, - max_norm=None, - norm_type=2.0, - scale_grad_by_freq=False, - sparse=False, - _weight=None, - device=None, - dtype=None, - ): - self.reduce = reduce +def get_linear(weight, bias, quantize): + if quantize is None: + linear = FastLinear(weight, bias) + elif quantize == "bitsandbytes": + linear = Linear8bitLt( + weight, + bias, + has_fp16_weights=False, + threshold=6.0, + ) + if bias is not None: + linear.bias = nn.Parameter(bias) + elif quantize == "gptq": + raise NotImplementedError("Soon") + else: + raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.") + return linear + + +class SuperLayer(nn.Module): + def __init__(self, linear): + super().__init__() + self.linear = linear + + def forward(self, x): + return self.linear.forward(x) + + +class TensorParallelHead(SuperLayer): + def __init__(self, linear, process_group): + super().__init__(linear) self.process_group = process_group - self.tp_rank = process_group.rank() - self.tp_world_size = process_group.size() - self.original_num_embeddings = num_embeddings - - assert num_embeddings % self.tp_world_size == 0 - block_size = num_embeddings // self.tp_world_size - # inputs in `[min_id, max_id[` are handled by `self` to get embeddings - self.min_id = self.tp_rank * block_size - self.max_id = (self.tp_rank + 1) * block_size - - # Additional entry that will map to zero - # Used for masking - self.null_idx = block_size - - super().__init__( - block_size, - embedding_dim, - padding_idx=padding_idx, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse, - _weight=_weight, - device=device, - dtype=dtype, + @staticmethod + def load(config, prefix: str, weights): + weight = weights.get_sharded(f"{prefix}.weight", dim=0) + return TensorParallelHead( + get_linear(weight, bias=None, quantize=config.quantize), + process_group=weights.process_group, ) - def add_null_idx(self): + def forward(self, input: torch.Tensor) -> torch.Tensor: + output = super().forward(input) + # Logits are sharded, so we need to gather them + world_output = [ + torch.empty_like(output) for _ in range(self.process_group.size()) + ] + torch.distributed.all_gather(world_output, output, group=self.process_group) + world_output = torch.cat(world_output, dim=-1) + return world_output + + +class TensorParallelColumnLinear(SuperLayer): + @classmethod + def load(cls, config, prefix: str, weights, bias: bool): + weight = weights.get_sharded(f"{prefix}.weight", dim=0) + if bias: + bias = weights.get_sharded(f"{prefix}.bias", dim=0) + else: + bias = None + return cls(get_linear(weight, bias, config.quantize)) + + @classmethod + def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int): + w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes] + weight = torch.cat(w, dim=dim) + + if bias: + b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes] + bias = torch.cat(b, dim=0) + else: + bias = None + return cls(get_linear(weight, bias, config.quantize)) + + +class TensorParallelRowLinear(SuperLayer): + def __init__(self, linear, process_group): + super().__init__(linear) + self.process_group = process_group + + @classmethod + def load(cls, config, prefix: str, weights, bias: bool): + weight = weights.get_sharded(f"{prefix}.weight", dim=1) + if bias and weights.process_group.rank() == 0: + # Rank is only on the first rank process + bias = weights.get_tensor(f"{prefix}.bias") + else: + bias = None + return cls( + get_linear(weight, bias, config.quantize), + process_group=weights.process_group, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + out = super().forward(input) + torch.distributed.all_reduce(out, group=self.process_group) + return out + + +class TensorParallelEmbedding(nn.Module): + def __init__(self, prefix: str, weights, reduce=True): + super().__init__() + weight = weights.get_sharded(f"{prefix}.weight", dim=0) + num_embeddings = weights.get_shape(f"{prefix}.weight")[0] + + process_group = weights.process_group + + world_size = process_group.size() + rank = process_group.rank() + + block_size = num_embeddings // world_size + self.min_id = rank * block_size + self.max_id = min(num_embeddings, (rank + 1) * block_size) + self.null_idx = block_size + self.process_group = weights.process_group + self.reduce = reduce + """Additional 0 entry used for masking""" - self.weight = nn.Parameter(F.pad(self.weight, (0, 0, 0, 1))) + self.weight = nn.Parameter(F.pad(weight, (0, 0, 0, 1))) def forward(self, input: torch.Tensor) -> torch.Tensor: # default all out of bounds values to `self.null_idx` that will then be mapped to 0 @@ -180,7 +244,7 @@ class TensorParallelEmbedding(nn.Embedding): self.null_idx, input - self.min_id, ) - out = super().forward(input) + out = torch.nn.functional.embedding(input, self.weight) if self.reduce: torch.distributed.all_reduce(out, group=self.process_group) return out @@ -232,7 +296,34 @@ try: from flash_attn.layers.rotary import RotaryEmbedding import rotary_emb - class PositionRotaryEmbedding(RotaryEmbedding): + class PositionRotaryEmbedding(nn.Module): + def __init__(self, inv_freq): + super().__init__() + + self.register_buffer("inv_freq", inv_freq) + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + self._cos_k_cached = None + self._sin_k_cached = None + + @classmethod + def static(cls, dim, base, device): + inv_freq = 1.0 / ( + base + ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) + ) + return cls(inv_freq) + + @classmethod + def load(cls, prefix, weights): + # XXX: Always load this in float32 ! + dtype = weights.dtype + weights.dtype = torch.float32 + inv_freq = weights.get_tensor(f"{prefix}.inv_freq") + weights.dtype = dtype + return cls(inv_freq) + def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py new file mode 100644 index 00000000..76a4f65a --- /dev/null +++ b/server/text_generation_server/utils/weights.py @@ -0,0 +1,77 @@ +from pathlib import Path +from typing import List +from safetensors import safe_open + + +class Weights: + def __init__(self, filenames: List[Path], device, dtype, process_group): + routing = {} + for filename in filenames: + with safe_open(filename, framework="pytorch") as f: + for k in f.keys(): + if k in routing: + raise RuntimeError( + f"Key {k} was found in multiple files: {filename} and {routing[k]}" + ) + routing[k] = filename + self.routing = routing + self.device = device + self.dtype = dtype + self.process_group = process_group + self._handles = {} + + def _get_handle(self, filename): + if filename not in self._handles: + f = safe_open(filename, framework="pytorch") + self._handles[filename] = f + + return self._handles[filename] + + def get_filename(self, tensor_name: str) -> str: + filename = self.routing.get(tensor_name, None) + if filename is None: + raise RuntimeError(f"weight {tensor_name} does not exist") + return str(filename) + + def _get_slice(self, tensor_name: str): + filename = self.get_filename(tensor_name) + f = self._get_handle(filename) + slice_ = f.get_slice(tensor_name) + return slice_ + + def get_shape(self, tensor_name: str): + return self._get_slice(tensor_name).get_shape() + + def get_tensor(self, tensor_name: str): + filename = self.get_filename(tensor_name) + f = self._get_handle(filename) + tensor = f.get_tensor(tensor_name) + tensor = tensor.to(dtype=self.dtype) + tensor = tensor.to(device=self.device) + return tensor + + def get_sharded(self, tensor_name: str, dim: int): + filename = self.get_filename(tensor_name) + world_size = self.process_group.size() + rank = self.process_group.rank() + + f = self._get_handle(filename) + slice_ = f.get_slice(tensor_name) + size = slice_.get_shape()[dim] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + + assert ( + size % world_size == 0 + ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" + + if dim == 0: + tensor = slice_[start:stop] + elif dim == 1: + tensor = slice_[:, start:stop] + else: + raise NotImplementedError("Let's make that generic when needed") + tensor = tensor.to(dtype=self.dtype) + tensor = tensor.to(device=self.device) + return tensor From e496c9ba5b574ce4e9d04d3b16bce67759ff0445 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Fri, 9 Jun 2023 11:55:29 +0200 Subject: [PATCH 11/22] feat(server): optimize dist ops (#434) --- .../custom_modeling/flash_neox_modeling.py | 3 +- .../custom_modeling/flash_rw_modeling.py | 6 ++-- .../flash_santacoder_modeling.py | 4 ++- server/text_generation_server/utils/layers.py | 32 +++++++++++++++++-- 4 files changed, 38 insertions(+), 7 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 0fe43bcb..c045f16e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -265,7 +265,8 @@ class FlashNeoXLayer(nn.Module): mlp_output = self.mlp(ln2_hidden_states) intermediate = mlp_output + attn_output - torch.distributed.all_reduce(intermediate, group=self.process_group) + if self.process_group.size() > 1: + torch.distributed.all_reduce(intermediate, group=self.process_group) return intermediate + hidden_states, None else: diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 55195162..af9fa548 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -440,7 +440,8 @@ class FlashRWLayer(nn.Module): mlp_output = self.mlp(ln_hidden_states) intermediate = mlp_output + attn_output - torch.distributed.all_reduce(intermediate, group=self.process_group) + if self.process_group.size() > 1: + torch.distributed.all_reduce(intermediate, group=self.process_group) return intermediate, residual else: @@ -524,7 +525,8 @@ class FlashRWLargeLayer(nn.Module): intermediate = attn_output + mlp_output - torch.distributed.all_reduce(intermediate, group=self.process_group) + if self.process_group.size() > 1: + torch.distributed.all_reduce(intermediate, group=self.process_group) return intermediate, residual diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 888a6066..fcf6be68 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -346,7 +346,9 @@ class FlashSantacoderModel(nn.Module): pre_allocate_past_size: Optional[int] = None, ): hidden_states = self.wte(input_ids) + self.wpe(position_ids) - torch.distributed.all_reduce(hidden_states, group=self.process_group) + + if self.process_group.size() > 1: + torch.distributed.all_reduce(hidden_states, group=self.process_group) # Prefill if past_key_values is None: diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index ee32a0dc..93865d52 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -158,8 +158,33 @@ class TensorParallelHead(SuperLayer): ) def forward(self, input: torch.Tensor) -> torch.Tensor: + world_size = self.process_group.size() + if world_size == 1: + return super().forward(input) + + if len(input.shape) == 2 and isinstance(self.linear, FastLinear): + out_dim = self.linear.weight.shape[0] + + if input.shape[0] == 1: + world_out = input.new_empty(1, out_dim * world_size) + local_out = input.new_empty(1, out_dim) + gather_input = local_out + else: + world_out = input.new_empty(out_dim * world_size, input.shape[0]) + gather_input = input.new_empty(out_dim, input.shape[0]) + local_out = gather_input.T + + torch.mm(input, self.linear.weight.T, out=local_out) + + torch.distributed.all_gather_into_tensor( + world_out, gather_input, group=self.process_group + ) + + if input.shape[0] == 1: + return world_out + return world_out.T + output = super().forward(input) - # Logits are sharded, so we need to gather them world_output = [ torch.empty_like(output) for _ in range(self.process_group.size()) ] @@ -211,7 +236,8 @@ class TensorParallelRowLinear(SuperLayer): def forward(self, input: torch.Tensor) -> torch.Tensor: out = super().forward(input) - torch.distributed.all_reduce(out, group=self.process_group) + if self.process_group.size() > 1: + torch.distributed.all_reduce(out, group=self.process_group) return out @@ -245,7 +271,7 @@ class TensorParallelEmbedding(nn.Module): input - self.min_id, ) out = torch.nn.functional.embedding(input, self.weight) - if self.reduce: + if self.reduce and self.process_group.size() > 1: torch.distributed.all_reduce(out, group=self.process_group) return out From d4eb60f48d22f4ae97acef4a239dccdf9549877d Mon Sep 17 00:00:00 2001 From: "A.J" <47569261+antferdom@users.noreply.github.com> Date: Mon, 12 Jun 2023 13:59:22 +0200 Subject: [PATCH 12/22] docs(launcher): fix CUDA_VISIBLE_DEVICES helper comment (#441) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? It solves a typo in the comment sections referencing the environment variable `CUDA_VISIBLE_DEVICES`. No misspelling references to this variable have been found in code logic leading to undefined behaviour or bugs. This PR is not expected to perform any code logic modification. --- launcher/src/main.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index f59ff685..36f6f6b6 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -60,8 +60,8 @@ struct Args { sharded: Option, /// The number of shards to use if you don't want to use all GPUs on a given machine. - /// You can use `CUDA_VISIBLE_DEVICE=0,1 text-generation-launcher... --num_shard 2` - /// and `CUDA_VISIBLE_DEVICE=2,3 text-generation-launcher... --num_shard 2` to + /// You can use `CUDA_VISIBLE_DEVICES=0,1 text-generation-launcher... --num_shard 2` + /// and `CUDA_VISIBLE_DEVICES=2,3 text-generation-launcher... --num_shard 2` to /// launch 2 copies with 2 shard each on a given machine with 4 GPUs for instance. #[clap(long, env)] num_shard: Option, From ca650e5bff1af8c8580c2be0d4dad37dd3285247 Mon Sep 17 00:00:00 2001 From: sayf eddine hammemi Date: Mon, 12 Jun 2023 15:24:53 +0200 Subject: [PATCH 13/22] fix(makefile): Fix typo and use POSIX comparison in the makefile (#443) # What does this PR do? This PR fixes: - The usage of non posix comparison which may fail depending on the shell used (`=` will always work, `==` only with bash) - Typo in the env variable name displayed in the error message `BUILD_EXTENSION` instead of `BUILD_EXTENSIONS` Fixes #422 --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index c7f649ec..3c2f2b9d 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ install-server: cd server && make install install-custom-kernels: - if [ "$$BUILD_EXTENSIONS" == "True" ]; then cd server/custom_kernels && python setup.py install; else echo "Custom kernels are disabled, you need set to BUILD_EXTENSION environment variable to 'True' in order to build them. (Please read the docs, kernels might not work on all hardware)"; fi + if [ "$$BUILD_EXTENSIONS" = "True" ]; then cd server/custom_kernels && python setup.py install; else echo "Custom kernels are disabled, you need to set the BUILD_EXTENSIONS environment variable to 'True' in order to build them. (Please read the docs, kernels might not work on all hardware)"; fi install-integration-tests: cd integration-tests && pip install -r requirements.txt From 5ce89059f8149eaf313c63e9ded4199670cd74bb Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Mon, 12 Jun 2023 18:30:29 +0200 Subject: [PATCH 14/22] feat(server): pre-allocate past key values for flash causal LM (#412) --- server/Makefile-flash-att | 4 +- .../custom_modeling/flash_llama_modeling.py | 114 ++++--- .../custom_modeling/flash_neox_modeling.py | 121 +++++--- .../custom_modeling/flash_rw_modeling.py | 182 +++++++---- .../flash_santacoder_modeling.py | 125 ++++---- .../models/flash_causal_lm.py | 293 +++++++++--------- 6 files changed, 494 insertions(+), 345 deletions(-) diff --git a/server/Makefile-flash-att b/server/Makefile-flash-att index ad894bfa..0e67a9e4 100644 --- a/server/Makefile-flash-att +++ b/server/Makefile-flash-att @@ -1,9 +1,9 @@ -flash_att_commit := d478eeec8f16c7939c54e4617dbd36f59b8eeed7 +flash_att_commit := 06ece1a1525ebcf4e183ac76b1e5108d2872f57f flash-attention: # Clone flash attention pip install packaging - git clone https://github.com/HazyResearch/flash-attention.git + git clone https://github.com/OlivierDehaene/flash-attention.git build-flash-attention: flash-attention cd flash-attention && git fetch && git checkout $(flash_att_commit) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 8a35ffa8..993e1e2a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -128,11 +128,14 @@ class FlashLlamaAttention(torch.nn.Module): hidden_states, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, 3, self.num_heads, self.head_size) @@ -142,7 +145,7 @@ class FlashLlamaAttention(torch.nn.Module): self.rotary_emb(qkv[:, 1], cos, sin) # Prefill - if layer_past_present_indices is None: + if prefill: # Copy to layer past layer_past[...] = qkv[:, 1:] @@ -154,8 +157,10 @@ class FlashLlamaAttention(torch.nn.Module): qkv[:, 1], qkv[:, 2], attn_output, - cu_seqlens, - cu_seqlens, + start_seq, + end_seq, + start_seq, + end_seq, max_s, max_s, 0.0, @@ -170,7 +175,7 @@ class FlashLlamaAttention(torch.nn.Module): else: query = qkv[:, 0] # Add present to the layer_past tensor at the correct indices - layer_past[layer_past_present_indices] = qkv[:, 1:] + layer_past[past_present_indices] = qkv[:, 1:] # output attn_output = torch.empty_like(query) @@ -180,8 +185,10 @@ class FlashLlamaAttention(torch.nn.Module): layer_past[:, 0], layer_past[:, 1], attn_output, - cu_seqlens_q, - cu_seqlens, + start_seq_q, + end_seq_q, + start_seq, + end_seq, 1, max_s, 0.0, @@ -258,11 +265,14 @@ class FlashLlamaLayer(nn.Module): residual, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -271,11 +281,14 @@ class FlashLlamaLayer(nn.Module): normed_hidden_states, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ) # faster post attention rms norm @@ -322,35 +335,37 @@ class FlashLlamaModel(torch.nn.Module): self, input_ids, position_ids, - cu_seqlens, - cu_seqlens_q, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, - past_key_values: Optional[torch.Tensor] = None, + past_present_indices, + past_key_values=None, pre_allocate_past_size: Optional[int] = None, ): hidden_states = self.embed_tokens(input_ids) # Prefill if past_key_values is None: + assert pre_allocate_past_size is not None + + prefill = True + # Create past tensor + # We create a tensor of the same size as input_ids as we don't want to slice at every layer past_key_values = hidden_states.new_empty( ( + len(input_ids), len(self.layers), - len(hidden_states) - if pre_allocate_past_size is None - else pre_allocate_past_size, 2, self.num_heads, self.head_size, ) ) - layer_past_present_indices = None - slice_past_index = len(hidden_states) # Decode else: - # Create indices from cumulative sequence lengths - layer_past_present_indices = cu_seqlens[1:] - 1 - slice_past_index = None + prefill = False # Get rotary cos and sin for this forward # Avoid to index in each layer @@ -360,25 +375,36 @@ class FlashLlamaModel(torch.nn.Module): residual = None for i, layer in enumerate(self.layers): - # We added padding that we now need to slice - layer_past_key_values = ( - past_key_values[i] - if slice_past_index is None - else past_key_values[i, :slice_past_index] - ) - hidden_states, residual = layer( hidden_states, residual, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, - layer_past_key_values, - layer_past_present_indices, - cu_seqlens_q, + past_key_values[:, i], + past_present_indices, + prefill, ) + if prefill: + present = past_key_values + # Create padded past tensor + past_key_values = hidden_states.new_empty( + ( + pre_allocate_past_size, + len(self.layers), + 2, + self.num_heads, + self.head_size, + ) + ) + # We slice only once instead of at every layer + past_key_values[past_present_indices] = present + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states, past_key_values @@ -399,9 +425,12 @@ class FlashLlamaForCausalLM(torch.nn.Module): self, input_ids, position_ids, - cu_seqlens, - cu_seqlens_q, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, + past_present_indices, past_key_values: Optional[torch.Tensor] = None, pre_allocate_past_size: Optional[int] = None, lm_head_indices: Optional[torch.Tensor] = None, @@ -409,9 +438,12 @@ class FlashLlamaForCausalLM(torch.nn.Module): hidden_states, present = self.model( input_ids, position_ids, - cu_seqlens, - cu_seqlens_q, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, + past_present_indices, past_key_values, pre_allocate_past_size, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index c045f16e..3586b85a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -113,11 +113,14 @@ class FlashNeoxAttention(torch.nn.Module): hidden_states, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, 3, self.num_heads, self.head_size) @@ -127,7 +130,7 @@ class FlashNeoxAttention(torch.nn.Module): self.rotary_emb(qkv[:, 1], cos, sin) # Prefill - if layer_past_present_indices is None: + if prefill: # Copy to layer past layer_past[...] = qkv[:, 1:] @@ -139,8 +142,10 @@ class FlashNeoxAttention(torch.nn.Module): qkv[:, 1], qkv[:, 2], attn_output, - cu_seqlens, - cu_seqlens, + start_seq, + end_seq, + start_seq, + end_seq, max_s, max_s, 0.0, @@ -155,7 +160,7 @@ class FlashNeoxAttention(torch.nn.Module): else: query = qkv[:, 0] # Add present to the layer_past tensor at the correct indices - layer_past[layer_past_present_indices] = qkv[:, 1:] + layer_past[past_present_indices] = qkv[:, 1:] # output attn_output = torch.empty_like(query) @@ -165,8 +170,10 @@ class FlashNeoxAttention(torch.nn.Module): layer_past[:, 0], layer_past[:, 1], attn_output, - cu_seqlens_q, - cu_seqlens, + start_seq_q, + end_seq_q, + start_seq, + end_seq, 1, max_s, 0.0, @@ -240,11 +247,14 @@ class FlashNeoXLayer(nn.Module): residual, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ): if self.use_parallel_residual: ln1_hidden_states, _ = self.input_layernorm(hidden_states) @@ -253,11 +263,14 @@ class FlashNeoXLayer(nn.Module): ln1_hidden_states, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ) ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states) @@ -276,11 +289,14 @@ class FlashNeoXLayer(nn.Module): hidden_states, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ) hidden_states, residual = self.post_attention_layernorm( @@ -329,9 +345,12 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): self, input_ids, position_ids, - cu_seqlens, - cu_seqlens_q, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, + past_present_indices, past_key_values=None, pre_allocate_past_size: Optional[int] = None, ): @@ -339,25 +358,24 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): # Prefill if past_key_values is None: + assert pre_allocate_past_size is not None + + prefill = True + # Create past tensor + # We create a tensor of the same size as input_ids as we don't want to slice at every layer past_key_values = hidden_states.new_empty( ( + len(input_ids), len(self.layers), - len(hidden_states) - if pre_allocate_past_size is None - else pre_allocate_past_size, 2, self.num_heads, self.head_size, ) ) - layer_past_present_indices = None - slice_past_index = len(hidden_states) # Decode else: - # Create indices from cumulative sequence lengths - layer_past_present_indices = cu_seqlens[1:] - 1 - slice_past_index = None + prefill = False # Get rotary cos and sin for this forward # Avoid to index in each layer @@ -367,25 +385,36 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): residual = None for i, layer in enumerate(self.layers): - # We added padding that we now need to slice - layer_past_key_values = ( - past_key_values[i] - if slice_past_index is None - else past_key_values[i, :slice_past_index] - ) - hidden_states, residual = layer( hidden_states, residual, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, - layer_past_key_values, - layer_past_present_indices, - cu_seqlens_q, + past_key_values[:, i], + past_present_indices, + prefill, ) + if prefill: + present = past_key_values + # Create padded past tensor + past_key_values = hidden_states.new_empty( + ( + pre_allocate_past_size, + len(self.layers), + 2, + self.num_heads, + self.head_size, + ) + ) + # We slice only once instead of at every layer + past_key_values[past_present_indices] = present + hidden_states, _ = self.final_layer_norm(hidden_states, residual) return hidden_states, past_key_values @@ -404,9 +433,12 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): self, input_ids, position_ids, - cu_seqlens, - cu_seqlens_q, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, + past_present_indices, past_key_values: Optional[torch.Tensor] = None, pre_allocate_past_size: Optional[int] = None, lm_head_indices: Optional[torch.Tensor] = None, @@ -414,9 +446,12 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): hidden_states, present = self.gpt_neox( input_ids, position_ids, - cu_seqlens, - cu_seqlens_q, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, + past_present_indices, past_key_values, pre_allocate_past_size, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index af9fa548..4a9063eb 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -130,11 +130,14 @@ class FlashRWAttention(torch.nn.Module): hidden_states, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ): qkv = self.query_key_value(hidden_states) @@ -150,10 +153,10 @@ class FlashRWAttention(torch.nn.Module): # Inplace rotary self.rotary_emb(query, cos, sin) - self.rotary_emb(kv[:, 0], cos, sin) + self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) # Prefill - if layer_past_present_indices is None: + if prefill: # Copy to layer past layer_past[...] = kv # Expand to query shape @@ -164,11 +167,13 @@ class FlashRWAttention(torch.nn.Module): # flash attention flash_attn_cuda.fwd( query, - kv[:, 0], - kv[:, 1], + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), attn_output, - cu_seqlens, - cu_seqlens, + start_seq, + end_seq, + start_seq, + end_seq, max_s, max_s, 0.0, @@ -182,7 +187,7 @@ class FlashRWAttention(torch.nn.Module): # Decode else: # Add present to the layer_past tensor at the correct indices - layer_past[layer_past_present_indices] = kv + layer_past[past_present_indices] = kv # Expand to query shape kv = layer_past.expand(-1, 2, self.num_heads, self.head_size) @@ -191,11 +196,13 @@ class FlashRWAttention(torch.nn.Module): # flash attention flash_attn_cuda.fwd( query, - kv[:, 0], - kv[:, 1], + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), attn_output, - cu_seqlens_q, - cu_seqlens, + start_seq_q, + end_seq_q, + start_seq, + end_seq, 1, max_s, 0.0, @@ -261,11 +268,14 @@ class FlashRWLargeAttention(torch.nn.Module): hidden_states, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size) @@ -280,10 +290,10 @@ class FlashRWLargeAttention(torch.nn.Module): # Inplace rotary self.rotary_emb(query, cos, sin) - self.rotary_emb(kv[:, :, 0], cos, sin) + self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin) # Prefill - if layer_past_present_indices is None: + if prefill: # Copy to layer past layer_past[...] = kv # Expand to query shape @@ -298,11 +308,13 @@ class FlashRWLargeAttention(torch.nn.Module): # flash attention flash_attn_cuda.fwd( query, - kv[:, :, 0], - kv[:, :, 1], + torch.select(kv, dim=2, index=0), + torch.select(kv, dim=2, index=1), attn_output, - cu_seqlens, - cu_seqlens, + start_seq, + end_seq, + start_seq, + end_seq, max_s, max_s, 0.0, @@ -316,7 +328,7 @@ class FlashRWLargeAttention(torch.nn.Module): # Decode else: # Add present to the layer_past tensor at the correct indices - layer_past[layer_past_present_indices] = kv + layer_past[past_present_indices] = kv # Expand to query shape kv = ( layer_past.unsqueeze(2) @@ -329,11 +341,13 @@ class FlashRWLargeAttention(torch.nn.Module): # flash attention flash_attn_cuda.fwd( query, - kv[:, :, 0], - kv[:, :, 1], + torch.select(kv, dim=2, index=0), + torch.select(kv, dim=2, index=1), attn_output, - cu_seqlens_q, - cu_seqlens, + start_seq_q, + end_seq_q, + start_seq, + end_seq, 1, max_s, 0.0, @@ -417,11 +431,14 @@ class FlashRWLayer(nn.Module): residual, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ): if self.parallel_attn: ln_hidden_states, residual = self.input_layernorm(hidden_states, residual) @@ -430,11 +447,14 @@ class FlashRWLayer(nn.Module): ln_hidden_states, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ) mlp_output = self.mlp(ln_hidden_states) @@ -451,11 +471,14 @@ class FlashRWLayer(nn.Module): hidden_states, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ) hidden_states, residual = self.post_attention_layernorm( @@ -499,11 +522,14 @@ class FlashRWLargeLayer(nn.Module): residual, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ): ln_attn, residual = self.ln_attn(hidden_states, residual) ln_mlp, _ = self.ln_mlp(residual) @@ -513,11 +539,14 @@ class FlashRWLargeLayer(nn.Module): ln_attn, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ) # MLP. @@ -584,9 +613,12 @@ class FlashRWModel(FlashRWPreTrainedModel): self, input_ids, position_ids, - cu_seqlens, - cu_seqlens_q, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, + past_present_indices, past_key_values=None, pre_allocate_past_size: Optional[int] = None, ): @@ -594,23 +626,22 @@ class FlashRWModel(FlashRWPreTrainedModel): # Prefill if past_key_values is None: + assert pre_allocate_past_size is not None + + prefill = True + # Create past tensor + # We create a tensor of the same size as input_ids as we don't want to slice at every layer past_key_values = hidden_states.new_empty( ( + len(input_ids), len(self.h), - len(hidden_states) - if pre_allocate_past_size is None - else pre_allocate_past_size, *self.cache_size, ) ) - layer_past_present_indices = None - slice_past_index = len(hidden_states) # Decode else: - # Create indices from cumulative sequence lengths - layer_past_present_indices = cu_seqlens[1:] - 1 - slice_past_index = None + prefill = False # Get rotary cos and sin for this forward # Avoid to index in each layer @@ -620,25 +651,34 @@ class FlashRWModel(FlashRWPreTrainedModel): residual = None for i, layer in enumerate(self.h): - # We added padding that we now need to slice - layer_past_key_values = ( - past_key_values[i] - if slice_past_index is None - else past_key_values[i, :slice_past_index] - ) - hidden_states, residual = layer( hidden_states, residual, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, - layer_past_key_values, - layer_past_present_indices, - cu_seqlens_q, + torch.select(past_key_values, dim=1, index=i), + past_present_indices, + prefill, ) + if prefill: + present = past_key_values + # Create padded past tensor + past_key_values = hidden_states.new_empty( + ( + pre_allocate_past_size, + len(self.h), + *self.cache_size, + ) + ) + # We slice only once instead of at every layer + past_key_values[past_present_indices] = present + hidden_states, _ = self.ln_f(hidden_states, residual) return hidden_states, past_key_values @@ -658,9 +698,12 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): self, input_ids, position_ids, - cu_seqlens, - cu_seqlens_q, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, + past_present_indices, past_key_values: Optional[torch.Tensor] = None, pre_allocate_past_size: Optional[int] = None, lm_head_indices: Optional[torch.Tensor] = None, @@ -668,9 +711,12 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): hidden_states, present = self.transformer( input_ids, position_ids, - cu_seqlens, - cu_seqlens_q, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, + past_present_indices, past_key_values, pre_allocate_past_size, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index fcf6be68..00cc47b6 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -7,6 +7,7 @@ from typing import Optional # Flash attention imports import flash_attn_cuda + from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -148,11 +149,14 @@ class FlashMQAttention(torch.nn.Module): def forward( self, hidden_states, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ): qkv = self.c_attn(hidden_states) @@ -166,7 +170,7 @@ class FlashMQAttention(torch.nn.Module): key_value = key_value.view(-1, 2, 1, self.head_size) # Prefill - if layer_past_present_indices is None: + if prefill: # Copy to layer past layer_past[...] = key_value # Expand from 1 to num_heads @@ -177,11 +181,13 @@ class FlashMQAttention(torch.nn.Module): # flash attention flash_attn_cuda.fwd( query, - key_value[:, 0], - key_value[:, 1], + torch.select(key_value, dim=1, index=0), + torch.select(key_value, dim=1, index=1), attn_output, - cu_seqlens, - cu_seqlens, + start_seq, + end_seq, + start_seq, + end_seq, max_s, max_s, 0.0, @@ -195,7 +201,7 @@ class FlashMQAttention(torch.nn.Module): # Decode else: # Add present to the layer_past tensor at the correct indices - layer_past[layer_past_present_indices] = key_value + layer_past[past_present_indices] = key_value # Expand from 1 to num_heads key_value = layer_past.expand(-1, 2, self.num_heads, self.head_size) @@ -204,11 +210,13 @@ class FlashMQAttention(torch.nn.Module): # flash attention flash_attn_cuda.fwd( query, - key_value[:, 0], - key_value[:, 1], + torch.select(key_value, dim=1, index=0), + torch.select(key_value, dim=1, index=1), attn_output, - cu_seqlens_q, - cu_seqlens, + start_seq_q, + end_seq_q, + start_seq, + end_seq, 1, max_s, 0.0, @@ -277,21 +285,27 @@ class Block(nn.Module): self, hidden_states, residual, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ): hidden_states, residual = self.ln_1(hidden_states, residual) hidden_states = self.attn( hidden_states, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ) hidden_states, residual = self.ln_2(hidden_states, residual) @@ -339,10 +353,13 @@ class FlashSantacoderModel(nn.Module): self, input_ids, position_ids, - cu_seqlens, - cu_seqlens_q, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, - past_key_values: Optional[torch.Tensor] = None, + past_present_indices, + past_key_values=None, pre_allocate_past_size: Optional[int] = None, ): hidden_states = self.wte(input_ids) + self.wpe(position_ids) @@ -352,45 +369,43 @@ class FlashSantacoderModel(nn.Module): # Prefill if past_key_values is None: + assert pre_allocate_past_size is not None + + prefill = True + # Create past tensor - past_key_values = hidden_states.new_empty( - ( - len(self.h), - len(hidden_states) - if pre_allocate_past_size is None - else pre_allocate_past_size, - 2, - 1, - self.head_size, - ) + # We create a tensor of the same size as input_ids as we don't want to slice at every layer + past_key_values = hidden_states.new_zeros( + (len(input_ids), len(self.h), 2, 1, self.head_size) ) - layer_past_present_indices = None - slice_past_index = len(hidden_states) # Decode else: - # Create indices from cumulative sequence lengths - layer_past_present_indices = cu_seqlens[1:] - 1 - slice_past_index = None + prefill = False residual = None for i, layer in enumerate(self.h): - # We added padding that we now need to slice - layer_past_key_values = ( - past_key_values[i] - if slice_past_index is None - else past_key_values[i, :slice_past_index] - ) - hidden_states, residual = layer( hidden_states, residual, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, - layer_past_key_values, - layer_past_present_indices, - cu_seqlens_q, + torch.select(past_key_values, dim=1, index=i), + past_present_indices, + prefill, ) + if prefill: + present = past_key_values + # Create padded past tensor + past_key_values = hidden_states.new_empty( + (pre_allocate_past_size, len(self.h), 2, 1, self.head_size) + ) + # We slice only once instead of at every layer + past_key_values[past_present_indices] = present + hidden_states, _ = self.ln_f(hidden_states, residual) return hidden_states, past_key_values @@ -408,9 +423,12 @@ class FlashSantacoderForCausalLM(nn.Module): self, input_ids, position_ids, - cu_seqlens, - cu_seqlens_q, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, + past_present_indices, past_key_values: Optional[torch.Tensor] = None, pre_allocate_past_size: Optional[int] = None, lm_head_indices: Optional[torch.Tensor] = None, @@ -418,9 +436,12 @@ class FlashSantacoderForCausalLM(nn.Module): hidden_states, present = self.transformer( input_ids, position_ids, - cu_seqlens, - cu_seqlens_q, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, + past_present_indices, past_key_values, pre_allocate_past_size, ) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index a2ad2d5e..ecea998e 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -3,8 +3,6 @@ import torch.distributed import numpy as np -from torch.nn import functional as F - from dataclasses import dataclass from opentelemetry import trace from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel @@ -34,10 +32,21 @@ class FlashCausalLMBatch(Batch): input_ids: torch.Tensor position_ids: torch.Tensor - # cumulative sequence lengths - cu_seqlens: torch.Tensor - # cumulative query sequence lengths, only used in decode - cu_seqlens_q: Optional[torch.Tensor] + # Indices to copy present to the correct indices is the pre-allocated past key values + past_present_indices: torch.Tensor + + # tensor of length b holding starting offset of each sequence + start_seq: torch.Tensor + # tensor of length b holding ending offset of each sequence + end_seq: torch.Tensor + # tensor of length b holding starting offset of each sequence, only used in prefill + start_seq_prefill: Optional[torch.Tensor] + # tensor of length b holding ending offset of each sequence, only used in prefill + end_seq_prefill: Optional[torch.Tensor] + # tensor of length b holding starting offset of each query sequence, only used in decode + start_seq_q: Optional[torch.Tensor] + # tensor of length b holding ending offset of each query sequence, only used in decode + end_seq_q: Optional[torch.Tensor] # past key values, only used in decode past_key_values: Optional[torch.Tensor] max_seqlen: int @@ -90,7 +99,11 @@ class FlashCausalLMBatch(Batch): )["input_ids"] position_ids = [] - cu_seqlens = [0] + past_present_indices = [] + start_seq = [] + end_seq = [] + start_seq_prefill = [] + end_seq_prefill = [] max_seqlen = 0 input_lengths = [] @@ -110,9 +123,9 @@ class FlashCausalLMBatch(Batch): # Cumulative length cumulative_length = 0 + cumulative_max_length = 0 prefill_out_cumulative_length = 0 - max_tokens = 0 max_length = 0 # Parse batch @@ -138,7 +151,10 @@ class FlashCausalLMBatch(Batch): position_ids.append(request_position_ids) # Add cumulative lengths of all previous inputs - cu_seqlens.append(cumulative_length + input_length) + start_seq_prefill.append(cumulative_length) + end_seq_prefill.append(cumulative_length + input_length) + start_seq.append(cumulative_max_length) + end_seq.append(cumulative_max_length + input_length) next_token_chooser_parameters.append(r.parameters) @@ -168,9 +184,17 @@ class FlashCausalLMBatch(Batch): prefill_cu_outlens.append(prefill_out_cumulative_length + 1) prefill_out_cumulative_length += 1 + request_past_present_indices = torch.arange( + cumulative_max_length, + cumulative_max_length + input_length, + dtype=torch.int64, + ) + past_present_indices.append(request_past_present_indices) + # Update + # Remove one as the first token des not have a past cumulative_length += input_length - max_tokens += input_length + max_new_tokens + cumulative_max_length += input_length + max_new_tokens - 1 max_length = max(max_length, input_length + max_new_tokens) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( @@ -184,26 +208,45 @@ class FlashCausalLMBatch(Batch): for i, input_ids in enumerate(all_input_ids): all_input_ids_tensor[i, : len(input_ids)] = input_ids + # Create tensors on device + all_input_ids_tensor = torch.tensor( + all_input_ids_tensor, dtype=torch.int64, device=device + ) + start_seq = torch.tensor(start_seq, device=device, dtype=torch.int32) + end_seq = torch.tensor(end_seq, device=device, dtype=torch.int32) + if len(pb.requests) > 1: input_ids = np.concatenate(all_input_ids, dtype=np.int64) position_ids = torch.cat(position_ids) + + past_present_indices = np.concatenate(past_present_indices, dtype=np.int64) + + start_seq_prefill = torch.tensor( + start_seq_prefill, device=device, dtype=torch.int32 + ) + end_seq_prefill = torch.tensor( + end_seq_prefill, device=device, dtype=torch.int32 + ) else: input_ids = all_input_ids[0] position_ids = position_ids[0] - # Create tensors on device + past_present_indices = past_present_indices[0] + + start_seq_prefill = start_seq + end_seq_prefill = end_seq + input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) - all_input_ids_tensor = torch.tensor( - all_input_ids_tensor, dtype=torch.int64, device=device - ) position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device) - cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32) + past_present_indices = torch.tensor( + past_present_indices, device=device, dtype=torch.int64 + ) if all_prefill_logprobs: prefill_head_indices = None - prefill_next_token_indices = cu_seqlens[1:] - 1 + prefill_next_token_indices = end_seq_prefill - 1 elif no_prefill_logprobs: - prefill_head_indices = cu_seqlens[1:] - 1 + prefill_head_indices = end_seq_prefill - 1 prefill_next_token_indices = None else: prefill_head_indices = torch.tensor( @@ -219,8 +262,13 @@ class FlashCausalLMBatch(Batch): requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, position_ids=position_ids, - cu_seqlens=cu_seqlens, - cu_seqlens_q=None, + past_present_indices=past_present_indices, + start_seq=start_seq, + end_seq=end_seq, + start_seq_prefill=start_seq_prefill, + end_seq_prefill=end_seq_prefill, + start_seq_q=None, + end_seq_q=None, max_seqlen=max_seqlen, prefill_head_indices=prefill_head_indices, prefill_next_token_indices=prefill_next_token_indices, @@ -233,7 +281,7 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor=all_input_ids_tensor, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, - max_tokens=max_tokens, + max_tokens=cumulative_max_length, ) @tracer.start_as_current_span("filter") @@ -244,10 +292,10 @@ class FlashCausalLMBatch(Batch): if len(request_ids) == len(self): return self - single_request = len(request_ids) == 1 + device = self.input_ids.device # Cumulative length - cumulative_length = 0 + cumulative_max_length = 0 # New values after filtering requests_idx_mapping = {} @@ -255,11 +303,17 @@ class FlashCausalLMBatch(Batch): # Used to index into tensors indices = [] + # past indices to keep + past_indices = torch.zeros( + self.past_key_values.shape[0], dtype=torch.bool, device=device + ) + # Create on CPU to only move to GPU once instead of at every copy - cu_seqlens = torch.zeros(len(request_ids) + 1, dtype=torch.int32) - cu_seqlens_q = self.cu_seqlens_q[: len(request_ids) + 1] + start_seq = torch.empty(len(request_ids), dtype=torch.int32) + end_seq = torch.empty(len(request_ids), dtype=torch.int32) + start_seq_q = self.start_seq_q[: len(request_ids)] + end_seq_q = self.end_seq_q[: len(request_ids)] max_seqlen = 0 - past_key_values = [] requests = [] all_input_ids = [] @@ -270,8 +324,6 @@ class FlashCausalLMBatch(Batch): stopping_criterias = [] - max_tokens = 0 - for i, request_id in enumerate(request_ids): idx = self.requests_idx_mapping[request_id] indices.append(idx) @@ -281,16 +333,8 @@ class FlashCausalLMBatch(Batch): # Get length request_input_length = self.input_lengths[idx] - - # Copy to tensor (CPU) - cu_seqlens[i + 1] = cumulative_length + request_input_length max_seqlen = max(max_seqlen, request_input_length) - # Slice from past - past_key_values.append( - self.past_key_values[:, self.cu_seqlens[idx] : self.cu_seqlens[idx + 1]] - ) - all_input_ids.append(self.all_input_ids[idx]) input_lengths.append(request_input_length) @@ -300,39 +344,32 @@ class FlashCausalLMBatch(Batch): stopping_criteria = self.stopping_criterias[idx] stopping_criterias.append(stopping_criteria) - cumulative_length += request_input_length - max_tokens += request_input_length + ( + remaining_tokens = ( stopping_criteria.max_new_tokens - stopping_criteria.current_tokens ) - if single_request: - # Preallocate tensor for bs = 1 case - past_key_values = F.pad( - past_key_values[0], - ( - 0, - 0, - 0, - 0, - 0, - 0, - 0, - stopping_criterias[0].max_new_tokens - - stopping_criterias[0].current_tokens, - ), - ) - else: - # Cat all past - past_key_values = torch.cat(past_key_values, dim=1) + # Copy to tensor (CPU) + start_seq[i] = cumulative_max_length + end_seq[i] = cumulative_max_length + request_input_length + + # Set slice + past_indices[ + self.start_seq[idx] : self.end_seq[idx] + remaining_tokens - 1 + ] = True + + cumulative_max_length += request_input_length + remaining_tokens - 1 # Index into tensors input_ids = self.input_ids[indices] position_ids = self.position_ids[indices] all_input_ids_tensor = self.all_input_ids_tensor[indices] next_token_chooser = self.next_token_chooser.filter(indices) + past_key_values = self.past_key_values[past_indices] # Move to GPU now that we have the whole tensor - cu_seqlens = cu_seqlens.to(self.cu_seqlens.device) + start_seq = start_seq.to(device) + end_seq = end_seq.to(device) + past_present_indices = end_seq - 1 return FlashCausalLMBatch( batch_id=self.batch_id, @@ -340,8 +377,13 @@ class FlashCausalLMBatch(Batch): requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, position_ids=position_ids, - cu_seqlens=cu_seqlens, - cu_seqlens_q=cu_seqlens_q, + past_present_indices=past_present_indices, + start_seq=start_seq, + end_seq=end_seq, + start_seq_prefill=None, + end_seq_prefill=None, + start_seq_q=start_seq_q, + end_seq_q=end_seq_q, max_seqlen=max_seqlen, prefill_head_indices=None, prefill_next_token_indices=None, @@ -354,7 +396,7 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor=all_input_ids_tensor, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, - max_tokens=max_tokens, + max_tokens=cumulative_max_length, ) @classmethod @@ -371,10 +413,12 @@ class FlashCausalLMBatch(Batch): input_ids = batches[0].input_ids.new_empty(total_batch_size) position_ids = batches[0].position_ids.new_empty(total_batch_size) - cu_seqlens = [0] - cu_seqlens_q = torch.arange( - 0, total_batch_size + 1, device=device, dtype=torch.int32 + start_seq = batches[0].start_seq.new_empty(total_batch_size) + end_seq = batches[0].end_seq.new_empty(total_batch_size) + start_seq_q = torch.arange( + 0, total_batch_size, device=device, dtype=torch.int32 ) + end_seq_q = start_seq_q + 1 max_seqlen = 0 past_key_values = [] @@ -389,7 +433,6 @@ class FlashCausalLMBatch(Batch): # Cumulative length cumulative_batch_size = 0 - cumulative_length = 0 max_tokens = 0 max_length = 0 @@ -410,18 +453,10 @@ class FlashCausalLMBatch(Batch): input_ids[start_index:end_index] = batch.input_ids position_ids[start_index:end_index] = batch.position_ids - # Add cumulative lengths of all previous inputs - cu_seqlens.extend([l + cumulative_length for l in batch.cu_seqlens[1:]]) - max_seqlen = max(max_seqlen, batch.max_seqlen) + start_seq[start_index:end_index] = batch.start_seq + max_tokens + end_seq[start_index:end_index] = batch.end_seq + max_tokens - if len(batch) != 1: - past_key_values.append(batch.past_key_values) - else: - # past was pre-allocated for this batch - # We need to slice to remove the padding - past_key_values.append( - batch.past_key_values[:, : batch.input_lengths[0]] - ) + max_seqlen = max(max_seqlen, batch.max_seqlen) all_input_ids.extend(batch.all_input_ids) @@ -431,9 +466,9 @@ class FlashCausalLMBatch(Batch): next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) stopping_criterias.extend(batch.stopping_criterias) + past_key_values.append(batch.past_key_values) # Update - cumulative_length += batch.cu_seqlens[-1] cumulative_batch_size += len(batch) max_tokens += batch.max_tokens max_length = max( @@ -448,6 +483,9 @@ class FlashCausalLMBatch(Batch): ), ) + past_key_values = torch.cat(past_key_values, dim=0) + past_present_indices = end_seq - 1 + all_input_ids_tensor = torch.zeros( (total_batch_size, max_length), dtype=torch.int64, device=device ) @@ -463,11 +501,6 @@ class FlashCausalLMBatch(Batch): cumulative_batch_size += len(batch) - # Cat past - past_key_values = torch.cat(past_key_values, dim=1) - # Create final tensor on GPU - cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) - next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, dtype=dtype, device=device ) @@ -478,8 +511,13 @@ class FlashCausalLMBatch(Batch): requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, position_ids=position_ids, - cu_seqlens=cu_seqlens, - cu_seqlens_q=cu_seqlens_q, + past_present_indices=past_present_indices, + start_seq=start_seq, + end_seq=end_seq, + start_seq_prefill=None, + end_seq_prefill=None, + start_seq_q=start_seq_q, + end_seq_q=end_seq_q, max_seqlen=max_seqlen, prefill_head_indices=None, prefill_next_token_indices=None, @@ -550,9 +588,12 @@ class FlashCausalLM(Model): self, input_ids: torch.Tensor, position_ids: torch.Tensor, - cu_seqlens: torch.Tensor, - cu_seqlens_q: Optional[torch.Tensor], + start_seq: torch.Tensor, + end_seq: torch.Tensor, + start_seq_q: Optional[torch.Tensor], + end_seq_q: Optional[torch.Tensor], max_s: int, + past_present_indices: torch.Tensor, past_key_values: Optional = None, pre_allocate_past_size: Optional[int] = None, lm_head_indices: Optional[torch.Tensor] = None, @@ -561,9 +602,12 @@ class FlashCausalLM(Model): return self.model.forward( input_ids=input_ids, position_ids=position_ids, - cu_seqlens=cu_seqlens, - cu_seqlens_q=cu_seqlens_q, + start_seq=start_seq, + end_seq=end_seq, + start_seq_q=start_seq_q, + end_seq_q=end_seq_q, max_s=max_s, + past_present_indices=past_present_indices, past_key_values=past_key_values, pre_allocate_past_size=pre_allocate_past_size, lm_head_indices=lm_head_indices, @@ -575,23 +619,27 @@ class FlashCausalLM(Model): ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: prefill = batch.past_key_values is None prefill_logprobs = batch.prefill_next_token_indices is not None - single_request = len(batch) == 1 - if prefill and single_request: + if prefill: # Ask to pre-allocate kv to its max size - # == number of tokens + max_new_tokens - pre_allocate_past_size = ( - batch.input_lengths[0] + batch.stopping_criterias[0].max_new_tokens - ) + # == Sum over batch size (number of tokens + max_new_tokens) - batch size + pre_allocate_past_size = batch.max_tokens + start_seq = batch.start_seq_prefill + end_seq = batch.end_seq_prefill else: pre_allocate_past_size = None + start_seq = batch.start_seq + end_seq = batch.end_seq out, present = self.forward( batch.input_ids, batch.position_ids, - batch.cu_seqlens, - batch.cu_seqlens_q, + start_seq, + end_seq, + batch.start_seq_q, + batch.end_seq_q, batch.max_seqlen, + batch.past_present_indices, batch.past_key_values, pre_allocate_past_size, batch.prefill_head_indices, @@ -614,55 +662,19 @@ class FlashCausalLM(Model): # When batch == 1, we will just use the batch.input_ids values directly prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) - # Create batch.cu_seqlens_q for decode - batch.cu_seqlens_q = torch.arange( - 0, len(batch) + 1, device=self.device, dtype=torch.int32 + # Create batch.start_seq_q and batch.end_seq_q for decode + batch.start_seq_q = torch.arange( + 0, len(batch), device=self.device, dtype=torch.int32 ) + batch.end_seq_q = batch.start_seq_q + 1 next_position_ids = batch.position_ids.new_empty(len(batch)) + # We do not need start_seq_prefill and end_seq_prefill anymore + batch.start_seq_prefill = None + batch.end_seq_prefill = None else: prefill_logprobs = None next_position_ids = batch.position_ids - # Prepare past for next decode - if len(batch) > 1: - # Used to slice next batch past - past_indices = torch.empty( - present.shape[1], dtype=torch.int64, device=self.device - ) - batch.past_key_values = present.new_empty( - ( - present.shape[0], - present.shape[1] + len(batch.requests), - *present.shape[2:], - ) - ) - - # It is actually faster to do a whole other for loop here as the copy from present to past is fairly slow - # and will run asynchronously while we do the next for loop - cumulative_length = 0 - for i, input_length in enumerate(batch.input_lengths): - # Indexing metadata - start_index = cumulative_length - end_index = cumulative_length + input_length - - # Indices to copy present at the correct place in past_key_values - torch.arange( - start_index + i, - end_index + i, - dtype=torch.int64, - device=self.device, - out=past_indices[start_index:end_index], - ) - cumulative_length += input_length - - # Copy from present to past_key_values - batch.past_key_values[:, past_indices] = present - - # Initialize past_key_values in prefill for len(batch) == 1 - elif prefill: - # present is already pre-padded - batch.past_key_values = present - # Cumulative length cumulative_length = 0 @@ -685,6 +697,7 @@ class FlashCausalLM(Model): input_length, all_input_ids, ) in enumerate(iterator): + # Indexing metadata start_index = cumulative_length end_index = cumulative_length + input_length @@ -718,7 +731,8 @@ class FlashCausalLM(Model): # Set values in batch batch.input_ids = next_input_ids batch.position_ids = next_position_ids + 1 - batch.cu_seqlens = batch.cu_seqlens + batch.cu_seqlens_q + batch.past_present_indices = batch.end_seq + batch.end_seq = batch.end_seq + 1 if prefill and prefill_logprobs: # Get prefill logprobs @@ -843,6 +857,7 @@ class FlashCausalLM(Model): batch.prefill_head_indices = None batch.prefill_next_token_indices = None batch.max_seqlen = batch.max_seqlen + 1 + batch.past_key_values = present # No need to return a batch if we know that all requests stopped return generations, batch if not stopped else None From f59fb8b630844c2ad2cd80e689202de89d45c37e Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Fri, 16 Jun 2023 16:25:11 +0200 Subject: [PATCH 15/22] feat(router): add ngrok integration (#453) --- Cargo.lock | 657 ++++++++++++------ launcher/src/main.rs | 44 ++ router/Cargo.toml | 5 + router/src/main.rs | 20 + router/src/queue.rs | 6 +- router/src/server.rs | 69 +- .../flash_santacoder_modeling.py | 1 - 7 files changed, 571 insertions(+), 231 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bd5994a8..7a6f4ad2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10,14 +10,13 @@ checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" [[package]] name = "aes" -version = "0.7.5" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e8b47f52ea9bae42228d07ec09eb676433d7c4ed1ebdf0f1d1c29ed446f1ab8" +checksum = "433cfd6710c9986c576a25ca913c39d66a6474107b406f34f91d4a8923395241" dependencies = [ "cfg-if", "cipher", "cpufeatures", - "opaque-debug", ] [[package]] @@ -41,10 +40,19 @@ dependencies = [ ] [[package]] -name = "anstream" -version = "0.3.0" +name = "aho-corasick" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e579a7752471abc2a8268df8b20005e3eadd975f585398f17efcfd8d4927371" +checksum = "43f6cb1bf222025340178f382c426f13757b2960e89779dfcb319c32542a5a41" +dependencies = [ + "memchr", +] + +[[package]] +name = "anstream" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ca84f3628370c59db74ee214b3263d58f9aadd9b4fe7e711fd87dc452b7f163" dependencies = [ "anstyle", "anstyle-parse", @@ -81,9 +89,9 @@ dependencies = [ [[package]] name = "anstyle-wincon" -version = "1.0.0" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bcd8291a340dd8ac70e18878bc4501dd7b4ff970cfa21c207d36ece51ea88fd" +checksum = "180abfa45703aebe0093f79badacc01b8fd4ea2e35118747e5811127f926e188" dependencies = [ "anstyle", "windows-sys 0.48.0", @@ -91,9 +99,26 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.70" +version = "1.0.71" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7de8ce5e0f9f8d88245311066a578d72b7af3e7088f32783804676302df237e4" +checksum = "9c7d0618f0e0b7e8ff11427422b64564d5fb0be1940354bfe2e0529b18a9d9b8" + +[[package]] +name = "arc-swap" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bddcadddf5e9015d310179a59bb28c4d4b9920ad0f11e8e14dbadf654890c9a6" + +[[package]] +name = "async-rustls" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93b21a03b7c21702a0110f9f8d228763a533570deb376119042dabf33c37a01a" +dependencies = [ + "futures-io", + "rustls", + "webpki", +] [[package]] name = "async-stream" @@ -114,7 +139,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" dependencies = [ "proc-macro2", "quote", - "syn 2.0.15", + "syn 2.0.18", ] [[package]] @@ -125,7 +150,7 @@ checksum = "b9ccdd8f2a161be9bd5c023df56f1b2a0bd1d83872ae53b71a84a12c9bf6e842" dependencies = [ "proc-macro2", "quote", - "syn 2.0.15", + "syn 2.0.18", ] [[package]] @@ -146,10 +171,22 @@ dependencies = [ ] [[package]] -name = "axum" -version = "0.6.15" +name = "awaitdrop" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b32c5ea3aabaf4deb5f5ced2d688ec0844c881c9e6c696a8b769a05fc691e62" +checksum = "771051cdc7eec2dc1b23fbf870bb7fbb89136fe374227c875e377f1eed99a429" +dependencies = [ + "futures", + "generational-arena", + "parking_lot", + "slotmap", +] + +[[package]] +name = "axum" +version = "0.6.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8175979259124331c1d7bf6586ee7e0da434155e4b2d48ec2c8386281d8df39" dependencies = [ "async-trait", "axum-core", @@ -218,9 +255,9 @@ checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" [[package]] name = "base64" -version = "0.21.0" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4a4ddaa51a5bc52a6948f74c06d20aaaddb71924eab79b8c97a8c556e942d6a" +checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" [[package]] name = "base64ct" @@ -245,9 +282,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.12.0" +version = "3.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d261e256854913907f67ed06efbc3338dfe6179796deefc1ff763fc1aee5535" +checksum = "a3e2c3daef883ecc1b5d58c15adae93470a91d425f3532ba1695849656af3fc1" [[package]] name = "bytecount" @@ -333,18 +370,19 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "cipher" -version = "0.3.0" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ee52072ec15386f770805afd189a01c8841be8696bed250fa2f13c4c0d6dfb7" +checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" dependencies = [ - "generic-array", + "crypto-common", + "inout", ] [[package]] name = "clap" -version = "4.2.2" +version = "4.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b802d85aaf3a1cdb02b224ba472ebdea62014fccfcb269b95a4d76443b5ee5a" +checksum = "80672091db20273a15cf9fdd4e47ed43b5091ec9841bf4c6145c9dfbbcae09ed" dependencies = [ "clap_builder", "clap_derive", @@ -353,9 +391,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.2.2" +version = "4.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14a1a858f532119338887a4b8e1af9c60de8249cd7bafd68036a489e261e37b6" +checksum = "c1458a1df40e1e2afebb7ab60ce55c1fa8f431146205aa5f4887e0b111c27636" dependencies = [ "anstream", "anstyle", @@ -366,21 +404,21 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.2.0" +version = "4.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f9644cd56d6b87dbe899ef8b053e331c0637664e9e21a33dfcdc36093f5c5c4" +checksum = "b8cd2b2a819ad6eec39e8f1d6b53001af1e5469f8c177579cdaeb313115b825f" dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.15", + "syn 2.0.18", ] [[package]] name = "clap_lex" -version = "0.4.1" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a2dd5a6fe8c6e3502f568a6353e5273bbb15193ad9a89e457b9970798efbea1" +checksum = "2da6da31387c7e4ef160ffab6d5e7f00c42626fe39aea70a7b0f1773f7dd6c1b" [[package]] name = "colorchoice" @@ -390,15 +428,15 @@ checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" [[package]] name = "console" -version = "0.15.5" +version = "0.15.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3d79fbe8970a77e3e34151cc13d3b3e248aa0faaecb9f6091fa07ebefe5ad60" +checksum = "c926e00cc70edefdc64d3a5ff31cc65bb97a3460097762bd23afb4d8145fccf8" dependencies = [ "encode_unicode", "lazy_static", "libc", "unicode-width", - "windows-sys 0.42.0", + "windows-sys 0.45.0", ] [[package]] @@ -425,9 +463,9 @@ checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" [[package]] name = "cpufeatures" -version = "0.2.6" +version = "0.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "280a9f2d8b3a38871a3c8a46fb80db65e5e5ed97da80c4d08bf27fb63e35e181" +checksum = "03e69e28e9f7f77debdedbaafa2866e1de9ba56df55a8bd7cfc724c25a09987c" dependencies = [ "libc", ] @@ -464,9 +502,9 @@ dependencies = [ [[package]] name = "crossbeam-epoch" -version = "0.9.14" +version = "0.9.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46bd5f3f85273295a9d14aedfb86f6aadbff6d8f5295c4a9edb08e819dcf5695" +checksum = "ae211234986c545741a7dc064309f67ee1e5ad243d0e48335adc0484d960bcc7" dependencies = [ "autocfg", "cfg-if", @@ -477,9 +515,9 @@ dependencies = [ [[package]] name = "crossbeam-utils" -version = "0.8.15" +version = "0.8.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c063cd8cc95f5c377ed0d4b49a4b21f632396ff690e8470c29b3359b346984b" +checksum = "5a22b2d63d4d1dc0b7f1b6b2747dd0088008a9be28b6ddf0b1e7d335e3037294" dependencies = [ "cfg-if", ] @@ -502,9 +540,9 @@ dependencies = [ [[package]] name = "crossterm_winapi" -version = "0.9.0" +version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ae1b35a484aa10e07fe0638d02301c5ad24de82d310ccbd2f3693da5f09bf1c" +checksum = "acdd7c62a3665c7f6830a51635d9ac9b23ed385797f70a83bb8bafe9c572ab2b" dependencies = [ "winapi", ] @@ -521,12 +559,12 @@ dependencies = [ [[package]] name = "ctrlc" -version = "3.2.5" +version = "3.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbcf33c2a618cbe41ee43ae6e9f2e48368cd9f9db2896f10167d8d762679f639" +checksum = "2a011bbe2c35ce9c1f143b7af6f94f29a167beb4cd1d29e6740ce836f723120e" dependencies = [ "nix", - "windows-sys 0.45.0", + "windows-sys 0.48.0", ] [[package]] @@ -610,9 +648,9 @@ dependencies = [ [[package]] name = "digest" -version = "0.10.6" +version = "0.10.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8168378f4e5023e7218c89c891c0fd8ecdb5e5e4f18cb78f38cf245dd021e76f" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", "crypto-common", @@ -728,9 +766,9 @@ checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" [[package]] name = "flate2" -version = "1.0.25" +version = "1.0.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8a2db397cb1c8772f31494cb8917e48cd1e64f0fa7efac59fbd741a0a8ce841" +checksum = "3b9429470923de8e8cbd4d2dc513535400b4b3fef0319fb5c4e1f520a7bef743" dependencies = [ "crc32fast", "miniz_oxide", @@ -758,7 +796,7 @@ dependencies = [ "futures-sink", "nanorand", "pin-project", - "spin", + "spin 0.9.8", ] [[package]] @@ -784,9 +822,9 @@ checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" [[package]] name = "form_urlencoded" -version = "1.1.0" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9c384f161156f5260c24a097c56119f9be8c798586aecc13afbcbe7b7e26bf8" +checksum = "a62bc1cf6f830c2ec14a513a9fb124d0a213a629668a4186f329db21fe045652" dependencies = [ "percent-encoding", ] @@ -857,7 +895,7 @@ checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" dependencies = [ "proc-macro2", "quote", - "syn 2.0.15", + "syn 2.0.18", ] [[package]] @@ -890,6 +928,15 @@ dependencies = [ "slab", ] +[[package]] +name = "generational-arena" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877e94aff08e743b651baaea359664321055749b398adff8740a7399af7796e7" +dependencies = [ + "cfg-if", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -902,9 +949,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.9" +version = "0.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c85e1d9ab2eadba7e5040d4e09cbd6d072b76a557ad64e797c2cb9d4da21d7e4" +checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" dependencies = [ "cfg-if", "js-sys", @@ -931,9 +978,9 @@ dependencies = [ [[package]] name = "h2" -version = "0.3.18" +version = "0.3.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17f8a914c2987b688368b5138aa05321db91f4090cf26118185672ad588bce21" +checksum = "d357c7ae988e7d2182f7d7871d0b963962420b0678b0997ce7de72001aeab782" dependencies = [ "bytes", "fnv", @@ -987,6 +1034,17 @@ dependencies = [ "digest", ] +[[package]] +name = "hostname" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c731c3e10504cc8ed35cfe2f1db4c9274c3d35fa486e3b31df46f068ef3e867" +dependencies = [ + "libc", + "match_cfg", + "winapi", +] + [[package]] name = "http" version = "0.2.9" @@ -1084,9 +1142,9 @@ checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" [[package]] name = "idna" -version = "0.3.0" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e14ddfc70884202db2244c223200c204c2bda1bc6e0998d11b5e024d657209e6" +checksum = "7d20d6b07bfbc108882d88ed8e37d39636dcc260e15e30c45e6ba089610b917c" dependencies = [ "unicode-bidi", "unicode-normalization", @@ -1127,6 +1185,15 @@ dependencies = [ "regex", ] +[[package]] +name = "inout" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0c10553d664a4d0bcff9f4215d0aac67a639cc68ef660840afe309b807bc9f5" +dependencies = [ + "generic-array", +] + [[package]] name = "instant" version = "0.1.12" @@ -1138,9 +1205,9 @@ dependencies = [ [[package]] name = "io-lifetimes" -version = "1.0.10" +version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c66c74d2ae7e79a5a8f7ac924adbe38ee42a859c6539ad869eb51f0b52dc220" +checksum = "eae7b9aee968036d54dce06cebaefd919e4472e753296daccd6d344e3e2df0c2" dependencies = [ "hermit-abi 0.3.1", "libc", @@ -1209,9 +1276,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.61" +version = "0.3.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "445dde2150c55e483f3d8416706b97ec8e8237c307e5b7b4b8dd15e6af2a0730" +checksum = "c5f195fe497f702db0f318b07fdd68edb16955aed830df8363d837542f8f935a" dependencies = [ "wasm-bindgen", ] @@ -1224,27 +1291,27 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.141" +version = "0.2.146" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3304a64d199bb964be99741b7a14d26972741915b3649639149b2479bb46f4b5" +checksum = "f92be4933c13fd498862a9e02a3055f8a8d9c039ce33db97306fd5a6caa7f29b" [[package]] name = "libm" -version = "0.2.6" +version = "0.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "348108ab3fba42ec82ff6e9564fc4ca0247bdccdc68dd8af9764bbc79c3c8ffb" +checksum = "f7012b1bbb0719e1097c47611d3898568c546d597c2e74d66f6087edd5233ff4" [[package]] name = "linux-raw-sys" -version = "0.3.1" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d59d8c75012853d2e872fb56bc8a2e53718e2cafe1a4c823143141c6d90c322f" +checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519" [[package]] name = "lock_api" -version = "0.4.9" +version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "435011366fe56583b16cf956f9df0095b405b82d76425bc8981c0e22e60ec4df" +checksum = "c1cc9717a20b1bb222f333e6a92fd32f7d8a18ddc5a3191a11af45dcbf4dcd16" dependencies = [ "autocfg", "scopeguard", @@ -1252,12 +1319,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.17" +version = "0.4.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "abb12e687cfb44aa40f41fc3978ef76448f9b6038cad6aef4259d3c095a2382e" -dependencies = [ - "cfg-if", -] +checksum = "b06a4cde4c0f271a446782e3eff8de789548ce57dbc8eca9292c27f4a42004b4" [[package]] name = "mach" @@ -1284,6 +1348,12 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "58093314a45e00c77d5c508f76e77c3396afbbc0d01506e7fae47b018bac2b1d" +[[package]] +name = "match_cfg" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffbee8634e0d45d258acb448e7eaab3fce7a0a467395d4d9f228e3c1f01fb2e4" + [[package]] name = "matchers" version = "0.1.0" @@ -1307,9 +1377,9 @@ checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" [[package]] name = "memoffset" -version = "0.8.0" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d61c719bcfbcf5d62b3a09efa6088de8c54bc0bfcd3ea7ae39fcc186108b8de1" +checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" dependencies = [ "autocfg", ] @@ -1322,7 +1392,7 @@ checksum = "7b9b8653cec6897f73b519a43fba5ee3d50f62fe9af80b428accdcc093b4a849" dependencies = [ "ahash", "metrics-macros", - "portable-atomic", + "portable-atomic 0.3.20", ] [[package]] @@ -1337,7 +1407,7 @@ dependencies = [ "metrics", "metrics-util", "parking_lot", - "portable-atomic", + "portable-atomic 0.3.20", "quanta", "thiserror", "tokio", @@ -1367,7 +1437,7 @@ dependencies = [ "metrics", "num_cpus", "parking_lot", - "portable-atomic", + "portable-atomic 0.3.20", "quanta", "sketches-ddsketch", ] @@ -1396,23 +1466,23 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.6.2" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b275950c28b37e794e8c55d88aeb5e139d0ce23fdbbeda68f8d7174abdf9e8fa" +checksum = "e7810e0be55b428ada41041c41f32c9f1a42817901b4ccf45fa3d4b6561e74c7" dependencies = [ "adler", ] [[package]] name = "mio" -version = "0.8.6" +version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b9d9a46eff5b4ff64b45a9e316a6d1e0bc719ef429cbec4dc630684212bfdf9" +checksum = "927a765cd3fc26206e66b296465fa9d3e5ab003e651c1b3c060e7956d96b19d2" dependencies = [ "libc", "log", "wasi 0.11.0+wasi-snapshot-preview1", - "windows-sys 0.45.0", + "windows-sys 0.48.0", ] [[package]] @@ -1433,7 +1503,7 @@ checksum = "8795add3e14028f11f8e848bd3294898a8294767b3776b6f733560d33bd2530b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.15", + "syn 2.0.18", ] [[package]] @@ -1442,6 +1512,25 @@ version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a" +[[package]] +name = "muxado" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e92b89ac3127251efde6f5a9586e5aae99468d06fcf9f133b377f58d5ed66446" +dependencies = [ + "async-trait", + "awaitdrop", + "bitflags", + "bytes", + "futures", + "pin-project", + "rand", + "thiserror", + "tokio", + "tokio-util", + "tracing", +] + [[package]] name = "nanorand" version = "0.7.0" @@ -1469,6 +1558,37 @@ dependencies = [ "tempfile", ] +[[package]] +name = "ngrok" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "98ce3514eec7338e2d4663e3efb4429e08d8f3662996be4b9585350e7d8ad728" +dependencies = [ + "arc-swap", + "async-rustls", + "async-trait", + "awaitdrop", + "axum", + "base64 0.13.1", + "bytes", + "futures", + "hostname", + "hyper", + "muxado", + "once_cell", + "parking_lot", + "regex", + "rustls-pemfile", + "serde", + "serde_json", + "thiserror", + "tokio", + "tokio-retry", + "tokio-util", + "tracing", + "windows-sys 0.45.0", +] + [[package]] name = "nix" version = "0.26.2" @@ -1550,9 +1670,9 @@ checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" [[package]] name = "once_cell" -version = "1.17.1" +version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3" +checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" [[package]] name = "onig" @@ -1576,17 +1696,11 @@ dependencies = [ "pkg-config", ] -[[package]] -name = "opaque-debug" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" - [[package]] name = "openssl" -version = "0.10.50" +version = "0.10.54" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e30d8bc91859781f0a943411186324d580f2bbeb71b452fe91ae344806af3f1" +checksum = "69b3f656a17a6cbc115b5c7a40c616947d213ba182135b014d6051b73ab6f019" dependencies = [ "bitflags", "cfg-if", @@ -1605,7 +1719,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.15", + "syn 2.0.18", ] [[package]] @@ -1616,9 +1730,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-sys" -version = "0.9.85" +version = "0.9.88" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d3d193fb1488ad46ffe3aaabc912cc931d02ee8518fe2959aea8ef52718b0c0" +checksum = "c2ce0f250f34a308dcfdbb351f511359857d4ed2134ba715a4eadd46e1ffd617" dependencies = [ "cc", "libc", @@ -1714,9 +1828,9 @@ checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" [[package]] name = "papergrid" -version = "0.9.0" +version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fdfe703c51ddc52887ad78fc69cd2ea78d895ffcd6e955c9d03566db8ab5bb1" +checksum = "ae7891b22598926e4398790c8fe6447930c72a67d36d983a49d6ce682ce83290" dependencies = [ "bytecount", "fnv", @@ -1735,15 +1849,15 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.9.7" +version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9069cbb9f99e3a5083476ccb29ceb1de18b9118cafa53e90c9551235de2b9521" +checksum = "93f00c865fe7cabf650081affecd3871070f26767e7b2070a3ffae14c654b447" dependencies = [ "cfg-if", "libc", - "redox_syscall 0.2.16", + "redox_syscall 0.3.5", "smallvec", - "windows-sys 0.45.0", + "windows-targets 0.48.0", ] [[package]] @@ -1777,9 +1891,9 @@ dependencies = [ [[package]] name = "percent-encoding" -version = "2.2.0" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e" +checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" [[package]] name = "petgraph" @@ -1793,22 +1907,22 @@ dependencies = [ [[package]] name = "pin-project" -version = "1.0.12" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad29a609b6bcd67fee905812e544992d216af9d755757c05ed2d0e15a74c6ecc" +checksum = "c95a7476719eab1e366eaf73d0260af3021184f18177925b07f54b30089ceead" dependencies = [ "pin-project-internal", ] [[package]] name = "pin-project-internal" -version = "1.0.12" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "069bdb1e05adc7a8990dce9cc75370895fbe4e3d58b9b73bf1aee56359344a55" +checksum = "39407670928234ebc5e6e580247dd567ad73a3578460c5990f9503df207e8f07" dependencies = [ "proc-macro2", "quote", - "syn 1.0.109", + "syn 2.0.18", ] [[package]] @@ -1825,15 +1939,24 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "pkg-config" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160" +checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" [[package]] name = "portable-atomic" -version = "0.3.19" +version = "0.3.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26f6a7b87c2e435a3241addceeeff740ff8b7e76b74c13bf9acb17fa454ea00b" +checksum = "e30165d31df606f5726b090ec7592c308a0eaf61721ff64c9a3018e344a8753e" +dependencies = [ + "portable-atomic 1.3.3", +] + +[[package]] +name = "portable-atomic" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "767eb9f07d4a5ebcb39bbf2d452058a93c011373abf6832e24194a1c3f004794" [[package]] name = "ppv-lite86" @@ -1877,9 +2000,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.56" +version = "1.0.60" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b63bdb0cd06f1f4dedf69b254734f9b45af66e4a031e42a7480257d9898b435" +checksum = "dec2b086b7a862cf4de201096214fa870344cf922b2b30c167badb3af3195406" dependencies = [ "unicode-ident", ] @@ -1956,9 +2079,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.26" +version = "1.0.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4424af4bf778aae2051a77b60283332f386554255d722233d09fbfc7e30da2fc" +checksum = "1b9ab9c7eadfd8df19006f1cf1a4aed13540ed5cbc047010ece5826e10825488" dependencies = [ "proc-macro2", ] @@ -2079,13 +2202,13 @@ dependencies = [ [[package]] name = "regex" -version = "1.7.3" +version = "1.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b1f693b24f6ac912f4893ef08244d70b6067480d2f1a46e950c9691e6749d1d" +checksum = "d0ab3ca65655bb1e41f2a8c8cd662eb4fb035e67c3f78da1d61dffe89d07300f" dependencies = [ - "aho-corasick", + "aho-corasick 1.0.2", "memchr", - "regex-syntax", + "regex-syntax 0.7.2", ] [[package]] @@ -2094,7 +2217,7 @@ version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" dependencies = [ - "regex-syntax", + "regex-syntax 0.6.29", ] [[package]] @@ -2104,12 +2227,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" [[package]] -name = "reqwest" -version = "0.11.16" +name = "regex-syntax" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "27b71749df584b7f4cac2c426c127a7c785a5106cc98f7a8feb044115f0fa254" +checksum = "436b050e76ed2903236f032a59761c1eb99e1b0aead2c257922771dab1fc8c78" + +[[package]] +name = "reqwest" +version = "0.11.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cde824a14b7c14f85caff81225f411faacc04a2013f41670f41443742b1c1c55" dependencies = [ - "base64 0.21.0", + "base64 0.21.2", "bytes", "encoding_rs", "futures-core", @@ -2141,10 +2270,25 @@ dependencies = [ ] [[package]] -name = "rust-embed" -version = "6.6.1" +name = "ring" +version = "0.16.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b68543d5527e158213414a92832d2aab11a84d2571a5eb021ebe22c43aab066" +checksum = "3053cf52e236a3ed746dfc745aa9cacf1b791d846bdaf412f60a8d7d6e17c8fc" +dependencies = [ + "cc", + "libc", + "once_cell", + "spin 0.5.2", + "untrusted", + "web-sys", + "winapi", +] + +[[package]] +name = "rust-embed" +version = "6.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b73e721f488c353141288f223b599b4ae9303ecf3e62923f40a492f0634a4dc3" dependencies = [ "rust-embed-impl", "rust-embed-utils", @@ -2153,15 +2297,15 @@ dependencies = [ [[package]] name = "rust-embed-impl" -version = "6.5.0" +version = "6.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d4e0f0ced47ded9a68374ac145edd65a6c1fa13a96447b873660b2a568a0fd7" +checksum = "e22ce362f5561923889196595504317a4372b84210e6e335da529a65ea5452b5" dependencies = [ "proc-macro2", "quote", "rust-embed-utils", "shellexpand", - "syn 1.0.109", + "syn 2.0.18", "walkdir", ] @@ -2186,9 +2330,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.37.11" +version = "0.37.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85597d61f83914ddeba6a47b3b8ffe7365107221c2e557ed94426489fefb5f77" +checksum = "b96e891d04aa506a6d1f318d2771bcb1c7dfda84e126660ace067c9b474bb2c0" dependencies = [ "bitflags", "errno", @@ -2198,6 +2342,27 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "rustls" +version = "0.20.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fff78fc74d175294f4e83b28343315ffcfb114b156f0185e9741cb5570f50e2f" +dependencies = [ + "log", + "ring", + "sct", + "webpki", +] + +[[package]] +name = "rustls-pemfile" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b" +dependencies = [ + "base64 0.21.2", +] + [[package]] name = "rustversion" version = "1.0.12" @@ -2235,10 +2400,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" [[package]] -name = "security-framework" -version = "2.8.2" +name = "sct" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a332be01508d814fed64bf28f798a146d73792121129962fdf335bb3c49a4254" +checksum = "d53dcdb7c9f8158937a7981b48accfd39a43af418591a5d008c7b22b5e1b7ca4" +dependencies = [ + "ring", + "untrusted", +] + +[[package]] +name = "security-framework" +version = "2.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fc758eb7bffce5b308734e9b0c1468893cae9ff70ebf13e7090be8dcbcc83a8" dependencies = [ "bitflags", "core-foundation", @@ -2249,9 +2424,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.8.0" +version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31c9bb296072e961fcbd8853511dd39c2d8be2deb1e17c6860b1d30732b323b4" +checksum = "f51d0c0d83bec45f16480d0ce0058397a69e48fcdc52d1dc8855fb68acbd31a7" dependencies = [ "core-foundation-sys", "libc", @@ -2265,29 +2440,29 @@ checksum = "bebd363326d05ec3e2f532ab7660680f3b02130d780c299bca73469d521bc0ed" [[package]] name = "serde" -version = "1.0.160" +version = "1.0.164" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb2f3770c8bce3bcda7e149193a069a0f4365bda1fa5cd88e03bca26afc1216c" +checksum = "9e8c8cf938e98f769bc164923b06dce91cea1751522f46f8466461af04c9027d" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.160" +version = "1.0.164" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "291a097c63d8497e00160b166a967a4a79c64f3facdd01cbd7502231688d77df" +checksum = "d9735b638ccc51c28bf6914d90a2e9725b377144fc612c49a611fddd1b631d68" dependencies = [ "proc-macro2", "quote", - "syn 2.0.15", + "syn 2.0.18", ] [[package]] name = "serde_json" -version = "1.0.96" +version = "1.0.97" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "057d394a50403bcac12672b2b18fb387ab6d289d957dab67dd201875391e52f1" +checksum = "bdf3bf93142acad5821c99197022e170842cdbc1c30482b98750c688c640842a" dependencies = [ "itoa", "ryu", @@ -2328,9 +2503,9 @@ dependencies = [ [[package]] name = "sha2" -version = "0.10.6" +version = "0.10.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82e6b795fe2e3b1e845bafcb27aa35405c4d47cdfc92af5fc8d3002f76cebdc0" +checksum = "479fb9d862239e610720565ca91403019f2f00410f1864c5aa7479b950a76ed8" dependencies = [ "cfg-if", "cpufeatures", @@ -2400,6 +2575,15 @@ dependencies = [ "autocfg", ] +[[package]] +name = "slotmap" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1e08e261d0e8f5c43123b7adf3e4ca1690d655377ac93a03b2c9d3e98de1342" +dependencies = [ + "version_check", +] + [[package]] name = "smallvec" version = "1.10.0" @@ -2416,6 +2600,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "spin" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" + [[package]] name = "spin" version = "0.9.8" @@ -2461,9 +2651,9 @@ dependencies = [ [[package]] name = "subtle" -version = "2.4.1" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6bdef32e8150c2a081110b42772ffe7d7c9032b606bc226c8260fd97e0976601" +checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" [[package]] name = "syn" @@ -2478,9 +2668,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.15" +version = "2.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a34fcf3e8b60f57e6a14301a2e916d323af98b0ea63c599441eec8558660c822" +checksum = "32d41677bcbe24c20c52e7c70b0d8db04134c5d1066bf98662e2871ad200ea3e" dependencies = [ "proc-macro2", "quote", @@ -2495,9 +2685,9 @@ checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" [[package]] name = "sysinfo" -version = "0.28.4" +version = "0.29.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4c2f3ca6693feb29a89724516f016488e9aafc7f37264f898593ee4b942f31b" +checksum = "9557d0845b86eea8182f7b10dff120214fb6cd9fd937b6f4917714e546a38695" dependencies = [ "cfg-if", "core-foundation-sys", @@ -2509,9 +2699,9 @@ dependencies = [ [[package]] name = "tabled" -version = "0.12.0" +version = "0.12.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da1a2e56bbf7bfdd08aaa7592157a742205459eff774b73bc01809ae2d99dc2a" +checksum = "0ce69a5028cd9576063ec1f48edb2c75339fd835e6094ef3e05b3a079bf594a6" dependencies = [ "papergrid", "tabled_derive", @@ -2544,15 +2734,16 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.5.0" +version = "3.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9fbec84f381d5795b08656e4912bec604d162bff9291d6189a78f4c8ab87998" +checksum = "31c0432476357e58790aaa47a8efb0c5138f137343f3b5f23bd36a27e3b0a6d6" dependencies = [ + "autocfg", "cfg-if", "fastrand", "redox_syscall 0.3.5", "rustix", - "windows-sys 0.45.0", + "windows-sys 0.48.0", ] [[package]] @@ -2619,6 +2810,7 @@ dependencies = [ "futures", "metrics", "metrics-exporter-prometheus", + "ngrok", "nohash-hasher", "opentelemetry", "opentelemetry-otlp", @@ -2656,7 +2848,7 @@ checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.15", + "syn 2.0.18", ] [[package]] @@ -2671,9 +2863,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.20" +version = "0.3.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd0cbfecb4d19b5ea75bb31ad904eb5b9fa13f21079c3b92017ebdf4999a5890" +checksum = "ea9e1b3cf1243ae005d9e74085d4d542f3125458f3a81af210d901dcd7411efd" dependencies = [ "itoa", "serde", @@ -2683,15 +2875,15 @@ dependencies = [ [[package]] name = "time-core" -version = "0.1.0" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e153e1f1acaef8acc537e68b44906d2db6436e2b35ac2c6b42640fff91f00fd" +checksum = "7300fbefb4dadc1af235a9cef3737cea692a9d97e1b9cbcd4ebdae6f8868e6fb" [[package]] name = "time-macros" -version = "0.2.8" +version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd80a657e71da814b8e5d60d3374fc6d35045062245d80224748ae522dd76f36" +checksum = "372950940a5f07bf38dbe211d7283c9e6d7327df53794992d293e534c733d09b" dependencies = [ "time-core", ] @@ -2717,7 +2909,7 @@ version = "0.13.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5cf49017523bf0bc01c9966f172c5f120bbb7b96cccd1708772dd42e767fb9f5" dependencies = [ - "aho-corasick", + "aho-corasick 0.7.20", "cached-path", "clap", "derive_builder", @@ -2736,7 +2928,7 @@ dependencies = [ "rayon", "rayon-cond", "regex", - "regex-syntax", + "regex-syntax 0.6.29", "reqwest", "serde", "serde_json", @@ -2749,9 +2941,9 @@ dependencies = [ [[package]] name = "tokio" -version = "1.27.0" +version = "1.28.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0de47a4eecbe11f498978a9b29d792f0d2692d1dd003650c24c76510e3bc001" +checksum = "94d7b1cfd2aa4011f2de74c2c4c63665e27a71006b0a192dcd2710272e73dfa2" dependencies = [ "autocfg", "bytes", @@ -2763,7 +2955,7 @@ dependencies = [ "signal-hook-registry", "socket2", "tokio-macros", - "windows-sys 0.45.0", + "windows-sys 0.48.0", ] [[package]] @@ -2778,13 +2970,13 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "2.0.0" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61a573bdc87985e9d6ddeed1b3d864e8a302c847e40d647746df2f1de209d1ce" +checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.15", + "syn 2.0.18", ] [[package]] @@ -2798,10 +2990,21 @@ dependencies = [ ] [[package]] -name = "tokio-stream" -version = "0.1.12" +name = "tokio-retry" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fb52b74f05dbf495a8fba459fdc331812b96aa086d9eb78101fa0d4569c3313" +checksum = "7f57eb36ecbe0fc510036adff84824dd3c24bb781e21bfa67b69d556aa85214f" +dependencies = [ + "pin-project", + "rand", + "tokio", +] + +[[package]] +name = "tokio-stream" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "397c988d37662c7dda6d2208364a706264bf3d6138b11d436cbac0ad38832842" dependencies = [ "futures-core", "pin-project-lite", @@ -2810,12 +3013,13 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.7" +version = "0.7.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5427d89453009325de0d8f342c9490009f76e999cb7672d77e46267448f7e6b2" +checksum = "806fe8c2c87eccc8b3267cbae29ed3ab2d0bd37fca70ab622e46aaa9375ddb7d" dependencies = [ "bytes", "futures-core", + "futures-io", "futures-sink", "pin-project-lite", "tokio", @@ -2951,20 +3155,20 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.23" +version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4017f8f45139870ca7e672686113917c71c7a6e02d4924eda67186083c03081a" +checksum = "0f57e3ca2a01450b1a921183a9c9cbfda207fd822cef4ccb00a65402cbba7a74" dependencies = [ "proc-macro2", "quote", - "syn 1.0.109", + "syn 2.0.18", ] [[package]] name = "tracing-core" -version = "0.1.30" +version = "0.1.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24eb03ba0eab1fd845050058ce5e616558e8f8d8fca633e6b163fe25c797213a" +checksum = "0955b8137a1df6f1a2e9a37d8a6656291ff0297c1a97c24e0d8425fe2312f79a" dependencies = [ "once_cell", "valuable", @@ -3017,9 +3221,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.16" +version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6176eae26dd70d0c919749377897b54a9276bd7061339665dd68777926b5a70" +checksum = "30a651bc37f915e81f087d86e62a18eec5f79550c7faff886f7090b4ea757c77" dependencies = [ "matchers", "nu-ansi-term", @@ -3065,9 +3269,9 @@ checksum = "92888ba5573ff080736b3648696b70cafad7d250551175acbaa4e0385b3e1460" [[package]] name = "unicode-ident" -version = "1.0.8" +version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5464a87b239f13a63a501f2701565754bae92d243d4bb7eb12f6d57d2269bf4" +checksum = "b15811caf2415fb889178633e7724bad2509101cde276048e013b9def5e51fa0" [[package]] name = "unicode-normalization" @@ -3106,10 +3310,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" [[package]] -name = "url" -version = "2.3.1" +name = "untrusted" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d68c799ae75762b8c3fe375feb6600ef5602c883c5d21eb51c09f22b83c4643" +checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" + +[[package]] +name = "url" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50bff7831e19200a85b17131d085c25d7811bc4e186efdaf54bbd132994a88cb" dependencies = [ "form_urlencoded", "idna", @@ -3143,7 +3353,7 @@ dependencies = [ "proc-macro-error", "proc-macro2", "quote", - "syn 2.0.15", + "syn 2.0.18", ] [[package]] @@ -3176,9 +3386,9 @@ checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" [[package]] name = "vergen" -version = "8.1.1" +version = "8.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1b86a8af1dedf089b1c78338678e4c7492b6045649042d94faf19690499d236" +checksum = "8b3c89c2c7e50f33e4d35527e5bf9c11d6d132226dbbd1753f0fbe9f19ef88c6" dependencies = [ "anyhow", "rustc_version", @@ -3205,11 +3415,10 @@ dependencies = [ [[package]] name = "want" -version = "0.3.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ce8a968cb1cd110d136ff8b819a556d6fb6d919363c61534f6860c7eb172ba0" +checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" dependencies = [ - "log", "try-lock", ] @@ -3227,9 +3436,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.84" +version = "0.2.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31f8dcbc21f30d9b8f2ea926ecb58f6b91192c17e9d33594b3df58b2007ca53b" +checksum = "7706a72ab36d8cb1f80ffbf0e071533974a60d0a308d01a5d0375bf60499a342" dependencies = [ "cfg-if", "wasm-bindgen-macro", @@ -3237,24 +3446,24 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.84" +version = "0.2.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95ce90fd5bcc06af55a641a86428ee4229e44e07033963a2290a8e241607ccb9" +checksum = "5ef2b6d3c510e9625e5fe6f509ab07d66a760f0885d858736483c32ed7809abd" dependencies = [ "bumpalo", "log", "once_cell", "proc-macro2", "quote", - "syn 1.0.109", + "syn 2.0.18", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-futures" -version = "0.4.34" +version = "0.4.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f219e0d211ba40266969f6dbdd90636da12f75bee4fc9d6c23d1260dadb51454" +checksum = "c02dbc21516f9f1f04f187958890d7e6026df8d16540b7ad9492bc34a67cea03" dependencies = [ "cfg-if", "js-sys", @@ -3264,9 +3473,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.84" +version = "0.2.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c21f77c0bedc37fd5dc21f897894a5ca01e7bb159884559461862ae90c0b4c5" +checksum = "dee495e55982a3bd48105a7b947fd2a9b4a8ae3010041b9e0faab3f9cd028f1d" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -3274,33 +3483,43 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.84" +version = "0.2.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2aff81306fcac3c7515ad4e177f521b5c9a15f2b08f4e32d823066102f35a5f6" +checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" dependencies = [ "proc-macro2", "quote", - "syn 1.0.109", + "syn 2.0.18", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.84" +version = "0.2.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0046fef7e28c3804e5e38bfa31ea2a0f73905319b677e57ebe37e49358989b5d" +checksum = "ca6ad05a4870b2bf5fe995117d3728437bd27d7cd5f06f13c17443ef369775a1" [[package]] name = "web-sys" -version = "0.3.61" +version = "0.3.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e33b99f4b23ba3eec1a53ac264e35a755f00e966e0065077d6027c0f575b0b97" +checksum = "9b85cbef8c220a6abc02aefd892dfc0fc23afb1c6a426316ec33253a3877249b" dependencies = [ "js-sys", "wasm-bindgen", ] +[[package]] +name = "webpki" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f095d78192e208183081cc07bc5515ef55216397af48b873e5edcd72637fa1bd" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "which" version = "4.4.0" @@ -3510,9 +3729,9 @@ dependencies = [ [[package]] name = "zip" -version = "0.6.4" +version = "0.6.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0445d0fbc924bb93539b4316c11afb121ea39296f99a3c4c9edad09e3658cdef" +checksum = "760394e246e4c28189f19d488c058bf16f564016aefac5d32bb1f3b51d5e9261" dependencies = [ "aes", "byteorder", diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 36f6f6b6..2e2bc7a5 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -229,6 +229,26 @@ struct Args { #[clap(long, env)] watermark_delta: Option, + /// Enable ngrok tunneling + #[clap(long, env)] + ngrok: bool, + + /// ngrok authentication token + #[clap(long, env)] + ngrok_authtoken: Option, + + /// ngrok domain name where the axum webserver will be available at + #[clap(long, env)] + ngrok_domain: Option, + + /// ngrok basic auth username + #[clap(long, env)] + ngrok_username: Option, + + /// ngrok basic auth password + #[clap(long, env)] + ngrok_password: Option, + /// Display a lot of information about your runtime environment #[clap(long, short, action)] env: bool, @@ -845,6 +865,30 @@ fn spawn_webserver( argv.push(origin); } + // Ngrok + if args.ngrok { + let authtoken = args.ngrok_authtoken.ok_or_else(|| { + tracing::error!("`ngrok-authtoken` must be set when using ngrok tunneling"); + LauncherError::WebserverCannotStart + })?; + + argv.push("--ngrok".to_string()); + argv.push("--ngrok-authtoken".to_string()); + argv.push(authtoken); + + if let Some(domain) = args.ngrok_domain { + argv.push("--ngrok-domain".to_string()); + argv.push(domain); + } + + if let (Some(username), Some(password)) = (args.ngrok_username, args.ngrok_password) { + argv.push("--ngrok-username".to_string()); + argv.push(username); + argv.push("--ngrok-password".to_string()); + argv.push(password); + } + } + // Copy current process env let mut env: Vec<(OsString, OsString)> = env::vars_os().collect(); diff --git a/router/Cargo.toml b/router/Cargo.toml index 6503e1bd..c1e665b1 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -40,6 +40,11 @@ tracing-opentelemetry = "0.18.0" tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] } utoipa = { version = "3.0.1", features = ["axum_extras"] } utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] } +ngrok = { version = "0.12.3", features = ["axum"], optional = true } [build-dependencies] vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] } + +[features] +default = ["ngrok"] +ngrok = ["dep:ngrok"] \ No newline at end of file diff --git a/router/src/main.rs b/router/src/main.rs index 82bf6ba8..7bbb6477 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -56,6 +56,16 @@ struct Args { otlp_endpoint: Option, #[clap(long, env)] cors_allow_origin: Option>, + #[clap(long, env)] + ngrok: bool, + #[clap(long, env)] + ngrok_authtoken: Option, + #[clap(long, env)] + ngrok_domain: Option, + #[clap(long, env)] + ngrok_username: Option, + #[clap(long, env)] + ngrok_password: Option, } fn main() -> Result<(), std::io::Error> { @@ -80,6 +90,11 @@ fn main() -> Result<(), std::io::Error> { json_output, otlp_endpoint, cors_allow_origin, + ngrok, + ngrok_authtoken, + ngrok_domain, + ngrok_username, + ngrok_password, } = args; if validation_workers == 0 { @@ -198,6 +213,11 @@ fn main() -> Result<(), std::io::Error> { validation_workers, addr, cors_allow_origin, + ngrok, + ngrok_authtoken, + ngrok_domain, + ngrok_username, + ngrok_password, ) .await; Ok(()) diff --git a/router/src/queue.rs b/router/src/queue.rs index 03807933..0586083d 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -49,7 +49,7 @@ impl Queue { // Send append command to the background task managing the state // Unwrap is safe here self.queue_sender - .send(QueueCommand::Append(entry, Span::current())) + .send(QueueCommand::Append(Box::new(entry), Span::current())) .unwrap(); } @@ -85,7 +85,7 @@ async fn queue_task(requires_padding: bool, receiver: flume::Receiver { - span.in_scope(|| state.append(entry)); + span.in_scope(|| state.append(*entry)); metrics::increment_gauge!("tgi_queue_size", 1.0); } QueueCommand::NextBatch { @@ -256,7 +256,7 @@ type NextBatch = (IntMap, Batch, Span); #[derive(Debug)] enum QueueCommand { - Append(Entry, Span), + Append(Box, Span), NextBatch { min_size: Option, token_budget: u32, diff --git a/router/src/server.rs b/router/src/server.rs index 10c0ba3c..b8c67b2c 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,5 +1,5 @@ -use crate::health::Health; /// HTTP Server logic +use crate::health::Health; use crate::infer::{InferError, InferResponse, InferStreamResponse}; use crate::validation::ValidationError; use crate::{ @@ -520,6 +520,11 @@ pub async fn run( validation_workers: usize, addr: SocketAddr, allow_origin: Option, + ngrok: bool, + ngrok_authtoken: Option, + ngrok_domain: Option, + ngrok_username: Option, + ngrok_password: Option, ) { // OpenAPI documentation #[derive(OpenApi)] @@ -683,13 +688,61 @@ pub async fn run( .layer(opentelemetry_tracing_layer()) .layer(cors_layer); - // Run server - axum::Server::bind(&addr) - .serve(app.into_make_service()) - // Wait until all requests are finished to shut down - .with_graceful_shutdown(shutdown_signal()) - .await - .unwrap(); + if ngrok { + #[cfg(feature = "ngrok")] + { + use ngrok::config::TunnelBuilder; + use ngrok::tunnel::UrlTunnel; + + let _ = addr; + + let authtoken = + ngrok_authtoken.expect("`ngrok-authtoken` must be set when using ngrok tunneling"); + + let mut tunnel = ngrok::Session::builder() + .authtoken(authtoken) + .connect() + .await + .unwrap() + .http_endpoint(); + + if let Some(domain) = ngrok_domain { + tunnel = tunnel.domain(domain); + } + + if let (Some(username), Some(password)) = (ngrok_username, ngrok_password) { + tunnel = tunnel.basic_auth(username, password); + } + + let listener = tunnel.listen().await.unwrap(); + + // Run server + tracing::info!("Ingress URL: {:?}", listener.url()); + axum::Server::builder(listener) + .serve(app.into_make_service()) + //Wait until all requests are finished to shut down + .with_graceful_shutdown(shutdown_signal()) + .await + .unwrap(); + } + #[cfg(not(feature = "ngrok"))] + { + let _ngrok_authtoken = ngrok_authtoken; + let _ngrok_domain = ngrok_domain; + let _ngrok_username = ngrok_username; + let _ngrok_password = ngrok_password; + + panic!("`text-generation-router` was compiled without the `ngrok` feature"); + } + } else { + // Run server + axum::Server::bind(&addr) + .serve(app.into_make_service()) + // Wait until all requests are finished to shut down + .with_graceful_shutdown(shutdown_signal()) + .await + .unwrap(); + } } /// Shutdown signal handler diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 00cc47b6..b01d752a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -47,7 +47,6 @@ def load_multi_mqa( shape = slice_.get_shape() block_size = (shape[0] - 2 * head_size) // world_size assert (shape[0] - 2 * head_size) % world_size == 0 - q_tensor = slice_[start:stop] start = rank * block_size stop = (rank + 1) * block_size q_tensor = slice_[start:stop] From ece7ffa40a7e167400e57d89b8a73751d095184c Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Mon, 19 Jun 2023 09:53:45 +0200 Subject: [PATCH 16/22] feat(server): improve flash attention import errors (#465) @lewtun, is this enough? Closes #458 Closes #456 --- .../text_generation_server/models/__init__.py | 73 ++++++++++--------- .../text_generation_server/utils/convert.py | 4 +- server/text_generation_server/utils/hub.py | 5 +- 3 files changed, 46 insertions(+), 36 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index f1b84a53..9540d99e 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -18,11 +18,43 @@ from text_generation_server.models.santacoder import SantaCoder from text_generation_server.models.t5 import T5Sharded from text_generation_server.models.gpt_neox import GPTNeoxSharded +# The flag below controls whether to allow TF32 on matmul. This flag defaults to False +# in PyTorch 1.12 and later. +torch.backends.cuda.matmul.allow_tf32 = True + +# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. +torch.backends.cudnn.allow_tf32 = True + +# Disable gradients +torch.set_grad_enabled(False) + +__all__ = [ + "Model", + "BLOOMSharded", + "CausalLM", + "FlashCausalLM", + "GalacticaSharded", + "Seq2SeqLM", + "SantaCoder", + "OPTSharded", + "T5Sharded", + "get_model", +] + +FLASH_ATT_ERROR_MESSAGE = ( + "{} requires CUDA and Flash Attention kernels to be installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + "or install flash attention with `cd server && make install install-flash-attention`" +) + try: - if ( - torch.cuda.is_available() - and not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false" - ): + if not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": + if not torch.cuda.is_available(): + FLASH_ATT_ERROR_MESSAGE = ( + "{} requires CUDA. No compatible CUDA devices found." + ) + raise ImportError("CUDA is not available") + major, minor = torch.cuda.get_device_capability() is_sm75 = major == 7 and minor == 5 is_sm8x = major == 8 and minor >= 0 @@ -30,6 +62,10 @@ try: supported = is_sm75 or is_sm8x or is_sm90 if not supported: + FLASH_ATT_ERROR_MESSAGE = ( + "{} requires a CUDA device with capability 7.5, > 8.0 or 9.0. " + "No compatible CUDA device found." + ) raise ImportError( f"GPU with CUDA capability {major} {minor} is not supported" ) @@ -52,41 +88,12 @@ except ImportError: ) FLASH_ATTENTION = False -__all__ = [ - "Model", - "BLOOMSharded", - "CausalLM", - "FlashCausalLM", - "GalacticaSharded", - "Seq2SeqLM", - "SantaCoder", - "OPTSharded", - "T5Sharded", - "get_model", -] - if FLASH_ATTENTION: __all__.append(FlashNeoXSharded) __all__.append(FlashRWSharded) __all__.append(FlashSantacoderSharded) __all__.append(FlashLlama) -FLASH_ATT_ERROR_MESSAGE = ( - "{} requires Flash Attention CUDA kernels to be installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - "or install flash attention with `cd server && make install install-flash-attention`" -) - -# The flag below controls whether to allow TF32 on matmul. This flag defaults to False -# in PyTorch 1.12 and later. -torch.backends.cuda.matmul.allow_tf32 = True - -# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. -torch.backends.cudnn.allow_tf32 = True - -# Disable gradients -torch.set_grad_enabled(False) - def get_model( model_id: str, diff --git a/server/text_generation_server/utils/convert.py b/server/text_generation_server/utils/convert.py index c43a4464..c4e79432 100644 --- a/server/text_generation_server/utils/convert.py +++ b/server/text_generation_server/utils/convert.py @@ -16,9 +16,9 @@ def check_file_size(source_file: Path, target_file: Path): source_file_size = source_file.stat().st_size target_file_size = target_file.stat().st_size - if (source_file_size - target_file_size) / source_file_size > 0.01: + if (source_file_size - target_file_size) / source_file_size > 0.05: raise RuntimeError( - f"""The file size different is more than 1%: + f"""The file size different is more than 5%: - {source_file}: {source_file_size} - {target_file}: {target_file_size} """ diff --git a/server/text_generation_server/utils/hub.py b/server/text_generation_server/utils/hub.py index 2ed7673c..fbb570a6 100644 --- a/server/text_generation_server/utils/hub.py +++ b/server/text_generation_server/utils/hub.py @@ -26,7 +26,10 @@ def weight_hub_files( filenames = [ s.rfilename for s in info.siblings - if s.rfilename.endswith(extension) and len(s.rfilename.split("/")) == 1 + if s.rfilename.endswith(extension) + and len(s.rfilename.split("/")) == 1 + and "arguments" not in s.rfilename + and "args" not in s.rfilename ] if not filenames: From 53aa9194c8c070afd19fa4660305dab2b280adf3 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Tue, 20 Jun 2023 11:06:10 +0200 Subject: [PATCH 17/22] fix(server): fix warpers on CPU (#472) Closes #471 --- .../text_generation_server/models/__init__.py | 20 ++++------- .../utils/logits_process.py | 36 +++++++++++-------- 2 files changed, 27 insertions(+), 29 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 9540d99e..3fdc23b2 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -237,20 +237,12 @@ def get_model( ) elif model_type == "t5": - if sharded: - return T5Sharded( - model_id, - revision, - quantize=quantize, - trust_remote_code=trust_remote_code, - ) - else: - return Seq2SeqLM( - model_id, - revision, - quantize=quantize, - trust_remote_code=trust_remote_code, - ) + return T5Sharded( + model_id, + revision, + quantize=quantize, + trust_remote_code=trust_remote_code, + ) if sharded: raise ValueError("sharded is not supported for AutoModel") diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index faa94516..0cbbf8b0 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -42,25 +42,31 @@ class StaticWarper: self.static_next_logprob = None def __call__(self, scores): - if self.cuda_graph is None: - self.static_scores = scores - self.cuda_graph = torch.cuda.CUDAGraph() + if torch.cuda.is_available(): + if self.cuda_graph is None: + self.static_scores = scores + self.cuda_graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(self.cuda_graph, pool=mempool): - local_scores = self.static_scores - for warper in self.warpers: - local_scores = warper(None, local_scores) + with torch.cuda.graph(self.cuda_graph, pool=mempool): + local_scores = self.static_scores + for warper in self.warpers: + local_scores = warper(None, local_scores) - self.static_warped_scores = local_scores - # Compute logprobs - self.static_next_logprob = torch.log_softmax( - self.static_warped_scores, -1 - ) + self.static_warped_scores = local_scores + # Compute logprobs + self.static_next_logprob = torch.log_softmax( + self.static_warped_scores, -1 + ) - self.static_scores.copy_(scores) - self.cuda_graph.replay() + self.static_scores.copy_(scores) + self.cuda_graph.replay() - return self.static_warped_scores, self.static_next_logprob + return self.static_warped_scores, self.static_next_logprob + + # CPU branch + for warper in self.warpers: + scores = warper(None, scores) + return scores, torch.log_softmax(scores, -1) @lru_cache(10) From c9c65ab323f48731e1fc2f7087547a7bd8b753f2 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 20 Jun 2023 18:03:36 +0200 Subject: [PATCH 18/22] fix(server): Fixing T5 in case the names are mixed up. (#475) --- .../models/custom_modeling/t5_modeling.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/t5_modeling.py b/server/text_generation_server/models/custom_modeling/t5_modeling.py index 51862e3c..12679e9d 100644 --- a/server/text_generation_server/models/custom_modeling/t5_modeling.py +++ b/server/text_generation_server/models/custom_modeling/t5_modeling.py @@ -1001,7 +1001,10 @@ class T5ForConditionalGeneration(T5PreTrainedModel): super().__init__(config) self.model_dim = config.d_model - self.shared = TensorParallelEmbedding(prefix="shared", weights=weights) + try: + self.shared = TensorParallelEmbedding(prefix="shared", weights=weights) + except RuntimeError: + self.shared = TensorParallelEmbedding(prefix="encoder.embed_tokens", weights=weights) encoder_config = copy.deepcopy(config) encoder_config.is_decoder = False From 49b4b33e805d0ffee62688fe2607120b0c759e3d Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 23 Jun 2023 12:40:46 +0200 Subject: [PATCH 19/22] feat(server): Update convert logic. (#483) Should be more robust to shared tensors (ok when using `from_pretrained). But forcing us to add new checks in our loading code (since the chosen key to keep might be different from `transformers`). --------- Co-authored-by: Ubuntu --- .../models/flash_santacoder.py | 3 +- .../text_generation_server/utils/convert.py | 89 ++++++------------- .../text_generation_server/utils/weights.py | 22 +++-- 3 files changed, 46 insertions(+), 68 deletions(-) diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 54634e4a..a71c0061 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -52,7 +52,8 @@ class FlashSantacoderSharded(FlashCausalLM): torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group + filenames, device=device, dtype=dtype, process_group=self.process_group, + aliases = {"transformer.wte.weight": ["lm_head.weight"]} ) model = FlashSantacoderForCausalLM(config, weights) diff --git a/server/text_generation_server/utils/convert.py b/server/text_generation_server/utils/convert.py index c4e79432..0e4adaba 100644 --- a/server/text_generation_server/utils/convert.py +++ b/server/text_generation_server/utils/convert.py @@ -1,76 +1,45 @@ import datetime import torch +import os -from collections import defaultdict from loguru import logger from pathlib import Path -from safetensors.torch import save_file -from safetensors import safe_open -from typing import Dict, List - - -def check_file_size(source_file: Path, target_file: Path): - """ - Check that two files are close in size - """ - source_file_size = source_file.stat().st_size - target_file_size = target_file.stat().st_size - - if (source_file_size - target_file_size) / source_file_size > 0.05: - raise RuntimeError( - f"""The file size different is more than 5%: - - {source_file}: {source_file_size} - - {target_file}: {target_file_size} - """ - ) - - -def remove_shared_pointers(tensors: Dict[str, torch.Tensor]): - """ - For a Dict of tensors, check if two or more tensors point to the same underlying memory and - remove them - """ - ptrs = defaultdict(list) - for k, v in tensors.items(): - ptrs[v.data_ptr()].append(k) - - # Iterate over all found memory addresses - for ptr, names in ptrs.items(): - if len(names) > 1: - # Multiple tensors are point to the same memory - # Only keep the first tensor - for name in names[1:]: - tensors.pop(name) +from safetensors.torch import save_file, _remove_duplicate_names, load_file +from typing import List def convert_file(pt_file: Path, sf_file: Path): """ Convert a pytorch file to a safetensors file + This will remove duplicate tensors from the file. + + Unfortunately, this might not respect *transformers* convention. + Forcing us to check for potentially different keys during load when looking + for specific tensors (making tensor sharing explicit). """ - logger.info(f"Convert {pt_file} to {sf_file}.") + loaded = torch.load(pt_file, map_location="cpu") + if "state_dict" in loaded: + loaded = loaded["state_dict"] + to_removes = _remove_duplicate_names(loaded) - pt_state = torch.load(pt_file, map_location="cpu") - if "state_dict" in pt_state: - pt_state = pt_state["state_dict"] + metadata = {"format": "pt"} + for kept_name, to_remove_group in to_removes.items(): + for to_remove in to_remove_group: + if to_remove not in metadata: + metadata[to_remove] = kept_name + del loaded[to_remove] + # Force tensors to be contiguous + loaded = {k: v.contiguous() for k, v in loaded.items()} - remove_shared_pointers(pt_state) - - # Tensors need to be contiguous - pt_state = {k: v.contiguous() for k, v in pt_state.items()} - - sf_file.parent.mkdir(parents=True, exist_ok=True) - save_file(pt_state, str(sf_file), metadata={"format": "pt"}) - - # Check that both files are close in size - check_file_size(pt_file, sf_file) - - # Load safetensors state - for k in pt_state: - pt_tensor = pt_state[k] - with safe_open(sf_file, framework="pt") as f: - sf_tensor = f.get_tensor(k) - if not torch.equal(pt_tensor, sf_tensor): - raise RuntimeError(f"The output tensors do not match for key {k}") + dirname = os.path.dirname(sf_file) + os.makedirs(dirname, exist_ok=True) + save_file(loaded, sf_file, metadata=metadata) + reloaded = load_file(sf_file) + for k in loaded: + pt_tensor = loaded[k] + sf_tensor = reloaded[k] + if not torch.equal(pt_tensor, sf_tensor): + raise RuntimeError(f"The output tensors do not match for key {k}") def convert_files(pt_files: List[Path], sf_files: List[Path]): diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 76a4f65a..88347a6a 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -1,10 +1,10 @@ from pathlib import Path -from typing import List +from typing import List, Dict, Optional from safetensors import safe_open class Weights: - def __init__(self, filenames: List[Path], device, dtype, process_group): + def __init__(self, filenames: List[Path], device, dtype, process_group, aliases: Optional[Dict[str, List[str]]]=None): routing = {} for filename in filenames: with safe_open(filename, framework="pytorch") as f: @@ -14,6 +14,9 @@ class Weights: f"Key {k} was found in multiple files: {filename} and {routing[k]}" ) routing[k] = filename + if aliases is None: + aliases = {} + self.aliases = aliases self.routing = routing self.device = device self.dtype = dtype @@ -27,14 +30,19 @@ class Weights: return self._handles[filename] - def get_filename(self, tensor_name: str) -> str: + def get_filename(self, tensor_name: str) -> (str, str): filename = self.routing.get(tensor_name, None) if filename is None: + aliases = self.aliases.get(tensor_name, []) + for alias in aliases: + filename = self.routing.get(alias, None) + if filename is not None: + return str(filename), alias raise RuntimeError(f"weight {tensor_name} does not exist") - return str(filename) + return str(filename), tensor_name def _get_slice(self, tensor_name: str): - filename = self.get_filename(tensor_name) + filename, tensor_name= self.get_filename(tensor_name) f = self._get_handle(filename) slice_ = f.get_slice(tensor_name) return slice_ @@ -43,7 +51,7 @@ class Weights: return self._get_slice(tensor_name).get_shape() def get_tensor(self, tensor_name: str): - filename = self.get_filename(tensor_name) + filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) tensor = f.get_tensor(tensor_name) tensor = tensor.to(dtype=self.dtype) @@ -51,7 +59,7 @@ class Weights: return tensor def get_sharded(self, tensor_name: str, dim: int): - filename = self.get_filename(tensor_name) + filename, tensor_name = self.get_filename(tensor_name) world_size = self.process_group.size() rank = self.process_group.rank() From 776d150c55d3d973a457a80b05440d593942752a Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 23 Jun 2023 12:41:13 +0200 Subject: [PATCH 20/22] feat(server): Adding new ignore_rule for conversion. (#485) --- server/text_generation_server/utils/hub.py | 1 + 1 file changed, 1 insertion(+) diff --git a/server/text_generation_server/utils/hub.py b/server/text_generation_server/utils/hub.py index fbb570a6..23743c9b 100644 --- a/server/text_generation_server/utils/hub.py +++ b/server/text_generation_server/utils/hub.py @@ -30,6 +30,7 @@ def weight_hub_files( and len(s.rfilename.split("/")) == 1 and "arguments" not in s.rfilename and "args" not in s.rfilename + and "training" not in s.rfilename ] if not filenames: From bd3a9d8e856cb7e2122f1a09d2fb0f44b7649dad Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Fri, 23 Jun 2023 14:58:28 +0200 Subject: [PATCH 21/22] fix(router): add timeout on flume sends (#488) --- router/src/infer.rs | 45 ++++++++++++++++++++++++++++++--------------- router/src/queue.rs | 2 +- 2 files changed, 31 insertions(+), 16 deletions(-) diff --git a/router/src/infer.rs b/router/src/infer.rs index 00fa2818..f738f986 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -3,7 +3,7 @@ use crate::validation::{Validation, ValidationError}; use crate::{Entry, Queue, Token}; use crate::{GenerateRequest, PrefillToken}; use flume::r#async::RecvStream; -use flume::SendError; +use flume::SendTimeoutError; use futures::future::try_join_all; use futures::stream::StreamExt; use nohash_hasher::IntMap; @@ -11,6 +11,7 @@ use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, }; +use std::time::Duration; use text_generation_client::{ Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient, }; @@ -472,6 +473,10 @@ fn filter_send_generations(generations: Vec, entries: &mut IntMap "dropped"); err }).unwrap_or(true); @@ -485,14 +490,20 @@ fn filter_send_generations(generations: Vec, entries: &mut IntMap Result>> { +) -> Result>>> { + // Return directly if the channel is disconnected + if entry.response_tx.is_disconnected() { + return Ok(true); + } + let mut stopped = false; if let Some(prefill_tokens) = generation.prefill_tokens { // Send message - entry - .response_tx - .send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?; + entry.response_tx.send_timeout( + Ok(InferStreamResponse::Prefill(prefill_tokens)), + Duration::from_millis(10), + )?; } // Create last Token @@ -507,17 +518,21 @@ fn send_responses( // Generation has ended stopped = true; // Send message - entry.response_tx.send(Ok(InferStreamResponse::End { - token, - generated_text, - queued: entry.queue_time, - start: entry.batch_time.unwrap(), - }))?; + entry.response_tx.send_timeout( + Ok(InferStreamResponse::End { + token, + generated_text, + queued: entry.queue_time, + start: entry.batch_time.unwrap(), + }), + Duration::from_millis(10), + )?; } else { // Send message - entry - .response_tx - .send(Ok(InferStreamResponse::Token(token)))?; + entry.response_tx.send_timeout( + Ok(InferStreamResponse::Token(token)), + Duration::from_millis(10), + )?; } Ok(stopped) } @@ -535,7 +550,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap) { // unwrap_or is valid here as we don't care if the receiver is gone. entry .response_tx - .send(Err(err)) + .send_timeout(Err(err), Duration::from_millis(10)) .unwrap_or(()); }); } diff --git a/router/src/queue.rs b/router/src/queue.rs index 0586083d..6d1d4d12 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -95,7 +95,7 @@ async fn queue_task(requires_padding: bool, receiver: flume::Receiver span.in_scope(|| { let next_batch = state.next_batch(min_size, token_budget); - response_sender.send(next_batch).unwrap_or(()); + response_sender.send(next_batch).unwrap(); metrics::gauge!("tgi_queue_size", state.entries.len() as f64); }), } From aefde28b45cad8c2e93f91f690ecc60eee7bd75c Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 26 Jun 2023 12:27:01 +0200 Subject: [PATCH 22/22] feat(server): Add inference support for GPTQ (llama + falcon tested) + Quantization script (#438) Let's start discussing implementation. - Need to expose the quantization scripts (either included here or add doc on how to use https://github.com/qwopqwop200/GPTQ-for-LLaMa) - Make sure GPTQ works for multiple models (priority to Falcon). Currently it means that every place we use `get_{tensor|sharded}` to check for quantization. My idea is to reintegrate as much as possible into `utils/layer.py` by expanding `load_multi` to be a bit more generic. This might require some thinking, but ultimately the `qweight,qzeros,scales,g_idx` should be in a single place, and independant of bias presence. # What does this PR do? Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --------- Co-authored-by: Ubuntu Co-authored-by: OlivierDehaene --- Dockerfile | 5 + server/poetry.lock | 1964 +++++++++-------- server/requirements.txt | 29 +- server/text_generation_server/cli.py | 32 + .../text_generation_server/models/__init__.py | 4 + .../custom_modeling/flash_neox_modeling.py | 29 +- .../custom_modeling/flash_rw_modeling.py | 3 +- .../flash_santacoder_modeling.py | 81 +- .../models/flash_llama.py | 25 +- .../utils/gptq/custom_autotune.py | 261 +++ .../utils/gptq/quant_linear.py | 359 +++ .../utils/gptq/quantize.py | 866 ++++++++ server/text_generation_server/utils/layers.py | 47 +- .../text_generation_server/utils/weights.py | 51 +- 14 files changed, 2776 insertions(+), 980 deletions(-) create mode 100644 server/text_generation_server/utils/gptq/custom_autotune.py create mode 100644 server/text_generation_server/utils/gptq/quant_linear.py create mode 100644 server/text_generation_server/utils/gptq/quantize.py diff --git a/Dockerfile b/Dockerfile index 576dab8d..2a313c25 100644 --- a/Dockerfile +++ b/Dockerfile @@ -159,6 +159,11 @@ COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bi # Install launcher COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + build-essential \ + g++ \ + && rm -rf /var/lib/apt/lists/* + # AWS Sagemaker compatbile image FROM base as sagemaker diff --git a/server/poetry.lock b/server/poetry.lock index 5d853ce2..9a6900bc 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -1,10 +1,15 @@ +# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. + [[package]] name = "accelerate" version = "0.19.0" description = "Accelerate" -category = "main" optional = true python-versions = ">=3.7.0" +files = [ + {file = "accelerate-0.19.0-py3-none-any.whl", hash = "sha256:2866b0bf9fff08f51e6384c95fa96725838b70f1988d1cce42e56b820d8a91dd"}, + {file = "accelerate-0.19.0.tar.gz", hash = "sha256:84920226b9e642e453ef37593ee55b956b08d8200dea4087c546c34e26157e76"}, +] [package.dependencies] numpy = ">=1.17" @@ -18,737 +23,51 @@ dev = ["black (>=23.1,<24.0)", "datasets", "deepspeed", "evaluate", "hf-doc-buil quality = ["black (>=23.1,<24.0)", "hf-doc-builder (>=0.3.0)", "ruff (>=0.0.241)", "urllib3 (<2.0.0)"] rich = ["rich"] sagemaker = ["sagemaker"] -test_dev = ["datasets", "deepspeed", "evaluate", "scikit-learn", "scipy", "tqdm", "transformers"] -test_prod = ["parameterized", "pytest", "pytest-subtests", "pytest-xdist"] -test_trackers = ["comet-ml", "tensorboard", "wandb"] +test-dev = ["datasets", "deepspeed", "evaluate", "scikit-learn", "scipy", "tqdm", "transformers"] +test-prod = ["parameterized", "pytest", "pytest-subtests", "pytest-xdist"] +test-trackers = ["comet-ml", "tensorboard", "wandb"] testing = ["datasets", "deepspeed", "evaluate", "parameterized", "pytest", "pytest-subtests", "pytest-xdist", "scikit-learn", "scipy", "tqdm", "transformers"] [[package]] name = "backoff" version = "2.2.1" description = "Function decoration for backoff and retry" -category = "main" optional = false python-versions = ">=3.7,<4.0" +files = [ + {file = "backoff-2.2.1-py3-none-any.whl", hash = "sha256:63579f9a0628e06278f7e47b7d7d5b6ce20dc65c5e96a6f3ca99a6adca0396e8"}, + {file = "backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba"}, +] [[package]] name = "bitsandbytes" version = "0.38.1" description = "8-bit optimizers and matrix multiplication routines." -category = "main" optional = true python-versions = "*" +files = [ + {file = "bitsandbytes-0.38.1-py3-none-any.whl", hash = "sha256:5f532e7b1353eb7049ae831da2eb62ed8a1e0444116bd51b9e088a6e0bc7a34a"}, + {file = "bitsandbytes-0.38.1.tar.gz", hash = "sha256:ba95a806b5065ea3263558e188f07eacb32ad691842932fb0d36a879883167ce"}, +] [[package]] name = "certifi" version = "2023.5.7" description = "Python package for providing Mozilla's CA Bundle." -category = "main" optional = false python-versions = ">=3.6" +files = [ + {file = "certifi-2023.5.7-py3-none-any.whl", hash = "sha256:c6c2e98f5c7869efca1f8916fed228dd91539f9f1b444c314c06eef02980c716"}, + {file = "certifi-2023.5.7.tar.gz", hash = "sha256:0f0d56dc5a6ad56fd4ba36484d6cc34451e1c6548c61daad8c320169f91eddc7"}, +] [[package]] name = "charset-normalizer" version = "3.1.0" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." -category = "main" optional = false python-versions = ">=3.7.0" - -[[package]] -name = "click" -version = "8.1.3" -description = "Composable command line interface toolkit" -category = "main" -optional = false -python-versions = ">=3.7" - -[package.dependencies] -colorama = {version = "*", markers = "platform_system == \"Windows\""} - -[[package]] -name = "colorama" -version = "0.4.6" -description = "Cross-platform colored terminal text." -category = "main" -optional = false -python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" - -[[package]] -name = "Deprecated" -version = "1.2.13" -description = "Python @deprecated decorator to deprecate old python classes, functions or methods." -category = "main" -optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" - -[package.dependencies] -wrapt = ">=1.10,<2" - -[package.extras] -dev = ["PyTest", "PyTest (<5)", "PyTest-Cov", "PyTest-Cov (<2.6)", "bump2version (<1)", "configparser (<5)", "importlib-metadata (<3)", "importlib-resources (<4)", "sphinx (<2)", "sphinxcontrib-websupport (<2)", "tox", "zipp (<2)"] - -[[package]] -name = "exceptiongroup" -version = "1.1.1" -description = "Backport of PEP 654 (exception groups)" -category = "dev" -optional = false -python-versions = ">=3.7" - -[package.extras] -test = ["pytest (>=6)"] - -[[package]] -name = "filelock" -version = "3.12.0" -description = "A platform independent file lock." -category = "main" -optional = false -python-versions = ">=3.7" - -[package.extras] -docs = ["furo (>=2023.3.27)", "sphinx (>=6.1.3)", "sphinx-autodoc-typehints (>=1.23,!=1.23.4)"] -testing = ["covdefaults (>=2.3)", "coverage (>=7.2.3)", "diff-cover (>=7.5)", "pytest (>=7.3.1)", "pytest-cov (>=4)", "pytest-mock (>=3.10)", "pytest-timeout (>=2.1)"] - -[[package]] -name = "fsspec" -version = "2023.5.0" -description = "File-system specification" -category = "main" -optional = false -python-versions = ">=3.8" - -[package.extras] -abfs = ["adlfs"] -adl = ["adlfs"] -arrow = ["pyarrow (>=1)"] -dask = ["dask", "distributed"] -devel = ["pytest", "pytest-cov"] -dropbox = ["dropbox", "dropboxdrivefs", "requests"] -full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "dask", "distributed", "dropbox", "dropboxdrivefs", "fusepy", "gcsfs", "libarchive-c", "ocifs", "panel", "paramiko", "pyarrow (>=1)", "pygit2", "requests", "s3fs", "smbprotocol", "tqdm"] -fuse = ["fusepy"] -gcs = ["gcsfs"] -git = ["pygit2"] -github = ["requests"] -gs = ["gcsfs"] -gui = ["panel"] -hdfs = ["pyarrow (>=1)"] -http = ["aiohttp (!=4.0.0a0,!=4.0.0a1)", "requests"] -libarchive = ["libarchive-c"] -oci = ["ocifs"] -s3 = ["s3fs"] -sftp = ["paramiko"] -smb = ["smbprotocol"] -ssh = ["paramiko"] -tqdm = ["tqdm"] - -[[package]] -name = "googleapis-common-protos" -version = "1.59.0" -description = "Common protobufs used in Google APIs" -category = "main" -optional = false -python-versions = ">=3.7" - -[package.dependencies] -protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev" - -[package.extras] -grpc = ["grpcio (>=1.44.0,<2.0.0dev)"] - -[[package]] -name = "grpc-interceptor" -version = "0.15.2" -description = "Simplifies gRPC interceptors" -category = "main" -optional = false -python-versions = ">=3.7,<4.0" - -[package.dependencies] -grpcio = ">=1.49.1,<2.0.0" - -[package.extras] -testing = ["protobuf (>=4.21.9)"] - -[[package]] -name = "grpcio" -version = "1.55.0" -description = "HTTP/2-based RPC framework" -category = "main" -optional = false -python-versions = ">=3.7" - -[package.extras] -protobuf = ["grpcio-tools (>=1.55.0)"] - -[[package]] -name = "grpcio-reflection" -version = "1.55.0" -description = "Standard Protobuf Reflection Service for gRPC" -category = "main" -optional = false -python-versions = ">=3.6" - -[package.dependencies] -grpcio = ">=1.55.0" -protobuf = ">=4.21.6" - -[[package]] -name = "grpcio-status" -version = "1.55.0" -description = "Status proto mapping for gRPC" -category = "main" -optional = false -python-versions = ">=3.6" - -[package.dependencies] -googleapis-common-protos = ">=1.5.5" -grpcio = ">=1.55.0" -protobuf = ">=4.21.6" - -[[package]] -name = "grpcio-tools" -version = "1.55.0" -description = "Protobuf code generator for gRPC" -category = "dev" -optional = false -python-versions = ">=3.7" - -[package.dependencies] -grpcio = ">=1.55.0" -protobuf = ">=4.21.6,<5.0dev" -setuptools = "*" - -[[package]] -name = "hf-transfer" -version = "0.1.3" -description = "" -category = "main" -optional = false -python-versions = ">=3.7" - -[[package]] -name = "huggingface-hub" -version = "0.14.0" -description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" -category = "main" -optional = false -python-versions = ">=3.7.0" - -[package.dependencies] -filelock = "*" -fsspec = "*" -packaging = ">=20.9" -pyyaml = ">=5.1" -requests = "*" -tqdm = ">=4.42.1" -typing-extensions = ">=3.7.4.3" - -[package.extras] -all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "black (>=23.1,<24.0)", "gradio", "jedi", "mypy (==0.982)", "pytest", "pytest-cov", "pytest-env", "pytest-xdist", "ruff (>=0.0.241)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"] -cli = ["InquirerPy (==0.3.4)"] -dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "black (>=23.1,<24.0)", "gradio", "jedi", "mypy (==0.982)", "pytest", "pytest-cov", "pytest-env", "pytest-xdist", "ruff (>=0.0.241)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"] -fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"] -quality = ["black (>=23.1,<24.0)", "mypy (==0.982)", "ruff (>=0.0.241)"] -tensorflow = ["graphviz", "pydot", "tensorflow"] -testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "gradio", "jedi", "pytest", "pytest-cov", "pytest-env", "pytest-xdist", "soundfile"] -torch = ["torch"] -typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"] - -[[package]] -name = "idna" -version = "3.4" -description = "Internationalized Domain Names in Applications (IDNA)" -category = "main" -optional = false -python-versions = ">=3.5" - -[[package]] -name = "iniconfig" -version = "2.0.0" -description = "brain-dead simple config-ini parsing" -category = "dev" -optional = false -python-versions = ">=3.7" - -[[package]] -name = "Jinja2" -version = "3.1.2" -description = "A very fast and expressive template engine." -category = "main" -optional = true -python-versions = ">=3.7" - -[package.dependencies] -MarkupSafe = ">=2.0" - -[package.extras] -i18n = ["Babel (>=2.7)"] - -[[package]] -name = "loguru" -version = "0.6.0" -description = "Python logging made (stupidly) simple" -category = "main" -optional = false -python-versions = ">=3.5" - -[package.dependencies] -colorama = {version = ">=0.3.4", markers = "sys_platform == \"win32\""} -win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""} - -[package.extras] -dev = ["Sphinx (>=4.1.1)", "black (>=19.10b0)", "colorama (>=0.3.4)", "docutils (==0.16)", "flake8 (>=3.7.7)", "isort (>=5.1.1)", "pytest (>=4.6.2)", "pytest-cov (>=2.7.1)", "sphinx-autobuild (>=0.7.1)", "sphinx-rtd-theme (>=0.4.3)", "tox (>=3.9.0)"] - -[[package]] -name = "MarkupSafe" -version = "2.1.2" -description = "Safely add untrusted strings to HTML/XML markup." -category = "main" -optional = true -python-versions = ">=3.7" - -[[package]] -name = "mpmath" -version = "1.3.0" -description = "Python library for arbitrary-precision floating-point arithmetic" -category = "main" -optional = true -python-versions = "*" - -[package.extras] -develop = ["codecov", "pycodestyle", "pytest (>=4.6)", "pytest-cov", "wheel"] -docs = ["sphinx"] -gmpy = ["gmpy2 (>=2.1.0a4)"] -tests = ["pytest (>=4.6)"] - -[[package]] -name = "networkx" -version = "3.1" -description = "Python package for creating and manipulating graphs and networks" -category = "main" -optional = true -python-versions = ">=3.8" - -[package.extras] -default = ["matplotlib (>=3.4)", "numpy (>=1.20)", "pandas (>=1.3)", "scipy (>=1.8)"] -developer = ["mypy (>=1.1)", "pre-commit (>=3.2)"] -doc = ["nb2plots (>=0.6)", "numpydoc (>=1.5)", "pillow (>=9.4)", "pydata-sphinx-theme (>=0.13)", "sphinx (>=6.1)", "sphinx-gallery (>=0.12)", "texext (>=0.6.7)"] -extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.10)", "sympy (>=1.10)"] -test = ["codecov (>=2.1)", "pytest (>=7.2)", "pytest-cov (>=4.0)"] - -[[package]] -name = "numpy" -version = "1.24.3" -description = "Fundamental package for array computing in Python" -category = "main" -optional = true -python-versions = ">=3.8" - -[[package]] -name = "opentelemetry-api" -version = "1.15.0" -description = "OpenTelemetry Python API" -category = "main" -optional = false -python-versions = ">=3.7" - -[package.dependencies] -deprecated = ">=1.2.6" -setuptools = ">=16.0" - -[[package]] -name = "opentelemetry-exporter-otlp" -version = "1.15.0" -description = "OpenTelemetry Collector Exporters" -category = "main" -optional = false -python-versions = ">=3.7" - -[package.dependencies] -opentelemetry-exporter-otlp-proto-grpc = "1.15.0" -opentelemetry-exporter-otlp-proto-http = "1.15.0" - -[[package]] -name = "opentelemetry-exporter-otlp-proto-grpc" -version = "1.15.0" -description = "OpenTelemetry Collector Protobuf over gRPC Exporter" -category = "main" -optional = false -python-versions = ">=3.7" - -[package.dependencies] -backoff = {version = ">=1.10.0,<3.0.0", markers = "python_version >= \"3.7\""} -googleapis-common-protos = ">=1.52,<2.0" -grpcio = ">=1.0.0,<2.0.0" -opentelemetry-api = ">=1.12,<2.0" -opentelemetry-proto = "1.15.0" -opentelemetry-sdk = ">=1.12,<2.0" - -[package.extras] -test = ["pytest-grpc"] - -[[package]] -name = "opentelemetry-exporter-otlp-proto-http" -version = "1.15.0" -description = "OpenTelemetry Collector Protobuf over HTTP Exporter" -category = "main" -optional = false -python-versions = ">=3.7" - -[package.dependencies] -backoff = {version = ">=1.10.0,<3.0.0", markers = "python_version >= \"3.7\""} -googleapis-common-protos = ">=1.52,<2.0" -opentelemetry-api = ">=1.12,<2.0" -opentelemetry-proto = "1.15.0" -opentelemetry-sdk = ">=1.12,<2.0" -requests = ">=2.7,<3.0" - -[package.extras] -test = ["responses (==0.22.0)"] - -[[package]] -name = "opentelemetry-instrumentation" -version = "0.36b0" -description = "Instrumentation Tools & Auto Instrumentation for OpenTelemetry Python" -category = "main" -optional = false -python-versions = ">=3.7" - -[package.dependencies] -opentelemetry-api = ">=1.4,<2.0" -setuptools = ">=16.0" -wrapt = ">=1.0.0,<2.0.0" - -[[package]] -name = "opentelemetry-instrumentation-grpc" -version = "0.36b0" -description = "OpenTelemetry gRPC instrumentation" -category = "main" -optional = false -python-versions = ">=3.7" - -[package.dependencies] -opentelemetry-api = ">=1.12,<2.0" -opentelemetry-instrumentation = "0.36b0" -opentelemetry-sdk = ">=1.12,<2.0" -opentelemetry-semantic-conventions = "0.36b0" -wrapt = ">=1.0.0,<2.0.0" - -[package.extras] -instruments = ["grpcio (>=1.27,<2.0)"] -test = ["opentelemetry-instrumentation-grpc[instruments]", "opentelemetry-sdk (>=1.12,<2.0)", "opentelemetry-test-utils (==0.36b0)", "protobuf (>=3.13,<4.0)"] - -[[package]] -name = "opentelemetry-proto" -version = "1.15.0" -description = "OpenTelemetry Python Proto" -category = "main" -optional = false -python-versions = ">=3.7" - -[package.dependencies] -protobuf = ">=3.19,<5.0" - -[[package]] -name = "opentelemetry-sdk" -version = "1.15.0" -description = "OpenTelemetry Python SDK" -category = "main" -optional = false -python-versions = ">=3.7" - -[package.dependencies] -opentelemetry-api = "1.15.0" -opentelemetry-semantic-conventions = "0.36b0" -setuptools = ">=16.0" -typing-extensions = ">=3.7.4" - -[[package]] -name = "opentelemetry-semantic-conventions" -version = "0.36b0" -description = "OpenTelemetry Semantic Conventions" -category = "main" -optional = false -python-versions = ">=3.7" - -[[package]] -name = "packaging" -version = "23.1" -description = "Core utilities for Python packages" -category = "main" -optional = false -python-versions = ">=3.7" - -[[package]] -name = "pluggy" -version = "1.0.0" -description = "plugin and hook calling mechanisms for python" -category = "dev" -optional = false -python-versions = ">=3.6" - -[package.extras] -dev = ["pre-commit", "tox"] -testing = ["pytest", "pytest-benchmark"] - -[[package]] -name = "protobuf" -version = "4.23.1" -description = "" -category = "main" -optional = false -python-versions = ">=3.7" - -[[package]] -name = "psutil" -version = "5.9.5" -description = "Cross-platform lib for process and system monitoring in Python." -category = "main" -optional = true -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" - -[package.extras] -test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] - -[[package]] -name = "pytest" -version = "7.3.1" -description = "pytest: simple powerful testing with Python" -category = "dev" -optional = false -python-versions = ">=3.7" - -[package.dependencies] -colorama = {version = "*", markers = "sys_platform == \"win32\""} -exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} -iniconfig = "*" -packaging = "*" -pluggy = ">=0.12,<2.0" -tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} - -[package.extras] -testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "xmlschema"] - -[[package]] -name = "PyYAML" -version = "6.0" -description = "YAML parser and emitter for Python" -category = "main" -optional = false -python-versions = ">=3.6" - -[[package]] -name = "requests" -version = "2.31.0" -description = "Python HTTP for Humans." -category = "main" -optional = false -python-versions = ">=3.7" - -[package.dependencies] -certifi = ">=2017.4.17" -charset-normalizer = ">=2,<4" -idna = ">=2.5,<4" -urllib3 = ">=1.21.1,<3" - -[package.extras] -socks = ["PySocks (>=1.5.6,!=1.5.7)"] -use_chardet_on_py3 = ["chardet (>=3.0.2,<6)"] - -[[package]] -name = "safetensors" -version = "0.3.1" -description = "Fast and Safe Tensor serialization" -category = "main" -optional = false -python-versions = "*" - -[package.extras] -all = ["black (==22.3)", "click (==8.0.4)", "flake8 (>=3.8.3)", "flax (>=0.6.3)", "h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "isort (>=5.5.4)", "jax (>=0.3.25)", "jaxlib (>=0.3.25)", "numpy (>=1.21.6)", "paddlepaddle (>=2.4.1)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "setuptools-rust (>=1.5.2)", "tensorflow (>=2.11.0)", "torch (>=1.10)"] -dev = ["black (==22.3)", "click (==8.0.4)", "flake8 (>=3.8.3)", "flax (>=0.6.3)", "h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "isort (>=5.5.4)", "jax (>=0.3.25)", "jaxlib (>=0.3.25)", "numpy (>=1.21.6)", "paddlepaddle (>=2.4.1)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "setuptools-rust (>=1.5.2)", "tensorflow (>=2.11.0)", "torch (>=1.10)"] -jax = ["flax (>=0.6.3)", "jax (>=0.3.25)", "jaxlib (>=0.3.25)"] -numpy = ["numpy (>=1.21.6)"] -paddlepaddle = ["paddlepaddle (>=2.4.1)"] -quality = ["black (==22.3)", "click (==8.0.4)", "flake8 (>=3.8.3)", "isort (>=5.5.4)"] -tensorflow = ["tensorflow (>=2.11.0)"] -testing = ["h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "numpy (>=1.21.6)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "setuptools-rust (>=1.5.2)"] -torch = ["torch (>=1.10)"] - -[[package]] -name = "sentencepiece" -version = "0.1.99" -description = "SentencePiece python wrapper" -category = "main" -optional = false -python-versions = "*" - -[[package]] -name = "setuptools" -version = "67.8.0" -description = "Easily download, build, install, upgrade, and uninstall Python packages" -category = "main" -optional = false -python-versions = ">=3.7" - -[package.extras] -docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-hoverxref (<2)", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (==0.8.3)", "sphinx-reredirects", "sphinxcontrib-towncrier"] -testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pip-run (>=8.8)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] -testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] - -[[package]] -name = "sympy" -version = "1.12" -description = "Computer algebra system (CAS) in Python" -category = "main" -optional = true -python-versions = ">=3.8" - -[package.dependencies] -mpmath = ">=0.19" - -[[package]] -name = "tokenizers" -version = "0.13.3" -description = "Fast and Customizable Tokenizers" -category = "main" -optional = false -python-versions = "*" - -[package.extras] -dev = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"] -docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"] -testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"] - -[[package]] -name = "tomli" -version = "2.0.1" -description = "A lil' TOML parser" -category = "dev" -optional = false -python-versions = ">=3.7" - -[[package]] -name = "torch" -version = "2.0.1" -description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" -category = "main" -optional = true -python-versions = ">=3.8.0" - -[package.dependencies] -filelock = "*" -jinja2 = "*" -networkx = "*" -sympy = "*" -typing-extensions = "*" - -[package.extras] -opt-einsum = ["opt-einsum (>=3.3)"] - -[[package]] -name = "tqdm" -version = "4.65.0" -description = "Fast, Extensible Progress Meter" -category = "main" -optional = false -python-versions = ">=3.7" - -[package.dependencies] -colorama = {version = "*", markers = "platform_system == \"Windows\""} - -[package.extras] -dev = ["py-make (>=0.1.0)", "twine", "wheel"] -notebook = ["ipywidgets (>=6)"] -slack = ["slack-sdk"] -telegram = ["requests"] - -[[package]] -name = "typer" -version = "0.6.1" -description = "Typer, build great CLIs. Easy to code. Based on Python type hints." -category = "main" -optional = false -python-versions = ">=3.6" - -[package.dependencies] -click = ">=7.1.1,<9.0.0" - -[package.extras] -all = ["colorama (>=0.4.3,<0.5.0)", "rich (>=10.11.0,<13.0.0)", "shellingham (>=1.3.0,<2.0.0)"] -dev = ["autoflake (>=1.3.1,<2.0.0)", "flake8 (>=3.8.3,<4.0.0)", "pre-commit (>=2.17.0,<3.0.0)"] -doc = ["mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-material (>=8.1.4,<9.0.0)"] -test = ["black (>=22.3.0,<23.0.0)", "coverage (>=5.2,<6.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.910)", "pytest (>=4.4.0,<5.4.0)", "pytest-cov (>=2.10.0,<3.0.0)", "pytest-sugar (>=0.9.4,<0.10.0)", "pytest-xdist (>=1.32.0,<2.0.0)", "rich (>=10.11.0,<13.0.0)", "shellingham (>=1.3.0,<2.0.0)"] - -[[package]] -name = "typing-extensions" -version = "4.6.0" -description = "Backported and Experimental Type Hints for Python 3.7+" -category = "main" -optional = false -python-versions = ">=3.7" - -[[package]] -name = "urllib3" -version = "2.0.2" -description = "HTTP library with thread-safe connection pooling, file post, and more." -category = "main" -optional = false -python-versions = ">=3.7" - -[package.extras] -brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] -secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17.1.0)", "urllib3-secure-extra"] -socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] -zstd = ["zstandard (>=0.18.0)"] - -[[package]] -name = "win32-setctime" -version = "1.1.0" -description = "A small Python utility to set file creation time on Windows" -category = "main" -optional = false -python-versions = ">=3.5" - -[package.extras] -dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"] - -[[package]] -name = "wrapt" -version = "1.15.0" -description = "Module for decorators, wrappers and monkey patching." -category = "main" -optional = false -python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" - -[extras] -accelerate = ["accelerate"] -bnb = ["bitsandbytes"] - -[metadata] -lock-version = "1.1" -python-versions = "^3.9" -content-hash = "152c8b82717e2b802aee3427152ef3e37417fb95b73c3af2f448a381f51a6a8d" - -[metadata.files] -accelerate = [ - {file = "accelerate-0.19.0-py3-none-any.whl", hash = "sha256:2866b0bf9fff08f51e6384c95fa96725838b70f1988d1cce42e56b820d8a91dd"}, - {file = "accelerate-0.19.0.tar.gz", hash = "sha256:84920226b9e642e453ef37593ee55b956b08d8200dea4087c546c34e26157e76"}, -] -backoff = [ - {file = "backoff-2.2.1-py3-none-any.whl", hash = "sha256:63579f9a0628e06278f7e47b7d7d5b6ce20dc65c5e96a6f3ca99a6adca0396e8"}, - {file = "backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba"}, -] -bitsandbytes = [ - {file = "bitsandbytes-0.38.1-py3-none-any.whl", hash = "sha256:5f532e7b1353eb7049ae831da2eb62ed8a1e0444116bd51b9e088a6e0bc7a34a"}, - {file = "bitsandbytes-0.38.1.tar.gz", hash = "sha256:ba95a806b5065ea3263558e188f07eacb32ad691842932fb0d36a879883167ce"}, -] -certifi = [ - {file = "certifi-2023.5.7-py3-none-any.whl", hash = "sha256:c6c2e98f5c7869efca1f8916fed228dd91539f9f1b444c314c06eef02980c716"}, - {file = "certifi-2023.5.7.tar.gz", hash = "sha256:0f0d56dc5a6ad56fd4ba36484d6cc34451e1c6548c61daad8c320169f91eddc7"}, -] -charset-normalizer = [ +files = [ {file = "charset-normalizer-3.1.0.tar.gz", hash = "sha256:34e0a2f9c370eb95597aae63bf85eb5e96826d81e3dcf88b8886012906f509b5"}, {file = "charset_normalizer-3.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:e0ac8959c929593fee38da1c2b64ee9778733cdf03c482c9ff1d508b6b593b2b"}, {file = "charset_normalizer-3.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d7fc3fca01da18fbabe4625d64bb612b533533ed10045a2ac3dd194bfa656b60"}, @@ -825,141 +144,301 @@ charset-normalizer = [ {file = "charset_normalizer-3.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:830d2948a5ec37c386d3170c483063798d7879037492540f10a475e3fd6f244b"}, {file = "charset_normalizer-3.1.0-py3-none-any.whl", hash = "sha256:3d9098b479e78c85080c98e1e35ff40b4a31d8953102bb0fd7d1b6f8a2111a3d"}, ] -click = [ + +[[package]] +name = "click" +version = "8.1.3" +description = "Composable command line interface toolkit" +optional = false +python-versions = ">=3.7" +files = [ {file = "click-8.1.3-py3-none-any.whl", hash = "sha256:bb4d8133cb15a609f44e8213d9b391b0809795062913b383c62be0ee95b1db48"}, {file = "click-8.1.3.tar.gz", hash = "sha256:7682dc8afb30297001674575ea00d1814d808d6a36af415a82bd481d37ba7b8e"}, ] -colorama = [ + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + +[[package]] +name = "colorama" +version = "0.4.6" +description = "Cross-platform colored terminal text." +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] -Deprecated = [ - {file = "Deprecated-1.2.13-py2.py3-none-any.whl", hash = "sha256:64756e3e14c8c5eea9795d93c524551432a0be75629f8f29e67ab8caf076c76d"}, - {file = "Deprecated-1.2.13.tar.gz", hash = "sha256:43ac5335da90c31c24ba028af536a91d41d53f9e6901ddb021bcc572ce44e38d"}, + +[[package]] +name = "deprecated" +version = "1.2.14" +description = "Python @deprecated decorator to deprecate old python classes, functions or methods." +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "Deprecated-1.2.14-py2.py3-none-any.whl", hash = "sha256:6fac8b097794a90302bdbb17b9b815e732d3c4720583ff1b198499d78470466c"}, + {file = "Deprecated-1.2.14.tar.gz", hash = "sha256:e5323eb936458dccc2582dc6f9c322c852a775a27065ff2b0c4970b9d53d01b3"}, ] -exceptiongroup = [ + +[package.dependencies] +wrapt = ">=1.10,<2" + +[package.extras] +dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"] + +[[package]] +name = "exceptiongroup" +version = "1.1.1" +description = "Backport of PEP 654 (exception groups)" +optional = false +python-versions = ">=3.7" +files = [ {file = "exceptiongroup-1.1.1-py3-none-any.whl", hash = "sha256:232c37c63e4f682982c8b6459f33a8981039e5fb8756b2074364e5055c498c9e"}, {file = "exceptiongroup-1.1.1.tar.gz", hash = "sha256:d484c3090ba2889ae2928419117447a14daf3c1231d5e30d0aae34f354f01785"}, ] -filelock = [ - {file = "filelock-3.12.0-py3-none-any.whl", hash = "sha256:ad98852315c2ab702aeb628412cbf7e95b7ce8c3bf9565670b4eaecf1db370a9"}, - {file = "filelock-3.12.0.tar.gz", hash = "sha256:fc03ae43288c013d2ea83c8597001b1129db351aad9c57fe2409327916b8e718"}, + +[package.extras] +test = ["pytest (>=6)"] + +[[package]] +name = "filelock" +version = "3.12.2" +description = "A platform independent file lock." +optional = false +python-versions = ">=3.7" +files = [ + {file = "filelock-3.12.2-py3-none-any.whl", hash = "sha256:cbb791cdea2a72f23da6ac5b5269ab0a0d161e9ef0100e653b69049a7706d1ec"}, + {file = "filelock-3.12.2.tar.gz", hash = "sha256:002740518d8aa59a26b0c76e10fb8c6e15eae825d34b6fdf670333fd7b938d81"}, ] -fsspec = [ - {file = "fsspec-2023.5.0-py3-none-any.whl", hash = "sha256:51a4ad01a5bb66fcc58036e288c0d53d3975a0df2a5dc59a93b59bade0391f2a"}, - {file = "fsspec-2023.5.0.tar.gz", hash = "sha256:b3b56e00fb93ea321bc9e5d9cf6f8522a0198b20eb24e02774d329e9c6fb84ce"}, + +[package.extras] +docs = ["furo (>=2023.5.20)", "sphinx (>=7.0.1)", "sphinx-autodoc-typehints (>=1.23,!=1.23.4)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "diff-cover (>=7.5)", "pytest (>=7.3.1)", "pytest-cov (>=4.1)", "pytest-mock (>=3.10)", "pytest-timeout (>=2.1)"] + +[[package]] +name = "fsspec" +version = "2023.6.0" +description = "File-system specification" +optional = false +python-versions = ">=3.8" +files = [ + {file = "fsspec-2023.6.0-py3-none-any.whl", hash = "sha256:1cbad1faef3e391fba6dc005ae9b5bdcbf43005c9167ce78c915549c352c869a"}, + {file = "fsspec-2023.6.0.tar.gz", hash = "sha256:d0b2f935446169753e7a5c5c55681c54ea91996cc67be93c39a154fb3a2742af"}, ] -googleapis-common-protos = [ - {file = "googleapis-common-protos-1.59.0.tar.gz", hash = "sha256:4168fcb568a826a52f23510412da405abd93f4d23ba544bb68d943b14ba3cb44"}, - {file = "googleapis_common_protos-1.59.0-py2.py3-none-any.whl", hash = "sha256:b287dc48449d1d41af0c69f4ea26242b5ae4c3d7249a38b0984c86a4caffff1f"}, + +[package.extras] +abfs = ["adlfs"] +adl = ["adlfs"] +arrow = ["pyarrow (>=1)"] +dask = ["dask", "distributed"] +devel = ["pytest", "pytest-cov"] +dropbox = ["dropbox", "dropboxdrivefs", "requests"] +full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "dask", "distributed", "dropbox", "dropboxdrivefs", "fusepy", "gcsfs", "libarchive-c", "ocifs", "panel", "paramiko", "pyarrow (>=1)", "pygit2", "requests", "s3fs", "smbprotocol", "tqdm"] +fuse = ["fusepy"] +gcs = ["gcsfs"] +git = ["pygit2"] +github = ["requests"] +gs = ["gcsfs"] +gui = ["panel"] +hdfs = ["pyarrow (>=1)"] +http = ["aiohttp (!=4.0.0a0,!=4.0.0a1)", "requests"] +libarchive = ["libarchive-c"] +oci = ["ocifs"] +s3 = ["s3fs"] +sftp = ["paramiko"] +smb = ["smbprotocol"] +ssh = ["paramiko"] +tqdm = ["tqdm"] + +[[package]] +name = "googleapis-common-protos" +version = "1.59.1" +description = "Common protobufs used in Google APIs" +optional = false +python-versions = ">=3.7" +files = [ + {file = "googleapis-common-protos-1.59.1.tar.gz", hash = "sha256:b35d530fe825fb4227857bc47ad84c33c809ac96f312e13182bdeaa2abe1178a"}, + {file = "googleapis_common_protos-1.59.1-py2.py3-none-any.whl", hash = "sha256:0cbedb6fb68f1c07e18eb4c48256320777707e7d0c55063ae56c15db3224a61e"}, ] -grpc-interceptor = [ + +[package.dependencies] +protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0" + +[package.extras] +grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"] + +[[package]] +name = "grpc-interceptor" +version = "0.15.2" +description = "Simplifies gRPC interceptors" +optional = false +python-versions = ">=3.7,<4.0" +files = [ {file = "grpc-interceptor-0.15.2.tar.gz", hash = "sha256:5c984110af4fb77d03472ec0468f9c77ddaf798e190410fb7b7f1e76c60c96a4"}, {file = "grpc_interceptor-0.15.2-py3-none-any.whl", hash = "sha256:596dac3cb709ffb6178a4873f5148e254c871c9069f0b11040189b257969490a"}, ] -grpcio = [ - {file = "grpcio-1.55.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:7b38e028a7bbc97a9ae5e418712452f298618b9d0493390770bf2de785251ae7"}, - {file = "grpcio-1.55.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:054b7164b25712ec71339e139875a66708a2ab09be36ac75e73b2d337ab2dc1b"}, - {file = "grpcio-1.55.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:1982c99c7091d1b7e3e78b1173097f705feef233e253a27e99746b11815ac897"}, - {file = "grpcio-1.55.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8bd4f4932ef63ed32a725065aebb8585e4118a523d923db896e85c09429a36e6"}, - {file = "grpcio-1.55.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70de2b73cf22241173cb21d308786ba4ea443e4c88441a2ce445829aa638dda8"}, - {file = "grpcio-1.55.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:2d25d7fcb528a40578b3d0428d401745fd5c0eeeda81f35ce2f40a10d79afd19"}, - {file = "grpcio-1.55.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:1173a05117798aca4834d3edd504e6adc25ae9967df0f44b91a612884fb2707a"}, - {file = "grpcio-1.55.0-cp310-cp310-win32.whl", hash = "sha256:7c00263d792a244bef67a8d3b357ccbcdae6341c5961dbee494d8f967f9aee69"}, - {file = "grpcio-1.55.0-cp310-cp310-win_amd64.whl", hash = "sha256:ab784204d9923368e0e5877d7795584b9606a51b128ee199ad8b5888d0c66592"}, - {file = "grpcio-1.55.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:c97cfae0b7a17dc1a0a3e4333f4f46daa114d85f950a67f39cc141b5425182e4"}, - {file = "grpcio-1.55.0-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:8a910fa9b95a286f4bc1879dcf8d5ccb95b5e33bb63323fc4414d157f23afef1"}, - {file = "grpcio-1.55.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:3ab9bf80c19c91847f45ff32af94c85d282545a62db39d797838244d57831d78"}, - {file = "grpcio-1.55.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4370d2cca37301bcc69453d3dd3c1576d06d6b3e337bfec55b3aab2fe106b25c"}, - {file = "grpcio-1.55.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dad999423b33ad5409e986587593b6062a8260b74ae8fc8162ce231c6b7a929e"}, - {file = "grpcio-1.55.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:d396ec4d520b58f43142958cff071e5ad1c50ac87d29d086a9c6a990a09ea536"}, - {file = "grpcio-1.55.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:b2a3b837d5837b9069783026b57aa0ff12e34d3218fdeda3f9c06d3950266d8e"}, - {file = "grpcio-1.55.0-cp311-cp311-win32.whl", hash = "sha256:ee0de9cb6813704969e53743e0969fd95225ff24bd686c89ed12a18147f6566c"}, - {file = "grpcio-1.55.0-cp311-cp311-win_amd64.whl", hash = "sha256:9a11b1dd4b1572e85fba5911309c15980a1ff77c555fad0ecdbe3711ef741908"}, - {file = "grpcio-1.55.0-cp37-cp37m-linux_armv7l.whl", hash = "sha256:d0209fb3cb55c5288a1dec72dcaae2c1b501edceca10d22c0f0baa5e60e2b22c"}, - {file = "grpcio-1.55.0-cp37-cp37m-macosx_10_10_universal2.whl", hash = "sha256:322d4ebc37cbc8d8596b1da6055e3e81e8cfd36816ab4b285c1163c3042e6067"}, - {file = "grpcio-1.55.0-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:60efab181c32e029e0960f238508396dd001ba2064168f8148e6356db093967c"}, - {file = "grpcio-1.55.0-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:48f6088d898e1e987d761d58dc4cd724e7457a7a86d11561fa95c3b826d025dc"}, - {file = "grpcio-1.55.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:29ab0e879b1585be41cfbb02faed67913700ced8015da4763f1f0bdd7dfb4ab7"}, - {file = "grpcio-1.55.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:157f5615c7b5d0968727472f6394dee01555ef4246d2f2cfb6555be857936d74"}, - {file = "grpcio-1.55.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:67c4fda71f92225c5e74fa15bffa6be022c07111f674fe1f234c1ef4c1bb7927"}, - {file = "grpcio-1.55.0-cp37-cp37m-win_amd64.whl", hash = "sha256:a202dcf0c512292fd7a2154e4044c70400212eaa726685ebf8af105e25693c5a"}, - {file = "grpcio-1.55.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:ce82d06cdfb8a9292fb857f00bee11a2430e4ac2742e07b46c1a3072d683256a"}, - {file = "grpcio-1.55.0-cp38-cp38-macosx_10_10_universal2.whl", hash = "sha256:51b7a27a129f743d68394f94029f88ef3da090fc13776b9dfa3c79c5f4b30525"}, - {file = "grpcio-1.55.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:7c32f87bec58a8a0d4f4d5387bd61a383bd32b2caffb2de3cd579e47490b7e19"}, - {file = "grpcio-1.55.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:89107071b5f14af6bbb855183d338a0fa94136bbeb3989c9773c6184e51a95e9"}, - {file = "grpcio-1.55.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1041cad23f00943d8889ad15427d87bbdacbbe2df5cec951c314f2f3967d4691"}, - {file = "grpcio-1.55.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:56631cc0bdf86d15ea1599b9697ace65e6b52c6b136d3666bf7769d3d6d087a8"}, - {file = "grpcio-1.55.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:10af4774da9c0665a1bf519333694ac40d72d83cb514534b99db0a5e3d5c3593"}, - {file = "grpcio-1.55.0-cp38-cp38-win32.whl", hash = "sha256:7b8665da31b5bd701b338a581de7b9631d50b4b7ee67125c2d1dc2228cc119d8"}, - {file = "grpcio-1.55.0-cp38-cp38-win_amd64.whl", hash = "sha256:74780f570c76feb8e62a8c019b495fea435b60218682fce513ff2c71262c346c"}, - {file = "grpcio-1.55.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:6b8dbb151b116825c10f01e5b7b75e14edd0e60736a65311d0d98a4cd0489303"}, - {file = "grpcio-1.55.0-cp39-cp39-macosx_10_10_universal2.whl", hash = "sha256:a82283d6e0403d3e2e7eebb99cb0d2783e20b6791c8c94bd8d4a4233b58b1ea0"}, - {file = "grpcio-1.55.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:ba32a8e9bc3eecc6bab6824b905f04c3fdc31659c3e6e06841b774e7cb4410af"}, - {file = "grpcio-1.55.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b1e2b705d524e780998218cf429d30b6ffc54cb6e54812c9597bc5df12dbcb5b"}, - {file = "grpcio-1.55.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fe78365c64b2c7470d31c4941e10c6654042bcbb53897b9b1e2c96d6d0da9ef9"}, - {file = "grpcio-1.55.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:8b440ccc434c1ad5874465bfae40c0a27f562ae5f7c5b468b6689bc55e8bf1c1"}, - {file = "grpcio-1.55.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:0d3d5c644d523dee82ffcc44ad50cd66e3bf66e7fa60ad3cdb1eb868228e4ab0"}, - {file = "grpcio-1.55.0-cp39-cp39-win32.whl", hash = "sha256:c33dbeecc14f1a413e8af8ae1208cb383b063fa2ff2e1f309b4d3d7739b0927e"}, - {file = "grpcio-1.55.0-cp39-cp39-win_amd64.whl", hash = "sha256:2663741acc117370fd53336267cfb24c965e9d3ea1e4933a3e4411712d3091fb"}, - {file = "grpcio-1.55.0.tar.gz", hash = "sha256:dd15027a171ff93c97f9c704fa120bc5d0691dc7e71ae450e2ecade1a2799b53"}, + +[package.dependencies] +grpcio = ">=1.49.1,<2.0.0" + +[package.extras] +testing = ["protobuf (>=4.21.9)"] + +[[package]] +name = "grpcio" +version = "1.54.2" +description = "HTTP/2-based RPC framework" +optional = false +python-versions = ">=3.7" +files = [ + {file = "grpcio-1.54.2-cp310-cp310-linux_armv7l.whl", hash = "sha256:40e1cbf69d6741b40f750f3cccc64326f927ac6145a9914d33879e586002350c"}, + {file = "grpcio-1.54.2-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:2288d76e4d4aa7ef3fe7a73c1c470b66ea68e7969930e746a8cd8eca6ef2a2ea"}, + {file = "grpcio-1.54.2-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:c0e3155fc5335ec7b3b70f15230234e529ca3607b20a562b6c75fb1b1218874c"}, + {file = "grpcio-1.54.2-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9bf88004fe086c786dc56ef8dd6cb49c026833fdd6f42cb853008bce3f907148"}, + {file = "grpcio-1.54.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2be88c081e33f20630ac3343d8ad9f1125f32987968e9c8c75c051c9800896e8"}, + {file = "grpcio-1.54.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:33d40954199bddbb6a78f8f6f2b2082660f381cd2583ec860a6c2fa7c8400c08"}, + {file = "grpcio-1.54.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b52d00d1793d290c81ad6a27058f5224a7d5f527867e5b580742e1bd211afeee"}, + {file = "grpcio-1.54.2-cp310-cp310-win32.whl", hash = "sha256:881d058c5ccbea7cc2c92085a11947b572498a27ef37d3eef4887f499054dca8"}, + {file = "grpcio-1.54.2-cp310-cp310-win_amd64.whl", hash = "sha256:0212e2f7fdf7592e4b9d365087da30cb4d71e16a6f213120c89b4f8fb35a3ab3"}, + {file = "grpcio-1.54.2-cp311-cp311-linux_armv7l.whl", hash = "sha256:1e623e0cf99a0ac114f091b3083a1848dbc64b0b99e181473b5a4a68d4f6f821"}, + {file = "grpcio-1.54.2-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:66233ccd2a9371158d96e05d082043d47dadb18cbb294dc5accfdafc2e6b02a7"}, + {file = "grpcio-1.54.2-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:4cb283f630624ebb16c834e5ac3d7880831b07cbe76cb08ab7a271eeaeb8943e"}, + {file = "grpcio-1.54.2-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2a1e601ee31ef30a9e2c601d0867e236ac54c922d32ed9f727b70dd5d82600d5"}, + {file = "grpcio-1.54.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8da84bbc61a4e92af54dc96344f328e5822d574f767e9b08e1602bb5ddc254a"}, + {file = "grpcio-1.54.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:5008964885e8d23313c8e5ea0d44433be9bfd7e24482574e8cc43c02c02fc796"}, + {file = "grpcio-1.54.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a2f5a1f1080ccdc7cbaf1171b2cf384d852496fe81ddedeb882d42b85727f610"}, + {file = "grpcio-1.54.2-cp311-cp311-win32.whl", hash = "sha256:b74ae837368cfffeb3f6b498688a123e6b960951be4dec0e869de77e7fa0439e"}, + {file = "grpcio-1.54.2-cp311-cp311-win_amd64.whl", hash = "sha256:8cdbcbd687e576d48f7886157c95052825ca9948c0ed2afdc0134305067be88b"}, + {file = "grpcio-1.54.2-cp37-cp37m-linux_armv7l.whl", hash = "sha256:782f4f8662a2157c4190d0f99eaaebc602899e84fb1e562a944e5025929e351c"}, + {file = "grpcio-1.54.2-cp37-cp37m-macosx_10_10_universal2.whl", hash = "sha256:714242ad0afa63a2e6dabd522ae22e1d76e07060b5af2ddda5474ba4f14c2c94"}, + {file = "grpcio-1.54.2-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:f900ed4ad7a0f1f05d35f955e0943944d5a75f607a836958c6b8ab2a81730ef2"}, + {file = "grpcio-1.54.2-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:96a41817d2c763b1d0b32675abeb9179aa2371c72aefdf74b2d2b99a1b92417b"}, + {file = "grpcio-1.54.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70fcac7b94f4c904152809a050164650ac81c08e62c27aa9f156ac518029ebbe"}, + {file = "grpcio-1.54.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:fd6c6c29717724acf9fc1847c4515d57e4dc12762452457b9cb37461f30a81bb"}, + {file = "grpcio-1.54.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:c2392f5b5d84b71d853918687d806c1aa4308109e5ca158a16e16a6be71041eb"}, + {file = "grpcio-1.54.2-cp37-cp37m-win_amd64.whl", hash = "sha256:51630c92591d6d3fe488a7c706bd30a61594d144bac7dee20c8e1ce78294f474"}, + {file = "grpcio-1.54.2-cp38-cp38-linux_armv7l.whl", hash = "sha256:b04202453941a63b36876a7172b45366dc0cde10d5fd7855c0f4a4e673c0357a"}, + {file = "grpcio-1.54.2-cp38-cp38-macosx_10_10_universal2.whl", hash = "sha256:89dde0ac72a858a44a2feb8e43dc68c0c66f7857a23f806e81e1b7cc7044c9cf"}, + {file = "grpcio-1.54.2-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:09d4bfd84686cd36fd11fd45a0732c7628308d094b14d28ea74a81db0bce2ed3"}, + {file = "grpcio-1.54.2-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7fc2b4edb938c8faa4b3c3ea90ca0dd89b7565a049e8e4e11b77e60e4ed2cc05"}, + {file = "grpcio-1.54.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61f7203e2767800edee7a1e1040aaaf124a35ce0c7fe0883965c6b762defe598"}, + {file = "grpcio-1.54.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:e416c8baf925b5a1aff31f7f5aecc0060b25d50cce3a5a7255dc5cf2f1d4e5eb"}, + {file = "grpcio-1.54.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:dc80c9c6b608bf98066a038e0172013a49cfa9a08d53335aefefda2c64fc68f4"}, + {file = "grpcio-1.54.2-cp38-cp38-win32.whl", hash = "sha256:8d6192c37a30a115f4663592861f50e130caed33efc4eec24d92ec881c92d771"}, + {file = "grpcio-1.54.2-cp38-cp38-win_amd64.whl", hash = "sha256:46a057329938b08e5f0e12ea3d7aed3ecb20a0c34c4a324ef34e00cecdb88a12"}, + {file = "grpcio-1.54.2-cp39-cp39-linux_armv7l.whl", hash = "sha256:2296356b5c9605b73ed6a52660b538787094dae13786ba53080595d52df13a98"}, + {file = "grpcio-1.54.2-cp39-cp39-macosx_10_10_universal2.whl", hash = "sha256:c72956972e4b508dd39fdc7646637a791a9665b478e768ffa5f4fe42123d5de1"}, + {file = "grpcio-1.54.2-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:9bdbb7624d65dc0ed2ed8e954e79ab1724526f09b1efa88dcd9a1815bf28be5f"}, + {file = "grpcio-1.54.2-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4c44e1a765b31e175c391f22e8fc73b2a2ece0e5e6ff042743d8109b5d2eff9f"}, + {file = "grpcio-1.54.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5cc928cfe6c360c1df636cf7991ab96f059666ac7b40b75a769410cc6217df9c"}, + {file = "grpcio-1.54.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:a08920fa1a97d4b8ee5db2f31195de4a9def1a91bc003544eb3c9e6b8977960a"}, + {file = "grpcio-1.54.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:4864f99aac207e3e45c5e26c6cbb0ad82917869abc2f156283be86c05286485c"}, + {file = "grpcio-1.54.2-cp39-cp39-win32.whl", hash = "sha256:b38b3de8cff5bc70f8f9c615f51b48eff7313fc9aca354f09f81b73036e7ddfa"}, + {file = "grpcio-1.54.2-cp39-cp39-win_amd64.whl", hash = "sha256:be48496b0e00460717225e7680de57c38be1d8629dc09dadcd1b3389d70d942b"}, + {file = "grpcio-1.54.2.tar.gz", hash = "sha256:50a9f075eeda5097aa9a182bb3877fe1272875e45370368ac0ee16ab9e22d019"}, ] -grpcio-reflection = [ - {file = "grpcio-reflection-1.55.0.tar.gz", hash = "sha256:46fc5e68ce7ae9bff0c0577f9e42bbb038a5afb26290fdf04943285e9db3c193"}, - {file = "grpcio_reflection-1.55.0-py3-none-any.whl", hash = "sha256:44e0dbfbfdcf1ac8646f1d32e4be72f0c633fd4b469e8e58b3a86e37b5a72756"}, + +[package.extras] +protobuf = ["grpcio-tools (>=1.54.2)"] + +[[package]] +name = "grpcio-reflection" +version = "1.54.2" +description = "Standard Protobuf Reflection Service for gRPC" +optional = false +python-versions = ">=3.6" +files = [ + {file = "grpcio-reflection-1.54.2.tar.gz", hash = "sha256:b2e021e1ce4f075615411edfbbd6fdcc485ba474dd6e5a3f559690582959a673"}, + {file = "grpcio_reflection-1.54.2-py3-none-any.whl", hash = "sha256:e7759addebbd90768f3a0278320278145758c4687d9e2cd7d76e7cbd0e329274"}, ] -grpcio-status = [ - {file = "grpcio-status-1.55.0.tar.gz", hash = "sha256:beeca8d5d3783e155676beaade0dae9eaea12cd9701498905dca0d35bd6b36f8"}, - {file = "grpcio_status-1.55.0-py3-none-any.whl", hash = "sha256:6da36bab11bb252b6854b86578f484c4fed9f8169816b490b6d3a32ec2a971fe"}, + +[package.dependencies] +grpcio = ">=1.54.2" +protobuf = ">=4.21.6" + +[[package]] +name = "grpcio-status" +version = "1.54.2" +description = "Status proto mapping for gRPC" +optional = false +python-versions = ">=3.6" +files = [ + {file = "grpcio-status-1.54.2.tar.gz", hash = "sha256:3255cbec5b7c706caa3d4dd584606c080e6415e15631bb2f6215e2b70055836d"}, + {file = "grpcio_status-1.54.2-py3-none-any.whl", hash = "sha256:2a7cb4838225f1b53bd0448a3008c5b5837941e1f3a0b13fa38768f08a7b68c2"}, ] -grpcio-tools = [ - {file = "grpcio-tools-1.55.0.tar.gz", hash = "sha256:d796f5d7cea260ef2afed12d13ec34b13e09dd74d7f292d7428c506fa8c17a74"}, - {file = "grpcio_tools-1.55.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:52e34e9b6496f4c1e3289ada7bc41d759e4a8ec5f2679e187067cab8532ffbf4"}, - {file = "grpcio_tools-1.55.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:b131b2bbf25198d9e508dfa588cb215580629b514e293d5609eeee98c8941dbc"}, - {file = "grpcio_tools-1.55.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:9933a1f18f780c42214b126ef27e273b54c9c28de3fae5b1887b413ceb374c4c"}, - {file = "grpcio_tools-1.55.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cfc82c11ce51de6ed5836fbafbc188d9eac0737abc116978f151c40271783817"}, - {file = "grpcio_tools-1.55.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c7a18bd5f994b7911d3e70e0abb05bea9f1b084a1725d404a8e231bf9727613b"}, - {file = "grpcio_tools-1.55.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:87152893c7c3bef58a6a9b548db290aa318cc314c700ae7d7f2970aa567f875e"}, - {file = "grpcio_tools-1.55.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bc23034b1959d6cda27347b2207fee0fb0fb0aff242da228a6b7c1a18fce4116"}, - {file = "grpcio_tools-1.55.0-cp310-cp310-win32.whl", hash = "sha256:ab64f9d6f5e3636ae6298e2d795225daa83aacb057105943728ed50a8a582237"}, - {file = "grpcio_tools-1.55.0-cp310-cp310-win_amd64.whl", hash = "sha256:b197de69ca0431b718ffa47b32a733703fa5503da49f49dd315c866842b6cfbd"}, - {file = "grpcio_tools-1.55.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:98ff3129ff7134a95f4d2857663625771f6838ac44b7799c34259b7ea87ebe5c"}, - {file = "grpcio_tools-1.55.0-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:c3c7b7eb89f963b87922ecc0c0ab2485fff05997ada66dffd53597b507a83bc8"}, - {file = "grpcio_tools-1.55.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:51a1ccab6f67edd1a3768a75ac495907fe0cd6d6617af2f9f2033400b5858a11"}, - {file = "grpcio_tools-1.55.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4dea66623548b52429fb03495f2c76f4c993bf9a56267c6b3d0fb62573dd52c2"}, - {file = "grpcio_tools-1.55.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e59fd4a58688117cb5128d7785909d45a6e5f8212efeb65b6fd74bb9b8b9512"}, - {file = "grpcio_tools-1.55.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:734ede84d613b044f72e7d9c190bd2388ebb83e85bcd3aa75afa9f30c096dbc7"}, - {file = "grpcio_tools-1.55.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f87d99aa3826aa20c3b89493984cf278f4c9d20b3418534a46239c804fee506c"}, - {file = "grpcio_tools-1.55.0-cp311-cp311-win32.whl", hash = "sha256:4580df5a9867f7bcbb828a5485c030ca232c1578e615caf751333c7a7980d838"}, - {file = "grpcio_tools-1.55.0-cp311-cp311-win_amd64.whl", hash = "sha256:b674de79571357c5381bc5fa12e3b89fefef74c164ab9077ed22158c3529aa8e"}, - {file = "grpcio_tools-1.55.0-cp37-cp37m-linux_armv7l.whl", hash = "sha256:0dead7fb37bfe7c7eb8294143015645297f4affa683783b8bbf2cd4d7f7036d4"}, - {file = "grpcio_tools-1.55.0-cp37-cp37m-macosx_10_10_universal2.whl", hash = "sha256:89f6ed47415a22568bbf4a62336bfde7cafb53492a5a9f33a22243411b00f443"}, - {file = "grpcio_tools-1.55.0-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:946266cbd639847548c9f97e38da0682746c2eadea790ceb4320b1f85387bd6d"}, - {file = "grpcio_tools-1.55.0-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:36745762689df18f83273a9a004848897793f63a10a30acd18acb2d170c663a9"}, - {file = "grpcio_tools-1.55.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:053bbdfb74f76511db47e1e18a1962432468ae9f356cc00f15d1f1353eaf32a1"}, - {file = "grpcio_tools-1.55.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:2eacb0b1e8e5cfd0b40e12e62bd5adebbbae8c73cdf6e04fad9ddd37e32d98a4"}, - {file = "grpcio_tools-1.55.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:9395c4fdee6b22137e878ebd461444854a3cd9c6c260c43f4a4c4a4023700129"}, - {file = "grpcio_tools-1.55.0-cp37-cp37m-win_amd64.whl", hash = "sha256:bcf5e1858137cbe13ef10a7931a7edc745c77f8b39f032f52072443f0dd681e1"}, - {file = "grpcio_tools-1.55.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:e76f35e5e65600a75c3547855e8c9ab935c55c473f5409e7746cca8f1f7c8f4a"}, - {file = "grpcio_tools-1.55.0-cp38-cp38-macosx_10_10_universal2.whl", hash = "sha256:a61567f27661ab9327dc060615dc22d2bde80c56731f1e856008f1fd8ee83311"}, - {file = "grpcio_tools-1.55.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:2ba87592f2cd689e127cd4fce76ec23b19562e230fa41ea089af8b15120aea78"}, - {file = "grpcio_tools-1.55.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:41005002cbfa0ad39972486bde8116b2a042804119e5b998086a4dc26e625d6a"}, - {file = "grpcio_tools-1.55.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4a6db1494955d2a5531575b5fcdc08094ea4a331a94b9cdf864d78e801c5fa23"}, - {file = "grpcio_tools-1.55.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:07c23ed940e046c9dd471bc870eb5db4d93e518f90011cf9aebf8bfda6cd68a5"}, - {file = "grpcio_tools-1.55.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:f900bce944b5777effecb9078e5fd3a224e42b1ca33c7546c3d043f9ef9eb4e8"}, - {file = "grpcio_tools-1.55.0-cp38-cp38-win32.whl", hash = "sha256:3724e48c3db499b2d212c5a89d7cc4b49ccd476dc26bf8a9b855d59b6cc00796"}, - {file = "grpcio_tools-1.55.0-cp38-cp38-win_amd64.whl", hash = "sha256:416a8b61ed4223715755b4519858419e1f4653d64572a28029f2ac63e677e3d2"}, - {file = "grpcio_tools-1.55.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:73ef9e0e0ee8ab055a621e7b42e5fb32753b0b6607900887dba6d55df5947be8"}, - {file = "grpcio_tools-1.55.0-cp39-cp39-macosx_10_10_universal2.whl", hash = "sha256:4a41130c97775bb0dfaf87e34b492f2eca448d02d213410005544c534f3f7c26"}, - {file = "grpcio_tools-1.55.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:87dbc98528f88faa3f8f56a47d41dc6fda382928abbdb5537b5444eb8bb1ac1b"}, - {file = "grpcio_tools-1.55.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7f084cd619cf66d8620a99f8586018f19b918ffb2ddb92d3e5943a06038bead8"}, - {file = "grpcio_tools-1.55.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:350303ef3a2b25ed1b90e42764923e40b664d9f10840f7a0f06117c4dc414aff"}, - {file = "grpcio_tools-1.55.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:fbbe2bee4af93c03ba064d40199dbf38067d2aa6ae98dfa0687a08ee980ebfd5"}, - {file = "grpcio_tools-1.55.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:b00a67a1230968c1a0424915922d17983d824ed45e8db06f9f17be6d5571faee"}, - {file = "grpcio_tools-1.55.0-cp39-cp39-win32.whl", hash = "sha256:632364ffbd4fb0338cb03c590a2ddc258d9cd59bff0bf4199c02e3e581f802d7"}, - {file = "grpcio_tools-1.55.0-cp39-cp39-win_amd64.whl", hash = "sha256:95428be2db12412ff23f0969386fc51d2aa6de38a57cc54c57363352f1d7a832"}, + +[package.dependencies] +googleapis-common-protos = ">=1.5.5" +grpcio = ">=1.54.2" +protobuf = ">=4.21.6" + +[[package]] +name = "grpcio-tools" +version = "1.54.2" +description = "Protobuf code generator for gRPC" +optional = false +python-versions = ">=3.7" +files = [ + {file = "grpcio-tools-1.54.2.tar.gz", hash = "sha256:e11c2c2aee53f340992e8e4d6a59172cbbbd0193f1351de98c4f810a5041d5ca"}, + {file = "grpcio_tools-1.54.2-cp310-cp310-linux_armv7l.whl", hash = "sha256:2b96f5f17d3156058be247fd25b062b4768138665694c00b056659618b8fb418"}, + {file = "grpcio_tools-1.54.2-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:11939c9a8a39bd4815c7e88cb2fee48e1948775b59dbb06de8fcae5991e84f9e"}, + {file = "grpcio_tools-1.54.2-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:129de5579f95d6a55dde185f188b4cbe19d1e2f1471425431d9930c31d300d70"}, + {file = "grpcio_tools-1.54.2-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c4128c01cd6f5ea8f7c2db405dbfd8582cd967d36e6fa0952565436633b0e591"}, + {file = "grpcio_tools-1.54.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e5c7292dd899ad8fa09a2be96719648cee37b17909fe8c12007e3bff58ebee61"}, + {file = "grpcio_tools-1.54.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:5ef30c2dbc63c1e0a462423ca4f95001814d26ef4fe66208e53fcf220ea3b717"}, + {file = "grpcio_tools-1.54.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4abfc1892380abe6cef381eab86f9350cbd703bfe5d834095aa66fd91c886b6d"}, + {file = "grpcio_tools-1.54.2-cp310-cp310-win32.whl", hash = "sha256:9acf443dcf6f68fbea3b7fb519e1716e014db1a561939f5aecc4abda74e4015d"}, + {file = "grpcio_tools-1.54.2-cp310-cp310-win_amd64.whl", hash = "sha256:21b9d2dee80f3f77e4097252e7f0db89772335a7300b72ab3d2e5c280872b1db"}, + {file = "grpcio_tools-1.54.2-cp311-cp311-linux_armv7l.whl", hash = "sha256:7b24fbab9e7598518ce4549e066df00aab79c2bf9bedcdde23fb5ef6a3cf532f"}, + {file = "grpcio_tools-1.54.2-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:7baa210c20f71a242d9ae0e02734628f6948e8bee3bf538647894af427d28800"}, + {file = "grpcio_tools-1.54.2-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:e3d0e5188ff8dbaddac2ee44731d36f09c4eccd3eac7328e547862c44f75cacd"}, + {file = "grpcio_tools-1.54.2-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:27671c68c7e0e3c5ff9967f5500799f65a04e7b153b8ce10243c87c43199039d"}, + {file = "grpcio_tools-1.54.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f39d8e8806b8857fb473ca6a9c7bd800b0673dfdb7283ff569af0345a222f32c"}, + {file = "grpcio_tools-1.54.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:8e4c5a48f7b2e8798ce381498ee7b9a83c65b87ae66ee5022387394e5eb51771"}, + {file = "grpcio_tools-1.54.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4f285f8ef3de422717a36bd372239ae778b8cc112ce780ca3c7fe266dadc49fb"}, + {file = "grpcio_tools-1.54.2-cp311-cp311-win32.whl", hash = "sha256:0f952c8a5c47e9204fe8959f7e9add149e660f6579d67cf65024c32736d34caf"}, + {file = "grpcio_tools-1.54.2-cp311-cp311-win_amd64.whl", hash = "sha256:3237149beec39e897fd62cef4aa1e1cd9422d7a95661d24bd0a79200b167e730"}, + {file = "grpcio_tools-1.54.2-cp37-cp37m-linux_armv7l.whl", hash = "sha256:0ab1b323905d449298523db5d34fa5bf5fffd645bd872b25598e2f8a01f0ea39"}, + {file = "grpcio_tools-1.54.2-cp37-cp37m-macosx_10_10_universal2.whl", hash = "sha256:7d7e6e8d62967b3f037f952620cb7381cc39a4bd31790c75fcfba56cc975d70b"}, + {file = "grpcio_tools-1.54.2-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:7f4624ef2e76a3a5313c4e61a81be38bcc16b59a68a85d30758b84cd2102b161"}, + {file = "grpcio_tools-1.54.2-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e543f457935ba7b763b121f1bf893974393b4d30065042f947f85a8d81081b80"}, + {file = "grpcio_tools-1.54.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0239b929eb8b3b30b2397eef3b9abb245087754d77c3721e3be43c44796de87d"}, + {file = "grpcio_tools-1.54.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:0de05c7698c655e9a240dc34ae91d6017b93143ac89e5b20046d7ca3bd09c27c"}, + {file = "grpcio_tools-1.54.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:a3ce0b98fb581c471424d2cda45120f57658ed97677c6fec4d6decf5d7c1b976"}, + {file = "grpcio_tools-1.54.2-cp37-cp37m-win_amd64.whl", hash = "sha256:37393ef90674964175923afe3859fc5a208e1ece565f642b4f76a8c0224a0993"}, + {file = "grpcio_tools-1.54.2-cp38-cp38-linux_armv7l.whl", hash = "sha256:8e4531267736d88fde1022b36dd42ed8163e3575bcbd12bfed96662872aa93fe"}, + {file = "grpcio_tools-1.54.2-cp38-cp38-macosx_10_10_universal2.whl", hash = "sha256:a0b7049814442f918b522d66b1d015286afbeb9e6d141af54bbfafe31710a3c8"}, + {file = "grpcio_tools-1.54.2-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:b80585e06c4f0082327eb5c9ad96fbdb2b0e7c14971ea5099fe78c22f4608451"}, + {file = "grpcio_tools-1.54.2-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:39fd530cfdf58dc05125775cc233b05554d553d27478f14ae5fd8a6306f0cb28"}, + {file = "grpcio_tools-1.54.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3bb9ec4aea0f2b3006fb002fa59e5c10f92b48fc374619fbffd14d2b0e388c3e"}, + {file = "grpcio_tools-1.54.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:d512de051342a576bb89777476d13c5266d9334cf4badb6468aed9dc8f5bdec1"}, + {file = "grpcio_tools-1.54.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:1b8ee3099c51ce987fa8a08e6b93fc342b10228415dd96b5c0caa0387f636a6f"}, + {file = "grpcio_tools-1.54.2-cp38-cp38-win32.whl", hash = "sha256:6037f123905dc0141f7c8383ca616ef0195e79cd3b4d82faaee789d4045e891b"}, + {file = "grpcio_tools-1.54.2-cp38-cp38-win_amd64.whl", hash = "sha256:10dd41862f579d185c60f629b5ee89103e216f63b576079d258d974d980bad87"}, + {file = "grpcio_tools-1.54.2-cp39-cp39-linux_armv7l.whl", hash = "sha256:f6787d07fdab31a32c433c1ba34883dea6559d8a3fbe08fb93d834ca34136b71"}, + {file = "grpcio_tools-1.54.2-cp39-cp39-macosx_10_10_universal2.whl", hash = "sha256:21b1467e31e44429d2a78b50135c9cdbd4b8f6d3b5cd548bc98985d3bdc352d0"}, + {file = "grpcio_tools-1.54.2-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:30a49b8b168aced2a4ff40959e6c4383ad6cfd7a20839a47a215e9837eb722dc"}, + {file = "grpcio_tools-1.54.2-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8742122782953d2fd038f0a199f047a24e941cc9718b1aac90876dbdb7167739"}, + {file = "grpcio_tools-1.54.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:503ef1351c62fb1d6747eaf74932b609d8fdd4345b3591ef910adef8fa9969d0"}, + {file = "grpcio_tools-1.54.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:72d15de4c4b6a764a76c4ae69d99c35f7a0751223688c3f7e62dfa95eb4f61be"}, + {file = "grpcio_tools-1.54.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:df079479fb1b9e488334312e35ebbf30cbf5ecad6c56599f1a961800b33ab7c1"}, + {file = "grpcio_tools-1.54.2-cp39-cp39-win32.whl", hash = "sha256:49c2846dcc4803476e839d8bd4db8845e928f19130e0ea86121f2d1f43d2b452"}, + {file = "grpcio_tools-1.54.2-cp39-cp39-win_amd64.whl", hash = "sha256:b82ca472db9c914c44e39a41e9e8bd3ed724523dd7aff5ce37592b8d16920ed9"}, ] -hf-transfer = [ + +[package.dependencies] +grpcio = ">=1.54.2" +protobuf = ">=4.21.6,<5.0dev" +setuptools = "*" + +[[package]] +name = "hf-transfer" +version = "0.1.3" +description = "" +optional = false +python-versions = ">=3.7" +files = [ {file = "hf_transfer-0.1.3-cp310-cp310-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:862b6ddba8e236bdc73408c20d020cfe5069cac3fd0b6de901c46f031df2b7d9"}, {file = "hf_transfer-0.1.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:569ef1ec6fec182e706ade4ea0c63f8510fd618ed7ced7c772efaafac7245b07"}, {file = "hf_transfer-0.1.3-cp310-none-win_amd64.whl", hash = "sha256:c9faa88b3491c50d4aa75faf18ae24040cd91aa0565c7f7ba2357dbcbf8372f6"}, @@ -980,87 +459,196 @@ hf-transfer = [ {file = "hf_transfer-0.1.3-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:efb8b41360c7e3d7700c147b70688aed0a03e86fbe5bcfdee079b0e634f026f9"}, {file = "hf_transfer-0.1.3.tar.gz", hash = "sha256:7afd7eb03efad7812a48591b639b2e3f3d1f93c1e9060c18cc63ebf08d7e193c"}, ] -huggingface-hub = [ - {file = "huggingface_hub-0.14.0-py3-none-any.whl", hash = "sha256:fa6a6139fe4a8a164bfd0cda90c225fe8471b47c12811738b6db8348a2f703a0"}, - {file = "huggingface_hub-0.14.0.tar.gz", hash = "sha256:42eeab833284e3fc1d39263cf9c3d1bb36b129acdd8195838694d165e8dd6cae"}, + +[[package]] +name = "huggingface-hub" +version = "0.14.1" +description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "huggingface_hub-0.14.1-py3-none-any.whl", hash = "sha256:9fc619170d800ff3793ad37c9757c255c8783051e1b5b00501205eb43ccc4f27"}, + {file = "huggingface_hub-0.14.1.tar.gz", hash = "sha256:9ab899af8e10922eac65e290d60ab956882ab0bf643e3d990b1394b6b47b7fbc"}, ] -idna = [ + +[package.dependencies] +filelock = "*" +fsspec = "*" +packaging = ">=20.9" +pyyaml = ">=5.1" +requests = "*" +tqdm = ">=4.42.1" +typing-extensions = ">=3.7.4.3" + +[package.extras] +all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "black (>=23.1,<24.0)", "gradio", "jedi", "mypy (==0.982)", "pytest", "pytest-cov", "pytest-env", "pytest-xdist", "ruff (>=0.0.241)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"] +cli = ["InquirerPy (==0.3.4)"] +dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "black (>=23.1,<24.0)", "gradio", "jedi", "mypy (==0.982)", "pytest", "pytest-cov", "pytest-env", "pytest-xdist", "ruff (>=0.0.241)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"] +fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"] +quality = ["black (>=23.1,<24.0)", "mypy (==0.982)", "ruff (>=0.0.241)"] +tensorflow = ["graphviz", "pydot", "tensorflow"] +testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "gradio", "jedi", "pytest", "pytest-cov", "pytest-env", "pytest-xdist", "soundfile"] +torch = ["torch"] +typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"] + +[[package]] +name = "idna" +version = "3.4" +description = "Internationalized Domain Names in Applications (IDNA)" +optional = false +python-versions = ">=3.5" +files = [ {file = "idna-3.4-py3-none-any.whl", hash = "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"}, {file = "idna-3.4.tar.gz", hash = "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4"}, ] -iniconfig = [ + +[[package]] +name = "iniconfig" +version = "2.0.0" +description = "brain-dead simple config-ini parsing" +optional = false +python-versions = ">=3.7" +files = [ {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] -Jinja2 = [ + +[[package]] +name = "jinja2" +version = "3.1.2" +description = "A very fast and expressive template engine." +optional = true +python-versions = ">=3.7" +files = [ {file = "Jinja2-3.1.2-py3-none-any.whl", hash = "sha256:6088930bfe239f0e6710546ab9c19c9ef35e29792895fed6e6e31a023a182a61"}, {file = "Jinja2-3.1.2.tar.gz", hash = "sha256:31351a702a408a9e7595a8fc6150fc3f43bb6bf7e319770cbc0db9df9437e852"}, ] -loguru = [ + +[package.dependencies] +MarkupSafe = ">=2.0" + +[package.extras] +i18n = ["Babel (>=2.7)"] + +[[package]] +name = "loguru" +version = "0.6.0" +description = "Python logging made (stupidly) simple" +optional = false +python-versions = ">=3.5" +files = [ {file = "loguru-0.6.0-py3-none-any.whl", hash = "sha256:4e2414d534a2ab57573365b3e6d0234dfb1d84b68b7f3b948e6fb743860a77c3"}, {file = "loguru-0.6.0.tar.gz", hash = "sha256:066bd06758d0a513e9836fd9c6b5a75bfb3fd36841f4b996bc60b547a309d41c"}, ] -MarkupSafe = [ - {file = "MarkupSafe-2.1.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:665a36ae6f8f20a4676b53224e33d456a6f5a72657d9c83c2aa00765072f31f7"}, - {file = "MarkupSafe-2.1.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:340bea174e9761308703ae988e982005aedf427de816d1afe98147668cc03036"}, - {file = "MarkupSafe-2.1.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:22152d00bf4a9c7c83960521fc558f55a1adbc0631fbb00a9471e097b19d72e1"}, - {file = "MarkupSafe-2.1.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:28057e985dace2f478e042eaa15606c7efccb700797660629da387eb289b9323"}, - {file = "MarkupSafe-2.1.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca244fa73f50a800cf8c3ebf7fd93149ec37f5cb9596aa8873ae2c1d23498601"}, - {file = "MarkupSafe-2.1.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:d9d971ec1e79906046aa3ca266de79eac42f1dbf3612a05dc9368125952bd1a1"}, - {file = "MarkupSafe-2.1.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:7e007132af78ea9df29495dbf7b5824cb71648d7133cf7848a2a5dd00d36f9ff"}, - {file = "MarkupSafe-2.1.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:7313ce6a199651c4ed9d7e4cfb4aa56fe923b1adf9af3b420ee14e6d9a73df65"}, - {file = "MarkupSafe-2.1.2-cp310-cp310-win32.whl", hash = "sha256:c4a549890a45f57f1ebf99c067a4ad0cb423a05544accaf2b065246827ed9603"}, - {file = "MarkupSafe-2.1.2-cp310-cp310-win_amd64.whl", hash = "sha256:835fb5e38fd89328e9c81067fd642b3593c33e1e17e2fdbf77f5676abb14a156"}, - {file = "MarkupSafe-2.1.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:2ec4f2d48ae59bbb9d1f9d7efb9236ab81429a764dedca114f5fdabbc3788013"}, - {file = "MarkupSafe-2.1.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:608e7073dfa9e38a85d38474c082d4281f4ce276ac0010224eaba11e929dd53a"}, - {file = "MarkupSafe-2.1.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:65608c35bfb8a76763f37036547f7adfd09270fbdbf96608be2bead319728fcd"}, - {file = "MarkupSafe-2.1.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f2bfb563d0211ce16b63c7cb9395d2c682a23187f54c3d79bfec33e6705473c6"}, - {file = "MarkupSafe-2.1.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:da25303d91526aac3672ee6d49a2f3db2d9502a4a60b55519feb1a4c7714e07d"}, - {file = "MarkupSafe-2.1.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:9cad97ab29dfc3f0249b483412c85c8ef4766d96cdf9dcf5a1e3caa3f3661cf1"}, - {file = "MarkupSafe-2.1.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:085fd3201e7b12809f9e6e9bc1e5c96a368c8523fad5afb02afe3c051ae4afcc"}, - {file = "MarkupSafe-2.1.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:1bea30e9bf331f3fef67e0a3877b2288593c98a21ccb2cf29b74c581a4eb3af0"}, - {file = "MarkupSafe-2.1.2-cp311-cp311-win32.whl", hash = "sha256:7df70907e00c970c60b9ef2938d894a9381f38e6b9db73c5be35e59d92e06625"}, - {file = "MarkupSafe-2.1.2-cp311-cp311-win_amd64.whl", hash = "sha256:e55e40ff0cc8cc5c07996915ad367fa47da6b3fc091fdadca7f5403239c5fec3"}, - {file = "MarkupSafe-2.1.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:a6e40afa7f45939ca356f348c8e23048e02cb109ced1eb8420961b2f40fb373a"}, - {file = "MarkupSafe-2.1.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf877ab4ed6e302ec1d04952ca358b381a882fbd9d1b07cccbfd61783561f98a"}, - {file = "MarkupSafe-2.1.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63ba06c9941e46fa389d389644e2d8225e0e3e5ebcc4ff1ea8506dce646f8c8a"}, - {file = "MarkupSafe-2.1.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f1cd098434e83e656abf198f103a8207a8187c0fc110306691a2e94a78d0abb2"}, - {file = "MarkupSafe-2.1.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:55f44b440d491028addb3b88f72207d71eeebfb7b5dbf0643f7c023ae1fba619"}, - {file = "MarkupSafe-2.1.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:a6f2fcca746e8d5910e18782f976489939d54a91f9411c32051b4aab2bd7c513"}, - {file = "MarkupSafe-2.1.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:0b462104ba25f1ac006fdab8b6a01ebbfbce9ed37fd37fd4acd70c67c973e460"}, - {file = "MarkupSafe-2.1.2-cp37-cp37m-win32.whl", hash = "sha256:7668b52e102d0ed87cb082380a7e2e1e78737ddecdde129acadb0eccc5423859"}, - {file = "MarkupSafe-2.1.2-cp37-cp37m-win_amd64.whl", hash = "sha256:6d6607f98fcf17e534162f0709aaad3ab7a96032723d8ac8750ffe17ae5a0666"}, - {file = "MarkupSafe-2.1.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:a806db027852538d2ad7555b203300173dd1b77ba116de92da9afbc3a3be3eed"}, - {file = "MarkupSafe-2.1.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:a4abaec6ca3ad8660690236d11bfe28dfd707778e2442b45addd2f086d6ef094"}, - {file = "MarkupSafe-2.1.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f03a532d7dee1bed20bc4884194a16160a2de9ffc6354b3878ec9682bb623c54"}, - {file = "MarkupSafe-2.1.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4cf06cdc1dda95223e9d2d3c58d3b178aa5dacb35ee7e3bbac10e4e1faacb419"}, - {file = "MarkupSafe-2.1.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:22731d79ed2eb25059ae3df1dfc9cb1546691cc41f4e3130fe6bfbc3ecbbecfa"}, - {file = "MarkupSafe-2.1.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:f8ffb705ffcf5ddd0e80b65ddf7bed7ee4f5a441ea7d3419e861a12eaf41af58"}, - {file = "MarkupSafe-2.1.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:8db032bf0ce9022a8e41a22598eefc802314e81b879ae093f36ce9ddf39ab1ba"}, - {file = "MarkupSafe-2.1.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:2298c859cfc5463f1b64bd55cb3e602528db6fa0f3cfd568d3605c50678f8f03"}, - {file = "MarkupSafe-2.1.2-cp38-cp38-win32.whl", hash = "sha256:50c42830a633fa0cf9e7d27664637532791bfc31c731a87b202d2d8ac40c3ea2"}, - {file = "MarkupSafe-2.1.2-cp38-cp38-win_amd64.whl", hash = "sha256:bb06feb762bade6bf3c8b844462274db0c76acc95c52abe8dbed28ae3d44a147"}, - {file = "MarkupSafe-2.1.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:99625a92da8229df6d44335e6fcc558a5037dd0a760e11d84be2260e6f37002f"}, - {file = "MarkupSafe-2.1.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:8bca7e26c1dd751236cfb0c6c72d4ad61d986e9a41bbf76cb445f69488b2a2bd"}, - {file = "MarkupSafe-2.1.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:40627dcf047dadb22cd25ea7ecfe9cbf3bbbad0482ee5920b582f3809c97654f"}, - {file = "MarkupSafe-2.1.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40dfd3fefbef579ee058f139733ac336312663c6706d1163b82b3003fb1925c4"}, - {file = "MarkupSafe-2.1.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:090376d812fb6ac5f171e5938e82e7f2d7adc2b629101cec0db8b267815c85e2"}, - {file = "MarkupSafe-2.1.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:2e7821bffe00aa6bd07a23913b7f4e01328c3d5cc0b40b36c0bd81d362faeb65"}, - {file = "MarkupSafe-2.1.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:c0a33bc9f02c2b17c3ea382f91b4db0e6cde90b63b296422a939886a7a80de1c"}, - {file = "MarkupSafe-2.1.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:b8526c6d437855442cdd3d87eede9c425c4445ea011ca38d937db299382e6fa3"}, - {file = "MarkupSafe-2.1.2-cp39-cp39-win32.whl", hash = "sha256:137678c63c977754abe9086a3ec011e8fd985ab90631145dfb9294ad09c102a7"}, - {file = "MarkupSafe-2.1.2-cp39-cp39-win_amd64.whl", hash = "sha256:0576fe974b40a400449768941d5d0858cc624e3249dfd1e0c33674e5c7ca7aed"}, - {file = "MarkupSafe-2.1.2.tar.gz", hash = "sha256:abcabc8c2b26036d62d4c746381a6f7cf60aafcc653198ad678306986b09450d"}, + +[package.dependencies] +colorama = {version = ">=0.3.4", markers = "sys_platform == \"win32\""} +win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""} + +[package.extras] +dev = ["Sphinx (>=4.1.1)", "black (>=19.10b0)", "colorama (>=0.3.4)", "docutils (==0.16)", "flake8 (>=3.7.7)", "isort (>=5.1.1)", "pytest (>=4.6.2)", "pytest-cov (>=2.7.1)", "sphinx-autobuild (>=0.7.1)", "sphinx-rtd-theme (>=0.4.3)", "tox (>=3.9.0)"] + +[[package]] +name = "markupsafe" +version = "2.1.3" +description = "Safely add untrusted strings to HTML/XML markup." +optional = true +python-versions = ">=3.7" +files = [ + {file = "MarkupSafe-2.1.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cd0f502fe016460680cd20aaa5a76d241d6f35a1c3350c474bac1273803893fa"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e09031c87a1e51556fdcb46e5bd4f59dfb743061cf93c4d6831bf894f125eb57"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:68e78619a61ecf91e76aa3e6e8e33fc4894a2bebe93410754bd28fce0a8a4f9f"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65c1a9bcdadc6c28eecee2c119465aebff8f7a584dd719facdd9e825ec61ab52"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:525808b8019e36eb524b8c68acdd63a37e75714eac50e988180b169d64480a00"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:962f82a3086483f5e5f64dbad880d31038b698494799b097bc59c2edf392fce6"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:aa7bd130efab1c280bed0f45501b7c8795f9fdbeb02e965371bbef3523627779"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c9c804664ebe8f83a211cace637506669e7890fec1b4195b505c214e50dd4eb7"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-win32.whl", hash = "sha256:10bbfe99883db80bdbaff2dcf681dfc6533a614f700da1287707e8a5d78a8431"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-win_amd64.whl", hash = "sha256:1577735524cdad32f9f694208aa75e422adba74f1baee7551620e43a3141f559"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ad9e82fb8f09ade1c3e1b996a6337afac2b8b9e365f926f5a61aacc71adc5b3c"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3c0fae6c3be832a0a0473ac912810b2877c8cb9d76ca48de1ed31e1c68386575"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b076b6226fb84157e3f7c971a47ff3a679d837cf338547532ab866c57930dbee"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bfce63a9e7834b12b87c64d6b155fdd9b3b96191b6bd334bf37db7ff1fe457f2"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:338ae27d6b8745585f87218a3f23f1512dbf52c26c28e322dbe54bcede54ccb9"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e4dd52d80b8c83fdce44e12478ad2e85c64ea965e75d66dbeafb0a3e77308fcc"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:df0be2b576a7abbf737b1575f048c23fb1d769f267ec4358296f31c2479db8f9"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca379055a47383d02a5400cb0d110cef0a776fc644cda797db0c5696cfd7e18e"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:b7ff0f54cb4ff66dd38bebd335a38e2c22c41a8ee45aa608efc890ac3e3931bc"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:c011a4149cfbcf9f03994ec2edffcb8b1dc2d2aede7ca243746df97a5d41ce48"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:56d9f2ecac662ca1611d183feb03a3fa4406469dafe241673d521dd5ae92a155"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-win32.whl", hash = "sha256:8758846a7e80910096950b67071243da3e5a20ed2546e6392603c096778d48e0"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-win_amd64.whl", hash = "sha256:787003c0ddb00500e49a10f2844fac87aa6ce977b90b0feaaf9de23c22508b24"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:2ef12179d3a291be237280175b542c07a36e7f60718296278d8593d21ca937d4"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2c1b19b3aaacc6e57b7e25710ff571c24d6c3613a45e905b1fde04d691b98ee0"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8afafd99945ead6e075b973fefa56379c5b5c53fd8937dad92c662da5d8fd5ee"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8c41976a29d078bb235fea9b2ecd3da465df42a562910f9022f1a03107bd02be"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d080e0a5eb2529460b30190fcfcc4199bd7f827663f858a226a81bc27beaa97e"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:69c0f17e9f5a7afdf2cc9fb2d1ce6aabdb3bafb7f38017c0b77862bcec2bbad8"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:504b320cd4b7eff6f968eddf81127112db685e81f7e36e75f9f84f0df46041c3"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:42de32b22b6b804f42c5d98be4f7e5e977ecdd9ee9b660fda1a3edf03b11792d"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-win32.whl", hash = "sha256:ceb01949af7121f9fc39f7d27f91be8546f3fb112c608bc4029aef0bab86a2a5"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-win_amd64.whl", hash = "sha256:1b40069d487e7edb2676d3fbdb2b0829ffa2cd63a2ec26c4938b2d34391b4ecc"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:8023faf4e01efadfa183e863fefde0046de576c6f14659e8782065bcece22198"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6b2b56950d93e41f33b4223ead100ea0fe11f8e6ee5f641eb753ce4b77a7042b"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9dcdfd0eaf283af041973bff14a2e143b8bd64e069f4c383416ecd79a81aab58"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:05fb21170423db021895e1ea1e1f3ab3adb85d1c2333cbc2310f2a26bc77272e"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:282c2cb35b5b673bbcadb33a585408104df04f14b2d9b01d4c345a3b92861c2c"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ab4a0df41e7c16a1392727727e7998a467472d0ad65f3ad5e6e765015df08636"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7ef3cb2ebbf91e330e3bb937efada0edd9003683db6b57bb108c4001f37a02ea"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:0a4e4a1aff6c7ac4cd55792abf96c915634c2b97e3cc1c7129578aa68ebd754e"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-win32.whl", hash = "sha256:fec21693218efe39aa7f8599346e90c705afa52c5b31ae019b2e57e8f6542bb2"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-win_amd64.whl", hash = "sha256:3fd4abcb888d15a94f32b75d8fd18ee162ca0c064f35b11134be77050296d6ba"}, + {file = "MarkupSafe-2.1.3.tar.gz", hash = "sha256:af598ed32d6ae86f1b747b82783958b1a4ab8f617b06fe68795c7f026abbdcad"}, ] -mpmath = [ + +[[package]] +name = "mpmath" +version = "1.3.0" +description = "Python library for arbitrary-precision floating-point arithmetic" +optional = true +python-versions = "*" +files = [ {file = "mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c"}, {file = "mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f"}, ] -networkx = [ + +[package.extras] +develop = ["codecov", "pycodestyle", "pytest (>=4.6)", "pytest-cov", "wheel"] +docs = ["sphinx"] +gmpy = ["gmpy2 (>=2.1.0a4)"] +tests = ["pytest (>=4.6)"] + +[[package]] +name = "networkx" +version = "3.1" +description = "Python package for creating and manipulating graphs and networks" +optional = true +python-versions = ">=3.8" +files = [ {file = "networkx-3.1-py3-none-any.whl", hash = "sha256:4f33f68cb2afcf86f28a45f43efc27a9386b535d567d2127f8f61d51dec58d36"}, {file = "networkx-3.1.tar.gz", hash = "sha256:de346335408f84de0eada6ff9fafafff9bcda11f0a0dfaa931133debb146ab61"}, ] -numpy = [ + +[package.extras] +default = ["matplotlib (>=3.4)", "numpy (>=1.20)", "pandas (>=1.3)", "scipy (>=1.8)"] +developer = ["mypy (>=1.1)", "pre-commit (>=3.2)"] +doc = ["nb2plots (>=0.6)", "numpydoc (>=1.5)", "pillow (>=9.4)", "pydata-sphinx-theme (>=0.13)", "sphinx (>=6.1)", "sphinx-gallery (>=0.12)", "texext (>=0.6.7)"] +extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.10)", "sympy (>=1.10)"] +test = ["codecov (>=2.1)", "pytest (>=7.2)", "pytest-cov (>=4.0)"] + +[[package]] +name = "numpy" +version = "1.24.3" +description = "Fundamental package for array computing in Python" +optional = false +python-versions = ">=3.8" +files = [ {file = "numpy-1.24.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:3c1104d3c036fb81ab923f507536daedc718d0ad5a8707c6061cdfd6d184e570"}, {file = "numpy-1.24.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:202de8f38fc4a45a3eea4b63e2f376e5f2dc64ef0fa692838e31a808520efaf7"}, {file = "numpy-1.24.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8535303847b89aa6b0f00aa1dc62867b5a32923e4d1681a35b5eef2d9591a463"}, @@ -1090,66 +678,216 @@ numpy = [ {file = "numpy-1.24.3-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:35400e6a8d102fd07c71ed7dcadd9eb62ee9a6e84ec159bd48c28235bbb0f8e4"}, {file = "numpy-1.24.3.tar.gz", hash = "sha256:ab344f1bf21f140adab8e47fdbc7c35a477dc01408791f8ba00d018dd0bc5155"}, ] -opentelemetry-api = [ + +[[package]] +name = "opentelemetry-api" +version = "1.15.0" +description = "OpenTelemetry Python API" +optional = false +python-versions = ">=3.7" +files = [ {file = "opentelemetry_api-1.15.0-py3-none-any.whl", hash = "sha256:e6c2d2e42140fd396e96edf75a7ceb11073f4efb4db87565a431cc9d0f93f2e0"}, {file = "opentelemetry_api-1.15.0.tar.gz", hash = "sha256:79ab791b4aaad27acc3dc3ba01596db5b5aac2ef75c70622c6038051d6c2cded"}, ] -opentelemetry-exporter-otlp = [ + +[package.dependencies] +deprecated = ">=1.2.6" +setuptools = ">=16.0" + +[[package]] +name = "opentelemetry-exporter-otlp" +version = "1.15.0" +description = "OpenTelemetry Collector Exporters" +optional = false +python-versions = ">=3.7" +files = [ {file = "opentelemetry_exporter_otlp-1.15.0-py3-none-any.whl", hash = "sha256:79f22748b6a54808a0448093dfa189c8490e729f67c134d4c992533d9393b33e"}, {file = "opentelemetry_exporter_otlp-1.15.0.tar.gz", hash = "sha256:4f7c49751d9720e2e726e13b0bb958ccade4e29122c305d92c033da432c8d2c5"}, ] -opentelemetry-exporter-otlp-proto-grpc = [ + +[package.dependencies] +opentelemetry-exporter-otlp-proto-grpc = "1.15.0" +opentelemetry-exporter-otlp-proto-http = "1.15.0" + +[[package]] +name = "opentelemetry-exporter-otlp-proto-grpc" +version = "1.15.0" +description = "OpenTelemetry Collector Protobuf over gRPC Exporter" +optional = false +python-versions = ">=3.7" +files = [ {file = "opentelemetry_exporter_otlp_proto_grpc-1.15.0-py3-none-any.whl", hash = "sha256:c2a5492ba7d140109968135d641d06ce3c5bd73c50665f787526065d57d7fd1d"}, {file = "opentelemetry_exporter_otlp_proto_grpc-1.15.0.tar.gz", hash = "sha256:844f2a4bb9bcda34e4eb6fe36765e5031aacb36dc60ed88c90fc246942ea26e7"}, ] -opentelemetry-exporter-otlp-proto-http = [ + +[package.dependencies] +backoff = {version = ">=1.10.0,<3.0.0", markers = "python_version >= \"3.7\""} +googleapis-common-protos = ">=1.52,<2.0" +grpcio = ">=1.0.0,<2.0.0" +opentelemetry-api = ">=1.12,<2.0" +opentelemetry-proto = "1.15.0" +opentelemetry-sdk = ">=1.12,<2.0" + +[package.extras] +test = ["pytest-grpc"] + +[[package]] +name = "opentelemetry-exporter-otlp-proto-http" +version = "1.15.0" +description = "OpenTelemetry Collector Protobuf over HTTP Exporter" +optional = false +python-versions = ">=3.7" +files = [ {file = "opentelemetry_exporter_otlp_proto_http-1.15.0-py3-none-any.whl", hash = "sha256:3ec2a02196c8a54bf5cbf7fe623a5238625638e83b6047a983bdf96e2bbb74c0"}, {file = "opentelemetry_exporter_otlp_proto_http-1.15.0.tar.gz", hash = "sha256:11b2c814249a49b22f6cca7a06b05701f561d577b747f3660dfd67b6eb9daf9c"}, ] -opentelemetry-instrumentation = [ + +[package.dependencies] +backoff = {version = ">=1.10.0,<3.0.0", markers = "python_version >= \"3.7\""} +googleapis-common-protos = ">=1.52,<2.0" +opentelemetry-api = ">=1.12,<2.0" +opentelemetry-proto = "1.15.0" +opentelemetry-sdk = ">=1.12,<2.0" +requests = ">=2.7,<3.0" + +[package.extras] +test = ["responses (==0.22.0)"] + +[[package]] +name = "opentelemetry-instrumentation" +version = "0.36b0" +description = "Instrumentation Tools & Auto Instrumentation for OpenTelemetry Python" +optional = false +python-versions = ">=3.7" +files = [ {file = "opentelemetry_instrumentation-0.36b0-py3-none-any.whl", hash = "sha256:83ba4ae7d5292b5b33e0f851cc5c76d8f91196b9b3527800fc13855c33383ac2"}, {file = "opentelemetry_instrumentation-0.36b0.tar.gz", hash = "sha256:e3ddac9b3b93408ef26c8ecbf38f717042977e16381bb4cd329a5b4cf16998cf"}, ] -opentelemetry-instrumentation-grpc = [ + +[package.dependencies] +opentelemetry-api = ">=1.4,<2.0" +setuptools = ">=16.0" +wrapt = ">=1.0.0,<2.0.0" + +[[package]] +name = "opentelemetry-instrumentation-grpc" +version = "0.36b0" +description = "OpenTelemetry gRPC instrumentation" +optional = false +python-versions = ">=3.7" +files = [ {file = "opentelemetry_instrumentation_grpc-0.36b0-py3-none-any.whl", hash = "sha256:eaa246ed2083c97b13bab2555cb9d170e8433230a31476c4cab8a17fa03380a4"}, {file = "opentelemetry_instrumentation_grpc-0.36b0.tar.gz", hash = "sha256:dc89447c9eb6ea868970f6c13b4ffdac182cdd5a41dd215a0f5393ca6375be55"}, ] -opentelemetry-proto = [ + +[package.dependencies] +opentelemetry-api = ">=1.12,<2.0" +opentelemetry-instrumentation = "0.36b0" +opentelemetry-sdk = ">=1.12,<2.0" +opentelemetry-semantic-conventions = "0.36b0" +wrapt = ">=1.0.0,<2.0.0" + +[package.extras] +instruments = ["grpcio (>=1.27,<2.0)"] +test = ["opentelemetry-instrumentation-grpc[instruments]", "opentelemetry-sdk (>=1.12,<2.0)", "opentelemetry-test-utils (==0.36b0)", "protobuf (>=3.13,<4.0)"] + +[[package]] +name = "opentelemetry-proto" +version = "1.15.0" +description = "OpenTelemetry Python Proto" +optional = false +python-versions = ">=3.7" +files = [ {file = "opentelemetry_proto-1.15.0-py3-none-any.whl", hash = "sha256:044b6d044b4d10530f250856f933442b8753a17f94ae37c207607f733fb9a844"}, {file = "opentelemetry_proto-1.15.0.tar.gz", hash = "sha256:9c4008e40ac8cab359daac283fbe7002c5c29c77ea2674ad5626a249e64e0101"}, ] -opentelemetry-sdk = [ + +[package.dependencies] +protobuf = ">=3.19,<5.0" + +[[package]] +name = "opentelemetry-sdk" +version = "1.15.0" +description = "OpenTelemetry Python SDK" +optional = false +python-versions = ">=3.7" +files = [ {file = "opentelemetry_sdk-1.15.0-py3-none-any.whl", hash = "sha256:555c533e9837766119bbccc7a80458c9971d853a6f1da683a2246cd5e53b4645"}, {file = "opentelemetry_sdk-1.15.0.tar.gz", hash = "sha256:98dbffcfeebcbff12c0c974292d6ea603180a145904cf838b1fe4d5c99078425"}, ] -opentelemetry-semantic-conventions = [ + +[package.dependencies] +opentelemetry-api = "1.15.0" +opentelemetry-semantic-conventions = "0.36b0" +setuptools = ">=16.0" +typing-extensions = ">=3.7.4" + +[[package]] +name = "opentelemetry-semantic-conventions" +version = "0.36b0" +description = "OpenTelemetry Semantic Conventions" +optional = false +python-versions = ">=3.7" +files = [ {file = "opentelemetry_semantic_conventions-0.36b0-py3-none-any.whl", hash = "sha256:adc05635e87b9d3e007c9f530eed487fc3ef2177d02f82f674f28ebf9aff8243"}, {file = "opentelemetry_semantic_conventions-0.36b0.tar.gz", hash = "sha256:829dc221795467d98b773c04096e29be038d77526dc8d6ac76f546fb6279bf01"}, ] -packaging = [ + +[[package]] +name = "packaging" +version = "23.1" +description = "Core utilities for Python packages" +optional = false +python-versions = ">=3.7" +files = [ {file = "packaging-23.1-py3-none-any.whl", hash = "sha256:994793af429502c4ea2ebf6bf664629d07c1a9fe974af92966e4b8d2df7edc61"}, {file = "packaging-23.1.tar.gz", hash = "sha256:a392980d2b6cffa644431898be54b0045151319d1e7ec34f0cfed48767dd334f"}, ] -pluggy = [ + +[[package]] +name = "pluggy" +version = "1.0.0" +description = "plugin and hook calling mechanisms for python" +optional = false +python-versions = ">=3.6" +files = [ {file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"}, {file = "pluggy-1.0.0.tar.gz", hash = "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159"}, ] -protobuf = [ - {file = "protobuf-4.23.1-cp310-abi3-win32.whl", hash = "sha256:410bcc0a5b279f634d3e16082ce221dfef7c3392fac723500e2e64d1806dd2be"}, - {file = "protobuf-4.23.1-cp310-abi3-win_amd64.whl", hash = "sha256:32e78beda26d7a101fecf15d7a4a792278a0d26a31bc327ff05564a9d68ab8ee"}, - {file = "protobuf-4.23.1-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:f9510cac91e764e86acd74e2b7f7bc5e6127a7f3fb646d7c8033cfb84fd1176a"}, - {file = "protobuf-4.23.1-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:346990f634272caac1f09efbcfbbacb23098b1f606d172534c6fa2d9758bb436"}, - {file = "protobuf-4.23.1-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:3ce113b3f3362493bddc9069c2163a38f240a9ed685ff83e7bcb756b05e1deb0"}, - {file = "protobuf-4.23.1-cp37-cp37m-win32.whl", hash = "sha256:2036a3a1e7fc27f973fa0a7888dce712393af644f4695385f117886abc792e39"}, - {file = "protobuf-4.23.1-cp37-cp37m-win_amd64.whl", hash = "sha256:3b8905eafe4439076e1f58e9d1fa327025fd2777cf90f14083092ae47f77b0aa"}, - {file = "protobuf-4.23.1-cp38-cp38-win32.whl", hash = "sha256:5b9cd6097e6acae48a68cb29b56bc79339be84eca65b486910bb1e7a30e2b7c1"}, - {file = "protobuf-4.23.1-cp38-cp38-win_amd64.whl", hash = "sha256:decf119d54e820f298ee6d89c72d6b289ea240c32c521f00433f9dc420595f38"}, - {file = "protobuf-4.23.1-cp39-cp39-win32.whl", hash = "sha256:91fac0753c3c4951fbb98a93271c43cc7cf3b93cf67747b3e600bb1e5cc14d61"}, - {file = "protobuf-4.23.1-cp39-cp39-win_amd64.whl", hash = "sha256:ac50be82491369a9ec3710565777e4da87c6d2e20404e0abb1f3a8f10ffd20f0"}, - {file = "protobuf-4.23.1-py3-none-any.whl", hash = "sha256:65f0ac96ef67d7dd09b19a46aad81a851b6f85f89725577f16de38f2d68ad477"}, - {file = "protobuf-4.23.1.tar.gz", hash = "sha256:95789b569418a3e32a53f43d7763be3d490a831e9c08042539462b6d972c2d7e"}, + +[package.extras] +dev = ["pre-commit", "tox"] +testing = ["pytest", "pytest-benchmark"] + +[[package]] +name = "protobuf" +version = "4.23.2" +description = "" +optional = false +python-versions = ">=3.7" +files = [ + {file = "protobuf-4.23.2-cp310-abi3-win32.whl", hash = "sha256:384dd44cb4c43f2ccddd3645389a23ae61aeb8cfa15ca3a0f60e7c3ea09b28b3"}, + {file = "protobuf-4.23.2-cp310-abi3-win_amd64.whl", hash = "sha256:09310bce43353b46d73ba7e3bca78273b9bc50349509b9698e64d288c6372c2a"}, + {file = "protobuf-4.23.2-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:b2cfab63a230b39ae603834718db74ac11e52bccaaf19bf20f5cce1a84cf76df"}, + {file = "protobuf-4.23.2-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:c52cfcbfba8eb791255edd675c1fe6056f723bf832fa67f0442218f8817c076e"}, + {file = "protobuf-4.23.2-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:86df87016d290143c7ce3be3ad52d055714ebaebb57cc659c387e76cfacd81aa"}, + {file = "protobuf-4.23.2-cp37-cp37m-win32.whl", hash = "sha256:281342ea5eb631c86697e1e048cb7e73b8a4e85f3299a128c116f05f5c668f8f"}, + {file = "protobuf-4.23.2-cp37-cp37m-win_amd64.whl", hash = "sha256:ce744938406de1e64b91410f473736e815f28c3b71201302612a68bf01517fea"}, + {file = "protobuf-4.23.2-cp38-cp38-win32.whl", hash = "sha256:6c081863c379bb1741be8f8193e893511312b1d7329b4a75445d1ea9955be69e"}, + {file = "protobuf-4.23.2-cp38-cp38-win_amd64.whl", hash = "sha256:25e3370eda26469b58b602e29dff069cfaae8eaa0ef4550039cc5ef8dc004511"}, + {file = "protobuf-4.23.2-cp39-cp39-win32.whl", hash = "sha256:efabbbbac1ab519a514579ba9ec52f006c28ae19d97915951f69fa70da2c9e91"}, + {file = "protobuf-4.23.2-cp39-cp39-win_amd64.whl", hash = "sha256:54a533b971288af3b9926e53850c7eb186886c0c84e61daa8444385a4720297f"}, + {file = "protobuf-4.23.2-py3-none-any.whl", hash = "sha256:8da6070310d634c99c0db7df48f10da495cc283fd9e9234877f0cd182d43ab7f"}, + {file = "protobuf-4.23.2.tar.gz", hash = "sha256:20874e7ca4436f683b64ebdbee2129a5a2c301579a67d1a7dda2cdf62fb7f5f7"}, ] -psutil = [ + +[[package]] +name = "psutil" +version = "5.9.5" +description = "Cross-platform lib for process and system monitoring in Python." +optional = true +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ {file = "psutil-5.9.5-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:be8929ce4313f9f8146caad4272f6abb8bf99fc6cf59344a3167ecd74f4f203f"}, {file = "psutil-5.9.5-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:ab8ed1a1d77c95453db1ae00a3f9c50227ebd955437bcf2a574ba8adbf6a74d5"}, {file = "psutil-5.9.5-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:4aef137f3345082a3d3232187aeb4ac4ef959ba3d7c10c33dd73763fbc063da4"}, @@ -1165,11 +903,39 @@ psutil = [ {file = "psutil-5.9.5-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:c607bb3b57dc779d55e1554846352b4e358c10fff3abf3514a7a6601beebdb30"}, {file = "psutil-5.9.5.tar.gz", hash = "sha256:5410638e4df39c54d957fc51ce03048acd8e6d60abc0f5107af51e5fb566eb3c"}, ] -pytest = [ - {file = "pytest-7.3.1-py3-none-any.whl", hash = "sha256:3799fa815351fea3a5e96ac7e503a96fa51cc9942c3753cda7651b93c1cfa362"}, - {file = "pytest-7.3.1.tar.gz", hash = "sha256:434afafd78b1d78ed0addf160ad2b77a30d35d4bdf8af234fe621919d9ed15e3"}, + +[package.extras] +test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] + +[[package]] +name = "pytest" +version = "7.3.2" +description = "pytest: simple powerful testing with Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-7.3.2-py3-none-any.whl", hash = "sha256:cdcbd012c9312258922f8cd3f1b62a6580fdced17db6014896053d47cddf9295"}, + {file = "pytest-7.3.2.tar.gz", hash = "sha256:ee990a3cc55ba808b80795a79944756f315c67c12b56abd3ac993a7b8c17030b"}, ] -PyYAML = [ + +[package.dependencies] +colorama = {version = "*", markers = "sys_platform == \"win32\""} +exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} +iniconfig = "*" +packaging = "*" +pluggy = ">=0.12,<2.0" +tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} + +[package.extras] +testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] + +[[package]] +name = "pyyaml" +version = "6.0" +description = "YAML parser and emitter for Python" +optional = false +python-versions = ">=3.6" +files = [ {file = "PyYAML-6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d4db7c7aef085872ef65a8fd7d6d09a14ae91f691dec3e87ee5ee0539d516f53"}, {file = "PyYAML-6.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9df7ed3b3d2e0ecfe09e14741b857df43adb5a3ddadc919a2d94fbdf78fea53c"}, {file = "PyYAML-6.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:77f396e6ef4c73fdc33a9157446466f1cff553d979bd00ecb64385760c6babdc"}, @@ -1211,11 +977,132 @@ PyYAML = [ {file = "PyYAML-6.0-cp39-cp39-win_amd64.whl", hash = "sha256:b3d267842bf12586ba6c734f89d1f5b871df0273157918b0ccefa29deb05c21c"}, {file = "PyYAML-6.0.tar.gz", hash = "sha256:68fb519c14306fec9720a2a5b45bc9f0c8d1b9c72adf45c37baedfcd949c35a2"}, ] -requests = [ + +[[package]] +name = "regex" +version = "2023.6.3" +description = "Alternative regular expression module, to replace re." +optional = false +python-versions = ">=3.6" +files = [ + {file = "regex-2023.6.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:824bf3ac11001849aec3fa1d69abcb67aac3e150a933963fb12bda5151fe1bfd"}, + {file = "regex-2023.6.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:05ed27acdf4465c95826962528f9e8d41dbf9b1aa8531a387dee6ed215a3e9ef"}, + {file = "regex-2023.6.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b49c764f88a79160fa64f9a7b425620e87c9f46095ef9c9920542ab2495c8bc"}, + {file = "regex-2023.6.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8e3f1316c2293e5469f8f09dc2d76efb6c3982d3da91ba95061a7e69489a14ef"}, + {file = "regex-2023.6.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:43e1dd9d12df9004246bacb79a0e5886b3b6071b32e41f83b0acbf293f820ee8"}, + {file = "regex-2023.6.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4959e8bcbfda5146477d21c3a8ad81b185cd252f3d0d6e4724a5ef11c012fb06"}, + {file = "regex-2023.6.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:af4dd387354dc83a3bff67127a124c21116feb0d2ef536805c454721c5d7993d"}, + {file = "regex-2023.6.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:2239d95d8e243658b8dbb36b12bd10c33ad6e6933a54d36ff053713f129aa536"}, + {file = "regex-2023.6.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:890e5a11c97cf0d0c550eb661b937a1e45431ffa79803b942a057c4fb12a2da2"}, + {file = "regex-2023.6.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:a8105e9af3b029f243ab11ad47c19b566482c150c754e4c717900a798806b222"}, + {file = "regex-2023.6.3-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:25be746a8ec7bc7b082783216de8e9473803706723b3f6bef34b3d0ed03d57e2"}, + {file = "regex-2023.6.3-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:3676f1dd082be28b1266c93f618ee07741b704ab7b68501a173ce7d8d0d0ca18"}, + {file = "regex-2023.6.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:10cb847aeb1728412c666ab2e2000ba6f174f25b2bdc7292e7dd71b16db07568"}, + {file = "regex-2023.6.3-cp310-cp310-win32.whl", hash = "sha256:dbbbfce33cd98f97f6bffb17801b0576e653f4fdb1d399b2ea89638bc8d08ae1"}, + {file = "regex-2023.6.3-cp310-cp310-win_amd64.whl", hash = "sha256:c5f8037000eb21e4823aa485149f2299eb589f8d1fe4b448036d230c3f4e68e0"}, + {file = "regex-2023.6.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c123f662be8ec5ab4ea72ea300359023a5d1df095b7ead76fedcd8babbedf969"}, + {file = "regex-2023.6.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9edcbad1f8a407e450fbac88d89e04e0b99a08473f666a3f3de0fd292badb6aa"}, + {file = "regex-2023.6.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dcba6dae7de533c876255317c11f3abe4907ba7d9aa15d13e3d9710d4315ec0e"}, + {file = "regex-2023.6.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:29cdd471ebf9e0f2fb3cac165efedc3c58db841d83a518b082077e612d3ee5df"}, + {file = "regex-2023.6.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:12b74fbbf6cbbf9dbce20eb9b5879469e97aeeaa874145517563cca4029db65c"}, + {file = "regex-2023.6.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c29ca1bd61b16b67be247be87390ef1d1ef702800f91fbd1991f5c4421ebae8"}, + {file = "regex-2023.6.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d77f09bc4b55d4bf7cc5eba785d87001d6757b7c9eec237fe2af57aba1a071d9"}, + {file = "regex-2023.6.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ea353ecb6ab5f7e7d2f4372b1e779796ebd7b37352d290096978fea83c4dba0c"}, + {file = "regex-2023.6.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:10590510780b7541969287512d1b43f19f965c2ece6c9b1c00fc367b29d8dce7"}, + {file = "regex-2023.6.3-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:e2fbd6236aae3b7f9d514312cdb58e6494ee1c76a9948adde6eba33eb1c4264f"}, + {file = "regex-2023.6.3-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:6b2675068c8b56f6bfd5a2bda55b8accbb96c02fd563704732fd1c95e2083461"}, + {file = "regex-2023.6.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:74419d2b50ecb98360cfaa2974da8689cb3b45b9deff0dcf489c0d333bcc1477"}, + {file = "regex-2023.6.3-cp311-cp311-win32.whl", hash = "sha256:fb5ec16523dc573a4b277663a2b5a364e2099902d3944c9419a40ebd56a118f9"}, + {file = "regex-2023.6.3-cp311-cp311-win_amd64.whl", hash = "sha256:09e4a1a6acc39294a36b7338819b10baceb227f7f7dbbea0506d419b5a1dd8af"}, + {file = "regex-2023.6.3-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:0654bca0cdf28a5956c83839162692725159f4cda8d63e0911a2c0dc76166525"}, + {file = "regex-2023.6.3-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:463b6a3ceb5ca952e66550a4532cef94c9a0c80dc156c4cc343041951aec1697"}, + {file = "regex-2023.6.3-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:87b2a5bb5e78ee0ad1de71c664d6eb536dc3947a46a69182a90f4410f5e3f7dd"}, + {file = "regex-2023.6.3-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6343c6928282c1f6a9db41f5fd551662310e8774c0e5ebccb767002fcf663ca9"}, + {file = "regex-2023.6.3-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b6192d5af2ccd2a38877bfef086d35e6659566a335b1492786ff254c168b1693"}, + {file = "regex-2023.6.3-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:74390d18c75054947e4194019077e243c06fbb62e541d8817a0fa822ea310c14"}, + {file = "regex-2023.6.3-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:742e19a90d9bb2f4a6cf2862b8b06dea5e09b96c9f2df1779e53432d7275331f"}, + {file = "regex-2023.6.3-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:8abbc5d54ea0ee80e37fef009e3cec5dafd722ed3c829126253d3e22f3846f1e"}, + {file = "regex-2023.6.3-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:c2b867c17a7a7ae44c43ebbeb1b5ff406b3e8d5b3e14662683e5e66e6cc868d3"}, + {file = "regex-2023.6.3-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:d831c2f8ff278179705ca59f7e8524069c1a989e716a1874d6d1aab6119d91d1"}, + {file = "regex-2023.6.3-cp36-cp36m-musllinux_1_1_s390x.whl", hash = "sha256:ee2d1a9a253b1729bb2de27d41f696ae893507c7db224436abe83ee25356f5c1"}, + {file = "regex-2023.6.3-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:61474f0b41fe1a80e8dfa70f70ea1e047387b7cd01c85ec88fa44f5d7561d787"}, + {file = "regex-2023.6.3-cp36-cp36m-win32.whl", hash = "sha256:0b71e63226e393b534105fcbdd8740410dc6b0854c2bfa39bbda6b0d40e59a54"}, + {file = "regex-2023.6.3-cp36-cp36m-win_amd64.whl", hash = "sha256:bbb02fd4462f37060122e5acacec78e49c0fbb303c30dd49c7f493cf21fc5b27"}, + {file = "regex-2023.6.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b862c2b9d5ae38a68b92e215b93f98d4c5e9454fa36aae4450f61dd33ff48487"}, + {file = "regex-2023.6.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:976d7a304b59ede34ca2921305b57356694f9e6879db323fd90a80f865d355a3"}, + {file = "regex-2023.6.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:83320a09188e0e6c39088355d423aa9d056ad57a0b6c6381b300ec1a04ec3d16"}, + {file = "regex-2023.6.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9427a399501818a7564f8c90eced1e9e20709ece36be701f394ada99890ea4b3"}, + {file = "regex-2023.6.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7178bbc1b2ec40eaca599d13c092079bf529679bf0371c602edaa555e10b41c3"}, + {file = "regex-2023.6.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:837328d14cde912af625d5f303ec29f7e28cdab588674897baafaf505341f2fc"}, + {file = "regex-2023.6.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:2d44dc13229905ae96dd2ae2dd7cebf824ee92bc52e8cf03dcead37d926da019"}, + {file = "regex-2023.6.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d54af539295392611e7efbe94e827311eb8b29668e2b3f4cadcfe6f46df9c777"}, + {file = "regex-2023.6.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:7117d10690c38a622e54c432dfbbd3cbd92f09401d622902c32f6d377e2300ee"}, + {file = "regex-2023.6.3-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:bb60b503ec8a6e4e3e03a681072fa3a5adcbfa5479fa2d898ae2b4a8e24c4591"}, + {file = "regex-2023.6.3-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:65ba8603753cec91c71de423a943ba506363b0e5c3fdb913ef8f9caa14b2c7e0"}, + {file = "regex-2023.6.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:271f0bdba3c70b58e6f500b205d10a36fb4b58bd06ac61381b68de66442efddb"}, + {file = "regex-2023.6.3-cp37-cp37m-win32.whl", hash = "sha256:9beb322958aaca059f34975b0df135181f2e5d7a13b84d3e0e45434749cb20f7"}, + {file = "regex-2023.6.3-cp37-cp37m-win_amd64.whl", hash = "sha256:fea75c3710d4f31389eed3c02f62d0b66a9da282521075061ce875eb5300cf23"}, + {file = "regex-2023.6.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8f56fcb7ff7bf7404becdfc60b1e81a6d0561807051fd2f1860b0d0348156a07"}, + {file = "regex-2023.6.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:d2da3abc88711bce7557412310dfa50327d5769a31d1c894b58eb256459dc289"}, + {file = "regex-2023.6.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a99b50300df5add73d307cf66abea093304a07eb017bce94f01e795090dea87c"}, + {file = "regex-2023.6.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5708089ed5b40a7b2dc561e0c8baa9535b77771b64a8330b684823cfd5116036"}, + {file = "regex-2023.6.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:687ea9d78a4b1cf82f8479cab23678aff723108df3edeac098e5b2498879f4a7"}, + {file = "regex-2023.6.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4d3850beab9f527f06ccc94b446c864059c57651b3f911fddb8d9d3ec1d1b25d"}, + {file = "regex-2023.6.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e8915cc96abeb8983cea1df3c939e3c6e1ac778340c17732eb63bb96247b91d2"}, + {file = "regex-2023.6.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:841d6e0e5663d4c7b4c8099c9997be748677d46cbf43f9f471150e560791f7ff"}, + {file = "regex-2023.6.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:9edce5281f965cf135e19840f4d93d55b3835122aa76ccacfd389e880ba4cf82"}, + {file = "regex-2023.6.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:b956231ebdc45f5b7a2e1f90f66a12be9610ce775fe1b1d50414aac1e9206c06"}, + {file = "regex-2023.6.3-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:36efeba71c6539d23c4643be88295ce8c82c88bbd7c65e8a24081d2ca123da3f"}, + {file = "regex-2023.6.3-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:cf67ca618b4fd34aee78740bea954d7c69fdda419eb208c2c0c7060bb822d747"}, + {file = "regex-2023.6.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b4598b1897837067a57b08147a68ac026c1e73b31ef6e36deeeb1fa60b2933c9"}, + {file = "regex-2023.6.3-cp38-cp38-win32.whl", hash = "sha256:f415f802fbcafed5dcc694c13b1292f07fe0befdb94aa8a52905bd115ff41e88"}, + {file = "regex-2023.6.3-cp38-cp38-win_amd64.whl", hash = "sha256:d4f03bb71d482f979bda92e1427f3ec9b220e62a7dd337af0aa6b47bf4498f72"}, + {file = "regex-2023.6.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ccf91346b7bd20c790310c4147eee6ed495a54ddb6737162a36ce9dbef3e4751"}, + {file = "regex-2023.6.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b28f5024a3a041009eb4c333863d7894d191215b39576535c6734cd88b0fcb68"}, + {file = "regex-2023.6.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e0bb18053dfcfed432cc3ac632b5e5e5c5b7e55fb3f8090e867bfd9b054dbcbf"}, + {file = "regex-2023.6.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9a5bfb3004f2144a084a16ce19ca56b8ac46e6fd0651f54269fc9e230edb5e4a"}, + {file = "regex-2023.6.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c6b48d0fa50d8f4df3daf451be7f9689c2bde1a52b1225c5926e3f54b6a9ed1"}, + {file = "regex-2023.6.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:051da80e6eeb6e239e394ae60704d2b566aa6a7aed6f2890a7967307267a5dc6"}, + {file = "regex-2023.6.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a4c3b7fa4cdaa69268748665a1a6ff70c014d39bb69c50fda64b396c9116cf77"}, + {file = "regex-2023.6.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:457b6cce21bee41ac292d6753d5e94dcbc5c9e3e3a834da285b0bde7aa4a11e9"}, + {file = "regex-2023.6.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:aad51907d74fc183033ad796dd4c2e080d1adcc4fd3c0fd4fd499f30c03011cd"}, + {file = "regex-2023.6.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:0385e73da22363778ef2324950e08b689abdf0b108a7d8decb403ad7f5191938"}, + {file = "regex-2023.6.3-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:c6a57b742133830eec44d9b2290daf5cbe0a2f1d6acee1b3c7b1c7b2f3606df7"}, + {file = "regex-2023.6.3-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:3e5219bf9e75993d73ab3d25985c857c77e614525fac9ae02b1bebd92f7cecac"}, + {file = "regex-2023.6.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e5087a3c59eef624a4591ef9eaa6e9a8d8a94c779dade95d27c0bc24650261cd"}, + {file = "regex-2023.6.3-cp39-cp39-win32.whl", hash = "sha256:20326216cc2afe69b6e98528160b225d72f85ab080cbdf0b11528cbbaba2248f"}, + {file = "regex-2023.6.3-cp39-cp39-win_amd64.whl", hash = "sha256:bdff5eab10e59cf26bc479f565e25ed71a7d041d1ded04ccf9aee1d9f208487a"}, + {file = "regex-2023.6.3.tar.gz", hash = "sha256:72d1a25bf36d2050ceb35b517afe13864865268dfb45910e2e17a84be6cbfeb0"}, +] + +[[package]] +name = "requests" +version = "2.31.0" +description = "Python HTTP for Humans." +optional = false +python-versions = ">=3.7" +files = [ {file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"}, {file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"}, ] -safetensors = [ + +[package.dependencies] +certifi = ">=2017.4.17" +charset-normalizer = ">=2,<4" +idna = ">=2.5,<4" +urllib3 = ">=1.21.1,<3" + +[package.extras] +socks = ["PySocks (>=1.5.6,!=1.5.7)"] +use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] + +[[package]] +name = "safetensors" +version = "0.3.1" +description = "Fast and Safe Tensor serialization" +optional = false +python-versions = "*" +files = [ {file = "safetensors-0.3.1-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:2ae9b7dd268b4bae6624729dac86deb82104820e9786429b0583e5168db2f770"}, {file = "safetensors-0.3.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:08c85c1934682f1e2cd904d38433b53cd2a98245a7cc31f5689f9322a2320bbf"}, {file = "safetensors-0.3.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ba625c7af9e1c5d0d91cb83d2fba97d29ea69d4db2015d9714d24c7f6d488e15"}, @@ -1257,7 +1144,25 @@ safetensors = [ {file = "safetensors-0.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:5f4f614b8e8161cd8a9ca19c765d176a82b122fa3d3387b77862145bfe9b4e93"}, {file = "safetensors-0.3.1.tar.gz", hash = "sha256:571da56ff8d0bec8ae54923b621cda98d36dcef10feb36fd492c4d0c2cd0e869"}, ] -sentencepiece = [ + +[package.extras] +all = ["black (==22.3)", "click (==8.0.4)", "flake8 (>=3.8.3)", "flax (>=0.6.3)", "h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "isort (>=5.5.4)", "jax (>=0.3.25)", "jaxlib (>=0.3.25)", "numpy (>=1.21.6)", "paddlepaddle (>=2.4.1)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "setuptools-rust (>=1.5.2)", "tensorflow (>=2.11.0)", "torch (>=1.10)"] +dev = ["black (==22.3)", "click (==8.0.4)", "flake8 (>=3.8.3)", "flax (>=0.6.3)", "h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "isort (>=5.5.4)", "jax (>=0.3.25)", "jaxlib (>=0.3.25)", "numpy (>=1.21.6)", "paddlepaddle (>=2.4.1)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "setuptools-rust (>=1.5.2)", "tensorflow (>=2.11.0)", "torch (>=1.10)"] +jax = ["flax (>=0.6.3)", "jax (>=0.3.25)", "jaxlib (>=0.3.25)"] +numpy = ["numpy (>=1.21.6)"] +paddlepaddle = ["paddlepaddle (>=2.4.1)"] +quality = ["black (==22.3)", "click (==8.0.4)", "flake8 (>=3.8.3)", "isort (>=5.5.4)"] +tensorflow = ["tensorflow (>=2.11.0)"] +testing = ["h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "numpy (>=1.21.6)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "setuptools-rust (>=1.5.2)"] +torch = ["torch (>=1.10)"] + +[[package]] +name = "sentencepiece" +version = "0.1.99" +description = "SentencePiece python wrapper" +optional = false +python-versions = "*" +files = [ {file = "sentencepiece-0.1.99-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0eb528e70571b7c02723e5804322469b82fe7ea418c96051d0286c0fa028db73"}, {file = "sentencepiece-0.1.99-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:77d7fafb2c4e4659cbdf303929503f37a26eabc4ff31d3a79bf1c5a1b338caa7"}, {file = "sentencepiece-0.1.99-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:be9cf5b9e404c245aeb3d3723c737ba7a8f5d4ba262ef233a431fa6c45f732a0"}, @@ -1304,15 +1209,44 @@ sentencepiece = [ {file = "sentencepiece-0.1.99-cp39-cp39-win_amd64.whl", hash = "sha256:350e5c74d739973f1c9643edb80f7cc904dc948578bcb1d43c6f2b173e5d18dd"}, {file = "sentencepiece-0.1.99.tar.gz", hash = "sha256:189c48f5cb2949288f97ccdb97f0473098d9c3dcf5a3d99d4eabe719ec27297f"}, ] -setuptools = [ + +[[package]] +name = "setuptools" +version = "67.8.0" +description = "Easily download, build, install, upgrade, and uninstall Python packages" +optional = false +python-versions = ">=3.7" +files = [ {file = "setuptools-67.8.0-py3-none-any.whl", hash = "sha256:5df61bf30bb10c6f756eb19e7c9f3b473051f48db77fddbe06ff2ca307df9a6f"}, {file = "setuptools-67.8.0.tar.gz", hash = "sha256:62642358adc77ffa87233bc4d2354c4b2682d214048f500964dbe760ccedf102"}, ] -sympy = [ + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-hoverxref (<2)", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (==0.8.3)", "sphinx-reredirects", "sphinxcontrib-towncrier"] +testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pip-run (>=8.8)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] +testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] + +[[package]] +name = "sympy" +version = "1.12" +description = "Computer algebra system (CAS) in Python" +optional = true +python-versions = ">=3.8" +files = [ {file = "sympy-1.12-py3-none-any.whl", hash = "sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5"}, {file = "sympy-1.12.tar.gz", hash = "sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8"}, ] -tokenizers = [ + +[package.dependencies] +mpmath = ">=0.19" + +[[package]] +name = "tokenizers" +version = "0.13.3" +description = "Fast and Customizable Tokenizers" +optional = false +python-versions = "*" +files = [ {file = "tokenizers-0.13.3-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:f3835c5be51de8c0a092058a4d4380cb9244fb34681fd0a295fbf0a52a5fdf33"}, {file = "tokenizers-0.13.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:4ef4c3e821730f2692489e926b184321e887f34fb8a6b80b8096b966ba663d07"}, {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5fd1a6a25353e9aa762e2aae5a1e63883cad9f4e997c447ec39d071020459bc"}, @@ -1354,11 +1288,30 @@ tokenizers = [ {file = "tokenizers-0.13.3-cp39-cp39-win_amd64.whl", hash = "sha256:bc0a6f1ba036e482db6453571c9e3e60ecd5489980ffd95d11dc9f960483d783"}, {file = "tokenizers-0.13.3.tar.gz", hash = "sha256:2e546dbb68b623008a5442353137fbb0123d311a6d7ba52f2667c8862a75af2e"}, ] -tomli = [ + +[package.extras] +dev = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"] +docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"] +testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"] + +[[package]] +name = "tomli" +version = "2.0.1" +description = "A lil' TOML parser" +optional = false +python-versions = ">=3.7" +files = [ {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, ] -torch = [ + +[[package]] +name = "torch" +version = "2.0.1" +description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" +optional = true +python-versions = ">=3.8.0" +files = [ {file = "torch-2.0.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:8ced00b3ba471856b993822508f77c98f48a458623596a4c43136158781e306a"}, {file = "torch-2.0.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:359bfaad94d1cda02ab775dc1cc386d585712329bb47b8741607ef6ef4950747"}, {file = "torch-2.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:7c84e44d9002182edd859f3400deaa7410f5ec948a519cc7ef512c2f9b34d2c4"}, @@ -1380,27 +1333,175 @@ torch = [ {file = "torch-2.0.1-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:c62df99352bd6ee5a5a8d1832452110435d178b5164de450831a3a8cc14dc680"}, {file = "torch-2.0.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:671a2565e3f63b8fe8e42ae3e36ad249fe5e567435ea27b94edaa672a7d0c416"}, ] -tqdm = [ + +[package.dependencies] +filelock = "*" +jinja2 = "*" +networkx = "*" +sympy = "*" +typing-extensions = "*" + +[package.extras] +opt-einsum = ["opt-einsum (>=3.3)"] + +[[package]] +name = "tqdm" +version = "4.65.0" +description = "Fast, Extensible Progress Meter" +optional = false +python-versions = ">=3.7" +files = [ {file = "tqdm-4.65.0-py3-none-any.whl", hash = "sha256:c4f53a17fe37e132815abceec022631be8ffe1b9381c2e6e30aa70edc99e9671"}, {file = "tqdm-4.65.0.tar.gz", hash = "sha256:1871fb68a86b8fb3b59ca4cdd3dcccbc7e6d613eeed31f4c332531977b89beb5"}, ] -typer = [ + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + +[package.extras] +dev = ["py-make (>=0.1.0)", "twine", "wheel"] +notebook = ["ipywidgets (>=6)"] +slack = ["slack-sdk"] +telegram = ["requests"] + +[[package]] +name = "transformers" +version = "4.30.2" +description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "transformers-4.30.2-py3-none-any.whl", hash = "sha256:c332e3a3097f9ed89ce556b403251235931c00237b8bc2d7adaa19d226c13f1d"}, + {file = "transformers-4.30.2.tar.gz", hash = "sha256:f4a8aac4e1baffab4033f4a345b0d7dc7957d12a4f1ba969afea08205a513045"}, +] + +[package.dependencies] +filelock = "*" +huggingface-hub = ">=0.14.1,<1.0" +numpy = ">=1.17" +packaging = ">=20.0" +pyyaml = ">=5.1" +regex = "!=2019.12.17" +requests = "*" +safetensors = ">=0.3.1" +tokenizers = ">=0.11.1,<0.11.3 || >0.11.3,<0.14" +tqdm = ">=4.27" + +[package.extras] +accelerate = ["accelerate (>=0.20.2)"] +agents = ["Pillow", "accelerate (>=0.20.2)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=1.9,!=1.12.0)"] +all = ["Pillow", "accelerate (>=0.20.2)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.6.9)", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf (<=3.20.3)", "pyctcdecode (>=0.4.0)", "ray[tune]", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.4,<2.13)", "tensorflow-text (<2.13)", "tf2onnx", "timm", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "torchaudio", "torchvision"] +audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] +codecarbon = ["codecarbon (==1.2.0)"] +deepspeed = ["accelerate (>=0.20.2)", "deepspeed (>=0.8.3)"] +deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.20.2)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.8.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "optuna", "parameterized", "protobuf (<=3.20.3)", "psutil", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "timeout-decorator"] +dev = ["GitPython (<3.1.19)", "Pillow", "accelerate (>=0.20.2)", "av (==9.2.0)", "beautifulsoup4", "black (>=23.1,<24.0)", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.6.9)", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf (<=3.20.3)", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune]", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorflow (>=2.4,<2.13)", "tensorflow-text (<2.13)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +dev-tensorflow = ["GitPython (<3.1.19)", "Pillow", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf (<=3.20.3)", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorflow (>=2.4,<2.13)", "tensorflow-text (<2.13)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "urllib3 (<2.0.0)"] +dev-torch = ["GitPython (<3.1.19)", "Pillow", "accelerate (>=0.20.2)", "beautifulsoup4", "black (>=23.1,<24.0)", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf (<=3.20.3)", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune]", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "timeout-decorator", "timm", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +docs = ["Pillow", "accelerate (>=0.20.2)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.6.9)", "hf-doc-builder", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf (<=3.20.3)", "pyctcdecode (>=0.4.0)", "ray[tune]", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.4,<2.13)", "tensorflow-text (<2.13)", "tf2onnx", "timm", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "torchaudio", "torchvision"] +docs-specific = ["hf-doc-builder"] +fairscale = ["fairscale (>0.3)"] +flax = ["flax (>=0.4.1,<=0.6.9)", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "optax (>=0.0.8,<=0.1.4)"] +flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] +ftfy = ["ftfy"] +integrations = ["optuna", "ray[tune]", "sigopt"] +ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)"] +modelcreation = ["cookiecutter (==1.7.3)"] +natten = ["natten (>=0.14.6)"] +onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"] +onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"] +optuna = ["optuna"] +quality = ["GitPython (<3.1.19)", "black (>=23.1,<24.0)", "datasets (!=2.5.0)", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "ruff (>=0.0.241,<=0.0.259)", "urllib3 (<2.0.0)"] +ray = ["ray[tune]"] +retrieval = ["datasets (!=2.5.0)", "faiss-cpu"] +sagemaker = ["sagemaker (>=2.31.0)"] +sentencepiece = ["protobuf (<=3.20.3)", "sentencepiece (>=0.1.91,!=0.1.92)"] +serving = ["fastapi", "pydantic", "starlette", "uvicorn"] +sigopt = ["sigopt"] +sklearn = ["scikit-learn"] +speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] +testing = ["GitPython (<3.1.19)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "parameterized", "protobuf (<=3.20.3)", "psutil", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "timeout-decorator"] +tf = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow (>=2.4,<2.13)", "tensorflow-text (<2.13)", "tf2onnx"] +tf-cpu = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow-cpu (>=2.4,<2.13)", "tensorflow-text (<2.13)", "tf2onnx"] +tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] +timm = ["timm"] +tokenizers = ["tokenizers (>=0.11.1,!=0.11.3,<0.14)"] +torch = ["accelerate (>=0.20.2)", "torch (>=1.9,!=1.12.0)"] +torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] +torch-vision = ["Pillow", "torchvision"] +torchhub = ["filelock", "huggingface-hub (>=0.14.1,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf (<=3.20.3)", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "tqdm (>=4.27)"] +video = ["av (==9.2.0)", "decord (==0.6.0)"] +vision = ["Pillow"] + +[[package]] +name = "typer" +version = "0.6.1" +description = "Typer, build great CLIs. Easy to code. Based on Python type hints." +optional = false +python-versions = ">=3.6" +files = [ {file = "typer-0.6.1-py3-none-any.whl", hash = "sha256:54b19e5df18654070a82f8c2aa1da456a4ac16a2a83e6dcd9f170e291c56338e"}, {file = "typer-0.6.1.tar.gz", hash = "sha256:2d5720a5e63f73eaf31edaa15f6ab87f35f0690f8ca233017d7d23d743a91d73"}, ] -typing-extensions = [ - {file = "typing_extensions-4.6.0-py3-none-any.whl", hash = "sha256:6ad00b63f849b7dcc313b70b6b304ed67b2b2963b3098a33efe18056b1a9a223"}, - {file = "typing_extensions-4.6.0.tar.gz", hash = "sha256:ff6b238610c747e44c268aa4bb23c8c735d665a63726df3f9431ce707f2aa768"}, + +[package.dependencies] +click = ">=7.1.1,<9.0.0" + +[package.extras] +all = ["colorama (>=0.4.3,<0.5.0)", "rich (>=10.11.0,<13.0.0)", "shellingham (>=1.3.0,<2.0.0)"] +dev = ["autoflake (>=1.3.1,<2.0.0)", "flake8 (>=3.8.3,<4.0.0)", "pre-commit (>=2.17.0,<3.0.0)"] +doc = ["mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-material (>=8.1.4,<9.0.0)"] +test = ["black (>=22.3.0,<23.0.0)", "coverage (>=5.2,<6.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.910)", "pytest (>=4.4.0,<5.4.0)", "pytest-cov (>=2.10.0,<3.0.0)", "pytest-sugar (>=0.9.4,<0.10.0)", "pytest-xdist (>=1.32.0,<2.0.0)", "rich (>=10.11.0,<13.0.0)", "shellingham (>=1.3.0,<2.0.0)"] + +[[package]] +name = "typing-extensions" +version = "4.6.3" +description = "Backported and Experimental Type Hints for Python 3.7+" +optional = false +python-versions = ">=3.7" +files = [ + {file = "typing_extensions-4.6.3-py3-none-any.whl", hash = "sha256:88a4153d8505aabbb4e13aacb7c486c2b4a33ca3b3f807914a9b4c844c471c26"}, + {file = "typing_extensions-4.6.3.tar.gz", hash = "sha256:d91d5919357fe7f681a9f2b5b4cb2a5f1ef0a1e9f59c4d8ff0d3491e05c0ffd5"}, ] -urllib3 = [ - {file = "urllib3-2.0.2-py3-none-any.whl", hash = "sha256:d055c2f9d38dc53c808f6fdc8eab7360b6fdbbde02340ed25cfbcd817c62469e"}, - {file = "urllib3-2.0.2.tar.gz", hash = "sha256:61717a1095d7e155cdb737ac7bb2f4324a858a1e2e6466f6d03ff630ca68d3cc"}, + +[[package]] +name = "urllib3" +version = "2.0.3" +description = "HTTP library with thread-safe connection pooling, file post, and more." +optional = false +python-versions = ">=3.7" +files = [ + {file = "urllib3-2.0.3-py3-none-any.whl", hash = "sha256:48e7fafa40319d358848e1bc6809b208340fafe2096f1725d05d67443d0483d1"}, + {file = "urllib3-2.0.3.tar.gz", hash = "sha256:bee28b5e56addb8226c96f7f13ac28cb4c301dd5ea8a6ca179c0b9835e032825"}, ] -win32-setctime = [ + +[package.extras] +brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] +secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17.1.0)", "urllib3-secure-extra"] +socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] +zstd = ["zstandard (>=0.18.0)"] + +[[package]] +name = "win32-setctime" +version = "1.1.0" +description = "A small Python utility to set file creation time on Windows" +optional = false +python-versions = ">=3.5" +files = [ {file = "win32_setctime-1.1.0-py3-none-any.whl", hash = "sha256:231db239e959c2fe7eb1d7dc129f11172354f98361c4fa2d6d2d7e278baa8aad"}, {file = "win32_setctime-1.1.0.tar.gz", hash = "sha256:15cf5750465118d6929ae4de4eb46e8edae9a5634350c01ba582df868e932cb2"}, ] -wrapt = [ + +[package.extras] +dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"] + +[[package]] +name = "wrapt" +version = "1.15.0" +description = "Module for decorators, wrappers and monkey patching." +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" +files = [ {file = "wrapt-1.15.0-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:ca1cccf838cd28d5a0883b342474c630ac48cac5df0ee6eacc9c7290f76b11c1"}, {file = "wrapt-1.15.0-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:e826aadda3cae59295b95343db8f3d965fb31059da7de01ee8d1c40a60398b29"}, {file = "wrapt-1.15.0-cp27-cp27m-manylinux1_x86_64.whl", hash = "sha256:5fc8e02f5984a55d2c653f5fea93531e9836abbd84342c1d1e17abc4a15084c2"}, @@ -1477,3 +1578,12 @@ wrapt = [ {file = "wrapt-1.15.0-py3-none-any.whl", hash = "sha256:64b1df0f83706b4ef4cfb4fb0e4c2669100fd7ecacfb59e091fad300d4e04640"}, {file = "wrapt-1.15.0.tar.gz", hash = "sha256:d06730c6aed78cee4126234cf2d071e01b44b915e725a6cb439a879ec9754a3a"}, ] + +[extras] +accelerate = ["accelerate"] +bnb = ["bitsandbytes"] + +[metadata] +lock-version = "2.0" +python-versions = "^3.9" +content-hash = "54ecacb32d699cb1298c237c4661c1b707f119cf2c27bd54bad7a1ea2ffb8b10" diff --git a/server/requirements.txt b/server/requirements.txt index e8cee52b..a9bd441c 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -1,21 +1,21 @@ backoff==2.2.1 ; python_version >= "3.9" and python_version < "4.0" -bitsandbytes==0.38.1 ; python_version >= "3.9" and python_version < "4.0" certifi==2023.5.7 ; python_version >= "3.9" and python_version < "4.0" charset-normalizer==3.1.0 ; python_version >= "3.9" and python_version < "4.0" click==8.1.3 ; python_version >= "3.9" and python_version < "4.0" -colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "win32" or python_version >= "3.9" and python_version < "4.0" and platform_system == "Windows" -deprecated==1.2.13 ; python_version >= "3.9" and python_version < "4.0" -filelock==3.12.0 ; python_version >= "3.9" and python_version < "4.0" -fsspec==2023.5.0 ; python_version >= "3.9" and python_version < "4.0" -googleapis-common-protos==1.59.0 ; python_version >= "3.9" and python_version < "4.0" +colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0" and (sys_platform == "win32" or platform_system == "Windows") +deprecated==1.2.14 ; python_version >= "3.9" and python_version < "4.0" +filelock==3.12.2 ; python_version >= "3.9" and python_version < "4.0" +fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "4.0" +googleapis-common-protos==1.59.1 ; python_version >= "3.9" and python_version < "4.0" grpc-interceptor==0.15.2 ; python_version >= "3.9" and python_version < "4.0" -grpcio-reflection==1.55.0 ; python_version >= "3.9" and python_version < "4.0" -grpcio-status==1.55.0 ; python_version >= "3.9" and python_version < "4.0" -grpcio==1.55.0 ; python_version >= "3.9" and python_version < "4.0" +grpcio-reflection==1.54.2 ; python_version >= "3.9" and python_version < "4.0" +grpcio-status==1.54.2 ; python_version >= "3.9" and python_version < "4.0" +grpcio==1.54.2 ; python_version >= "3.9" and python_version < "4.0" hf-transfer==0.1.3 ; python_version >= "3.9" and python_version < "4.0" huggingface-hub==0.14.1 ; python_version >= "3.9" and python_version < "4.0" -idna==3.4 ; python_version >= "3.9" and python_version < "4" +idna==3.4 ; python_version >= "3.9" and python_version < "4.0" loguru==0.6.0 ; python_version >= "3.9" and python_version < "4.0" +numpy==1.24.3 ; python_version >= "3.9" and python_version < "4.0" opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "4.0" opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "4.0" opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "4.0" @@ -26,17 +26,18 @@ opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "4.0" opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "4.0" opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "4.0" packaging==23.1 ; python_version >= "3.9" and python_version < "4.0" -protobuf==4.23.1 ; python_version >= "3.9" and python_version < "4.0" +protobuf==4.23.2 ; python_version >= "3.9" and python_version < "4.0" pyyaml==6.0 ; python_version >= "3.9" and python_version < "4.0" +regex==2023.6.3 ; python_version >= "3.9" and python_version < "4.0" requests==2.31.0 ; python_version >= "3.9" and python_version < "4.0" safetensors==0.3.1 ; python_version >= "3.9" and python_version < "4.0" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "4.0" setuptools==67.8.0 ; python_version >= "3.9" and python_version < "4.0" tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "4.0" -transformers==4.29.2 ; python_version >= "3.9" and python_version < "4.0" tqdm==4.65.0 ; python_version >= "3.9" and python_version < "4.0" +transformers==4.30.2 ; python_version >= "3.9" and python_version < "4.0" typer==0.6.1 ; python_version >= "3.9" and python_version < "4.0" -typing-extensions==4.6.0 ; python_version >= "3.9" and python_version < "4.0" -urllib3==2.0.2 ; python_version >= "3.9" and python_version < "4.0" +typing-extensions==4.6.3 ; python_version >= "3.9" and python_version < "4.0" +urllib3==2.0.3 ; python_version >= "3.9" and python_version < "4.0" win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "win32" wrapt==1.15.0 ; python_version >= "3.9" and python_version < "4.0" diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index c0e6c2dc..aeb1f13b 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -151,5 +151,37 @@ def download_weights( utils.convert_files(local_pt_files, local_st_files) +@app.command() +def quantize( + model_id: str, + output_dir: str, + revision: Optional[str] = None, + logger_level: str = "INFO", + json_output: bool = False, + trust_remote_code: bool = False, + upload_to_model_id: Optional[str] = None, + percdamp: float = 0.01, + act_order: bool = False, +): + download_weights( + model_id=model_id, + revision=revision, + logger_level=logger_level, + json_output=json_output, + ) + from text_generation_server.utils.gptq.quantize import quantize + + quantize( + model_id=model_id, + bits=4, + groupsize=128, + output_dir=output_dir, + trust_remote_code=trust_remote_code, + upload_to_model_id=upload_to_model_id, + percdamp=percdamp, + act_order=act_order, + ) + + if __name__ == "__main__": app() diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 3fdc23b2..2abde685 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -246,6 +246,10 @@ def get_model( if sharded: raise ValueError("sharded is not supported for AutoModel") + if quantize == "gptq": + raise ValueError( + "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: return CausalLM( diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 3586b85a..9c1020a5 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -42,7 +42,8 @@ from text_generation_server.utils.layers import ( def load_row(config, prefix: str, weights, bias: bool): - weight = weights.get_sharded(f"{prefix}.weight", dim=1) + weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) + if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process bias = weights.get_tensor(f"{prefix}.bias") @@ -57,19 +58,21 @@ def load_row(config, prefix: str, weights, bias: bool): def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size): - weight = weights.get_sharded(f"{prefix}.weight", dim=0) - bias = weights.get_sharded(f"{prefix}.bias", dim=0) - - weight = ( - weight.view( - num_heads, - 3, - head_size, - hidden_size, + weight = weights.get_multi_weights_col([prefix], quantize=config.quantize, dim=0) + if isinstance(weight, torch.Tensor): + # Only on non quantized versions + weight = ( + weight.view( + num_heads, + 3, + head_size, + hidden_size, + ) + .permute(1, 0, 2, 3) + .reshape(-1, hidden_size) ) - .permute(1, 0, 2, 3) - .reshape(-1, hidden_size) - ) + + bias = weights.get_sharded(f"{prefix}.bias", dim=0) bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1) linear = get_linear(weight, bias, config.quantize) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 4a9063eb..fa35c359 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -21,7 +21,8 @@ from text_generation_server.utils.layers import ( def load_row(config, prefix: str, weights, bias: bool): - weight = weights.get_sharded(f"{prefix}.weight", dim=1) + weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) + if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process bias = weights.get_tensor(f"{prefix}.bias") diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index b01d752a..4eb0034d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -21,6 +21,81 @@ from text_generation_server.utils.layers import ( def load_multi_mqa( config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size ): + + if config.quantize == "gptq": + return _load_multi_mqa_gptq( + config, prefix, weights, bias, head_size, num_heads, hidden_size + ) + else: + return _load_multi_mqa( + config, prefix, weights, bias, head_size, num_heads, hidden_size + ) + + +def _load_multi_mqa_gptq( + config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size +): + if any("c_attn" in k for k in weights.routing.keys()) and not config.transpose: + world_size = weights.process_group.size() + rank = weights.process_group.rank() + + slice_ = weights._get_slice(f"{prefix}.c_attn.qweight") + shape = slice_.get_shape() + block_size = (shape[1] - 2 * head_size) // world_size + start = rank * block_size + stop = (rank + 1) * block_size + assert (shape[1] - 2 * head_size) % world_size == 0 + q_tensor = slice_[:, start:stop] + kv_tensor = slice_[:, -2 * head_size :] + qweight = torch.cat([q_tensor, kv_tensor], dim=1) + + slice_ = weights._get_slice(f"{prefix}.c_attn.scales") + shape = slice_.get_shape() + block_size = (shape[1] - 2 * head_size) // world_size + start = rank * block_size + stop = (rank + 1) * block_size + assert (shape[1] - 2 * head_size) % world_size == 0 + q_tensor = slice_[:, start:stop] + kv_tensor = slice_[:, -2 * head_size :] + scales = torch.cat([q_tensor, kv_tensor], dim=1) + + slice_ = weights._get_slice(f"{prefix}.c_attn.qzeros") + shape = slice_.get_shape() + block_size = (shape[1] - (2 * head_size) * 4 // 32) // world_size + start = rank * block_size + stop = (rank + 1) * block_size + assert 2 * head_size % (32 // 4) == 0 + q_tensor = slice_[:, start:stop] + kv_tensor = slice_[:, -2 * head_size * 4 // 32 :] + qzeros = torch.cat([q_tensor, kv_tensor], dim=1) + + g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") + bits = weights.get_tensor("gptq_bits").item() + groupsize = weights.get_tensor("gptq_groupsize").item() + + weight = (qweight, qzeros, scales, g_idx, bits, groupsize) + + if bias: + slice_ = weights._get_slice(f"{prefix}.c_attn.bias") + shape = slice_.get_shape() + block_size = (shape[0] - 2 * head_size) // world_size + assert (shape[0] - 2 * head_size) % world_size == 0 + q_tensor = slice_[start:stop] + start = rank * block_size + stop = (rank + 1) * block_size + q_tensor = slice_[start:stop] + kv_tensor = slice_[-2 * head_size :] + bias = torch.cat([q_tensor, kv_tensor], dim=0) + + return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + else: + raise NotImplementedError("Gptq loading with santacoder is not implemented") + + +def _load_multi_mqa( + config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size +): + if any("c_attn" in k for k in weights.routing.keys()): slice_ = weights._get_slice(f"{prefix}.c_attn.weight") shape = slice_.get_shape() @@ -92,7 +167,9 @@ def load_col(config, prefix: str, weights, bias: bool): if config.transpose: weight = weights.get_sharded(f"{prefix}.weight", dim=1).T else: - weight = weights.get_sharded(f"{prefix}.weight", dim=0) + weight = weights.get_multi_weights_col( + [prefix], quantize=config.quantize, dim=0 + ) if bias: bias = weights.get_sharded(f"{prefix}.bias", dim=0) @@ -105,7 +182,7 @@ def load_row(config, prefix: str, weights, bias: bool): if config.transpose: weight = weights.get_sharded(f"{prefix}.weight", dim=0).T else: - weight = weights.get_sharded(f"{prefix}.weight", dim=1) + weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index eb216a20..a80d58cb 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -3,7 +3,7 @@ import torch.distributed from opentelemetry import trace from transformers import AutoConfig -from transformers.models.llama import LlamaTokenizer +from transformers.models.llama import LlamaTokenizer, LlamaTokenizerFast from typing import Optional from text_generation_server.models import FlashCausalLM @@ -34,13 +34,22 @@ class FlashLlama(FlashCausalLM): else: raise NotImplementedError("FlashLlama is only available on GPU") - tokenizer = LlamaTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) + try: + tokenizer = LlamaTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + except Exception: + tokenizer = LlamaTokenizerFast.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) config = AutoConfig.from_pretrained( model_id, revision=revision, trust_remote_code=trust_remote_code diff --git a/server/text_generation_server/utils/gptq/custom_autotune.py b/server/text_generation_server/utils/gptq/custom_autotune.py new file mode 100644 index 00000000..17dff02e --- /dev/null +++ b/server/text_generation_server/utils/gptq/custom_autotune.py @@ -0,0 +1,261 @@ +# https://github.com/fpgaminer/GPTQ-triton +""" +Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100. +""" + +import builtins +import math +import time +from typing import Dict + +import triton + + +class Autotuner(triton.KernelInterface): + def __init__( + self, + fn, + arg_names, + configs, + key, + reset_to_zero, + prune_configs_by: Dict = None, + nearest_power_of_two: bool = False, + ): + """ + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs. + 'nearest_power_of_two'(optional): whether to round key arguments to the nearest power of two when caching tuning results + """ + if not configs: + self.configs = [triton.Config({}, num_warps=4, num_stages=2)] + else: + self.configs = configs + self.key_idx = [arg_names.index(k) for k in key] + self.nearest_power_of_two = nearest_power_of_two + self.cache = {} + # hook to reset all required tensor to zeros before relaunching a kernel + self.hook = lambda args: 0 + if reset_to_zero is not None: + self.reset_idx = [arg_names.index(k) for k in reset_to_zero] + + def _hook(args): + for i in self.reset_idx: + args[i].zero_() + + self.hook = _hook + self.arg_names = arg_names + # prune configs + if prune_configs_by: + perf_model, top_k = ( + prune_configs_by["perf_model"], + prune_configs_by["top_k"], + ) + if "early_config_prune" in prune_configs_by: + early_config_prune = prune_configs_by["early_config_prune"] + else: + perf_model, top_k, early_config_prune = None, None, None + self.perf_model, self.configs_top_k = perf_model, top_k + self.early_config_prune = early_config_prune + self.fn = fn + + def _bench(self, *args, config, **meta): + # check for conflicts, i.e. meta-parameters both provided + # as kwargs and by the autotuner + conflicts = meta.keys() & config.kwargs.keys() + if conflicts: + raise ValueError( + f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols." + ) + # augment meta-parameters with tunable ones + current = dict(meta, **config.kwargs) + + def kernel_call(): + if config.pre_hook: + config.pre_hook(self.nargs) + self.hook(args) + self.fn.run( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + **current, + ) + + try: + # In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses + # PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default + return triton.testing.do_bench( + kernel_call, percentiles=(0.5, 0.2, 0.8), rep=40 + ) + except triton.compiler.OutOfResources: + return (float("inf"), float("inf"), float("inf")) + + def run(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + if len(self.configs) > 1: + key = tuple(args[i] for i in self.key_idx) + + # This reduces the amount of autotuning by rounding the keys to the nearest power of two + # In my testing this gives decent results, and greatly reduces the amount of tuning required + if self.nearest_power_of_two: + key = tuple([2 ** int(math.log2(x) + 0.5) for x in key]) + + if key not in self.cache: + # prune configs + pruned_configs = self.prune_configs(kwargs) + bench_start = time.time() + timings = { + config: self._bench(*args, config=config, **kwargs) + for config in pruned_configs + } + bench_end = time.time() + self.bench_time = bench_end - bench_start + self.cache[key] = builtins.min(timings, key=timings.get) + self.hook(args) + self.configs_timings = timings + config = self.cache[key] + else: + config = self.configs[0] + self.best_config = config + if config.pre_hook is not None: + config.pre_hook(self.nargs) + return self.fn.run( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + **kwargs, + **config.kwargs, + ) + + def prune_configs(self, kwargs): + pruned_configs = self.configs + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs) + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + if len(pruned_configs) > top_k: + est_timing = { + config: self.perf_model( + **self.nargs, + **kwargs, + **config.kwargs, + num_stages=config.num_stages, + num_warps=config.num_warps, + ) + for config in pruned_configs + } + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[ + :top_k + ] + return pruned_configs + + def warmup(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + for config in self.prune_configs(kwargs): + self.fn.warmup( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + **kwargs, + **config.kwargs, + ) + self.nargs = None + + +def autotune( + configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False +): + """ + Decorator for auto-tuning a :code:`triton.jit`'d function. + .. highlight:: python + .. code-block:: python + @triton.autotune(configs=[ + triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4), + triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8), + ], + key=['x_size'] # the two above configs will be evaluated anytime + # the value of x_size changes + ) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] + :note: When all the configurations are evaluated, the kernel will run multiple time. + This means that whatever value the kernel updates will be updated multiple times. + To avoid this undesired behavior, you can use the `reset_to_zero` argument, which + reset the value of the provided tensor to `zero` before running any configuration. + :param configs: a list of :code:`triton.Config` objects + :type configs: list[triton.Config] + :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. + :type key: list[str] + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs. + :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. + :type reset_to_zero: list[str] + """ + + def decorator(fn): + return Autotuner( + fn, + fn.arg_names, + configs, + key, + reset_to_zero, + prune_configs_by, + nearest_power_of_two, + ) + + return decorator + + +def matmul248_kernel_config_pruner(configs, nargs): + """ + The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller. + """ + m = max(2 ** int(math.ceil(math.log2(nargs["M"]))), 16) + n = max(2 ** int(math.ceil(math.log2(nargs["N"]))), 16) + k = max(2 ** int(math.ceil(math.log2(nargs["K"]))), 16) + + used = set() + for config in configs: + block_size_m = min(m, config.kwargs["BLOCK_SIZE_M"]) + block_size_n = min(n, config.kwargs["BLOCK_SIZE_N"]) + block_size_k = min(k, config.kwargs["BLOCK_SIZE_K"]) + group_size_m = config.kwargs["GROUP_SIZE_M"] + + if ( + block_size_m, + block_size_n, + block_size_k, + group_size_m, + config.num_stages, + config.num_warps, + ) in used: + continue + + used.add( + ( + block_size_m, + block_size_n, + block_size_k, + group_size_m, + config.num_stages, + config.num_warps, + ) + ) + yield triton.Config( + { + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, + "GROUP_SIZE_M": group_size_m, + }, + num_stages=config.num_stages, + num_warps=config.num_warps, + ) diff --git a/server/text_generation_server/utils/gptq/quant_linear.py b/server/text_generation_server/utils/gptq/quant_linear.py new file mode 100644 index 00000000..54fa2014 --- /dev/null +++ b/server/text_generation_server/utils/gptq/quant_linear.py @@ -0,0 +1,359 @@ +import math +import numpy as np +import torch +import torch.nn as nn +from torch.cuda.amp import custom_bwd, custom_fwd + +try: + import triton + import triton.language as tl + from . import custom_autotune + + # code based https://github.com/fpgaminer/GPTQ-triton + @custom_autotune.autotune( + configs=[ + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=2, + num_warps=8, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=3, + num_warps=8, + ), + triton.Config( + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=2, + num_warps=4, + ), + ], + key=["M", "N", "K"], + nearest_power_of_two=True, + prune_configs_by={ + "early_config_prune": custom_autotune.matmul248_kernel_config_pruner, + "perf_model": None, + "top_k": None, + }, + ) + @triton.jit + def matmul_248_kernel( + a_ptr, + b_ptr, + c_ptr, + scales_ptr, + zeros_ptr, + g_ptr, + M, + N, + K, + bits, + maxq, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_scales, + stride_zeros, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + ): + """ + Compute the matrix multiplication C = A x B. + A is of shape (M, K) float16 + B is of shape (K//8, N) int32 + C is of shape (M, N) float16 + scales is of shape (G, N) float16 + zeros is of shape (G, N) float16 + g_ptr is of shape (K) int32 + """ + infearure_per_bits = 32 // bits + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + ( + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak + ) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + a_mask = offs_am[:, None] < M + # b_ptrs is set up such that it repeats elements along the K axis 8 times + b_ptrs = b_ptr + ( + (offs_k[:, None] // infearure_per_bits) * stride_bk + + offs_bn[None, :] * stride_bn + ) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + g_ptrs = g_ptr + offs_k + # shifter is used to extract the N bits of each element in the 32-bit word from B + scales_ptrs = scales_ptr + offs_bn[None, :] + zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits) + + shifter = (offs_k % infearure_per_bits) * bits + zeros_shifter = (offs_bn % infearure_per_bits) * bits + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, num_pid_k): + g_idx = tl.load(g_ptrs) + + # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop + scales = tl.load( + scales_ptrs + g_idx[:, None] * stride_scales + ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load( + zeros_ptrs + g_idx[:, None] * stride_zeros + ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + + zeros = (zeros >> zeros_shifter[None, :]) & maxq + zeros = zeros + 1 + + a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated + + # Now we need to unpack b (which is N-bit values) into 32-bit values + b = (b >> shifter[:, None]) & maxq # Extract the N-bit values + b = (b - zeros) * scales # Scale and shift + + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K + b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk + g_ptrs += BLOCK_SIZE_K + + c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + +except: + print("triton not installed.") + + +def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): + with torch.cuda.device(input.device): + output = torch.empty( + (input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16 + ) + grid = lambda META: ( + triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) + * triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]), + ) + matmul_248_kernel[grid]( + input, + qweight, + output, + scales, + qzeros, + g_idx, + input.shape[0], + qweight.shape[1], + input.shape[1], + bits, + maxq, + input.stride(0), + input.stride(1), + qweight.stride(0), + qweight.stride(1), + output.stride(0), + output.stride(1), + scales.stride(0), + qzeros.stride(0), + ) + return output + + +class QuantLinearFunction(torch.autograd.Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq): + output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq) + return output + + +class QuantLinear(nn.Module): + def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): + super().__init__() + self.register_buffer("qweight", qweight) + self.register_buffer("qzeros", qzeros) + self.register_buffer("scales", scales) + self.register_buffer("g_idx", g_idx) + if bias is not None: + self.register_buffer("bias", bias) + else: + self.bias = None + if bits not in [2, 4, 8]: + raise NotImplementedError("Only 2,4,8 bits are supported.") + self.bits = bits + self.maxq = 2**self.bits - 1 + self.groupsize = groupsize + + self.outfeatures = qweight.shape[1] + self.infeatures = qweight.shape[0] * 32 // 4 + + @classmethod + def new(cls, bits, groupsize, infeatures, outfeatures, bias): + if bits not in [2, 4, 8]: + raise NotImplementedError("Only 2,4,8 bits are supported.") + + qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32) + qzeros = torch.zeros( + (math.ceil(infeatures / groupsize), outfeatures // 32 * bits), + dtype=torch.int32, + ) + scales = torch.zeros( + (math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16 + ) + g_idx = torch.tensor( + [i // groupsize for i in range(infeatures)], dtype=torch.int32 + ) + if bias: + bias = torch.zeros((outfeatures), dtype=torch.float16) + else: + bias = None + return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize) + + def pack(self, linear, scales, zeros, g_idx=None): + self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + self.scales = scales.clone().half() + if linear.bias is not None: + self.bias = linear.bias.clone().half() + + intweight = [] + for idx in range(self.infeatures): + intweight.append( + torch.round( + (linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) + / self.scales[self.g_idx[idx]] + ).to(torch.int)[:, None] + ) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.numpy().astype(np.uint32) + qweight = np.zeros( + (intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32 + ) + i = 0 + row = 0 + while row < qweight.shape[0]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (32 // self.bits)): + qweight[row] |= intweight[j] << (self.bits * (j - i)) + i += 32 // self.bits + row += 1 + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + + qweight = qweight.astype(np.int32) + self.qweight = torch.from_numpy(qweight) + + zeros -= 1 + zeros = zeros.numpy().astype(np.uint32) + qzeros = np.zeros( + (zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32 + ) + i = 0 + col = 0 + while col < qzeros.shape[1]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (32 // self.bits)): + qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) + i += 32 // self.bits + col += 1 + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + + qzeros = qzeros.astype(np.int32) + self.qzeros = torch.from_numpy(qzeros) + + def forward(self, x): + out_shape = x.shape[:-1] + (self.outfeatures,) + out = QuantLinearFunction.apply( + x.reshape(-1, x.shape[-1]), + self.qweight, + self.scales, + self.qzeros, + self.g_idx, + self.bits, + self.maxq, + ) + out = out + self.bias if self.bias is not None else out + return out.reshape(out_shape) diff --git a/server/text_generation_server/utils/gptq/quantize.py b/server/text_generation_server/utils/gptq/quantize.py new file mode 100644 index 00000000..5a4ed8da --- /dev/null +++ b/server/text_generation_server/utils/gptq/quantize.py @@ -0,0 +1,866 @@ +import argparse +import time +import numpy as np +import torch +import torch.nn as nn +import math +import json +import os + +from texttable import Texttable +from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer +import transformers +from huggingface_hub import HfApi +import numpy as np +import torch +from text_generation_server.utils.gptq.quant_linear import QuantLinear +from loguru import logger +from typing import Optional + +DEV = torch.device("cuda:0") + + +class Quantizer(nn.Module): + def __init__(self, shape=1): + super(Quantizer, self).__init__() + self.register_buffer("maxq", torch.tensor(0)) + self.register_buffer("scale", torch.zeros(shape)) + self.register_buffer("zero", torch.zeros(shape)) + + def configure( + self, + bits, + perchannel=False, + sym=True, + mse=False, + norm=2.4, + grid=100, + maxshrink=0.8, + trits=False, + ): + + self.maxq = torch.tensor(2**bits - 1) + self.perchannel = perchannel + self.sym = sym + self.mse = mse + self.norm = norm + self.grid = grid + self.maxshrink = maxshrink + if trits: + self.maxq = torch.tensor(-1) + self.scale = torch.zeros_like(self.scale) + + def _quantize(self, x, scale, zero, maxq): + if maxq < 0: + return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero + q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) + return scale * (q - zero) + + def find_params(self, x, weight=False): + dev = x.device + self.maxq = self.maxq.to(dev) + + shape = x.shape + if self.perchannel: + if weight: + x = x.flatten(1) + else: + if len(shape) == 4: + x = x.permute([1, 0, 2, 3]) + x = x.flatten(1) + if len(shape) == 3: + x = x.reshape((-1, shape[-1])).t() + if len(shape) == 2: + x = x.t() + else: + x = x.flatten().unsqueeze(0) + + tmp = torch.zeros(x.shape[0], device=dev) + xmin = torch.minimum(x.min(1)[0], tmp) + xmax = torch.maximum(x.max(1)[0], tmp) + + if self.sym: + xmax = torch.maximum(torch.abs(xmin), xmax) + tmp = xmin < 0 + if torch.any(tmp): + xmin[tmp] = -xmax[tmp] + tmp = (xmin == 0) & (xmax == 0) + xmin[tmp] = -1 + xmax[tmp] = +1 + + if self.maxq < 0: + self.scale = xmax + self.zero = xmin + else: + self.scale = (xmax - xmin) / self.maxq + if self.sym: + self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) + else: + self.zero = torch.round(-xmin / self.scale) + + if self.mse: + best = torch.full([x.shape[0]], float("inf"), device=dev) + for i in range(int(self.maxshrink * self.grid)): + p = 1 - i / self.grid + xmin1 = p * xmin + xmax1 = p * xmax + scale1 = (xmax1 - xmin1) / self.maxq + zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero + q = self._quantize( + x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq + ) + q -= x + q.abs_() + q.pow_(self.norm) + err = torch.sum(q, 1) + tmp = err < best + if torch.any(tmp): + best[tmp] = err[tmp] + self.scale[tmp] = scale1[tmp] + self.zero[tmp] = zero1[tmp] + if not self.perchannel: + if weight: + tmp = shape[0] + else: + tmp = shape[1] if len(shape) != 3 else shape[2] + self.scale = self.scale.repeat(tmp) + self.zero = self.zero.repeat(tmp) + + if weight: + shape = [-1] + [1] * (len(shape) - 1) + self.scale = self.scale.reshape(shape) + self.zero = self.zero.reshape(shape) + return + if len(shape) == 4: + self.scale = self.scale.reshape((1, -1, 1, 1)) + self.zero = self.zero.reshape((1, -1, 1, 1)) + if len(shape) == 3: + self.scale = self.scale.reshape((1, 1, -1)) + self.zero = self.zero.reshape((1, 1, -1)) + if len(shape) == 2: + self.scale = self.scale.unsqueeze(0) + self.zero = self.zero.unsqueeze(0) + + def quantize(self, x): + if self.ready(): + return self._quantize(x, self.scale, self.zero, self.maxq) + + return x + + def enabled(self): + return self.maxq > 0 + + def ready(self): + return torch.all(self.scale != 0) + + +class GPTQ: + def __init__(self, layer, observe=False): + self.layer = layer + self.dev = self.layer.weight.device + W = layer.weight.data.clone() + if isinstance(self.layer, nn.Conv2d): + W = W.flatten(1) + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + self.rows = W.shape[0] + self.columns = W.shape[1] + self.H = torch.zeros((self.columns, self.columns), device=self.dev) + self.nsamples = 0 + self.quantizer = Quantizer() + self.observe = observe + + def add_batch(self, inp, out): + # Hessian H = 2 X XT + λ I + if self.observe: + self.inp1 = inp + self.out1 = out + else: + self.inp1 = None + self.out1 = None + + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + tmp = inp.shape[0] + if isinstance(self.layer, nn.Linear) or isinstance( + self.layer, transformers.Conv1D + ): + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + if isinstance(self.layer, nn.Conv2d): + unfold = nn.Unfold( + self.layer.kernel_size, + dilation=self.layer.dilation, + padding=self.layer.padding, + stride=self.layer.stride, + ) + inp = unfold(inp) + inp = inp.permute([1, 0, 2]) + inp = inp.flatten(1) + self.H *= self.nsamples / (self.nsamples + tmp) + self.nsamples += tmp + # inp = inp.float() + inp = math.sqrt(2 / self.nsamples) * inp.float() + # self.H += 2 / self.nsamples * inp.matmul(inp.t()) + self.H += inp.matmul(inp.t()) + + def print_loss(self, name, q_weight, weight_error, timecost): + table = Texttable() + length = 28 + name = ( + (name + " " * (length - len(name))) + if len(name) <= length + else name[:length] + ) + + table.header(["name", "weight_error", "fp_inp_SNR", "q_inp_SNR", "time"]) + + # assign weight + self.layer.weight.data = q_weight.reshape(self.layer.weight.shape).to( + self.layer.weight.data.dtype + ) + + if self.inp1 is not None: + # quantize input to int8 + quantizer = Quantizer() + quantizer.configure(8, perchannel=False, sym=True, mse=False) + quantizer.find_params(self.inp1) + q_in = quantizer.quantize(self.inp1).type(torch.float16) + q_out = self.layer(q_in) + + # get kinds of SNR + q_SNR = torch_snr_error(q_out, self.out1).item() + fp_SNR = torch_snr_error(self.layer(self.inp1), self.out1).item() + else: + q_SNR = "-" + fp_SNR = "-" + + table.add_row([name, weight_error, fp_SNR, q_SNR, timecost]) + print(table.draw().split("\n")[-2]) + + def fasterquant( + self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False, name="" + ): + self.layer.to(self.dev) + + W = self.layer.weight.data.clone() + if isinstance(self.layer, nn.Conv2d): + W = W.flatten(1) + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + W = W.float() + + tick = time.time() + + if not self.quantizer.ready(): + self.quantizer.find_params(W, weight=True) + + H = self.H + if not self.observe: + del self.H + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + W[:, dead] = 0 + + if act_order: + perm = torch.argsort(torch.diag(H), descending=True) + W = W[:, perm] + H = H[perm][:, perm] + + Losses = torch.zeros_like(W) + Q = torch.zeros_like(W) + + damp = percdamp * torch.mean(torch.diag(H)) + diag = torch.arange(self.columns, device=self.dev) + H[diag, diag] += damp + H = torch.linalg.cholesky(H) + H = torch.cholesky_inverse(H) + try: + H = torch.linalg.cholesky(H, upper=True) + except Exception: + # Addition because Falcon fails on h_to_4h + H = torch.linalg.cholesky( + H + 1e-5 * torch.eye(H.shape[0]).to(H.device), upper=True + ) + Hinv = H + + g_idx = [] + scale = [] + zero = [] + now_idx = 1 + + for i1 in range(0, self.columns, blocksize): + i2 = min(i1 + blocksize, self.columns) + count = i2 - i1 + + W1 = W[:, i1:i2].clone() + Q1 = torch.zeros_like(W1) + Err1 = torch.zeros_like(W1) + Losses1 = torch.zeros_like(W1) + Hinv1 = Hinv[i1:i2, i1:i2] + + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + + if groupsize != -1: + if (i1 + i) % groupsize == 0: + self.quantizer.find_params( + W[:, (i1 + i) : (i1 + i + groupsize)], weight=True + ) + + if ((i1 + i) // groupsize) - now_idx == -1: + scale.append(self.quantizer.scale) + zero.append(self.quantizer.zero) + now_idx += 1 + + q = self.quantizer.quantize(w.unsqueeze(1)).flatten() + Q1[:, i] = q + Losses1[:, i] = (w - q) ** 2 / d**2 + + err1 = (w - q) / d + W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) + Err1[:, i] = err1 + + Q[:, i1:i2] = Q1 + Losses[:, i1:i2] = Losses1 / 2 + + W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) + + torch.cuda.synchronize() + error = torch.sum(Losses).item() + + groupsize = groupsize if groupsize != -1 else self.columns + g_idx = [i // groupsize for i in range(self.columns)] + g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device) + if act_order: + invperm = torch.argsort(perm) + Q = Q[:, invperm] + g_idx = g_idx[invperm] + + if isinstance(self.layer, transformers.Conv1D): + Q = Q.t() + + self.print_loss( + name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick) + ) + + if scale == []: + scale.append(self.quantizer.scale) + zero.append(self.quantizer.zero) + scale = torch.cat(scale, dim=1) + zero = torch.cat(zero, dim=1) + return scale, zero, g_idx, error + + def free(self): + self.inp1 = None + self.out1 = None + self.H = None + self.Losses = None + self.Trace = None + torch.cuda.empty_cache() + + +def get_wikitext2(nsamples, seed, seqlen, model_id): + from datasets import load_dataset + + traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train") + testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) + trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt") + testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt") + + import random + + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader, testenc + + +def get_ptb(nsamples, seed, seqlen, model_id): + from datasets import load_dataset + + traindata = load_dataset("ptb_text_only", "penn_treebank", split="train") + valdata = load_dataset("ptb_text_only", "penn_treebank", split="validation") + + from transformers import AutoTokenizer + + try: + tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) + except: + tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) + trainenc = tokenizer("\n\n".join(traindata["sentence"]), return_tensors="pt") + testenc = tokenizer("\n\n".join(valdata["sentence"]), return_tensors="pt") + + import random + + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader, testenc + + +def get_c4(nsamples, seed, seqlen, model_id): + from datasets import load_dataset + + traindata = load_dataset( + "allenai/c4", + "allenai--c4", + data_files={"train": "en/c4-train.00000-of-01024.json.gz"}, + split="train", + use_auth_token=False, + ) + valdata = load_dataset( + "allenai/c4", + "allenai--c4", + data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"}, + split="validation", + use_auth_token=False, + ) + + from transformers import AutoTokenizer + + try: + tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) + except: + tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) + + import random + + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + while True: + i = random.randint(0, len(traindata) - 1) + trainenc = tokenizer(traindata[i]["text"], return_tensors="pt") + if trainenc.input_ids.shape[1] >= seqlen: + break + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + + import random + + random.seed(0) + valenc = [] + for _ in range(256): + while True: + i = random.randint(0, len(valdata) - 1) + tmp = tokenizer(valdata[i]["text"], return_tensors="pt") + if tmp.input_ids.shape[1] >= seqlen: + break + i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + valenc.append(tmp.input_ids[:, i:j]) + valenc = torch.hstack(valenc) + + class TokenizerWrapper: + def __init__(self, input_ids): + self.input_ids = input_ids + + valenc = TokenizerWrapper(valenc) + + return trainloader, valenc + + +def get_ptb_new(nsamples, seed, seqlen, model_id): + from datasets import load_dataset + + traindata = load_dataset("ptb_text_only", "penn_treebank", split="train") + testdata = load_dataset("ptb_text_only", "penn_treebank", split="test") + + from transformers import AutoTokenizer + + try: + tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) + except: + tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) + trainenc = tokenizer(" ".join(traindata["sentence"]), return_tensors="pt") + testenc = tokenizer(" ".join(testdata["sentence"]), return_tensors="pt") + + import random + + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader, testenc + + +def get_c4_new(nsamples, seed, seqlen, model_id): + from datasets import load_dataset + + traindata = load_dataset( + "allenai/c4", + "allenai--c4", + data_files={"train": "en/c4-train.00000-of-01024.json.gz"}, + split="train", + ) + valdata = load_dataset( + "allenai/c4", + "allenai--c4", + data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"}, + split="validation", + ) + + from transformers import AutoTokenizer + + try: + tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) + except: + tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) + + import random + + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + while True: + i = random.randint(0, len(traindata) - 1) + trainenc = tokenizer(traindata[i]["text"], return_tensors="pt") + if trainenc.input_ids.shape[1] >= seqlen: + break + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + + valenc = tokenizer(" ".join(valdata[:1100]["text"]), return_tensors="pt") + valenc = valenc.input_ids[:, : (256 * seqlen)] + + class TokenizerWrapper: + def __init__(self, input_ids): + self.input_ids = input_ids + + valenc = TokenizerWrapper(valenc) + + return trainloader, valenc + + +def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model_id=""): + if "wikitext2" in name: + return get_wikitext2(nsamples, seed, seqlen, model_id) + if "ptb" in name: + if "new" in name: + return get_ptb_new(nsamples, seed, seqlen, model_id) + return get_ptb(nsamples, seed, seqlen, model_id) + if "c4" in name: + if "new" in name: + return get_c4_new(nsamples, seed, seqlen, model_id) + return get_c4(nsamples, seed, seqlen, model_id) + + +def find_layers(module, layers=(nn.Conv2d, nn.Linear), name=""): + # Skip last lm_head linear + # Need isintance Falcon is inheriting Linear. + if isinstance(module, layers) and "lm_head" not in name: + return {name: module} + res = {} + for name1, child in module.named_children(): + res.update( + find_layers( + child, layers=layers, name=name + "." + name1 if name != "" else name1 + ) + ) + return res + + +@torch.no_grad() +def sequential( + model, + dataloader, + dev, + nsamples, + bits, + groupsize, + percdamp=0.01, + sym: bool = False, + act_order: bool = False, +): + print("Starting ...") + + use_cache = model.config.use_cache + model.config.use_cache = False + try: + layers = model.model.layers + prefix = "model.layers" + except Exception: + layers = model.transformer.h + prefix = "transformer.h" + + dtype = next(iter(model.parameters())).dtype + inps = torch.zeros( + (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev + ) + + cache = {"i": 0} + extra = {} + + class Catcher(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, inp, **kwargs): + inps[cache["i"]] = inp + cache["i"] += 1 + extra.update(kwargs.copy()) + raise ValueError + + layers[0] = Catcher(layers[0]) + for batch in dataloader: + try: + model(batch[0]) + except ValueError: + pass + layers[0] = layers[0].module + + # layers[0] = layers[0].cpu() + # model.model.embed_tokens = model.model.embed_tokens.cpu() + # model.model.norm = model.model.norm.cpu() + torch.cuda.empty_cache() + + outs = torch.zeros_like(inps) + + extra = { + k: v.to(dev) if isinstance(v, torch.Tensor) else v for k, v in extra.items() + } + + print("Ready.") + + quantizers = {} + for i in range(len(layers)): + print(f"Quantizing layer {i+1}/{len(layers)}..") + print("+------------------+--------------+------------+-----------+-------+") + print("| name | weight_error | fp_inp_SNR | q_inp_SNR | time |") + print("+==================+==============+============+===========+=======+") + + from accelerate.hooks import remove_hook_from_submodules + + layer = layers[i].to(dev) + remove_hook_from_submodules(layer) + full = find_layers(layer) + sequential = [list(full.keys())] + + for names in sequential: + subset = {n: full[n] for n in names} + gptq = {} + for name in subset: + gptq[name] = GPTQ(subset[name]) + gptq[name].quantizer.configure( + bits, perchannel=True, sym=sym, mse=False + ) + + def add_batch(name): + def tmp(_, inp, out): + gptq[name].add_batch(inp[0].data, out.data) + + return tmp + + handles = [] + for name in subset: + handles.append(subset[name].register_forward_hook(add_batch(name))) + for j in range(nsamples): + + outs[j] = layer(inps[j].unsqueeze(0), **extra)[0] + for h in handles: + h.remove() + + for name in subset: + scale, zero, g_idx, error = gptq[name].fasterquant( + percdamp=percdamp, + groupsize=groupsize, + act_order=act_order, + name=name, + ) + quantizers[f"{prefix}.{i}.{name}"] = ( + gptq[name].quantizer.cpu(), + scale.cpu(), + zero.cpu(), + g_idx.cpu(), + bits, + groupsize, + ) + + gptq[name].free() + + for j in range(nsamples): + outs[j] = layer(inps[j].unsqueeze(0), **extra)[0] + + layers[i] = layer.cpu() + del layer + del gptq + torch.cuda.empty_cache() + + inps, outs = outs, inps + print("+------------------+--------------+------------+-----------+-------+") + print("\n") + + model.config.use_cache = use_cache + + return quantizers + + +def make_quant_linear(module, names, bits, groupsize, name=""): + if isinstance(module, QuantLinear): + return + for attr in dir(module): + tmp = getattr(module, attr) + name1 = name + "." + attr if name != "" else attr + if name1 in names: + delattr(module, attr) + setattr( + module, + attr, + QuantLinear.new( + bits, + groupsize, + tmp.in_features, + tmp.out_features, + tmp.bias is not None, + ), + ) + for name1, child in module.named_children(): + make_quant_linear( + child, names, bits, groupsize, name + "." + name1 if name != "" else name1 + ) + + +# TODO: perform packing on GPU +def pack(model, quantizers, bits, groupsize): + layers = find_layers(model) + layers = {n: layers[n] for n in quantizers} + make_quant_linear(model, quantizers, bits, groupsize) + qlayers = find_layers(model, (QuantLinear,)) + print("Packing ...") + for name in qlayers: + print(name) + quantizers[name], scale, zero, g_idx, _, _ = quantizers[name] + qlayers[name].pack(layers[name], scale, zero, g_idx) + print("Done.") + return model + + +def quantize( + model_id: str, + bits: int, + groupsize: int, + output_dir: str, + trust_remote_code: bool, + upload_to_model_id: Optional[str], + percdamp: float, + act_order: bool, +): + print("loading model") + model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.float16, + device_map="balanced_low_0", + trust_remote_code=trust_remote_code, + ) + print("LOADED model") + model.seqlen = 2048 + + dataset = "wikitext2" + nsamples = 128 + seed = None + + dataloader, testloader = get_loaders( + dataset, nsamples=nsamples, seed=seed, model_id=model_id, seqlen=model.seqlen + ) + + tick = time.time() + quantizers = sequential( + model, + dataloader, + DEV, + nsamples, + bits, + groupsize, + percdamp=percdamp, + act_order=act_order, + ) + print(time.time() - tick) + + pack(model, quantizers, bits, groupsize) + from safetensors.torch import save_file + from transformers.modeling_utils import shard_checkpoint + + state_dict = model.state_dict() + state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()} + state_dict["gptq_bits"] = torch.LongTensor([bits]) + state_dict["gptq_groupsize"] = torch.LongTensor([groupsize]) + + max_shard_size = "10GB" + shards, index = shard_checkpoint( + state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors" + ) + os.makedirs(output_dir, exist_ok=True) + for shard_file, shard in shards.items(): + save_file( + shard, + os.path.join(output_dir, shard_file), + metadata={ + "format": "pt", + "quantized": "gptq", + "origin": "text-generation-inference", + }, + ) + if index is None: + path_to_weights = os.path.join(output_dir, "model.safetensors") + logger.info(f"Model weights saved in {path_to_weights}") + else: + save_index_file = "model.safetensors.index.json" + save_index_file = os.path.join(output_dir, save_index_file) + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + logger.info( + f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " + f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code) + config.save_pretrained(output_dir) + logger.info("Saved config") + logger.info("Saving tokenizer") + tokenizer = AutoTokenizer.from_pretrained( + model_id, trust_remote_code=trust_remote_code + ) + tokenizer.save_pretrained(output_dir) + logger.info("Saved tokenizer") + + if upload_to_model_id: + + api = HfApi() + + api.upload_folder( + folder_path=output_dir, repo_id=upload_to_model_id, repo_type="model" + ) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 93865d52..a2b0c739 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -15,6 +15,8 @@ except ImportError: from accelerate import init_empty_weights +from text_generation_server.utils.gptq.quant_linear import QuantLinear + # Monkey patching @classmethod @@ -129,7 +131,22 @@ def get_linear(weight, bias, quantize): if bias is not None: linear.bias = nn.Parameter(bias) elif quantize == "gptq": - raise NotImplementedError("Soon") + try: + qweight, qzeros, scales, g_idx, bits, groupsize = weight + except Exception: + raise NotImplementedError( + f"The passed weight is not `gptq` compatible, loader needs to be updated." + ) + + linear = QuantLinear( + qweight, + qzeros, + scales, + g_idx, + bias, + bits, + groupsize, + ) else: raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.") return linear @@ -152,8 +169,14 @@ class TensorParallelHead(SuperLayer): @staticmethod def load(config, prefix: str, weights): weight = weights.get_sharded(f"{prefix}.weight", dim=0) + + # GPTQ doesn't quantize heads (nor embeddings) + if config.quantize == "gptq": + quantize = None + else: + quantize = config.quantize return TensorParallelHead( - get_linear(weight, bias=None, quantize=config.quantize), + get_linear(weight, bias=None, quantize=quantize), process_group=weights.process_group, ) @@ -196,24 +219,21 @@ class TensorParallelHead(SuperLayer): class TensorParallelColumnLinear(SuperLayer): @classmethod def load(cls, config, prefix: str, weights, bias: bool): - weight = weights.get_sharded(f"{prefix}.weight", dim=0) - if bias: - bias = weights.get_sharded(f"{prefix}.bias", dim=0) - else: - bias = None - return cls(get_linear(weight, bias, config.quantize)) + return cls.load_multi(config, [prefix], weights, bias, dim=0) @classmethod def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int): - w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes] - weight = torch.cat(w, dim=dim) + weight = weights.get_multi_weights_col( + prefixes, quantize=config.quantize, dim=dim + ) if bias: b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes] - bias = torch.cat(b, dim=0) + bias = torch.cat(b, dim=dim) else: bias = None - return cls(get_linear(weight, bias, config.quantize)) + linear = get_linear(weight, bias, config.quantize) + return cls(linear) class TensorParallelRowLinear(SuperLayer): @@ -223,7 +243,8 @@ class TensorParallelRowLinear(SuperLayer): @classmethod def load(cls, config, prefix: str, weights, bias: bool): - weight = weights.get_sharded(f"{prefix}.weight", dim=1) + weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) + if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process bias = weights.get_tensor(f"{prefix}.bias") diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 88347a6a..9d371834 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -1,6 +1,7 @@ from pathlib import Path from typing import List, Dict, Optional from safetensors import safe_open +import torch class Weights: @@ -54,7 +55,10 @@ class Weights: filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) tensor = f.get_tensor(tensor_name) - tensor = tensor.to(dtype=self.dtype) + # Special case for gptq which shouldn't convert + # u4 which are disguised as int32 + if tensor.dtype not in [torch.int32, torch.int64]: + tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(device=self.device) return tensor @@ -80,6 +84,49 @@ class Weights: tensor = slice_[:, start:stop] else: raise NotImplementedError("Let's make that generic when needed") - tensor = tensor.to(dtype=self.dtype) + # Special case for gptq which shouldn't convert + # u4 which are disguised as int32 + if tensor.dtype != torch.int32: + tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(device=self.device) return tensor + + def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): + if quantize == "gptq": + try: + qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1) + except RuntimeError: + raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`") + + qzeros = torch.cat([self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1) + scales = torch.cat([self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1) + w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] + for w2 in w[1:]: + torch.testing.assert_close(w2, w[0]) + g_idx = w[0] + + bits = self.get_tensor("gptq_bits").item() + groupsize = self.get_tensor("gptq_groupsize").item() + weight = (qweight, qzeros, scales, g_idx, bits, groupsize) + else: + w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] + weight = torch.cat(w, dim=dim) + return weight + + def get_multi_weights_row(self, prefix: str, quantize: str): + if quantize == "gptq": + try: + qweight = self.get_sharded(f"{prefix}.qweight", dim=0) + except RuntimeError: + raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`") + qzeros = self.get_tensor(f"{prefix}.qzeros") + scales = self.get_tensor(f"{prefix}.scales") + g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) + + bits = self.get_tensor("gptq_bits").item() + groupsize = self.get_tensor("gptq_groupsize").item() + + weight = (qweight, qzeros, scales, g_idx, bits, groupsize) + else: + weight = self.get_sharded(f"{prefix}.weight", dim=1) + return weight