diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 53de36b2..42235315 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -23,6 +23,7 @@ mod env_runtime; enum Quantization { Bitsandbytes, Gptq, + Ct2, } impl std::fmt::Display for Quantization { @@ -35,6 +36,9 @@ impl std::fmt::Display for Quantization { Quantization::Gptq => { write!(f, "gptq") } + Quantization::Ct2 => { + write!(f, "ct2") + } } } } @@ -96,7 +100,7 @@ struct Args { num_shard: Option, /// Whether you want the model to be quantized. This will use `bitsandbytes` for - /// quantization on the fly, or `gptq`. + /// quantization on the fly, `gptq`, or ctranslate2. #[clap(long, env, value_enum)] quantize: Option, diff --git a/server/tests/models/test_ct2.py b/server/tests/models/test_ct2.py new file mode 100644 index 00000000..ab4b0587 --- /dev/null +++ b/server/tests/models/test_ct2.py @@ -0,0 +1,99 @@ +import pytest +import torch +from text_generation_server.pb import generate_pb2 +from text_generation_server.models.causal_lm import CausalLMBatch +from text_generation_server.models.ct2_causal_lm import CT2CausalLM + + +@pytest.fixture(scope="session") +def default_santacoder(): + return CT2CausalLM("bigcode/gpt_bigcode-santacoder", dtype=torch.float16) + + +@pytest.fixture +def default_pb_request(default_pb_parameters, default_pb_stop_parameters): + return generate_pb2.Request( + id=0, + inputs="def", + prefill_logprobs=True, + truncate=100, + parameters=default_pb_parameters, + stopping_parameters=default_pb_stop_parameters, + ) + + +@pytest.fixture +def default_pb_batch(default_pb_request): + return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1) + + +@pytest.fixture +def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters): + return generate_pb2.Request( + id=0, + inputs="defworld", + prefill_logprobs=True, + truncate=100, + parameters=default_pb_parameters, + stopping_parameters=default_pb_stop_parameters, + ) + + +@pytest.fixture +def default_fim_pb_batch(default_fim_pb_request): + return generate_pb2.Batch(id=0, requests=[default_fim_pb_request], size=1) + + +def test_ct2santa_generate_token_completion(default_santacoder, default_pb_batch): + batch = CausalLMBatch.from_pb( + default_pb_batch, + default_santacoder.tokenizer, + default_santacoder.dtype, + default_santacoder.device, + ) + next_batch = batch + + for _ in range(batch.stopping_criterias[0].max_new_tokens - 1): + generations, next_batch = default_santacoder.generate_token(next_batch) + assert len(generations) == len(next_batch) + + generations, next_batch = default_santacoder.generate_token(next_batch) + assert next_batch is None + + assert len(generations) == 1 + assert generations[0].generated_text.text in (" test_get_all_users_with_", ' test_get_all_users(client):') + assert generations[0].request_id == batch.requests[0].id + assert ( + generations[0].generated_text.generated_tokens + == batch.stopping_criterias[0].max_new_tokens + ) + + +def test_fim_ct2santacoder_generate_token_completion( + default_santacoder, default_fim_pb_batch +): + batch = CausalLMBatch.from_pb( + default_fim_pb_batch, + default_santacoder.tokenizer, + default_santacoder.dtype, + default_santacoder.device, + ) + next_batch = batch + + for _ in range(batch.stopping_criterias[0].max_new_tokens - 1): + generations, next_batch = default_santacoder.generate_token(next_batch) + assert len(generations) == len(next_batch) + + generations, next_batch = default_santacoder.generate_token(next_batch) + assert next_batch is None + + assert len(generations) == 1 + assert ( + generations[0].generated_text.text + == """ineProperty(exports, "__esModule", { value""" + ) + assert generations[0].request_id == batch.requests[0].id + assert ( + generations[0].generated_text.generated_tokens + == batch.stopping_criterias[0].max_new_tokens + ) diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index e74c0331..76cb9835 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -14,6 +14,7 @@ app = typer.Typer() class Quantization(str, Enum): bitsandbytes = "bitsandbytes" gptq = "gptq" + ct2 = "ct2" class Dtype(str, Enum): @@ -71,7 +72,7 @@ def serve( # Downgrade enum into str for easier management later on quantize = None if quantize is None else quantize.value dtype = None if dtype is None else dtype.value - if dtype is not None and quantize is not None: + if dtype is not None and quantize is not None and quantize != Quantization.ct2: raise RuntimeError( "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." ) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index ffc224cc..68092e29 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -18,6 +18,7 @@ from text_generation_server.models.galactica import GalacticaSharded from text_generation_server.models.santacoder import SantaCoder from text_generation_server.models.t5 import T5Sharded from text_generation_server.models.gpt_neox import GPTNeoxSharded +from text_generation_server.models.ct2_causal_lm import CT2CausalLM # The flag below controls whether to allow TF32 on matmul. This flag defaults to False # in PyTorch 1.12 and later. @@ -74,6 +75,7 @@ def get_model( dtype: Optional[str], trust_remote_code: bool, ) -> Model: + dtype_ct2 = dtype if dtype is None: dtype = torch.float16 elif dtype == "float16": @@ -83,6 +85,15 @@ def get_model( else: raise RuntimeError(f"Unknown dtype {dtype}") + if quantize is not None and "ct2" in quantize: + return CT2CausalLM( + model_id, + revision, + quantize=quantize, + dtype=dtype_ct2, + trust_remote_code=trust_remote_code, + ) + if "facebook/galactica" in model_id: return GalacticaSharded( model_id, diff --git a/server/text_generation_server/models/ct2_causal_lm.py b/server/text_generation_server/models/ct2_causal_lm.py new file mode 100644 index 00000000..81bffb70 --- /dev/null +++ b/server/text_generation_server/models/ct2_causal_lm.py @@ -0,0 +1,349 @@ +import torch +import inspect +import numpy as np +import os + +from dataclasses import dataclass +from opentelemetry import trace +from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, AutoConfig +from typing import Optional, Tuple, List, Type, Dict + +from text_generation_server.models import Model +from text_generation_server.models.types import ( + Batch, + PrefillTokens, + Generation, + GeneratedText, +) +from text_generation_server.pb import generate_pb2 +from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling +from text_generation_server.models.causal_lm import CausalLMBatch +tracer = trace.get_tracer(__name__) +from timeit import default_timer as td +try: + import ctranslate2 + from ctranslate2.converters import TransformersConverter +except ImportError: + ctranslate2 = None + + +class CT2CausalLM(Model): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + if ctranslate2 is None: + raise ValueError( + "for your configuration, pip install ctranslate2>=3.16.0 is required.", + ) + + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + + # Start CT2 + if torch.cuda.is_available(): + self.ct2_device = "cuda" + else: + self.ct2_device = "cpu" + + if dtype==torch.float16: + ct2_compute_type = "float16" + elif dtype==torch.float16: + ct2_compute_type = "bfloat16" + else: + # default, int8 quantization. + if "cuda" in self.ct2_device: + ct2_compute_type = "int8_float16" + else: + ct2_compute_type = "int8" + # raise ValueError("cpu is currently experimental due to" + # " sampling based / non-greedy next_token" + # " of code only working in float16.") + # Start CT2 - conversion + out_dir = f"./ct2-{model_id.replace('/','_')}-{ct2_compute_type}" + if not os.path.exists(os.path.join(out_dir, "model.bin")): + ex = "" + try: + converter = ctranslate2.converters.TransformersConverter( + model_id, + activation_scales=None, + load_as_float16=True, + revision=None, + low_cpu_mem_usage=True, + trust_remote_code=True, + ) + converter.convert( + output_dir=out_dir, + vmap = None, + quantization=ct2_compute_type, + force = True, + ) + except Exception as ex: + pass + if not os.path.exists(os.path.join(out_dir, "model.bin")) or ex: + raise ValueError(f"conversion for {model_id} failed with ctranslate2: Error {ex}") + + # Start CT2 + self.ct2_model = ctranslate2.Generator(out_dir, device=self.ct2_device, compute_type=ct2_compute_type) + + class DummyModel(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.config = AutoConfig.from_pretrained( + model_id, + revision=revision) + model = DummyModel() + self.vocab_size = model.config.vocab_size + + if tokenizer.pad_token_id is None: + if model.config.pad_token_id is not None: + tokenizer.pad_token_id = model.config.pad_token_id + elif model.config.eos_token_id is not None: + tokenizer.pad_token_id = model.config.eos_token_id + elif tokenizer.eos_token_id is not None: + tokenizer.pad_token_id = tokenizer.eos_token_id + else: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + + super(CT2CausalLM, self).__init__( + model=model, + tokenizer=tokenizer, + requires_padding=True, + dtype=torch.int32, + device=torch.device("cuda"), + ) + + @property + def batch_type(self) -> Type[CausalLMBatch]: + return CausalLMBatch + + def decode(self, generated_ids: List[int]) -> str: + return self.tokenizer.decode( + generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + + def forward_true_logits( + self, all_input_ids, + ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: + # Model Forward + tokens_in = [self.tokenizer.convert_ids_to_tokens(i) for i in all_input_ids] + logits = self.ct2_model.forward_batch( + tokens_in + ) + logits = torch.as_tensor(logits, device="cuda") + logits = logits.to("cuda").to(torch.float16) + return logits, None + + def forward_true_logits2( + self, all_input_ids, input_lengths, + ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: + # Model Forward + ids_input = torch.nested.to_padded_tensor(torch.nested.nested_tensor(all_input_ids), -1).flatten(1).to(torch.int32) + lengths = torch.from_numpy(np.array(input_lengths, dtype=np.int32)).to(ids_input.device) + if self.ct2_device == "cpu": + ids_input, lengths = ids_input.numpy(), lengths.numpy() + ids_input = ctranslate2.StorageView.from_array(ids_input) + lengths = ctranslate2.StorageView.from_array(lengths) + logits = self.ct2_model.forward_batch( + ids_input, lengths + ) + logits = torch.as_tensor(logits, device=self.ct2_device) + if self.ct2_device == "cpu": + logits = logits.to(self.ct2_device).to(torch.float32) + else: + logits = logits.to("cuda").to(torch.float16) + return logits, None + + def forward_patched_logits( + self, all_input_ids: List[List[int]], + ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: + # Model Forward + tokens_in = [self.tokenizer.convert_ids_to_tokens(i) for i in all_input_ids] + ids = self.ct2_model.generate_batch( + tokens_in, + min_length=1, + max_length=1, + include_prompt_in_result=False, + sampling_temperature=0, + ) + logits = torch.full((len(tokens_in), 1, self.vocab_size), -10, dtype=torch.float16, device="cuda") + for i, seq in enumerate(ids): + token = seq.sequences_ids[0] + logits[i, 0, token] = 10 + + return logits, None + + @tracer.start_as_current_span("generate_token") + def generate_token( + self, batch: CausalLMBatch + ) -> Tuple[List[Generation], Optional[CausalLMBatch]]: + # slice the attention mask to the correct shape + # attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] + + # one = -td() + logits, past = self.forward_true_logits2( + batch.all_input_ids, batch.input_lengths + ) + # one += td() + + + # one = -td() + # two = -td() + # logits2, past2 = self.forward_true_logits( + # batch.all_input_ids + # ) + # two += td() + + # diff = two - one + # if 1020 > batch.input_lengths[0] > 30: + # raise ValueError(f"one took {one}, two took {two}, {batch.input_lengths}") + # if sum := torch.isnan(logits).sum(): + # sum2 = torch.isnan(logits2).sum() + # raise ValueError(f"logits {sum}, {sum2}") + # if sum2 := torch.isnan(logits2).sum(): + # raise ValueError(f"logits2 {sum}") + # torch.testing.assert_close(logits, logits2) + # raise ValueError(f"all_input_ids={len(batch.all_input_ids)},{batch.all_input_ids[0].shape}, logits={logits.shape}, tokens_in={len(tokens_in)},{len(tokens_in[0])}") + + + # Results + generations: List[Generation] = [] + stopped = True + + # Zipped iterator + iterator = zip( + batch.requests, + batch.input_lengths, + batch.prefix_offsets, + batch.read_offsets, + logits, + batch.next_token_choosers, + batch.stopping_criterias, + batch.all_input_ids, + ) + + # For each member of the batch + for i, ( + request, + input_length, + prefix_offset, + read_offset, + logits, + next_token_chooser, + stopping_criteria, + all_input_ids, + ) in enumerate(iterator): + # Select next token + next_token_id, logprobs = next_token_chooser( + all_input_ids.view(1, -1), logits[-1:, :] + ) + + # Append next token to all tokens + all_input_ids = torch.cat([all_input_ids, next_token_id]) + new_input_length = input_length + 1 + + # Generated token + next_token_logprob = logprobs[-1, next_token_id] + next_token_id_squeezed = next_token_id.squeeze() + next_token_text, prefix_offset, read_offset = self.decode_token( + all_input_ids[:, 0], prefix_offset, read_offset + ) + + # Evaluate stopping criteria + stop, reason = stopping_criteria( + next_token_id_squeezed, + next_token_text, + ) + + if not stop: + stopped = False + + # Shard generations + # All generations will be appended in the rust sharded client + if i % self.world_size == self.rank: + if stop: + # Decode generated tokens + output_text = self.decode( + all_input_ids[-stopping_criteria.current_tokens :, 0] + ) + # Get seed + if isinstance(next_token_chooser.choice, Sampling): + seed = next_token_chooser.choice.seed + else: + seed = None + + generated_text = GeneratedText( + output_text, stopping_criteria.current_tokens, reason, seed + ) + else: + generated_text = None + + # Prefill + if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: + # Remove generated token to only have prefill and add nan for first prompt token + + prefill_logprobs = [float("nan")] + torch.log_softmax( + logits, -1 + ).gather(1, all_input_ids[1:]).squeeze(1)[ + -new_input_length:-1 + ].tolist() + prefill_token_ids = all_input_ids[-new_input_length:-1] + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + prefill_tokens = PrefillTokens( + prefill_token_ids, prefill_logprobs, prefill_texts + ) + else: + prefill_tokens = None + + generation = Generation( + request.id, + prefill_tokens, + next_token_id_squeezed, + next_token_logprob, + next_token_text, + next_token_id_squeezed.item() in self.all_special_ids, + generated_text, + ) + + generations.append(generation) + + # Update values + batch.input_ids[i, 0] = next_token_id + batch.all_input_ids[i] = all_input_ids + batch.input_lengths[i] = new_input_length + batch.prefix_offsets[i] = prefix_offset + batch.read_offsets[i] = read_offset + batch.max_input_length = max(batch.max_input_length, new_input_length) + + # We finished all generations in the batch; there is no next batch + if stopped: + return generations, None + + # Slice unused values from prefill + batch.input_ids = batch.input_ids[:, :1] + + # Update attention_mask as we added a new token to input_ids + batch.attention_mask[:, -batch.padding_right_offset] = 1 + # Decrease right offset + batch.padding_right_offset -= 1 + + # Update position_ids + batch.position_ids = batch.position_ids[:, -1:] + 1 + + # Update past key values + batch.past_key_values = past + + return generations, batch