From 0036084294ba3ed40f6ff8f6cf81d15305787cf0 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Thu, 13 Jul 2023 15:41:57 +0000 Subject: [PATCH] support all, test llama --- integration-tests/conftest.py | 6 +- .../test_flash_llama_gptq.json | 102 +++++ .../test_flash_llama_gptq_all_params.json | 97 +++++ .../test_flash_llama_gptq_load.json | 410 ++++++++++++++++++ .../models/test_flash_llama_gptq.py | 58 +++ .../models/causal_lm.py | 1 + .../custom_modeling/flash_llama_modeling.py | 2 +- .../flash_santacoder_modeling.py | 13 +- .../models/flash_causal_lm.py | 6 +- .../models/flash_llama.py | 1 + .../models/flash_neox.py | 1 + .../text_generation_server/models/flash_rw.py | 1 + .../models/flash_santacoder.py | 1 + .../models/galactica.py | 1 + .../text_generation_server/models/gpt_neox.py | 1 + server/text_generation_server/models/model.py | 42 +- server/text_generation_server/models/mpt.py | 1 + server/text_generation_server/models/opt.py | 1 + server/text_generation_server/models/rw.py | 1 + .../models/santacoder.py | 1 + .../models/seq2seq_lm.py | 1 + server/text_generation_server/models/t5.py | 1 + .../text_generation_server/utils/weights.py | 44 +- 23 files changed, 740 insertions(+), 53 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq.json create mode 100644 integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_all_params.json create mode 100644 integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_load.json create mode 100644 integration-tests/models/test_flash_llama_gptq.py diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 812b1d18..e6b577d4 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -230,15 +230,19 @@ def launcher(event_loop): shard_uds_path, ] + env = os.environ + if num_shard is not None: args.extend(["--num-shard", str(num_shard)]) if quantize is not None: args.append("--quantize") args.append(quantize) + if quantize == "gptq": + env["GPTQ_GROUPSIZE"] = "128" + env["GPTQ_BITS"] = "4" if trust_remote_code: args.append("--trust-remote-code") - env = os.environ env["LOG_LEVEL"] = "info,text_generation_router=debug" if not use_flash_attention: diff --git a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq.json b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq.json new file mode 100644 index 00000000..6bd6c729 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq.json @@ -0,0 +1,102 @@ +{ + "generated_text": ", and I am going to visit the Louvre", + "details": { + "finish_reason": "length", + "generated_tokens": 10, + "seed": null, + "prefill": [ + { + "id": 2, + "text": "", + "logprob": null + }, + { + "id": 20628, + "text": "Today", + "logprob": -11.2265625 + }, + { + "id": 306, + "text": "I", + "logprob": -4.1757812 + }, + { + "id": 626, + "text": "am", + "logprob": -1.9746094 + }, + { + "id": 297, + "text": "in", + "logprob": -5.4648438 + }, + { + "id": 3444, + "text": "France", + "logprob": -9.03125 + } + ], + "tokens": [ + { + "id": 29892, + "text": ",", + "logprob": -0.31298828, + "special": false + }, + { + "id": 322, + "text": " and", + "logprob": -1.4345703, + "special": false + }, + { + "id": 306, + "text": " I", + "logprob": -0.32080078, + "special": false + }, + { + "id": 626, + "text": " am", + "logprob": -1.3798828, + "special": false + }, + { + "id": 2675, + "text": " going", + "logprob": -1.2304688, + "special": false + }, + { + "id": 304, + "text": " to", + "logprob": -0.0014791489, + "special": false + }, + { + "id": 6493, + "text": " visit", + "logprob": -1.1503906, + "special": false + }, + { + "id": 278, + "text": " the", + "logprob": -0.41259766, + "special": false + }, + { + "id": 4562, + "text": " Lou", + "logprob": -1.8134766, + "special": false + }, + { + "id": 12675, + "text": "vre", + "logprob": -0.000767231, + "special": false + } + ] + } +} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_all_params.json b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_all_params.json new file mode 100644 index 00000000..687e1784 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_all_params.json @@ -0,0 +1,97 @@ +{ + "generated_text": "The capital city of France isParis.\n The Best Way to Visit", + "details": { + "finish_reason": "length", + "generated_tokens": 10, + "seed": 0, + "prefill": [ + { + "id": 2, + "text": "", + "logprob": null + }, + { + "id": 4272, + "text": "city", + "logprob": -12.4453125 + }, + { + "id": 310, + "text": "of", + "logprob": -2.4023438 + }, + { + "id": 3444, + "text": "France", + "logprob": -12.515625 + }, + { + "id": 338, + "text": "is", + "logprob": -5.1914062 + } + ], + "tokens": [ + { + "id": 3681, + "text": " Paris", + "logprob": -0.22546387, + "special": false + }, + { + "id": 29889, + "text": ".", + "logprob": 0, + "special": false + }, + { + "id": 13, + "text": "\n", + "logprob": 0, + "special": false + }, + { + "id": 1, + "text": "", + "logprob": 0, + "special": false + }, + { + "id": 450, + "text": " The", + "logprob": 0, + "special": false + }, + { + "id": 6407, + "text": " Best", + "logprob": -0.5522461, + "special": false + }, + { + "id": 5307, + "text": " Way", + "logprob": 0, + "special": false + }, + { + "id": 304, + "text": " to", + "logprob": 0, + "special": false + }, + { + "id": 5741, + "text": " Vis", + "logprob": -2.3496094, + "special": false + }, + { + "id": 277, + "text": "it", + "logprob": 0, + "special": false + } + ] + } +} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_load.json b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_load.json new file mode 100644 index 00000000..4099be28 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_load.json @@ -0,0 +1,410 @@ +[ + { + "generated_text": ", and I am going to visit the Louvre", + "details": { + "finish_reason": "length", + "generated_tokens": 10, + "seed": null, + "prefill": [ + { + "id": 2, + "text": "", + "logprob": null + }, + { + "id": 20628, + "text": "Today", + "logprob": -10.734375 + }, + { + "id": 306, + "text": "I", + "logprob": -4.2265625 + }, + { + "id": 626, + "text": "am", + "logprob": -1.9794922 + }, + { + "id": 297, + "text": "in", + "logprob": -5.4960938 + }, + { + "id": 3444, + "text": "France", + "logprob": -9.1171875 + } + ], + "tokens": [ + { + "id": 29892, + "text": ",", + "logprob": -0.30737305, + "special": false + }, + { + "id": 322, + "text": " and", + "logprob": -1.3701172, + "special": false + }, + { + "id": 306, + "text": " I", + "logprob": -0.31567383, + "special": false + }, + { + "id": 626, + "text": " am", + "logprob": -1.3886719, + "special": false + }, + { + "id": 2675, + "text": " going", + "logprob": -1.2070312, + "special": false + }, + { + "id": 304, + "text": " to", + "logprob": -0.0014028549, + "special": false + }, + { + "id": 6493, + "text": " visit", + "logprob": -1.1181641, + "special": false + }, + { + "id": 278, + "text": " the", + "logprob": -0.3942871, + "special": false + }, + { + "id": 4562, + "text": " Lou", + "logprob": -1.8789062, + "special": false + }, + { + "id": 12675, + "text": "vre", + "logprob": -0.00082969666, + "special": false + } + ] + } + }, + { + "generated_text": ", and I am going to visit the Louvre", + "details": { + "finish_reason": "length", + "generated_tokens": 10, + "seed": null, + "prefill": [ + { + "id": 2, + "text": "", + "logprob": null + }, + { + "id": 20628, + "text": "Today", + "logprob": -10.734375 + }, + { + "id": 306, + "text": "I", + "logprob": -4.2265625 + }, + { + "id": 626, + "text": "am", + "logprob": -1.9794922 + }, + { + "id": 297, + "text": "in", + "logprob": -5.4960938 + }, + { + "id": 3444, + "text": "France", + "logprob": -9.1171875 + } + ], + "tokens": [ + { + "id": 29892, + "text": ",", + "logprob": -0.30737305, + "special": false + }, + { + "id": 322, + "text": " and", + "logprob": -1.3720703, + "special": false + }, + { + "id": 306, + "text": " I", + "logprob": -0.31469727, + "special": false + }, + { + "id": 626, + "text": " am", + "logprob": -1.3916016, + "special": false + }, + { + "id": 2675, + "text": " going", + "logprob": -1.2050781, + "special": false + }, + { + "id": 304, + "text": " to", + "logprob": -0.0014019012, + "special": false + }, + { + "id": 6493, + "text": " visit", + "logprob": -1.1162109, + "special": false + }, + { + "id": 278, + "text": " the", + "logprob": -0.3959961, + "special": false + }, + { + "id": 4562, + "text": " Lou", + "logprob": -1.8847656, + "special": false + }, + { + "id": 12675, + "text": "vre", + "logprob": -0.0008392334, + "special": false + } + ] + } + }, + { + "generated_text": ", and I am going to visit the Louvre", + "details": { + "finish_reason": "length", + "generated_tokens": 10, + "seed": null, + "prefill": [ + { + "id": 2, + "text": "", + "logprob": null + }, + { + "id": 20628, + "text": "Today", + "logprob": -10.734375 + }, + { + "id": 306, + "text": "I", + "logprob": -4.2265625 + }, + { + "id": 626, + "text": "am", + "logprob": -1.9794922 + }, + { + "id": 297, + "text": "in", + "logprob": -5.4960938 + }, + { + "id": 3444, + "text": "France", + "logprob": -9.1171875 + } + ], + "tokens": [ + { + "id": 29892, + "text": ",", + "logprob": -0.30737305, + "special": false + }, + { + "id": 322, + "text": " and", + "logprob": -1.3710938, + "special": false + }, + { + "id": 306, + "text": " I", + "logprob": -0.31225586, + "special": false + }, + { + "id": 626, + "text": " am", + "logprob": -1.3994141, + "special": false + }, + { + "id": 2675, + "text": " going", + "logprob": -1.2060547, + "special": false + }, + { + "id": 304, + "text": " to", + "logprob": -0.0013828278, + "special": false + }, + { + "id": 6493, + "text": " visit", + "logprob": -1.1181641, + "special": false + }, + { + "id": 278, + "text": " the", + "logprob": -0.39135742, + "special": false + }, + { + "id": 4562, + "text": " Lou", + "logprob": -1.8808594, + "special": false + }, + { + "id": 12675, + "text": "vre", + "logprob": -0.00084352493, + "special": false + } + ] + } + }, + { + "generated_text": ", and I am going to visit the Louvre", + "details": { + "finish_reason": "length", + "generated_tokens": 10, + "seed": null, + "prefill": [ + { + "id": 2, + "text": "", + "logprob": null + }, + { + "id": 20628, + "text": "Today", + "logprob": -11.203125 + }, + { + "id": 306, + "text": "I", + "logprob": -4.1757812 + }, + { + "id": 626, + "text": "am", + "logprob": -1.9697266 + }, + { + "id": 297, + "text": "in", + "logprob": -5.4609375 + }, + { + "id": 3444, + "text": "France", + "logprob": -9.046875 + } + ], + "tokens": [ + { + "id": 29892, + "text": ",", + "logprob": -0.3083496, + "special": false + }, + { + "id": 322, + "text": " and", + "logprob": -1.4228516, + "special": false + }, + { + "id": 306, + "text": " I", + "logprob": -0.32055664, + "special": false + }, + { + "id": 626, + "text": " am", + "logprob": -1.3847656, + "special": false + }, + { + "id": 2675, + "text": " going", + "logprob": -1.21875, + "special": false + }, + { + "id": 304, + "text": " to", + "logprob": -0.0014572144, + "special": false + }, + { + "id": 6493, + "text": " visit", + "logprob": -1.1542969, + "special": false + }, + { + "id": 278, + "text": " the", + "logprob": -0.41455078, + "special": false + }, + { + "id": 4562, + "text": " Lou", + "logprob": -1.8193359, + "special": false + }, + { + "id": 12675, + "text": "vre", + "logprob": -0.0007710457, + "special": false + } + ] + } + } +] diff --git a/integration-tests/models/test_flash_llama_gptq.py b/integration-tests/models/test_flash_llama_gptq.py new file mode 100644 index 00000000..2a00a2b1 --- /dev/null +++ b/integration-tests/models/test_flash_llama_gptq.py @@ -0,0 +1,58 @@ + +import pytest + + +@pytest.fixture(scope="module") +def flash_llama_gptq_handle(launcher): + with launcher("TheBloke/WizardLM-7B-uncensored-GPTQ", num_shard=2, quantize="gptq") as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llama_gptq(flash_llama_gptq_handle): + await flash_llama_gptq_handle.health(300) + return flash_llama_gptq_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_gptq(flash_llama_gptq, response_snapshot): + response = await flash_llama_gptq.generate( + "Today I am in France", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot): + response = await flash_llama_gptq.generate( + "The capital city of France is", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_gptq_load(flash_llama_gptq, generate_load, response_snapshot): + responses = await generate_load(flash_llama_gptq, "Today I am in France", max_new_tokens=10, n=4) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index cbdf4808..3bad2daf 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -500,6 +500,7 @@ class CausalLM(Model): super(CausalLM, self).__init__( model=model, tokenizer=tokenizer, + config=model.config, requires_padding=True, dtype=dtype, device=device, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index d9f3c7b8..626404e6 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -298,7 +298,6 @@ class FlashLlamaLayer(nn.Module): class FlashLlamaModel(torch.nn.Module): def __init__(self, config, weights): super().__init__() - self.config = config process_group = weights.process_group self.tp_rank = process_group.rank() @@ -368,6 +367,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): def __init__(self, config, weights): super().__init__() + self.config = config self.model = FlashLlamaModel(config, weights) self.lm_head = TensorParallelHead.load( config, diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index a0da1e20..43dc3606 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -73,17 +73,7 @@ def _load_multi_mqa_gptq( qzeros = torch.cat([q_tensor, kv_tensor], dim=1) g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") - try: - bits = weights.get_tensor("gptq_bits").item() - groupsize = weights.get_tensor("gptq_groupsize").item() - except SafetensorError as e: - try: - import os - - bits = int(os.getenv("GPTQ_BITS")) - groupsize = int(os.getenv("GPTQ_GROUPSIZE")) - except Exception: - raise e + bits, groupsize = weights.get_gptq_qparams() qweight = qweight.to(weights.device) qzeros = qzeros.to(weights.device) @@ -471,7 +461,6 @@ class FlashSantacoderForCausalLM(nn.Module): self.lm_head = TensorParallelHead.load( config, prefix="transformer.wte", weights=weights ) - self.config = config def forward( self, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 4e5804f5..78db35f0 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -6,9 +6,8 @@ import torch.distributed import numpy as np from dataclasses import dataclass -from loguru import logger from opentelemetry import trace -from transformers import PreTrainedTokenizerBase +from transformers import PreTrainedTokenizerBase, PretrainedConfig from typing import Optional, Tuple, List, Type, Union, Dict from text_generation_server.models import Model @@ -21,6 +20,7 @@ from text_generation_server.models.types import ( from text_generation_server.pb import generate_pb2 from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser + tracer = trace.get_tracer(__name__) BLOCK_SIZE = 16 @@ -684,6 +684,7 @@ class FlashCausalLM(Model): self, model: torch.nn.Module, tokenizer: PreTrainedTokenizerBase, + config: PretrainedConfig, num_layers: int, num_kv_heads: int, head_size: int, @@ -699,6 +700,7 @@ class FlashCausalLM(Model): super(FlashCausalLM, self).__init__( model=model, tokenizer=tokenizer, + config=config, requires_padding=False, dtype=dtype, device=device, diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 417ccabb..3fd17a73 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -68,6 +68,7 @@ class FlashLlama(FlashCausalLM): super(FlashLlama, self).__init__( model=model, tokenizer=tokenizer, + config=config, num_layers=len(model.model.layers), num_kv_heads=model.model.num_heads, head_size=model.model.head_size, diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 61004d8e..ce868904 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -59,6 +59,7 @@ class FlashNeoXSharded(FlashCausalLM): super(FlashNeoXSharded, self).__init__( model=model.to(device), tokenizer=tokenizer, + config=config, num_layers=len(model.gpt_neox.layers), num_kv_heads=model.gpt_neox.num_heads, head_size=model.gpt_neox.head_size, diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index 55d555fc..2006993f 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -65,6 +65,7 @@ class FlashRWSharded(FlashCausalLM): super(FlashRWSharded, self).__init__( model=model.to(device), tokenizer=tokenizer, + config=config, num_layers=len(model.transformer.h), num_kv_heads=model.transformer.cache_size, head_size=model.transformer.head_size, diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 415ec2df..ecbd6b31 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -66,6 +66,7 @@ class FlashSantacoderSharded(FlashCausalLM): super(FlashSantacoderSharded, self).__init__( model=model.to(device), tokenizer=tokenizer, + config=config, num_layers=len(model.transformer.h), num_kv_heads=1, head_size=model.transformer.head_size, diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 01e58bad..d6885d5f 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -198,6 +198,7 @@ class GalacticaSharded(CausalLM): super(CausalLM, self).__init__( model=model, tokenizer=tokenizer, + config=config, requires_padding=True, dtype=dtype, device=device, diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index 91877fa0..2b55a70b 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -63,6 +63,7 @@ class GPTNeoxSharded(CausalLM): super(CausalLM, self).__init__( model=model, tokenizer=tokenizer, + config=config, requires_padding=True, dtype=dtype, device=device, diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 95dd5447..c686506b 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -3,7 +3,7 @@ import torch from abc import ABC, abstractmethod from typing import List, Tuple, Optional, TypeVar, Type -from transformers import PreTrainedTokenizerBase +from transformers import PreTrainedTokenizerBase, PretrainedConfig from text_generation_server.models.types import Batch, GeneratedText from text_generation_server.pb.generate_pb2 import InfoResponse @@ -23,6 +23,7 @@ class Model(ABC): self, model: torch.nn.Module, tokenizer: PreTrainedTokenizerBase, + config: PretrainedConfig, requires_padding: bool, dtype: torch.dtype, device: torch.device, @@ -45,24 +46,41 @@ class Model(ABC): inspect.signature(model.forward).parameters.get("position_ids", None) is not None ) + self.config = config - if model.config.quantize == "gptq": + if config.quantize == "gptq": # Buffers need to be persistent to avoid any bug. self.buffers = {} - max_dq_buffer_size = 0 - for name, submodule in self.model.named_modules(): + use_exllama_act_order = False + max_dq_buffer_size = 1 + max_inner_outer_dim = 1 + for name, submodule in model.named_modules(): if isinstance(submodule, (TensorParallelColumnLinear, TensorParallelRowLinear)) and isinstance(submodule.linear, Ex4bitLinear): - max_dq_buffer_size = max(max_dq_buffer_size, submodule.linear.qweight.numel() * 8) - - intermediate_size = model.config.n_inner - max_seq_len = 2048 # TODO: we should be able to set it - - self.buffers["temp_state"] = torch.zeros((max_seq_len, intermediate_size), dtype=torch.float16, device=device) - self.buffers["temp_dq"] = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device) + max_dq_buffer_size = max(max_dq_buffer_size, submodule.linear.qweight.numel() * 8) + + if submodule.linear.act_order: + max_inner_outer_dim = max(max_inner_outer_dim, submodule.linear.height, submodule.linear.width) + + use_exllama_act_order = True + + if use_exllama_act_order: + # TODO: this should be set to rust side `max_total_tokens`, but TGI + # does not offer an API to expose this variable to python, as this variable + # is handled by the client but it appears the model is initialized by the server. + # An alternative could be to initialize the buffers during warmup. + max_total_tokens = 2048 + else: + max_total_tokens = 1 + + # This temp_state buffer is required to reorder X in the act-order case. + self.buffers["temp_state"] = torch.zeros((max_total_tokens, max_inner_outer_dim), dtype=torch.float16, device=device) + + # This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. + self.buffers["temp_dq"] = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device) + prepare_buffers(device, self.buffers["temp_state"], self.buffers["temp_dq"]) - # TODO: ability to set them matmul_recons_thd = 8 matmul_fused_remap = False matmul_no_half2 = False diff --git a/server/text_generation_server/models/mpt.py b/server/text_generation_server/models/mpt.py index a4fe5105..8d7d68ab 100644 --- a/server/text_generation_server/models/mpt.py +++ b/server/text_generation_server/models/mpt.py @@ -86,6 +86,7 @@ class MPTSharded(CausalLM): super(CausalLM, self).__init__( model=model, tokenizer=tokenizer, + config=config, requires_padding=False, dtype=dtype, device=device, diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index d407b44a..8e64b1ab 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -61,6 +61,7 @@ class OPTSharded(CausalLM): super(CausalLM, self).__init__( model=model, tokenizer=tokenizer, + config=config, requires_padding=True, dtype=dtype, device=device, diff --git a/server/text_generation_server/models/rw.py b/server/text_generation_server/models/rw.py index 92bb135b..796a9590 100644 --- a/server/text_generation_server/models/rw.py +++ b/server/text_generation_server/models/rw.py @@ -58,6 +58,7 @@ class RW(CausalLM): super(CausalLM, self).__init__( model=model, tokenizer=tokenizer, + config=model.config, requires_padding=True, dtype=dtype, device=device, diff --git a/server/text_generation_server/models/santacoder.py b/server/text_generation_server/models/santacoder.py index a2b38737..f8f72e9c 100644 --- a/server/text_generation_server/models/santacoder.py +++ b/server/text_generation_server/models/santacoder.py @@ -63,6 +63,7 @@ class SantaCoder(CausalLM): super(CausalLM, self).__init__( model=model, tokenizer=tokenizer, + config=model.config, requires_padding=True, dtype=dtype, device=device, diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 9e5c21d1..f6e744f5 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -542,6 +542,7 @@ class Seq2SeqLM(Model): super(Seq2SeqLM, self).__init__( model=model, tokenizer=tokenizer, + config=model.config, requires_padding=True, dtype=dtype, device=device, diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 133aafd8..890141e2 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -73,6 +73,7 @@ class T5Sharded(Seq2SeqLM): super(Seq2SeqLM, self).__init__( model=model, tokenizer=tokenizer, + config=config, requires_padding=True, dtype=dtype, device=device, diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index ff18d656..4a9cb983 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -1,9 +1,8 @@ from pathlib import Path -from typing import List, Dict, Optional +from typing import List, Dict, Optional, Tuple from safetensors import safe_open, SafetensorError import torch - class Weights: def __init__( self, @@ -127,17 +126,7 @@ class Weights: torch.testing.assert_close(w2, w[0]) g_idx = w[0] - try: - bits = self.get_tensor("gptq_bits").item() - groupsize = self.get_tensor("gptq_groupsize").item() - except SafetensorError as e: - try: - import os - - bits = int(os.getenv("GPTQ_BITS")) - groupsize = int(os.getenv("GPTQ_GROUPSIZE")) - except Exception: - raise e + bits, groupsize = self.get_gptq_qparams() weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) else: w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] @@ -149,7 +138,7 @@ class Weights: use_triton_kernel = False if self.process_group.size() > 1: g_idx = self.get_tensor(f"{prefix}.g_idx") - groupsize = self.get_tensor("gptq_groupsize").item() + _, groupsize = self.get_gptq_qparams() if g_idx is not None: if not torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) and not (g_idx == 0).all(): @@ -180,19 +169,24 @@ class Weights: else: g_idx = None - try: - bits = self.get_tensor("gptq_bits").item() - groupsize = self.get_tensor("gptq_groupsize").item() - except SafetensorError as e: - try: - import os - - bits = int(os.getenv("GPTQ_BITS")) - groupsize = int(os.getenv("GPTQ_GROUPSIZE")) - except Exception: - raise e + bits, groupsize = self.get_gptq_qparams() weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_triton_kernel) else: weight = self.get_sharded(f"{prefix}.weight", dim=1) return weight + + def get_gptq_qparams(self) -> Tuple[int, int]: + try: + bits = self.get_tensor("gptq_bits").item() + groupsize = self.get_tensor("gptq_groupsize").item() + except (SafetensorError, RuntimeError) as e: + try: + import os + + bits = int(os.getenv("GPTQ_BITS")) + groupsize = int(os.getenv("GPTQ_GROUPSIZE")) + except Exception: + raise e + + return bits, groupsize