diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index fd612ee8..fead2cec 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -194,7 +194,7 @@ jobs: steps: - uses: actions/checkout@v2 - name: Set up Python - uses: actions/setup-python@v1 + uses: actions/setup-python@v4.6 with: python-version: 3.9 - name: Tailscale @@ -213,6 +213,7 @@ jobs: - name: Run tests run: | export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }} + export HUGGING_FACE_HUB_TOKEN={{ secrets.HUGGING_FACE_HUB_TOKEN }} make integration-tests stop-runner: diff --git a/Makefile b/Makefile index 0d4a2f73..29c318fa 100644 --- a/Makefile +++ b/Makefile @@ -25,7 +25,7 @@ rust-tests: install-router install-launcher cargo test integration-tests: install-integration-tests - pytest -s -vv integration-tests + pytest -s -vv -m "not private" integration-tests update-integration-tests: install-integration-tests pytest -s -vv --snapshot-update integration-tests diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 68528a2f..521c9a0a 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -59,7 +59,7 @@ def launcher(event_loop): process.terminate() process.wait(60) - launcher_output = process.stdout.read1().decode("utf-8") + launcher_output = process.stdout.read().decode("utf-8") print(launcher_output) process.stdout.close() diff --git a/integration-tests/models/test_flash_llama.py b/integration-tests/models/test_flash_llama.py index 899a26bf..e1e23cd7 100644 --- a/integration-tests/models/test_flash_llama.py +++ b/integration-tests/models/test_flash_llama.py @@ -10,6 +10,7 @@ def flash_llama(launcher): @pytest.mark.asyncio +@pytest.mark.private async def test_flash_llama(flash_llama, snapshot): await health_check(flash_llama, 120) @@ -20,6 +21,7 @@ async def test_flash_llama(flash_llama, snapshot): @pytest.mark.asyncio +@pytest.mark.private async def test_flash_llama_all_params(flash_llama, snapshot): await health_check(flash_llama, 120) @@ -43,6 +45,7 @@ async def test_flash_llama_all_params(flash_llama, snapshot): @pytest.mark.asyncio +@pytest.mark.private async def test_flash_llama_load(flash_llama, generate_load, snapshot): await health_check(flash_llama, 120) diff --git a/integration-tests/models/test_flash_starcoder.py b/integration-tests/models/test_flash_starcoder.py index f5d2a47a..52e55296 100644 --- a/integration-tests/models/test_flash_starcoder.py +++ b/integration-tests/models/test_flash_starcoder.py @@ -10,6 +10,7 @@ def flash_starcoder(launcher): @pytest.mark.asyncio +@pytest.mark.private async def test_flash_starcoder(flash_starcoder, snapshot): await health_check(flash_starcoder, 240) @@ -20,6 +21,7 @@ async def test_flash_starcoder(flash_starcoder, snapshot): @pytest.mark.asyncio +@pytest.mark.private async def test_flash_starcoder_default_params(flash_starcoder, snapshot): await health_check(flash_starcoder, 240) @@ -32,6 +34,7 @@ async def test_flash_starcoder_default_params(flash_starcoder, snapshot): @pytest.mark.asyncio +@pytest.mark.private async def test_flash_starcoder_load(flash_starcoder, generate_load, snapshot): await health_check(flash_starcoder, 240) diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index ed959291..9029e954 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -129,7 +129,7 @@ class BLOOMSharded(BLOOM): parameters = dict(model.named_parameters()) for file in filenames: with safe_open( - file, framework="pt", device=str(device) if not quantize else "cpu" + file, framework="pt", device=str(device) if quantize is None else "cpu" ) as f: for name in f.keys(): full_name = f"transformer.{name}" diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 481fe8a6..54670b79 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -21,14 +21,13 @@ import torch import torch.distributed -from torch.nn import functional as F - from torch import nn from transformers.activations import ACT2FN from typing import Optional # Flash attention imports import flash_attn_cuda +import dropout_layer_norm from text_generation_server.utils.layers import ( FastLinear, @@ -331,15 +330,15 @@ class FlashLlamaModel(torch.nn.Module): self.head_size = self.layers[0].self_attn.head_size self.num_heads = self.layers[0].self_attn.num_heads - def post_load_weights(self, load_in_8bit: bool = False): + def post_load_weights(self, quantize: Optional[str] = None): if isinstance(self.embed_tokens, TensorParallelEmbedding): self.embed_tokens.add_null_idx() for layer in self.layers: layer: FlashLlamaLayer - layer.self_attn.query_key_value.prepare_weights(load_in_8bit) - layer.self_attn.o_proj.prepare_weights(load_in_8bit) - layer.mlp.gate_up_proj.prepare_weights(load_in_8bit) - layer.mlp.down_proj.prepare_weights(load_in_8bit) + layer.self_attn.query_key_value.prepare_weights(quantize) + layer.self_attn.o_proj.prepare_weights(quantize) + layer.mlp.gate_up_proj.prepare_weights(quantize) + layer.mlp.down_proj.prepare_weights(quantize) def forward( self, @@ -428,8 +427,8 @@ class FlashLlamaForCausalLM(torch.nn.Module): else: self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False) - def post_load_weights(self, load_in_8bit: bool = False): - self.model.post_load_weights(load_in_8bit) + def post_load_weights(self, quantize: Optional[str] = None): + self.model.post_load_weights(quantize) self.lm_head.prepare_weights() def forward( diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 369e8d4f..1a4ad551 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -345,16 +345,16 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): self.head_size = self.layers[0].attention.head_size self.num_heads = self.layers[0].attention.num_heads - def post_load_weights(self, load_in_8bit=False): + def post_load_weights(self, quantize: Optional[str] = None): if isinstance(self.embed_in, TensorParallelEmbedding): self.embed_in.add_null_idx() for layer in self.layers: layer: FlashNeoXLayer layer.attention.shuffle_qkv_dims() - layer.attention.query_key_value.prepare_weights(load_in_8bit) - layer.attention.dense.prepare_weights(load_in_8bit) - layer.mlp.dense_h_to_4h.prepare_weights(load_in_8bit) - layer.mlp.dense_4h_to_h.prepare_weights(load_in_8bit) + layer.attention.query_key_value.prepare_weights(quantize) + layer.attention.dense.prepare_weights(quantize) + layer.mlp.dense_h_to_4h.prepare_weights(quantize) + layer.mlp.dense_4h_to_h.prepare_weights(quantize) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): @@ -457,8 +457,8 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): config.hidden_size, config.vocab_size, bias=False ) - def post_load_weights(self, load_in_8bit=False): - self.gpt_neox.post_load_weights(load_in_8bit) + def post_load_weights(self, quantize: Optional[str] = None): + self.gpt_neox.post_load_weights(quantize) self.embed_out.prepare_weights() @classmethod diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 9451b01a..7a301c1f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -261,16 +261,16 @@ class FlashSantacoderModel(nn.Module): self.head_size = self.h[0].attn.head_size self.num_heads = self.h[0].attn.num_heads - def post_load_weights(self, load_in_8bit: bool = False): + def post_load_weights(self, quantize: Optional[str] = None): if self.tp_embeddings: self.wte.add_null_idx() self.wpe.add_null_idx() for layer in self.h: layer: Block - layer.attn.c_attn.prepare_weights(load_in_8bit) - layer.attn.c_proj.prepare_weights(load_in_8bit) - layer.mlp.c_fc.prepare_weights(load_in_8bit) - layer.mlp.c_proj.prepare_weights(load_in_8bit) + layer.attn.c_attn.prepare_weights(quantize) + layer.attn.c_proj.prepare_weights(quantize) + layer.mlp.c_fc.prepare_weights(quantize) + layer.mlp.c_proj.prepare_weights(quantize) def forward( self, @@ -347,8 +347,8 @@ class FlashSantacoderForCausalLM(nn.Module): else: self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False) - def post_load_weights(self, load_in_8bit: bool = False): - self.transformer.post_load_weights(load_in_8bit) + def post_load_weights(self, quantize: Optional[str] = None): + self.transformer.post_load_weights(quantize) self.lm_head.prepare_weights() def forward( diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 156fed76..b775bd79 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -77,14 +77,14 @@ class FlashLlama(FlashCausalLM): def load_weights( model, filenames: List[Path], - quantize: bool, + quantize: Optional[str], device: torch.device, dtype: torch.dtype, ): for filename in filenames: state_dict = torch.load(filename, map_location="cpu") for key, value in state_dict.items(): - value = value.to(device if not quantize else "cpu").to(dtype) + value = value.to(device if quantize is None else "cpu").to(dtype) layer_name = ".".join(key.split(".")[:4]) @@ -204,7 +204,7 @@ class FlashLlamaSharded(FlashLlama): def load_weights( model, filenames: List[str], - quantize: bool, + quantize: Optional[str], device: torch.device, dtype: torch.dtype, rank: int, @@ -212,7 +212,7 @@ class FlashLlamaSharded(FlashLlama): ): for file in filenames: with safe_open( - file, framework="pt", device=str(device) if not quantize else "cpu" + file, framework="pt", device=str(device) if quantize is None else "cpu" ) as f: for name in f.keys(): slice_ = f.get_slice(name) diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 7ae06036..0924f107 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -97,7 +97,7 @@ class FlashNeoXSharded(FlashNeoX): parameters = dict(model.named_parameters()) for file in filenames: with safe_open( - file, framework="pt", device=str(device) if not quantize else "cpu" + file, framework="pt", device=str(device) if quantize is None else "cpu" ) as f: for name in f.keys(): module_name, param_name = name.rsplit(".", 1) diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 2b37bd0f..031a67eb 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -89,7 +89,7 @@ class FlashSantacoder(FlashCausalLM): for filename in filenames: state_dict = torch.load(filename, map_location="cpu") for key, value in state_dict.items(): - value = value.to(device if not quantize else "cpu").to(dtype) + value = value.to(device if quantize is None else "cpu").to(dtype) layer_name = ".".join(key.split(".")[:4]) @@ -229,7 +229,7 @@ class FlashSantacoderSharded(FlashSantacoder): def load_weights( model, filenames: List[str], - quantize: bool, + quantize: Optional[str], device: torch.device, dtype: torch.dtype, rank: int, @@ -238,7 +238,7 @@ class FlashSantacoderSharded(FlashSantacoder): ): for file in filenames: with safe_open( - file, framework="pt", device=str(device) if not quantize else "cpu" + file, framework="pt", device=str(device) if quantize is None else "cpu" ) as f: for key in f.keys(): slice_ = f.get_slice(key) diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index a0111250..d1e5e841 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -255,7 +255,7 @@ class GalacticaSharded(Galactica): parameters = dict(model.named_parameters()) for file in filenames: with safe_open( - file, framework="pt", device=str(device) if not quantize else "cpu" + file, framework="pt", device=str(device) if quantize is None else "cpu" ) as f: for name in f.keys(): if name == "lm_head.weight": diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index 3e8557b2..f95e5be2 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -94,7 +94,7 @@ class GPTNeoxSharded(CausalLM): parameters = dict(model.named_parameters()) for file in filenames: with safe_open( - file, framework="pt", device=str(device) if not quantize else "cpu" + file, framework="pt", device=str(device) if quantize is None else "cpu" ) as f: for name in f.keys(): module_name, param_name = name.rsplit(".", 1) diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index 87b64a45..093cf70a 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -110,7 +110,7 @@ class OPTSharded(OPT): parameters = dict(model.named_parameters()) for file in filenames: with safe_open( - file, framework="pt", device=str(device) if not quantize else "cpu" + file, framework="pt", device=str(device) if quantize is None else "cpu" ) as f: for name in f.keys(): if name == "lm_head.weight": diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index c8521dbf..8e3826a4 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -97,7 +97,7 @@ class T5Sharded(Seq2SeqLM): parameters = dict(model.named_parameters()) for file in filenames: with safe_open( - file, framework="pt", device=str(device) if not quantize else "cpu" + file, framework="pt", device=str(device) if quantize is None else "cpu" ) as f: for name in f.keys(): module_name, param_name = name.rsplit(".", 1) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 3383bf4b..3386bc7d 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -1,6 +1,8 @@ import torch from torch import nn +from torch.nn import functional as F +from typing import Optional HAS_BITS_AND_BYTES = True try: @@ -22,7 +24,7 @@ class FastLinear(nn.Linear): self.quantized = False self.bnb_linear = None - def prepare_weights(self, quantize: bool = False): + def prepare_weights(self, quantize: Optional[str] = None): if quantize == "bitsandbytes": if not HAS_BITS_AND_BYTES: raise ImportError(