hacks to support native continuous batching

This commit is contained in:
Bill Nell 2023-10-12 11:44:50 -04:00
parent a1c9cc422a
commit 5e02f40a83
2 changed files with 69 additions and 32 deletions

View File

@ -57,8 +57,6 @@ class DeepSparseCausalLMBatch:
) )
) )
print(r.generation_parameters)
# get next token chooser based on input # get next token chooser based on input
next_token_chooser_list.append( next_token_chooser_list.append(
NextTokenChooser( NextTokenChooser(
@ -123,6 +121,8 @@ class DeepSparseCausalLMBatch:
self.stopping_criteria_list = stopping_criteria_list self.stopping_criteria_list = stopping_criteria_list
self.next_token_chooser_list = next_token_chooser_list self.next_token_chooser_list = next_token_chooser_list
assert len(self.input_ids_list) == len(self.past_key_values_list)
return self return self
# combine two batches into one # combine two batches into one
@ -144,8 +144,8 @@ class DeepSparseCausalLMBatch:
# concatenate request, input_ids, and past_key_values lists # concatenate request, input_ids, and past_key_values lists
requests.extend(batch.requests) requests.extend(batch.requests)
input_ids_list.extend(batch.input_ids_list) input_ids_list.extend(batch.input_ids_list)
print(f"pkv {past_key_values_list}") #print(f"pkv {past_key_values_list}")
print(f"bpkv {batch.past_key_values_list}") #print(f"bpkv {batch.past_key_values_list}")
past_key_values_list.extend(batch.past_key_values_list) past_key_values_list.extend(batch.past_key_values_list)
stopping_criteria_list.extend(batch.stopping_criteria_list) stopping_criteria_list.extend(batch.stopping_criteria_list)
next_token_chooser_list.extend(batch.next_token_chooser_list) next_token_chooser_list.extend(batch.next_token_chooser_list)
@ -159,6 +159,8 @@ class DeepSparseCausalLMBatch:
start_index += len(batch) start_index += len(batch)
assert len(input_ids_list) == len(past_key_values_list)
return cls( return cls(
batch_id=batches[0].batch_id, batch_id=batches[0].batch_id,
requests=requests, requests=requests,
@ -187,6 +189,7 @@ class DeepSparseCausalLM:
onnx_file_path = model_path, onnx_file_path = model_path,
sequence_length = DEEPSPARSE_SEQUENCE_LENGTH, sequence_length = DEEPSPARSE_SEQUENCE_LENGTH,
multitoken_length = DEEPSPARSE_MULTITOKEN_LENGTH, multitoken_length = DEEPSPARSE_MULTITOKEN_LENGTH,
batch_size=4
) )
def generate_token( def generate_token(
@ -213,16 +216,22 @@ class DeepSparseCausalLM:
#assert len(input_ids.shape) == 2 #assert len(input_ids.shape) == 2
#assert input_ids.shape[0] == 1 #assert input_ids.shape[0] == 1
print(batch.past_key_values_list) #print(batch.past_key_values_list)
print(len(batch.past_key_values_list)) #print(len(batch.past_key_values_list))
print(batch.input_ids_list) #print(batch.input_ids_list)
print(len(batch.input_ids_list)) #print(len(batch.input_ids_list))
#print(f"before {len(batch.input_ids_list)} {len(batch.past_key_values_list)}")
# a) run inference # a) run inference
logits, batch.past_key_values_list = self.model(batch.input_ids_list, batch.past_key_values_list) logits, batch.past_key_values_list = self.model(batch.input_ids_list, batch.past_key_values_list)
print(logits) #print(f"after {len(batch.input_ids_list)} {len(batch.past_key_values_list)} {batch.past_key_values_list}")
print(logits.shape)
assert len(batch.input_ids_list) == len(batch.past_key_values_list)
#print(logits)
#print(logits.shape)
for i, ( for i, (
request, request,

View File

@ -41,11 +41,20 @@ class DeepSparseDecoderEngine:
input_ids_length=input_ids_length, input_ids_length=input_ids_length,
) )
self.engine_type = DEEPSPARSE_ENGINE
#self.engine_type = "onnxruntime"
if self.engine_type == DEEPSPARSE_ENGINE:
engine_args = {"cached_outputs": cached_outputs, "batch_size": batch_size}
else:
engine_args = {"batch_size": batch_size}
# compile engine # compile engine
print(f"compiling for batch size: {batch_size}")
self.engine = create_engine( self.engine = create_engine(
onnx_file_path=onnx_file_path, onnx_file_path=onnx_file_path,
engine_type=DEEPSPARSE_ENGINE, engine_type=self.engine_type,
engine_args={"cached_outputs": cached_outputs, "batch_size": batch_size}, engine_args=engine_args,
context=engine_context, context=engine_context,
) )
print(self.engine) print(self.engine)
@ -59,7 +68,7 @@ class DeepSparseDecoderEngine:
def __call__( def __call__(
self, self,
engine_inputs: Dict[str, np.ndarray], engine_inputs: Dict[str, np.ndarray],
past_key_values: DeepSparsePastKeyValues, past_key_values: DeepSparsePastKeyValues, # XXXX this can be a list
val_inputs: bool = True val_inputs: bool = True
): ):
# format input into lists (we pass empty past key values) # format input into lists (we pass empty past key values)
@ -72,11 +81,21 @@ class DeepSparseDecoderEngine:
if val_inputs: if val_inputs:
self.engine._validate_inputs(inputs) self.engine._validate_inputs(inputs)
#print(f"here {past_key_values}")
if type(past_key_values) is list:
caches = [pkv.internal_past_key_values for pkv in past_key_values]
else:
caches = past_key_values.internal_past_key_values
# run inference, updates past_key_values internally # run inference, updates past_key_values internally
output = self.engine._eng_net.execute_list_out( if self.engine_type == DEEPSPARSE_ENGINE:
inputs, output = self.engine._eng_net.execute_list_out(
past_key_values.internal_past_key_values inputs,
) caches
)
else:
output = self.engine.run(inputs)
logits = output[0] logits = output[0]
return logits, past_key_values return logits, past_key_values
@ -189,13 +208,13 @@ class DeepSparseDecoderModel:
engine_inputs = {} engine_inputs = {}
print(batch_size) #print(batch_size)
print(input_ids) #print(input_ids)
print(len(input_ids)) #print(len(input_ids))
last_input_ids = [x[:,-1:] for x in input_ids] last_input_ids = [x[:,-1:] for x in input_ids]
print(f"last_input_ids {last_input_ids}") #print(f"last_input_ids {last_input_ids}")
engine_inputs["input_ids"] = np.concatenate(last_input_ids, axis=0) engine_inputs["input_ids"] = np.concatenate(last_input_ids, axis=0)
@ -206,14 +225,14 @@ class DeepSparseDecoderModel:
engine_inputs["input_ids"], engine_inputs["input_ids"],
engine_inputs["attention_mask"] engine_inputs["attention_mask"]
) )
#engine_inputs["positions"] = np.ndarray([batch_size, input_ids[0].shape[1] - 1], dtype=np.int64) poses = [pos.shape[1] - 1 for pos in input_ids]
engine_inputs["positions"] = np.array([[input_ids[0].shape[1] - 1]], dtype=np.int64) #print(f"poses {poses}")
engine_inputs["positions"] = np.array(poses, dtype=np.int64)[:,None]
print(f"inputs {engine_inputs['input_ids']} {engine_inputs['input_ids'].shape}")
print(f"attn mask {engine_inputs['attention_mask']} {engine_inputs['attention_mask'].shape}")
print(f"causal mask {engine_inputs['causal_mask']} {engine_inputs['causal_mask'].shape}")
print(f"pos {engine_inputs['positions']} {engine_inputs['positions'].shape}")
#print(f"inputs {engine_inputs['input_ids']} {engine_inputs['input_ids'].shape}")
#print(f"attn mask {engine_inputs['attention_mask']} {engine_inputs['attention_mask'].shape}")
#print(f"causal mask {engine_inputs['causal_mask']} {engine_inputs['causal_mask'].shape}")
#print(f"pos {engine_inputs['positions']} {engine_inputs['positions'].shape}")
return engine_inputs return engine_inputs
@ -223,6 +242,7 @@ class DeepSparseDecoderModel:
batched_past_key_values: List[DeepSparsePastKeyValues] batched_past_key_values: List[DeepSparsePastKeyValues]
) -> (np.ndarray, List[DeepSparsePastKeyValues]): ) -> (np.ndarray, List[DeepSparsePastKeyValues]):
#print(f"{len(batched_input_ids)} {len(batched_past_key_values)}")
assert len(batched_input_ids) == len(batched_past_key_values) assert len(batched_input_ids) == len(batched_past_key_values)
batched_logits = [] batched_logits = []
@ -235,17 +255,19 @@ class DeepSparseDecoderModel:
for input_ids, past_key_values in chunks: for input_ids, past_key_values in chunks:
# assert input is of shape [1,seq_len] w/ seq_len < self.sequence_len # assert input is of shape [1,seq_len] w/ seq_len < self.sequence_len
print(input_ids) #print(input_ids)
assert len(input_ids[0].shape) == 2 assert len(input_ids[0].shape) == 2
assert input_ids[0].shape[1] < self.sequence_length assert input_ids[0].shape[1] < self.sequence_length
if len(input_ids) == self.batch_size and self.batch_size != 1: if len(input_ids) == self.batch_size and self.batch_size != 1:
engine_inputs = self.engine_inputs_for_decode(input_ids) engine_inputs = self.engine_inputs_for_decode(input_ids)
print(f"GOT HERE!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! {past_key_values}")
logits, new_key_values = self.batched_singletoken_engine( logits, new_key_values = self.batched_singletoken_engine(
engine_inputs, engine_inputs,
past_key_values past_key_values
) )
batched_logits.append(logits) batched_logits.append(logits)
# XXXXX this is bogus
batched_new_key_values.append(new_key_values) batched_new_key_values.append(new_key_values)
else: else:
for i in range(len(input_ids)): for i in range(len(input_ids)):
@ -256,7 +278,11 @@ class DeepSparseDecoderModel:
batched_logits.append(logits) batched_logits.append(logits)
batched_new_key_values.append(new_key_values) batched_new_key_values.append(new_key_values)
return np.concatenate(batched_logits, axis=0), batched_new_key_values #print(f"decode {len(batched_input_ids)} {len(batched_new_key_values)}")
# XXXXX this is bogus
return np.concatenate(batched_logits, axis=0), batched_past_key_values
#return np.concatenate(batched_logits, axis=0), np.concatenate(batched_new_key_values)
def prefill( def prefill(
self, self,
@ -283,7 +309,7 @@ class DeepSparseDecoderModel:
# if anything left over, run inference w/ singletoken engine # if anything left over, run inference w/ singletoken engine
while tokens_processed < input_ids.shape[1]: while tokens_processed < input_ids.shape[1]:
print(f"got here {input_ids[:,:tokens_processed+1]}") #print(f"got here {input_ids[:,:tokens_processed+1]}")
assert len(input_ids.shape) == 2 assert len(input_ids.shape) == 2
logits, past_key_values = self.decode( logits, past_key_values = self.decode(
[input_ids[:,:tokens_processed+1]], [input_ids[:,:tokens_processed+1]],
@ -300,11 +326,13 @@ class DeepSparseDecoderModel:
past_key_values: List[Optional[DeepSparsePastKeyValues]], past_key_values: List[Optional[DeepSparsePastKeyValues]],
): ):
assert len(past_key_values) > 0 assert len(past_key_values) > 0
print(f"forward pkv {past_key_values} {past_key_values[0] is None}") #print(f"forward pkv {past_key_values} {past_key_values[0] is None}")
if past_key_values[0] is None: if past_key_values[0] is None:
assert len(input_ids) == 1 assert len(input_ids) == 1
print("PREFILL!!!!!!!!!!!!!!!!!!!!!")
return self.prefill(input_ids[0]) return self.prefill(input_ids[0])
else: else:
print("DECODE!!!!!!!!!!!!!!!!!!!!!")
return self.decode(input_ids, past_key_values) return self.decode(input_ids, past_key_values)
def __call__( def __call__(