diff --git a/deepsparse/service/causal_lm.py b/deepsparse/service/causal_lm.py index a1430c0b..b1ef2ff0 100644 --- a/deepsparse/service/causal_lm.py +++ b/deepsparse/service/causal_lm.py @@ -57,8 +57,6 @@ class DeepSparseCausalLMBatch: ) ) - print(r.generation_parameters) - # get next token chooser based on input next_token_chooser_list.append( NextTokenChooser( @@ -123,6 +121,8 @@ class DeepSparseCausalLMBatch: self.stopping_criteria_list = stopping_criteria_list self.next_token_chooser_list = next_token_chooser_list + assert len(self.input_ids_list) == len(self.past_key_values_list) + return self # combine two batches into one @@ -144,8 +144,8 @@ class DeepSparseCausalLMBatch: # concatenate request, input_ids, and past_key_values lists requests.extend(batch.requests) input_ids_list.extend(batch.input_ids_list) - print(f"pkv {past_key_values_list}") - print(f"bpkv {batch.past_key_values_list}") + #print(f"pkv {past_key_values_list}") + #print(f"bpkv {batch.past_key_values_list}") past_key_values_list.extend(batch.past_key_values_list) stopping_criteria_list.extend(batch.stopping_criteria_list) next_token_chooser_list.extend(batch.next_token_chooser_list) @@ -159,6 +159,8 @@ class DeepSparseCausalLMBatch: start_index += len(batch) + assert len(input_ids_list) == len(past_key_values_list) + return cls( batch_id=batches[0].batch_id, requests=requests, @@ -187,6 +189,7 @@ class DeepSparseCausalLM: onnx_file_path = model_path, sequence_length = DEEPSPARSE_SEQUENCE_LENGTH, multitoken_length = DEEPSPARSE_MULTITOKEN_LENGTH, + batch_size=4 ) def generate_token( @@ -213,16 +216,22 @@ class DeepSparseCausalLM: #assert len(input_ids.shape) == 2 #assert input_ids.shape[0] == 1 - print(batch.past_key_values_list) - print(len(batch.past_key_values_list)) - print(batch.input_ids_list) - print(len(batch.input_ids_list)) + #print(batch.past_key_values_list) + #print(len(batch.past_key_values_list)) + #print(batch.input_ids_list) + #print(len(batch.input_ids_list)) + + #print(f"before {len(batch.input_ids_list)} {len(batch.past_key_values_list)}") # a) run inference logits, batch.past_key_values_list = self.model(batch.input_ids_list, batch.past_key_values_list) - print(logits) - print(logits.shape) + #print(f"after {len(batch.input_ids_list)} {len(batch.past_key_values_list)} {batch.past_key_values_list}") + + assert len(batch.input_ids_list) == len(batch.past_key_values_list) + + #print(logits) + #print(logits.shape) for i, ( request, diff --git a/deepsparse/service/model.py b/deepsparse/service/model.py index 77fd6f18..78a64a09 100644 --- a/deepsparse/service/model.py +++ b/deepsparse/service/model.py @@ -41,11 +41,20 @@ class DeepSparseDecoderEngine: input_ids_length=input_ids_length, ) + self.engine_type = DEEPSPARSE_ENGINE + #self.engine_type = "onnxruntime" + + if self.engine_type == DEEPSPARSE_ENGINE: + engine_args = {"cached_outputs": cached_outputs, "batch_size": batch_size} + else: + engine_args = {"batch_size": batch_size} + # compile engine + print(f"compiling for batch size: {batch_size}") self.engine = create_engine( onnx_file_path=onnx_file_path, - engine_type=DEEPSPARSE_ENGINE, - engine_args={"cached_outputs": cached_outputs, "batch_size": batch_size}, + engine_type=self.engine_type, + engine_args=engine_args, context=engine_context, ) print(self.engine) @@ -59,7 +68,7 @@ class DeepSparseDecoderEngine: def __call__( self, engine_inputs: Dict[str, np.ndarray], - past_key_values: DeepSparsePastKeyValues, + past_key_values: DeepSparsePastKeyValues, # XXXX this can be a list val_inputs: bool = True ): # format input into lists (we pass empty past key values) @@ -72,11 +81,21 @@ class DeepSparseDecoderEngine: if val_inputs: self.engine._validate_inputs(inputs) + #print(f"here {past_key_values}") + + if type(past_key_values) is list: + caches = [pkv.internal_past_key_values for pkv in past_key_values] + else: + caches = past_key_values.internal_past_key_values + # run inference, updates past_key_values internally - output = self.engine._eng_net.execute_list_out( - inputs, - past_key_values.internal_past_key_values - ) + if self.engine_type == DEEPSPARSE_ENGINE: + output = self.engine._eng_net.execute_list_out( + inputs, + caches + ) + else: + output = self.engine.run(inputs) logits = output[0] return logits, past_key_values @@ -189,13 +208,13 @@ class DeepSparseDecoderModel: engine_inputs = {} - print(batch_size) - print(input_ids) - print(len(input_ids)) + #print(batch_size) + #print(input_ids) + #print(len(input_ids)) last_input_ids = [x[:,-1:] for x in input_ids] - print(f"last_input_ids {last_input_ids}") + #print(f"last_input_ids {last_input_ids}") engine_inputs["input_ids"] = np.concatenate(last_input_ids, axis=0) @@ -206,14 +225,14 @@ class DeepSparseDecoderModel: engine_inputs["input_ids"], engine_inputs["attention_mask"] ) - #engine_inputs["positions"] = np.ndarray([batch_size, input_ids[0].shape[1] - 1], dtype=np.int64) - engine_inputs["positions"] = np.array([[input_ids[0].shape[1] - 1]], dtype=np.int64) - - print(f"inputs {engine_inputs['input_ids']} {engine_inputs['input_ids'].shape}") - print(f"attn mask {engine_inputs['attention_mask']} {engine_inputs['attention_mask'].shape}") - print(f"causal mask {engine_inputs['causal_mask']} {engine_inputs['causal_mask'].shape}") - print(f"pos {engine_inputs['positions']} {engine_inputs['positions'].shape}") + poses = [pos.shape[1] - 1 for pos in input_ids] + #print(f"poses {poses}") + engine_inputs["positions"] = np.array(poses, dtype=np.int64)[:,None] + #print(f"inputs {engine_inputs['input_ids']} {engine_inputs['input_ids'].shape}") + #print(f"attn mask {engine_inputs['attention_mask']} {engine_inputs['attention_mask'].shape}") + #print(f"causal mask {engine_inputs['causal_mask']} {engine_inputs['causal_mask'].shape}") + #print(f"pos {engine_inputs['positions']} {engine_inputs['positions'].shape}") return engine_inputs @@ -223,6 +242,7 @@ class DeepSparseDecoderModel: batched_past_key_values: List[DeepSparsePastKeyValues] ) -> (np.ndarray, List[DeepSparsePastKeyValues]): + #print(f"{len(batched_input_ids)} {len(batched_past_key_values)}") assert len(batched_input_ids) == len(batched_past_key_values) batched_logits = [] @@ -235,17 +255,19 @@ class DeepSparseDecoderModel: for input_ids, past_key_values in chunks: # assert input is of shape [1,seq_len] w/ seq_len < self.sequence_len - print(input_ids) + #print(input_ids) assert len(input_ids[0].shape) == 2 assert input_ids[0].shape[1] < self.sequence_length if len(input_ids) == self.batch_size and self.batch_size != 1: engine_inputs = self.engine_inputs_for_decode(input_ids) + print(f"GOT HERE!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! {past_key_values}") logits, new_key_values = self.batched_singletoken_engine( engine_inputs, past_key_values ) batched_logits.append(logits) + # XXXXX this is bogus batched_new_key_values.append(new_key_values) else: for i in range(len(input_ids)): @@ -256,7 +278,11 @@ class DeepSparseDecoderModel: batched_logits.append(logits) batched_new_key_values.append(new_key_values) - return np.concatenate(batched_logits, axis=0), batched_new_key_values + #print(f"decode {len(batched_input_ids)} {len(batched_new_key_values)}") + + # XXXXX this is bogus + return np.concatenate(batched_logits, axis=0), batched_past_key_values + #return np.concatenate(batched_logits, axis=0), np.concatenate(batched_new_key_values) def prefill( self, @@ -283,7 +309,7 @@ class DeepSparseDecoderModel: # if anything left over, run inference w/ singletoken engine while tokens_processed < input_ids.shape[1]: - print(f"got here {input_ids[:,:tokens_processed+1]}") + #print(f"got here {input_ids[:,:tokens_processed+1]}") assert len(input_ids.shape) == 2 logits, past_key_values = self.decode( [input_ids[:,:tokens_processed+1]], @@ -300,11 +326,13 @@ class DeepSparseDecoderModel: past_key_values: List[Optional[DeepSparsePastKeyValues]], ): assert len(past_key_values) > 0 - print(f"forward pkv {past_key_values} {past_key_values[0] is None}") + #print(f"forward pkv {past_key_values} {past_key_values[0] is None}") if past_key_values[0] is None: assert len(input_ids) == 1 + print("PREFILL!!!!!!!!!!!!!!!!!!!!!") return self.prefill(input_ids[0]) else: + print("DECODE!!!!!!!!!!!!!!!!!!!!!") return self.decode(input_ids, past_key_values) def __call__(