diff --git a/interaction.ipynb b/interaction.ipynb index 94c8b85b..ea45c44d 100644 --- a/interaction.ipynb +++ b/interaction.ipynb @@ -2095,6 +2095,33 @@ "print(f\"{sequence}{pipeline(sequences=sequence).sequences[0]}\")" ] }, + { + "cell_type": "code", + "execution_count": 279, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using pad_token, but it is not set yet.\n", + "2023-08-20 13:44:57 deepsparse.transformers.pipelines.text_generation INFO Compiling an auxiliary engine to process a prompt with a larger processing length. This improves performance, but may result in additional memory consumption.\n", + "2023-08-20 13:44:58 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n", + "2023-08-20 13:45:23 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n" + ] + } + ], + "source": [ + "pipeline2 = deepsparse.Pipeline.create(\n", + " task=\"text-generation\", \n", + " model_path=\"zoo:nlg/text_generation/codegen_mono-350m/pytorch/huggingface/bigpython_bigquery_thepile/base-none\",\n", + " use_deepsparse_cache=False,\n", + " prompt_processing_sequence_length=4,\n", + " max_generated_tokens=64,\n", + " sequence_length=128\n", + ")" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/server/text_generation_server/models/deepsparse_causal_lm.py b/server/text_generation_server/models/deepsparse_causal_lm.py new file mode 100644 index 00000000..1037263b --- /dev/null +++ b/server/text_generation_server/models/deepsparse_causal_lm.py @@ -0,0 +1,127 @@ +import numpy, torch +from dataclasses import dataclass +from typing import Optional, Tuple, List, Type, Dict + +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +from text_generation_server.models import Model +from text_generation_server.models.types import ( + Batch, + PrefillTokens, + Generation, + GeneratedText, +) + +from text_generation_server.pb import generate_pb2 +from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling + +DEEPSPARSE_SEQUENCE_LENGTH = 128 + +@dataclass +class DeepSparseCausalLMBatch(Batch): + batch_id: int + requests: List[generate_pb2.Request] + requests_idx_mapping: Dict[int, int] + + # TODO: update to handle calculating max_tokens --- needed for CachedBatch + + # Decoder values + input_ids_list: List[numpy.ndarray] + past_key_values_list: Optional[List[DeepSparsePastKeyValues]] + + # Generation helpers + next_token_choosers: List[NextTokenChooser] + stopping_criterias: List[StoppingCriteria] + + def to_pb(self) -> generate_pb2.CachedBatch: + return generate_pb2.CachedBatch( + id=self.batch_id, + request_ids=[r.id for r in self.requests], + size=len(self), + max_tokens=self.max_tokens, + ) + + @classmethod + def from_pb( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + dtype: torch.dtype, + device: torch.device, + ) -> "DeepSparseCausalLMBatch": + + # parse batch + input_ids_list = [] + next_token_choosers = [] + stopping_criterias = [] + requests_idx_mapping = {} + + # setup tokenizer for deepsparse left padding + tokenizer.padding_side = "left" + if not tokenizer.pad_token: + tokenizer.pad_token = tokenizer.eos_token + padding, truncation = "longest", False + + # loop through items in the batch + for i, r in enumerate(pb.requests): + # get mapping + requests_idx_mapping[r.id] = i + + # setup inputs + tokenized_inputs = tokenizer( + r.inputs, + return_tensors="np", + padding=padding, + truncation=truncation, + return_token_type_ids=False, + max_length=DEEPSPARSE_SEQUENCE_LENGTH + ) + input_ids_list.append(tokenized_inputs["input_ids"]) + + # setup sequence generation helpers, capping at DEEPSPARSE_SEQUENCE_LENGTH + # cap stopping parameters to DeepSparse sequence length + input_len = tokenized_inputs["input_ids"].shape[1] + assert DEEPSPARSE_SEQUENCE_LENGTH - input_len > 0 + r.stopping_parameters.max_new_tokens = min( + r.stopping_parameters.max_new_tokens, + DEEPSPARSE_SEQUENCE_LENGTH - input_len + ) + stopping_criterias.append(StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) + + return cls( + batch_id=pb.id, + requests=pb.requests, + requests_idx_mapping=requests_idx_mapping, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + input_ids_list=input_ids_list, + past_key_values_list=None + ) + + def __len__(self): + return len(self.requests) + + def filter(self, request_ids: List[int]) -> Optional["DeepSparseCausalLMBatch"]: + pass + + def concatenate(cls, batches: List["DeepSparseCausalLMBatch"]) -> "DeepSparseCausalLMBatch": + pass + + +class DeepSparseCausalLM: + def __init__( + self, + deployment_path: str + ): + self.tokenizer = AutoTokenizer.from_pretrained(deployment_path) + self.tokenizer.padding_side = "left" + if not self.tokenizer.pad_token: + assert self.tokenizer.eos_token + self.tokenizer.pad_token = self.tokenizer.eos_token + + + + @property + def batch_type(self) -> Type[DeepSparseCausalLMBatch]: + return DeepSparseCausalLMBatch \ No newline at end of file diff --git a/server/text_generation_server/models/deepsparse_model.py b/server/text_generation_server/models/deepsparse_model.py new file mode 100644 index 00000000..03f73bad --- /dev/null +++ b/server/text_generation_server/models/deepsparse_model.py @@ -0,0 +1,212 @@ +import numpy as np +from typing import Optional, List, Dict + +from deepsparse import Context +from deepsparse.engine import LIB +from deepsparse.pipeline import DEEPSPARSE_ENGINE, create_engine +from deepsparse.transformers.utils.helpers import overwrite_onnx_model_inputs, create_causal_mask + +PAST_KEY_VALUES_NAME = "past_key_values" + +class DeepSparsePastKeyValues: + def __init__(self): + prev_num_tokens = 0 + num_frozen_tokens = 1 + self.internal_past_key_values = LIB.kv_cache(prev_num_tokens, num_frozen_tokens) + +class DeepSparseDecoderEngine: + def __init__ ( + self, + onnx_file_path: str, + sequence_length: int = 1024, + input_ids_length: int = 1, + engine_context: Optional[Context] = None, + ): + + # update ONNX graph + onnx_file_path, cached_outputs, data_type = overwrite_onnx_model_inputs( + onnx_file_path=onnx_file_path, + batch_size=1, + sequence_length=sequence_length, + input_ids_length=input_ids_length, + ) + + # compile engine + self.engine = create_engine( + onnx_file_path=onnx_file_path, + engine_type=DEEPSPARSE_ENGINE, + engine_args={"cached_outputs": cached_outputs}, + context=engine_context, + ) + print(self.engine) + + # save utilties + self.past_key_value_dtype = data_type + self.onnx_inputs = self.engine.input_names + self.empty_past_key_values = self.make_empty_past_key_values() + + # forward function + def __call__( + self, + engine_inputs: Dict[str, np.ndarray], + past_key_values: DeepSparsePastKeyValues, + val_inputs: bool = True + ): + # format input into lists (we pass empty past key values) + inputs = [self.empty_past_key_values[name] if name.startswith(PAST_KEY_VALUES_NAME) + else engine_inputs[name] for name in self.engine.input_names] + + # validate inputs formatted correctly + if val_inputs: + self.engine._validate_inputs(inputs) + + # run inference, updates past_key_values internally + output = self.engine._eng_net.execute_list_out( + inputs, + past_key_values.internal_past_key_values + ) + logits = output[0] + return logits, past_key_values + + # empty past kvs (dummy values to be passed around) + def make_empty_past_key_values(self): + past_key_values = {} + for idx, name in enumerate(self.onnx_inputs): + if name.startswith(PAST_KEY_VALUES_NAME): + past_key_values[name] = np.zeros( + self.engine.input_shapes[idx], + dtype=self.past_key_value_dtype + ) + + return past_key_values + +class DeepSparseDecoderModel: + def __init__( + self, + onnx_file_path: str, + sequence_length: int = 1024, + multitoken_length: int = 16, + engine_context: Optional[Context] = None, + singletoken_engine = None, + multitoken_engine = None, + ): + self.sequence_length = sequence_length + self.multitoken_length = multitoken_length + + if singletoken_engine is not None and multitoken_engine is not None: + self.singletoken_engine = singletoken_engine + self.multitoken_engine = multitoken_engine + else: + self.singletoken_engine = DeepSparseDecoderEngine( + onnx_file_path=onnx_file_path, + engine_context=engine_context, + sequence_length=sequence_length, + input_ids_length=1, + ) + + self.multitoken_engine = DeepSparseDecoderEngine( + onnx_file_path=onnx_file_path, + engine_context=engine_context, + sequence_length=sequence_length, + input_ids_length=self.multitoken_length, + ) + + assert "input_ids" in self.singletoken_engine.onnx_inputs + assert "attention_mask" in self.singletoken_engine.onnx_inputs + assert "causal_mask" in self.singletoken_engine.onnx_inputs + assert "positions" in self.singletoken_engine.onnx_inputs + + def engine_inputs_for_prefill( + self, + tokens: List[int], + ): + assert len(tokens) < self.sequence_length + + # split batch into N token_batches + num_batches = len(tokens) // self.multitoken_length + token_batches = [ + tokens[i * self.multitoken_length : (i+1) * self.multitoken_length] + for i in range(0, num_batches) + ] + + # format inputs for each of the N token_batches + for idx, token_batch in enumerate(token_batches): + num_processed_tokens = self.multitoken_length * idx + + engine_inputs = {} + engine_inputs["input_ids"] = np.array([token_batch]) + + # make attention mask from the right + engine_inputs["attention_mask"] = np.zeros((1, self.sequence_length), dtype=np.int64) + engine_inputs["attention_mask"][:, -(self.multitoken_length + num_processed_tokens):] = 1 + + # make positions (building from the right) + # TODO: handle case when multitoken engine is 1 + assert self.multitoken_length > 1 + engine_inputs["positions"] = np.arange( + num_processed_tokens, num_processed_tokens + self.multitoken_length + ).reshape(1, -1).astype(np.int64) + + # make causal mask (building from the right) + engine_inputs["causal_mask"] = create_causal_mask( + input_ids=engine_inputs["input_ids"], + attention_mask=engine_inputs["attention_mask"] + ) + yield engine_inputs + + def engine_inputs_for_decode( + self, + tokens: List[int], + ): + assert len(tokens) < self.sequence_length + + engine_inputs = {} + engine_inputs["input_ids"] = np.array([[tokens[-1]]]) + engine_inputs["attention_mask"] = np.zeros((1, self.sequence_length), dtype=np.int64) + engine_inputs["attention_mask"][:, -len(tokens):] = 1 + + engine_inputs["causal_mask"] = create_causal_mask( + engine_inputs["input_ids"], + engine_inputs["attention_mask"] + ) + engine_inputs["positions"] = np.array([[len(tokens) - 1]], dtype=np.int64) + + return engine_inputs + + def decode( + self, + tokens: List[int], + past_key_values: DeepSparsePastKeyValues + ) -> (np.ndarray, DeepSparsePastKeyValues): + + engine_inputs = self.engine_inputs_for_decode(tokens) + logits, past_key_values = self.singletoken_engine( + engine_inputs, + past_key_values + ) + + return logits, past_key_values + + def prefill( + self, + tokens: List[int], + past_key_values: DeepSparsePastKeyValues + ) -> (np.ndarray, DeepSparsePastKeyValues): + + tokens_processed = 0 + + for engine_inputs in self.engine_inputs_for_prefill(tokens): + _, past_key_values = self.multitoken_engine( + engine_inputs, + past_key_values + ) + tokens_processed += self.multitoken_length + + while tokens_processed < len(tokens): + logits, past_key_values = self.decode( + tokens= tokens[:tokens_processed + 1], + past_key_values=past_key_values + ) + tokens_processed += 1 + + return logits, past_key_values \ No newline at end of file