From a1c9cc422a8d8d1b79af9b2ca0320ec4491efa15 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 6 Oct 2023 22:00:30 -0400 Subject: [PATCH] wip --- deepsparse/router.py | 2 +- deepsparse/service/causal_lm.py | 33 ++++--- deepsparse/service/model.py | 156 +++++++++++++++++++++++--------- deepsparse/service/service.py | 2 +- deepsparse/utils.py | 6 +- 5 files changed, 141 insertions(+), 58 deletions(-) diff --git a/deepsparse/router.py b/deepsparse/router.py index aaa973fa..4784ba3a 100644 --- a/deepsparse/router.py +++ b/deepsparse/router.py @@ -186,4 +186,4 @@ class DeepSparseQueue: self.next_batch_id += 1 # return batch, generate_requests - return (batch, generate_requests) \ No newline at end of file + return (batch, generate_requests) diff --git a/deepsparse/service/causal_lm.py b/deepsparse/service/causal_lm.py index 23d0021f..a1430c0b 100644 --- a/deepsparse/service/causal_lm.py +++ b/deepsparse/service/causal_lm.py @@ -57,6 +57,8 @@ class DeepSparseCausalLMBatch: ) ) + print(r.generation_parameters) + # get next token chooser based on input next_token_chooser_list.append( NextTokenChooser( @@ -142,6 +144,8 @@ class DeepSparseCausalLMBatch: # concatenate request, input_ids, and past_key_values lists requests.extend(batch.requests) input_ids_list.extend(batch.input_ids_list) + print(f"pkv {past_key_values_list}") + print(f"bpkv {batch.past_key_values_list}") past_key_values_list.extend(batch.past_key_values_list) stopping_criteria_list.extend(batch.stopping_criteria_list) next_token_chooser_list.extend(batch.next_token_chooser_list) @@ -202,24 +206,30 @@ class DeepSparseCausalLM: iterator = zip( batch.requests, batch.input_ids_list, - batch.past_key_values_list, batch.stopping_criteria_list, batch.next_token_chooser_list, ) + + #assert len(input_ids.shape) == 2 + #assert input_ids.shape[0] == 1 + + print(batch.past_key_values_list) + print(len(batch.past_key_values_list)) + print(batch.input_ids_list) + print(len(batch.input_ids_list)) + + # a) run inference + logits, batch.past_key_values_list = self.model(batch.input_ids_list, batch.past_key_values_list) + + print(logits) + print(logits.shape) + for i, ( request, input_ids, - past_key_values, stopping_criteria, next_token_chooser ) in enumerate(iterator): - # assert input_ids is b=1 - assert len(input_ids.shape) == 2 - assert input_ids.shape[0] == 1 - - # a) run inference - logits, past_key_values = self.model(input_ids, past_key_values) - # b) sample token and check stopping criteria # TODO: should use NextTokenChooser/StoppingCriteria (simple for now) generated_token_id = next_token_chooser(input_ids=input_ids, scores=logits[:,-1,:]) @@ -241,17 +251,16 @@ class DeepSparseCausalLM: # d) update batch # TODO: this does not occur in place assert len(batch.input_ids_list[i].shape) == 2 - assert batch.input_ids_list[i].shape[0] == 1 + #assert batch.input_ids_list[i].shape[0] == 1 batch.input_ids_list[i] = np.append( batch.input_ids_list[i], np.array([[generated_token_id]]), axis=1 ) - batch.past_key_values_list[i] = past_key_values # if all elements of the batch are done, return null for batch if all_stopped: return generations, None # return generation + updated batch - return generations, batch \ No newline at end of file + return generations, batch diff --git a/deepsparse/service/model.py b/deepsparse/service/model.py index e0d8c9d0..77fd6f18 100644 --- a/deepsparse/service/model.py +++ b/deepsparse/service/model.py @@ -7,10 +7,16 @@ from typing import Optional, List, Dict from deepsparse import Context from deepsparse.engine import LIB from deepsparse.pipeline import DEEPSPARSE_ENGINE, create_engine -from deepsparse.transformers.utils.helpers import overwrite_onnx_model_inputs_for_kv_cache_models, create_causal_mask +from deepsparse.utils.onnx import overwrite_onnx_model_inputs_for_kv_cache_models +from deepsparse.transformers.utils.helpers import create_causal_mask PAST_KEY_VALUES_NAME = "past_key_values" +def chunkify(lst, n): + """Yield successive n-sized chunks from lst.""" + for i in range(0, len(lst), n): + yield lst[i:i + n] + class DeepSparsePastKeyValues: def __init__(self): prev_num_tokens = 0 @@ -23,13 +29,14 @@ class DeepSparseDecoderEngine: onnx_file_path: str, sequence_length: int = 1024, input_ids_length: int = 1, + batch_size: int = 1, engine_context: Optional[Context] = None, ): - - # setup ONNX graph + + # setup ONNX graph(s) onnx_file_path, cached_outputs, data_type = overwrite_onnx_model_inputs_for_kv_cache_models( onnx_file_path=onnx_file_path, - batch_size=1, + batch_size=batch_size, sequence_length=sequence_length, input_ids_length=input_ids_length, ) @@ -38,7 +45,7 @@ class DeepSparseDecoderEngine: self.engine = create_engine( onnx_file_path=onnx_file_path, engine_type=DEEPSPARSE_ENGINE, - engine_args={"cached_outputs": cached_outputs}, + engine_args={"cached_outputs": cached_outputs, "batch_size": batch_size}, context=engine_context, ) print(self.engine) @@ -91,19 +98,33 @@ class DeepSparseDecoderModel: onnx_file_path: str, sequence_length: int = 1024, multitoken_length: int = 16, + batch_size: int = 1, # 16 engine_context: Optional[Context] = None, ): self.sequence_length = sequence_length self.multitoken_length = multitoken_length + self.batch_size = batch_size - # compile decode engine + # compile decode engines self.singletoken_engine = DeepSparseDecoderEngine( onnx_file_path=onnx_file_path, engine_context=engine_context, sequence_length=sequence_length, input_ids_length=1, + batch_size=1 ) + if batch_size > 1: + self.batched_singletoken_engine = DeepSparseDecoderEngine( + onnx_file_path=onnx_file_path, + engine_context=engine_context, + sequence_length=sequence_length, + input_ids_length=1, + batch_size=batch_size + ) + else: + self.batched_singletoken_engine = None + # compile prefill engine self.multitoken_engine = DeepSparseDecoderEngine( onnx_file_path=onnx_file_path, @@ -155,40 +176,88 @@ class DeepSparseDecoderModel: def engine_inputs_for_decode( self, - input_ids: np.ndarray, + input_ids: List[np.ndarray], ): + # TODO: assert input_ids all have same shape + assert type(input_ids) is list + assert type(input_ids[0]) is np.ndarray + assert len(input_ids) > 0 + assert len(input_ids[0].shape) == 2 + assert input_ids[0].shape[1] < self.sequence_length + + batch_size = len(input_ids) + engine_inputs = {} - engine_inputs["input_ids"] = input_ids[:,-1:] - engine_inputs["attention_mask"] = np.zeros((1, self.sequence_length), dtype=np.int64) - engine_inputs["attention_mask"][:, -input_ids.shape[1]:] = 1 + + print(batch_size) + print(input_ids) + print(len(input_ids)) + + last_input_ids = [x[:,-1:] for x in input_ids] + + print(f"last_input_ids {last_input_ids}") + + engine_inputs["input_ids"] = np.concatenate(last_input_ids, axis=0) + + engine_inputs["attention_mask"] = np.zeros((batch_size, self.sequence_length), dtype=np.int64) + engine_inputs["attention_mask"][:, -input_ids[0].shape[1]:] = 1 engine_inputs["causal_mask"] = create_causal_mask( engine_inputs["input_ids"], engine_inputs["attention_mask"] ) - engine_inputs["positions"] = np.array([[input_ids.shape[1] - 1]], dtype=np.int64) - + #engine_inputs["positions"] = np.ndarray([batch_size, input_ids[0].shape[1] - 1], dtype=np.int64) + engine_inputs["positions"] = np.array([[input_ids[0].shape[1] - 1]], dtype=np.int64) + + print(f"inputs {engine_inputs['input_ids']} {engine_inputs['input_ids'].shape}") + print(f"attn mask {engine_inputs['attention_mask']} {engine_inputs['attention_mask'].shape}") + print(f"causal mask {engine_inputs['causal_mask']} {engine_inputs['causal_mask'].shape}") + print(f"pos {engine_inputs['positions']} {engine_inputs['positions'].shape}") + + return engine_inputs def decode( self, - input_ids: np.ndarray, - past_key_values: DeepSparsePastKeyValues - ) -> (np.ndarray, DeepSparsePastKeyValues): - - # assert input is of shape [1,seq_len] w/ seq_len < self.sequence_len - assert len(input_ids.shape) == 2 - assert input_ids.shape[0] == 1 - assert input_ids.shape[1] < self.sequence_length - - engine_inputs = self.engine_inputs_for_decode(input_ids) - logits, past_key_values = self.singletoken_engine( - engine_inputs, - past_key_values + batched_input_ids: List[np.ndarray], + batched_past_key_values: List[DeepSparsePastKeyValues] + ) -> (np.ndarray, List[DeepSparsePastKeyValues]): + + assert len(batched_input_ids) == len(batched_past_key_values) + + batched_logits = [] + batched_new_key_values = [] + + chunks = zip( + chunkify(batched_input_ids, self.batch_size), + chunkify(batched_past_key_values, self.batch_size) ) - - return logits, past_key_values - + + for input_ids, past_key_values in chunks: + # assert input is of shape [1,seq_len] w/ seq_len < self.sequence_len + print(input_ids) + assert len(input_ids[0].shape) == 2 + assert input_ids[0].shape[1] < self.sequence_length + + if len(input_ids) == self.batch_size and self.batch_size != 1: + engine_inputs = self.engine_inputs_for_decode(input_ids) + logits, new_key_values = self.batched_singletoken_engine( + engine_inputs, + past_key_values + ) + batched_logits.append(logits) + batched_new_key_values.append(new_key_values) + else: + for i in range(len(input_ids)): + engine_inputs = self.engine_inputs_for_decode([input_ids[i]]) + logits, new_key_values = self.singletoken_engine( + engine_inputs, + past_key_values[i]) + batched_logits.append(logits) + batched_new_key_values.append(new_key_values) + + return np.concatenate(batched_logits, axis=0), batched_new_key_values + def prefill( self, input_ids: np.ndarray, @@ -202,40 +271,45 @@ class DeepSparseDecoderModel: tokens_processed = 0 # setup empty past key values - past_key_values = DeepSparsePastKeyValues() + past_key_values = [DeepSparsePastKeyValues()] # loop through chunks, run inference w/ multitoken engine for engine_inputs in self.engine_inputs_for_prefill(input_ids): - logits, past_key_values = self.multitoken_engine( + logits, past_key_values[0] = self.multitoken_engine( engine_inputs, - past_key_values + past_key_values[0] ) tokens_processed += self.multitoken_length # if anything left over, run inference w/ singletoken engine while tokens_processed < input_ids.shape[1]: + print(f"got here {input_ids[:,:tokens_processed+1]}") + assert len(input_ids.shape) == 2 logits, past_key_values = self.decode( - input_ids=input_ids[:,:tokens_processed+1], - past_key_values=past_key_values + [input_ids[:,:tokens_processed+1]], + past_key_values ) tokens_processed += 1 # print(logits[:,-1:,:]) - + return logits, past_key_values def forward( self, - input_ids: np.ndarray, - past_key_values: Optional[DeepSparsePastKeyValues] = None, + input_ids: List[np.ndarray], + past_key_values: List[Optional[DeepSparsePastKeyValues]], ): - if past_key_values is None: - return self.prefill(input_ids) + assert len(past_key_values) > 0 + print(f"forward pkv {past_key_values} {past_key_values[0] is None}") + if past_key_values[0] is None: + assert len(input_ids) == 1 + return self.prefill(input_ids[0]) else: return self.decode(input_ids, past_key_values) def __call__( self, - input_ids: np.ndarray, - past_key_values: Optional[DeepSparsePastKeyValues] = None, + input_ids: List[np.ndarray], + past_key_values: List[Optional[DeepSparsePastKeyValues]] = [], ): - return self.forward(input_ids, past_key_values) \ No newline at end of file + return self.forward(input_ids, past_key_values) diff --git a/deepsparse/service/service.py b/deepsparse/service/service.py index c49c8f6b..2543cff8 100644 --- a/deepsparse/service/service.py +++ b/deepsparse/service/service.py @@ -73,4 +73,4 @@ class DeepSparseService: generations, next_ds_batch = self.model.generate_token(ds_batch) self.cache.set(next_ds_batch) - return generations, (next_ds_batch.to_cached_batch() if next_ds_batch else None) \ No newline at end of file + return generations, (next_ds_batch.to_cached_batch() if next_ds_batch else None) diff --git a/deepsparse/utils.py b/deepsparse/utils.py index fda83570..d41df358 100644 --- a/deepsparse/utils.py +++ b/deepsparse/utils.py @@ -11,10 +11,10 @@ class Greedy: def __call__(self, logits: np.ndarray): # assert b=1 for now # shape == (batch, vocabulary_size) - assert(logits.shape[0] == 1) + #assert(logits.shape[0] == 1) assert(len(logits.shape) == 2) - return np.argmax(logits[0,:]) + return np.argmax(logits[0,:]) # XXXXXXXXXXXXXXX fix # TODO: sample for b > 1 with vectorized code # https://stackoverflow.com/questions/47722005/vectorizing-numpy-random-choice-for-given-2d-array-of-probabilities-along-an-a @@ -149,4 +149,4 @@ class GenerateRequest: inputs=gr_inputs.inputs, generation_parameters=gr_inputs.generation_parameters, response_stream=Queue() - ) \ No newline at end of file + )