From 6a739e51426785aa9c35b8b429986ae067ad305c Mon Sep 17 00:00:00 2001 From: rsnm2 Date: Tue, 22 Aug 2023 02:59:18 +0000 Subject: [PATCH] made service for deepsparse --- deepsparse-server/deepsparse_causal_lm.py | 240 --------------------- deepsparse-server/deepsparse_model.py | 242 ---------------------- server/deepsparse/deepsparse_requests.py | 34 +++ 3 files changed, 34 insertions(+), 482 deletions(-) delete mode 100644 deepsparse-server/deepsparse_causal_lm.py delete mode 100644 deepsparse-server/deepsparse_model.py create mode 100644 server/deepsparse/deepsparse_requests.py diff --git a/deepsparse-server/deepsparse_causal_lm.py b/deepsparse-server/deepsparse_causal_lm.py deleted file mode 100644 index 1e88b2fe..00000000 --- a/deepsparse-server/deepsparse_causal_lm.py +++ /dev/null @@ -1,240 +0,0 @@ -import numpy as np -from transformers import AutoTokenizer, PreTrainedTokenizerBase - -from dataclasses import dataclass -from typing import List, Dict, Optional, Type - -from text_generation_server.models.deepsparse_model import ( - DeepSparsePastKeyValues, - DeepSparseDecoderModel -) -from text_generation_server.pb import generate_pb2 - -DEEPSPARSE_SEQUENCE_LENGTH = 128 -DEEPSPARSE_MULTITOKEN_LENGTH = 4 - -@dataclass -class DeepSparseCausalLMBatch: - batch_id: int - requests: List[generate_pb2.Request] - requests_idx_mapping: Dict[int,int] - input_ids_list: List[np.ndarray] - past_key_values_list: List[Optional[DeepSparsePastKeyValues]] - - @classmethod - def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - ) -> "DeepSparseCausalLMBatch": - - # parse batch - requests_idx_mapping = {} - input_ids_list = [] - - # setup tokenizer for deepsparse left padding - tokenizer.padding_side = "left" - if not tokenizer.pad_token: - tokenizer.pad_token = tokenizer.eos_token - padding, truncation = "longest", False - - # loop through items in the batch - for idx, r in enumerate(pb.requests): - requests_idx_mapping[r.id] = idx - - # setup inputs_ids, past_key_values - tokenized_inputs = tokenizer( - r.inputs, - return_tensors="np", - padding=padding, - truncation=truncation, - return_token_type_ids=False, - max_length=DEEPSPARSE_SEQUENCE_LENGTH - ) - input_ids_list.append(tokenized_inputs["input_ids"]) - - return cls( - batch_id=pb.id, - requests=pb.requests, - requests_idx_mapping=requests_idx_mapping, - input_ids_list=input_ids_list, - past_key_values_list=[None] * len(pb.requests), - ) - - # length of the batch - def __len__(self): - return len(self.requests) - - # pass list of request ids, returns batch with only those request ids - def filter(self, request_ids: List[int]) -> Optional["DeepSparseCausalLMBatch"]: - assert(len(request_ids) > 0) - - requests_idx_mapping = {} - requests = [] - input_ids_list = [] - past_key_values_list = [] - - # loop through requests, keep ones that should remain - for new_idx, request_id in enumerate(request_ids): - assert request_id in self.requests_idx_mapping.keys(), "all request ids must be in the batch" - - requests_idx_mapping[request_id] = new_idx - - old_idx = self.requests_idx_mapping[request_id] - requests.append(self.requests[old_idx]) - input_ids_list.append(self.input_ids_list[old_idx]) - past_key_values_list.append(self.past_key_values[old_idx]) - - # update batch state - self.requests = requests - self.requests_idx_mapping = requests_idx_mapping - self.input_ids_list = input_ids_list - self.past_key_values_list = past_key_values_list - - return self - - # combine two batches into one - @classmethod - def concatenate(cls, batches: List["DeepSparseCausalLMBatch"]) -> "DeepSparseCausalLMBatch": - assert len(batches) > 1, "must have more than 1 batch to concatenate" - - requests_idx_mapping = {} - requests = [] - input_ids_list = [] - past_key_values_list = [] - - start_index = 0 - for i, batch in enumerate(batches): - assert batch.past_key_values_list is None, "only concatenate prefilled batches" - - # concatenate request, input_ids, and past_key_values lists - requests.extend(batch.requests) - input_ids_list.extend(batch.input_ids_list) - past_key_values_list.extend(batch.past_key_values_list) - - # merge the request_id to index mapping - if i == 0: - requests_idx_mapping = batch.requests_idx_mapping - else: - for k, v in batch.requests_idx_mapping.items(): - requests_idx_mapping[k] = v + start_index - - start_index += len(batch) - - return cls( - batch_id= batches[0].id, - requests=requests, - requests_idx_mapping=requests_idx_mapping, - input_ids_list=input_ids_list, - past_key_values_list=past_key_values_list - ) - -class DeepSparseCausalLM: - def __init__( - self, - model_path: str, - tokenizer_path: str, - ): - # setup tokenizer - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) - self.tokenizer.padding_side = "left" - if not self.tokenizer.pad_token: - assert self.tokenizer.eos_token - self.tokenizer.pad_token = self.tokenizer.eos_token - - # setup model - self.model = DeepSparseDecoderModel( - onnx_file_path = model_path, - sequence_length = DEEPSPARSE_SEQUENCE_LENGTH, - multitoken_length = DEEPSPARSE_MULTITOKEN_LENGTH, - ) - - # TODO (@rsnm2): switch to NextTokenChooser - def sample_token( - self, - logits: np.ndarray - ): - assert(logits.shape[0] == 1) # assert b=1 for now - return np.argmax(logits[0,-1,:]) # grab logits for the last item in the sequence - - # TODO (@rsnm2): switch to StoppingCriteria - def should_stop( - self, - num_tokens_processed: int, - generated_token_id: int - ): - if num_tokens_processed >= self.model.sequence_length: - return True - if generated_token_id == self.tokenizer.eos_token_id: - return True - return False - - def generate_token( - self, - batch: DeepSparseCausalLMBatch, - ) -> (Dict[int,str], Optional[DeepSparseCausalLMBatch]): - - generations: Dict[int, str] = {} - all_stopped = True - - # if we supported continuous batching, we would do batched inference here - # logits, past_key_values = self.model(batch) - - # for each member of the batch: - # a) run inference - # b) sample and check stopping criteria - # c) create generation + update batch - for i, ( - request, - input_ids, - past_key_values, - ) in enumerate(zip( - batch.requests, - batch.input_ids_list, - batch.past_key_values_list - )): - - # run inference - logits, past_key_values = self.model(input_ids, past_key_values) - - # sample token - # simple for now --- should use NextTokenChooser - generated_token_id = self.sample_token(logits) - - # check stopping criteria - # simple for now --- should use StoppingCriteria - stop = self.should_stop( - num_tokens_processed=len(input_ids) + 1, - generated_token_id = generated_token_id - ) - - # if not stopped, convert token id to text - generated_text = None - if not stop: - all_stopped = False - generated_text = self.tokenizer.decode( - generated_token_id, - skip_special_tokens=True, - clean_up_tokenization_spaces=False - ) - generations[request.id] = generated_text - - # update values in the batch - assert len(batch.input_ids_list[i].shape) == 2 - assert batch.input_ids_list[i].shape[0] == 1 - - # bad --- this does not occur in place - # print(batch.input_ids_list[i]) - batch.input_ids_list[i] = np.append( - batch.input_ids_list[i], - np.array([[generated_token_id]]), - axis=1 - ) - batch.past_key_values_list[i] = past_key_values - - # if all elements of the batch are done, return generation + null for batch - if all_stopped: - return generations, None - - # return generation + updated batch - return generations, batch \ No newline at end of file diff --git a/deepsparse-server/deepsparse_model.py b/deepsparse-server/deepsparse_model.py deleted file mode 100644 index cdd0fb52..00000000 --- a/deepsparse-server/deepsparse_model.py +++ /dev/null @@ -1,242 +0,0 @@ -import os - -os.environ["WAND_OPT_FLAGS"] = "default,~pyramids" - -import numpy as np -from typing import Optional, List, Dict - -from deepsparse import Context -from deepsparse.engine import LIB -from deepsparse.pipeline import DEEPSPARSE_ENGINE, create_engine -from deepsparse.transformers.utils.helpers import overwrite_onnx_model_inputs, create_causal_mask - -PAST_KEY_VALUES_NAME = "past_key_values" - -class DeepSparsePastKeyValues: - def __init__(self): - prev_num_tokens = 0 - num_frozen_tokens = 1 - self.internal_past_key_values = LIB.kv_cache(prev_num_tokens, num_frozen_tokens) - -class DeepSparseDecoderEngine: - def __init__ ( - self, - onnx_file_path: str, - sequence_length: int = 1024, - input_ids_length: int = 1, - engine_context: Optional[Context] = None, - ): - - # update ONNX graph - onnx_file_path, cached_outputs, data_type = overwrite_onnx_model_inputs( - onnx_file_path=onnx_file_path, - batch_size=1, - sequence_length=sequence_length, - input_ids_length=input_ids_length, - ) - - # compile engine - self.engine = create_engine( - onnx_file_path=onnx_file_path, - engine_type=DEEPSPARSE_ENGINE, - engine_args={"cached_outputs": cached_outputs}, - context=engine_context, - ) - print(self.engine) - - # save utilties - self.past_key_value_dtype = data_type - self.onnx_inputs = self.engine.input_names - self.empty_past_key_values = self.make_empty_past_key_values() - - # forward function - def __call__( - self, - engine_inputs: Dict[str, np.ndarray], - past_key_values: DeepSparsePastKeyValues, - val_inputs: bool = True - ): - # format input into lists (we pass empty past key values) - inputs = [ - self.empty_past_key_values[name] if name.startswith(PAST_KEY_VALUES_NAME) - else engine_inputs[name] for name in self.engine.input_names - ] - - # validate inputs formatted correctly - if val_inputs: - self.engine._validate_inputs(inputs) - - # run inference, updates past_key_values internally - output = self.engine._eng_net.execute_list_out( - inputs, - past_key_values.internal_past_key_values - ) - logits = output[0] - return logits, past_key_values - - # empty past kvs (dummy values to be passed around) - def make_empty_past_key_values(self): - past_key_values = {} - for idx, name in enumerate(self.onnx_inputs): - if name.startswith(PAST_KEY_VALUES_NAME): - past_key_values[name] = np.zeros( - self.engine.input_shapes[idx], - dtype=self.past_key_value_dtype - ) - - return past_key_values - -class DeepSparseDecoderModel: - def __init__( - self, - onnx_file_path: str, - sequence_length: int = 1024, - multitoken_length: int = 16, - engine_context: Optional[Context] = None, - ): - self.sequence_length = sequence_length - self.multitoken_length = multitoken_length - - # compile decode engine - self.singletoken_engine = DeepSparseDecoderEngine( - onnx_file_path=onnx_file_path, - engine_context=engine_context, - sequence_length=sequence_length, - input_ids_length=1, - ) - - # compile prefill engine - self.multitoken_engine = DeepSparseDecoderEngine( - onnx_file_path=onnx_file_path, - engine_context=engine_context, - sequence_length=sequence_length, - input_ids_length=self.multitoken_length, - ) - - assert "input_ids" in self.singletoken_engine.onnx_inputs - assert "attention_mask" in self.singletoken_engine.onnx_inputs - assert "causal_mask" in self.singletoken_engine.onnx_inputs - assert "positions" in self.singletoken_engine.onnx_inputs - - def engine_inputs_for_prefill( - self, - input_ids: np.ndarray, - ): - # split batch into N token_batches - num_batches = input_ids.shape[1] // self.multitoken_length - token_batches = [ - input_ids[:, i*self.multitoken_length : (i+1)*self.multitoken_length] - for i in range(0, num_batches) - ] - - # format inputs for each of the N token_batches - for idx, token_batch in enumerate(token_batches): - num_processed_tokens = self.multitoken_length * idx - - engine_inputs = {} - engine_inputs["input_ids"] = token_batch - - # make attention mask from the right - engine_inputs["attention_mask"] = np.zeros((1, self.sequence_length), dtype=np.int64) - engine_inputs["attention_mask"][:, -(self.multitoken_length + num_processed_tokens):] = 1 - - # make positions (building from the right) - # TODO: handle case when multitoken engine is 1 - assert self.multitoken_length > 1 - engine_inputs["positions"] = np.arange( - num_processed_tokens, num_processed_tokens + self.multitoken_length - ).reshape(1, -1).astype(np.int64) - - # make causal mask (building from the right) - engine_inputs["causal_mask"] = create_causal_mask( - input_ids=engine_inputs["input_ids"], - attention_mask=engine_inputs["attention_mask"] - ) - yield engine_inputs - - def engine_inputs_for_decode( - self, - input_ids: np.ndarray, - ): - engine_inputs = {} - engine_inputs["input_ids"] = input_ids[:,-1:] - engine_inputs["attention_mask"] = np.zeros((1, self.sequence_length), dtype=np.int64) - engine_inputs["attention_mask"][:, -input_ids.shape[1]:] = 1 - - engine_inputs["causal_mask"] = create_causal_mask( - engine_inputs["input_ids"], - engine_inputs["attention_mask"] - ) - engine_inputs["positions"] = np.array([[input_ids.shape[1] - 1]], dtype=np.int64) - - return engine_inputs - - def decode( - self, - input_ids: np.ndarray, - past_key_values: DeepSparsePastKeyValues - ) -> (np.ndarray, DeepSparsePastKeyValues): - - # assert input is of shape [1,seq_len] w/ seq_len < self.sequence_len - assert len(input_ids.shape) == 2 - assert input_ids.shape[0] == 1 - assert input_ids.shape[1] < self.sequence_length - - engine_inputs = self.engine_inputs_for_decode(input_ids) - logits, past_key_values = self.singletoken_engine( - engine_inputs, - past_key_values - ) - - return logits, past_key_values - - def prefill( - self, - input_ids: np.ndarray, - ) -> (np.ndarray, DeepSparsePastKeyValues): - - # assert input is of shape [1,seq_len] w/ seq_len < self.sequence_len - assert len(input_ids.shape) == 2 - assert input_ids.shape[0] == 1 - assert input_ids.shape[1] < self.sequence_length - - tokens_processed = 0 - - # setup empty past key values - past_key_values = DeepSparsePastKeyValues() - - # loop through chunks, run inference w/ multitoken engine - for engine_inputs in self.engine_inputs_for_prefill(input_ids): - logits, past_key_values = self.multitoken_engine( - engine_inputs, - past_key_values - ) - tokens_processed += self.multitoken_length - - # if anything left over, run inference w/ singletoken engine - while tokens_processed < input_ids.shape[1]: - logits, past_key_values = self.decode( - input_ids=input_ids[:,:tokens_processed+1], - past_key_values=past_key_values - ) - tokens_processed += 1 - # print(logits[:,-1:,:]) - - return logits, past_key_values - - def forward( - self, - input_ids: np.ndarray, - past_key_values: Optional[DeepSparsePastKeyValues] = None, - ): - if past_key_values is None: - return self.prefill(input_ids) - else: - return self.decode(input_ids, past_key_values) - - def __call__( - self, - input_ids: np.ndarray, - past_key_values: Optional[DeepSparsePastKeyValues] = None, - ): - return self.forward(input_ids, past_key_values) \ No newline at end of file diff --git a/server/deepsparse/deepsparse_requests.py b/server/deepsparse/deepsparse_requests.py new file mode 100644 index 00000000..106b87b6 --- /dev/null +++ b/server/deepsparse/deepsparse_requests.py @@ -0,0 +1,34 @@ +from dataclasses import dataclass +from typing import List, Optional + +@dataclass +class Request: + id: int + prompt: str + +@dataclass +class Batch: + id: int + requests: List[Request] + +@dataclass +class CachedBatch: + batch_id: int + request_ids: List[int] + +@dataclass +class Generation: + request_id: int + generated_text: Optional[str] + +@dataclass +class PrefillRequest: + batch: Batch + +@dataclass +class DecodeRequest: + batches: List[CachedBatch] + +@dataclass +class FilterBatchRequest: + batch_id: int \ No newline at end of file