This commit is contained in:
Bill Nell 2023-10-06 22:00:30 -04:00
parent 1700d11905
commit a1c9cc422a
5 changed files with 141 additions and 58 deletions

View File

@ -57,6 +57,8 @@ class DeepSparseCausalLMBatch:
) )
) )
print(r.generation_parameters)
# get next token chooser based on input # get next token chooser based on input
next_token_chooser_list.append( next_token_chooser_list.append(
NextTokenChooser( NextTokenChooser(
@ -142,6 +144,8 @@ class DeepSparseCausalLMBatch:
# concatenate request, input_ids, and past_key_values lists # concatenate request, input_ids, and past_key_values lists
requests.extend(batch.requests) requests.extend(batch.requests)
input_ids_list.extend(batch.input_ids_list) input_ids_list.extend(batch.input_ids_list)
print(f"pkv {past_key_values_list}")
print(f"bpkv {batch.past_key_values_list}")
past_key_values_list.extend(batch.past_key_values_list) past_key_values_list.extend(batch.past_key_values_list)
stopping_criteria_list.extend(batch.stopping_criteria_list) stopping_criteria_list.extend(batch.stopping_criteria_list)
next_token_chooser_list.extend(batch.next_token_chooser_list) next_token_chooser_list.extend(batch.next_token_chooser_list)
@ -202,24 +206,30 @@ class DeepSparseCausalLM:
iterator = zip( iterator = zip(
batch.requests, batch.requests,
batch.input_ids_list, batch.input_ids_list,
batch.past_key_values_list,
batch.stopping_criteria_list, batch.stopping_criteria_list,
batch.next_token_chooser_list, batch.next_token_chooser_list,
) )
#assert len(input_ids.shape) == 2
#assert input_ids.shape[0] == 1
print(batch.past_key_values_list)
print(len(batch.past_key_values_list))
print(batch.input_ids_list)
print(len(batch.input_ids_list))
# a) run inference
logits, batch.past_key_values_list = self.model(batch.input_ids_list, batch.past_key_values_list)
print(logits)
print(logits.shape)
for i, ( for i, (
request, request,
input_ids, input_ids,
past_key_values,
stopping_criteria, stopping_criteria,
next_token_chooser next_token_chooser
) in enumerate(iterator): ) in enumerate(iterator):
# assert input_ids is b=1
assert len(input_ids.shape) == 2
assert input_ids.shape[0] == 1
# a) run inference
logits, past_key_values = self.model(input_ids, past_key_values)
# b) sample token and check stopping criteria # b) sample token and check stopping criteria
# TODO: should use NextTokenChooser/StoppingCriteria (simple for now) # TODO: should use NextTokenChooser/StoppingCriteria (simple for now)
generated_token_id = next_token_chooser(input_ids=input_ids, scores=logits[:,-1,:]) generated_token_id = next_token_chooser(input_ids=input_ids, scores=logits[:,-1,:])
@ -241,13 +251,12 @@ class DeepSparseCausalLM:
# d) update batch # d) update batch
# TODO: this does not occur in place # TODO: this does not occur in place
assert len(batch.input_ids_list[i].shape) == 2 assert len(batch.input_ids_list[i].shape) == 2
assert batch.input_ids_list[i].shape[0] == 1 #assert batch.input_ids_list[i].shape[0] == 1
batch.input_ids_list[i] = np.append( batch.input_ids_list[i] = np.append(
batch.input_ids_list[i], batch.input_ids_list[i],
np.array([[generated_token_id]]), np.array([[generated_token_id]]),
axis=1 axis=1
) )
batch.past_key_values_list[i] = past_key_values
# if all elements of the batch are done, return null for batch # if all elements of the batch are done, return null for batch
if all_stopped: if all_stopped:

View File

@ -7,10 +7,16 @@ from typing import Optional, List, Dict
from deepsparse import Context from deepsparse import Context
from deepsparse.engine import LIB from deepsparse.engine import LIB
from deepsparse.pipeline import DEEPSPARSE_ENGINE, create_engine from deepsparse.pipeline import DEEPSPARSE_ENGINE, create_engine
from deepsparse.transformers.utils.helpers import overwrite_onnx_model_inputs_for_kv_cache_models, create_causal_mask from deepsparse.utils.onnx import overwrite_onnx_model_inputs_for_kv_cache_models
from deepsparse.transformers.utils.helpers import create_causal_mask
PAST_KEY_VALUES_NAME = "past_key_values" PAST_KEY_VALUES_NAME = "past_key_values"
def chunkify(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i:i + n]
class DeepSparsePastKeyValues: class DeepSparsePastKeyValues:
def __init__(self): def __init__(self):
prev_num_tokens = 0 prev_num_tokens = 0
@ -23,13 +29,14 @@ class DeepSparseDecoderEngine:
onnx_file_path: str, onnx_file_path: str,
sequence_length: int = 1024, sequence_length: int = 1024,
input_ids_length: int = 1, input_ids_length: int = 1,
batch_size: int = 1,
engine_context: Optional[Context] = None, engine_context: Optional[Context] = None,
): ):
# setup ONNX graph # setup ONNX graph(s)
onnx_file_path, cached_outputs, data_type = overwrite_onnx_model_inputs_for_kv_cache_models( onnx_file_path, cached_outputs, data_type = overwrite_onnx_model_inputs_for_kv_cache_models(
onnx_file_path=onnx_file_path, onnx_file_path=onnx_file_path,
batch_size=1, batch_size=batch_size,
sequence_length=sequence_length, sequence_length=sequence_length,
input_ids_length=input_ids_length, input_ids_length=input_ids_length,
) )
@ -38,7 +45,7 @@ class DeepSparseDecoderEngine:
self.engine = create_engine( self.engine = create_engine(
onnx_file_path=onnx_file_path, onnx_file_path=onnx_file_path,
engine_type=DEEPSPARSE_ENGINE, engine_type=DEEPSPARSE_ENGINE,
engine_args={"cached_outputs": cached_outputs}, engine_args={"cached_outputs": cached_outputs, "batch_size": batch_size},
context=engine_context, context=engine_context,
) )
print(self.engine) print(self.engine)
@ -91,19 +98,33 @@ class DeepSparseDecoderModel:
onnx_file_path: str, onnx_file_path: str,
sequence_length: int = 1024, sequence_length: int = 1024,
multitoken_length: int = 16, multitoken_length: int = 16,
batch_size: int = 1, # 16
engine_context: Optional[Context] = None, engine_context: Optional[Context] = None,
): ):
self.sequence_length = sequence_length self.sequence_length = sequence_length
self.multitoken_length = multitoken_length self.multitoken_length = multitoken_length
self.batch_size = batch_size
# compile decode engine # compile decode engines
self.singletoken_engine = DeepSparseDecoderEngine( self.singletoken_engine = DeepSparseDecoderEngine(
onnx_file_path=onnx_file_path, onnx_file_path=onnx_file_path,
engine_context=engine_context, engine_context=engine_context,
sequence_length=sequence_length, sequence_length=sequence_length,
input_ids_length=1, input_ids_length=1,
batch_size=1
) )
if batch_size > 1:
self.batched_singletoken_engine = DeepSparseDecoderEngine(
onnx_file_path=onnx_file_path,
engine_context=engine_context,
sequence_length=sequence_length,
input_ids_length=1,
batch_size=batch_size
)
else:
self.batched_singletoken_engine = None
# compile prefill engine # compile prefill engine
self.multitoken_engine = DeepSparseDecoderEngine( self.multitoken_engine = DeepSparseDecoderEngine(
onnx_file_path=onnx_file_path, onnx_file_path=onnx_file_path,
@ -155,39 +176,87 @@ class DeepSparseDecoderModel:
def engine_inputs_for_decode( def engine_inputs_for_decode(
self, self,
input_ids: np.ndarray, input_ids: List[np.ndarray],
): ):
# TODO: assert input_ids all have same shape
assert type(input_ids) is list
assert type(input_ids[0]) is np.ndarray
assert len(input_ids) > 0
assert len(input_ids[0].shape) == 2
assert input_ids[0].shape[1] < self.sequence_length
batch_size = len(input_ids)
engine_inputs = {} engine_inputs = {}
engine_inputs["input_ids"] = input_ids[:,-1:]
engine_inputs["attention_mask"] = np.zeros((1, self.sequence_length), dtype=np.int64) print(batch_size)
engine_inputs["attention_mask"][:, -input_ids.shape[1]:] = 1 print(input_ids)
print(len(input_ids))
last_input_ids = [x[:,-1:] for x in input_ids]
print(f"last_input_ids {last_input_ids}")
engine_inputs["input_ids"] = np.concatenate(last_input_ids, axis=0)
engine_inputs["attention_mask"] = np.zeros((batch_size, self.sequence_length), dtype=np.int64)
engine_inputs["attention_mask"][:, -input_ids[0].shape[1]:] = 1
engine_inputs["causal_mask"] = create_causal_mask( engine_inputs["causal_mask"] = create_causal_mask(
engine_inputs["input_ids"], engine_inputs["input_ids"],
engine_inputs["attention_mask"] engine_inputs["attention_mask"]
) )
engine_inputs["positions"] = np.array([[input_ids.shape[1] - 1]], dtype=np.int64) #engine_inputs["positions"] = np.ndarray([batch_size, input_ids[0].shape[1] - 1], dtype=np.int64)
engine_inputs["positions"] = np.array([[input_ids[0].shape[1] - 1]], dtype=np.int64)
print(f"inputs {engine_inputs['input_ids']} {engine_inputs['input_ids'].shape}")
print(f"attn mask {engine_inputs['attention_mask']} {engine_inputs['attention_mask'].shape}")
print(f"causal mask {engine_inputs['causal_mask']} {engine_inputs['causal_mask'].shape}")
print(f"pos {engine_inputs['positions']} {engine_inputs['positions'].shape}")
return engine_inputs return engine_inputs
def decode( def decode(
self, self,
input_ids: np.ndarray, batched_input_ids: List[np.ndarray],
past_key_values: DeepSparsePastKeyValues batched_past_key_values: List[DeepSparsePastKeyValues]
) -> (np.ndarray, DeepSparsePastKeyValues): ) -> (np.ndarray, List[DeepSparsePastKeyValues]):
# assert input is of shape [1,seq_len] w/ seq_len < self.sequence_len assert len(batched_input_ids) == len(batched_past_key_values)
assert len(input_ids.shape) == 2
assert input_ids.shape[0] == 1
assert input_ids.shape[1] < self.sequence_length
engine_inputs = self.engine_inputs_for_decode(input_ids) batched_logits = []
logits, past_key_values = self.singletoken_engine( batched_new_key_values = []
engine_inputs,
past_key_values chunks = zip(
chunkify(batched_input_ids, self.batch_size),
chunkify(batched_past_key_values, self.batch_size)
) )
return logits, past_key_values for input_ids, past_key_values in chunks:
# assert input is of shape [1,seq_len] w/ seq_len < self.sequence_len
print(input_ids)
assert len(input_ids[0].shape) == 2
assert input_ids[0].shape[1] < self.sequence_length
if len(input_ids) == self.batch_size and self.batch_size != 1:
engine_inputs = self.engine_inputs_for_decode(input_ids)
logits, new_key_values = self.batched_singletoken_engine(
engine_inputs,
past_key_values
)
batched_logits.append(logits)
batched_new_key_values.append(new_key_values)
else:
for i in range(len(input_ids)):
engine_inputs = self.engine_inputs_for_decode([input_ids[i]])
logits, new_key_values = self.singletoken_engine(
engine_inputs,
past_key_values[i])
batched_logits.append(logits)
batched_new_key_values.append(new_key_values)
return np.concatenate(batched_logits, axis=0), batched_new_key_values
def prefill( def prefill(
self, self,
@ -202,21 +271,23 @@ class DeepSparseDecoderModel:
tokens_processed = 0 tokens_processed = 0
# setup empty past key values # setup empty past key values
past_key_values = DeepSparsePastKeyValues() past_key_values = [DeepSparsePastKeyValues()]
# loop through chunks, run inference w/ multitoken engine # loop through chunks, run inference w/ multitoken engine
for engine_inputs in self.engine_inputs_for_prefill(input_ids): for engine_inputs in self.engine_inputs_for_prefill(input_ids):
logits, past_key_values = self.multitoken_engine( logits, past_key_values[0] = self.multitoken_engine(
engine_inputs, engine_inputs,
past_key_values past_key_values[0]
) )
tokens_processed += self.multitoken_length tokens_processed += self.multitoken_length
# if anything left over, run inference w/ singletoken engine # if anything left over, run inference w/ singletoken engine
while tokens_processed < input_ids.shape[1]: while tokens_processed < input_ids.shape[1]:
print(f"got here {input_ids[:,:tokens_processed+1]}")
assert len(input_ids.shape) == 2
logits, past_key_values = self.decode( logits, past_key_values = self.decode(
input_ids=input_ids[:,:tokens_processed+1], [input_ids[:,:tokens_processed+1]],
past_key_values=past_key_values past_key_values
) )
tokens_processed += 1 tokens_processed += 1
# print(logits[:,-1:,:]) # print(logits[:,-1:,:])
@ -225,17 +296,20 @@ class DeepSparseDecoderModel:
def forward( def forward(
self, self,
input_ids: np.ndarray, input_ids: List[np.ndarray],
past_key_values: Optional[DeepSparsePastKeyValues] = None, past_key_values: List[Optional[DeepSparsePastKeyValues]],
): ):
if past_key_values is None: assert len(past_key_values) > 0
return self.prefill(input_ids) print(f"forward pkv {past_key_values} {past_key_values[0] is None}")
if past_key_values[0] is None:
assert len(input_ids) == 1
return self.prefill(input_ids[0])
else: else:
return self.decode(input_ids, past_key_values) return self.decode(input_ids, past_key_values)
def __call__( def __call__(
self, self,
input_ids: np.ndarray, input_ids: List[np.ndarray],
past_key_values: Optional[DeepSparsePastKeyValues] = None, past_key_values: List[Optional[DeepSparsePastKeyValues]] = [],
): ):
return self.forward(input_ids, past_key_values) return self.forward(input_ids, past_key_values)

View File

@ -11,10 +11,10 @@ class Greedy:
def __call__(self, logits: np.ndarray): def __call__(self, logits: np.ndarray):
# assert b=1 for now # assert b=1 for now
# shape == (batch, vocabulary_size) # shape == (batch, vocabulary_size)
assert(logits.shape[0] == 1) #assert(logits.shape[0] == 1)
assert(len(logits.shape) == 2) assert(len(logits.shape) == 2)
return np.argmax(logits[0,:]) return np.argmax(logits[0,:]) # XXXXXXXXXXXXXXX fix
# TODO: sample for b > 1 with vectorized code # TODO: sample for b > 1 with vectorized code
# https://stackoverflow.com/questions/47722005/vectorizing-numpy-random-choice-for-given-2d-array-of-probabilities-along-an-a # https://stackoverflow.com/questions/47722005/vectorizing-numpy-random-choice-for-given-2d-array-of-probabilities-along-an-a