mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
wip
This commit is contained in:
parent
1700d11905
commit
a1c9cc422a
@ -57,6 +57,8 @@ class DeepSparseCausalLMBatch:
|
||||
)
|
||||
)
|
||||
|
||||
print(r.generation_parameters)
|
||||
|
||||
# get next token chooser based on input
|
||||
next_token_chooser_list.append(
|
||||
NextTokenChooser(
|
||||
@ -142,6 +144,8 @@ class DeepSparseCausalLMBatch:
|
||||
# concatenate request, input_ids, and past_key_values lists
|
||||
requests.extend(batch.requests)
|
||||
input_ids_list.extend(batch.input_ids_list)
|
||||
print(f"pkv {past_key_values_list}")
|
||||
print(f"bpkv {batch.past_key_values_list}")
|
||||
past_key_values_list.extend(batch.past_key_values_list)
|
||||
stopping_criteria_list.extend(batch.stopping_criteria_list)
|
||||
next_token_chooser_list.extend(batch.next_token_chooser_list)
|
||||
@ -202,24 +206,30 @@ class DeepSparseCausalLM:
|
||||
iterator = zip(
|
||||
batch.requests,
|
||||
batch.input_ids_list,
|
||||
batch.past_key_values_list,
|
||||
batch.stopping_criteria_list,
|
||||
batch.next_token_chooser_list,
|
||||
)
|
||||
|
||||
#assert len(input_ids.shape) == 2
|
||||
#assert input_ids.shape[0] == 1
|
||||
|
||||
print(batch.past_key_values_list)
|
||||
print(len(batch.past_key_values_list))
|
||||
print(batch.input_ids_list)
|
||||
print(len(batch.input_ids_list))
|
||||
|
||||
# a) run inference
|
||||
logits, batch.past_key_values_list = self.model(batch.input_ids_list, batch.past_key_values_list)
|
||||
|
||||
print(logits)
|
||||
print(logits.shape)
|
||||
|
||||
for i, (
|
||||
request,
|
||||
input_ids,
|
||||
past_key_values,
|
||||
stopping_criteria,
|
||||
next_token_chooser
|
||||
) in enumerate(iterator):
|
||||
# assert input_ids is b=1
|
||||
assert len(input_ids.shape) == 2
|
||||
assert input_ids.shape[0] == 1
|
||||
|
||||
# a) run inference
|
||||
logits, past_key_values = self.model(input_ids, past_key_values)
|
||||
|
||||
# b) sample token and check stopping criteria
|
||||
# TODO: should use NextTokenChooser/StoppingCriteria (simple for now)
|
||||
generated_token_id = next_token_chooser(input_ids=input_ids, scores=logits[:,-1,:])
|
||||
@ -241,13 +251,12 @@ class DeepSparseCausalLM:
|
||||
# d) update batch
|
||||
# TODO: this does not occur in place
|
||||
assert len(batch.input_ids_list[i].shape) == 2
|
||||
assert batch.input_ids_list[i].shape[0] == 1
|
||||
#assert batch.input_ids_list[i].shape[0] == 1
|
||||
batch.input_ids_list[i] = np.append(
|
||||
batch.input_ids_list[i],
|
||||
np.array([[generated_token_id]]),
|
||||
axis=1
|
||||
)
|
||||
batch.past_key_values_list[i] = past_key_values
|
||||
|
||||
# if all elements of the batch are done, return null for batch
|
||||
if all_stopped:
|
||||
|
@ -7,10 +7,16 @@ from typing import Optional, List, Dict
|
||||
from deepsparse import Context
|
||||
from deepsparse.engine import LIB
|
||||
from deepsparse.pipeline import DEEPSPARSE_ENGINE, create_engine
|
||||
from deepsparse.transformers.utils.helpers import overwrite_onnx_model_inputs_for_kv_cache_models, create_causal_mask
|
||||
from deepsparse.utils.onnx import overwrite_onnx_model_inputs_for_kv_cache_models
|
||||
from deepsparse.transformers.utils.helpers import create_causal_mask
|
||||
|
||||
PAST_KEY_VALUES_NAME = "past_key_values"
|
||||
|
||||
def chunkify(lst, n):
|
||||
"""Yield successive n-sized chunks from lst."""
|
||||
for i in range(0, len(lst), n):
|
||||
yield lst[i:i + n]
|
||||
|
||||
class DeepSparsePastKeyValues:
|
||||
def __init__(self):
|
||||
prev_num_tokens = 0
|
||||
@ -23,13 +29,14 @@ class DeepSparseDecoderEngine:
|
||||
onnx_file_path: str,
|
||||
sequence_length: int = 1024,
|
||||
input_ids_length: int = 1,
|
||||
batch_size: int = 1,
|
||||
engine_context: Optional[Context] = None,
|
||||
):
|
||||
|
||||
# setup ONNX graph
|
||||
# setup ONNX graph(s)
|
||||
onnx_file_path, cached_outputs, data_type = overwrite_onnx_model_inputs_for_kv_cache_models(
|
||||
onnx_file_path=onnx_file_path,
|
||||
batch_size=1,
|
||||
batch_size=batch_size,
|
||||
sequence_length=sequence_length,
|
||||
input_ids_length=input_ids_length,
|
||||
)
|
||||
@ -38,7 +45,7 @@ class DeepSparseDecoderEngine:
|
||||
self.engine = create_engine(
|
||||
onnx_file_path=onnx_file_path,
|
||||
engine_type=DEEPSPARSE_ENGINE,
|
||||
engine_args={"cached_outputs": cached_outputs},
|
||||
engine_args={"cached_outputs": cached_outputs, "batch_size": batch_size},
|
||||
context=engine_context,
|
||||
)
|
||||
print(self.engine)
|
||||
@ -91,19 +98,33 @@ class DeepSparseDecoderModel:
|
||||
onnx_file_path: str,
|
||||
sequence_length: int = 1024,
|
||||
multitoken_length: int = 16,
|
||||
batch_size: int = 1, # 16
|
||||
engine_context: Optional[Context] = None,
|
||||
):
|
||||
self.sequence_length = sequence_length
|
||||
self.multitoken_length = multitoken_length
|
||||
self.batch_size = batch_size
|
||||
|
||||
# compile decode engine
|
||||
# compile decode engines
|
||||
self.singletoken_engine = DeepSparseDecoderEngine(
|
||||
onnx_file_path=onnx_file_path,
|
||||
engine_context=engine_context,
|
||||
sequence_length=sequence_length,
|
||||
input_ids_length=1,
|
||||
batch_size=1
|
||||
)
|
||||
|
||||
if batch_size > 1:
|
||||
self.batched_singletoken_engine = DeepSparseDecoderEngine(
|
||||
onnx_file_path=onnx_file_path,
|
||||
engine_context=engine_context,
|
||||
sequence_length=sequence_length,
|
||||
input_ids_length=1,
|
||||
batch_size=batch_size
|
||||
)
|
||||
else:
|
||||
self.batched_singletoken_engine = None
|
||||
|
||||
# compile prefill engine
|
||||
self.multitoken_engine = DeepSparseDecoderEngine(
|
||||
onnx_file_path=onnx_file_path,
|
||||
@ -155,39 +176,87 @@ class DeepSparseDecoderModel:
|
||||
|
||||
def engine_inputs_for_decode(
|
||||
self,
|
||||
input_ids: np.ndarray,
|
||||
input_ids: List[np.ndarray],
|
||||
):
|
||||
# TODO: assert input_ids all have same shape
|
||||
assert type(input_ids) is list
|
||||
assert type(input_ids[0]) is np.ndarray
|
||||
assert len(input_ids) > 0
|
||||
assert len(input_ids[0].shape) == 2
|
||||
assert input_ids[0].shape[1] < self.sequence_length
|
||||
|
||||
batch_size = len(input_ids)
|
||||
|
||||
engine_inputs = {}
|
||||
engine_inputs["input_ids"] = input_ids[:,-1:]
|
||||
engine_inputs["attention_mask"] = np.zeros((1, self.sequence_length), dtype=np.int64)
|
||||
engine_inputs["attention_mask"][:, -input_ids.shape[1]:] = 1
|
||||
|
||||
print(batch_size)
|
||||
print(input_ids)
|
||||
print(len(input_ids))
|
||||
|
||||
last_input_ids = [x[:,-1:] for x in input_ids]
|
||||
|
||||
print(f"last_input_ids {last_input_ids}")
|
||||
|
||||
engine_inputs["input_ids"] = np.concatenate(last_input_ids, axis=0)
|
||||
|
||||
engine_inputs["attention_mask"] = np.zeros((batch_size, self.sequence_length), dtype=np.int64)
|
||||
engine_inputs["attention_mask"][:, -input_ids[0].shape[1]:] = 1
|
||||
|
||||
engine_inputs["causal_mask"] = create_causal_mask(
|
||||
engine_inputs["input_ids"],
|
||||
engine_inputs["attention_mask"]
|
||||
)
|
||||
engine_inputs["positions"] = np.array([[input_ids.shape[1] - 1]], dtype=np.int64)
|
||||
#engine_inputs["positions"] = np.ndarray([batch_size, input_ids[0].shape[1] - 1], dtype=np.int64)
|
||||
engine_inputs["positions"] = np.array([[input_ids[0].shape[1] - 1]], dtype=np.int64)
|
||||
|
||||
print(f"inputs {engine_inputs['input_ids']} {engine_inputs['input_ids'].shape}")
|
||||
print(f"attn mask {engine_inputs['attention_mask']} {engine_inputs['attention_mask'].shape}")
|
||||
print(f"causal mask {engine_inputs['causal_mask']} {engine_inputs['causal_mask'].shape}")
|
||||
print(f"pos {engine_inputs['positions']} {engine_inputs['positions'].shape}")
|
||||
|
||||
|
||||
return engine_inputs
|
||||
|
||||
def decode(
|
||||
self,
|
||||
input_ids: np.ndarray,
|
||||
past_key_values: DeepSparsePastKeyValues
|
||||
) -> (np.ndarray, DeepSparsePastKeyValues):
|
||||
batched_input_ids: List[np.ndarray],
|
||||
batched_past_key_values: List[DeepSparsePastKeyValues]
|
||||
) -> (np.ndarray, List[DeepSparsePastKeyValues]):
|
||||
|
||||
# assert input is of shape [1,seq_len] w/ seq_len < self.sequence_len
|
||||
assert len(input_ids.shape) == 2
|
||||
assert input_ids.shape[0] == 1
|
||||
assert input_ids.shape[1] < self.sequence_length
|
||||
assert len(batched_input_ids) == len(batched_past_key_values)
|
||||
|
||||
engine_inputs = self.engine_inputs_for_decode(input_ids)
|
||||
logits, past_key_values = self.singletoken_engine(
|
||||
engine_inputs,
|
||||
past_key_values
|
||||
batched_logits = []
|
||||
batched_new_key_values = []
|
||||
|
||||
chunks = zip(
|
||||
chunkify(batched_input_ids, self.batch_size),
|
||||
chunkify(batched_past_key_values, self.batch_size)
|
||||
)
|
||||
|
||||
return logits, past_key_values
|
||||
for input_ids, past_key_values in chunks:
|
||||
# assert input is of shape [1,seq_len] w/ seq_len < self.sequence_len
|
||||
print(input_ids)
|
||||
assert len(input_ids[0].shape) == 2
|
||||
assert input_ids[0].shape[1] < self.sequence_length
|
||||
|
||||
if len(input_ids) == self.batch_size and self.batch_size != 1:
|
||||
engine_inputs = self.engine_inputs_for_decode(input_ids)
|
||||
logits, new_key_values = self.batched_singletoken_engine(
|
||||
engine_inputs,
|
||||
past_key_values
|
||||
)
|
||||
batched_logits.append(logits)
|
||||
batched_new_key_values.append(new_key_values)
|
||||
else:
|
||||
for i in range(len(input_ids)):
|
||||
engine_inputs = self.engine_inputs_for_decode([input_ids[i]])
|
||||
logits, new_key_values = self.singletoken_engine(
|
||||
engine_inputs,
|
||||
past_key_values[i])
|
||||
batched_logits.append(logits)
|
||||
batched_new_key_values.append(new_key_values)
|
||||
|
||||
return np.concatenate(batched_logits, axis=0), batched_new_key_values
|
||||
|
||||
def prefill(
|
||||
self,
|
||||
@ -202,21 +271,23 @@ class DeepSparseDecoderModel:
|
||||
tokens_processed = 0
|
||||
|
||||
# setup empty past key values
|
||||
past_key_values = DeepSparsePastKeyValues()
|
||||
past_key_values = [DeepSparsePastKeyValues()]
|
||||
|
||||
# loop through chunks, run inference w/ multitoken engine
|
||||
for engine_inputs in self.engine_inputs_for_prefill(input_ids):
|
||||
logits, past_key_values = self.multitoken_engine(
|
||||
logits, past_key_values[0] = self.multitoken_engine(
|
||||
engine_inputs,
|
||||
past_key_values
|
||||
past_key_values[0]
|
||||
)
|
||||
tokens_processed += self.multitoken_length
|
||||
|
||||
# if anything left over, run inference w/ singletoken engine
|
||||
while tokens_processed < input_ids.shape[1]:
|
||||
print(f"got here {input_ids[:,:tokens_processed+1]}")
|
||||
assert len(input_ids.shape) == 2
|
||||
logits, past_key_values = self.decode(
|
||||
input_ids=input_ids[:,:tokens_processed+1],
|
||||
past_key_values=past_key_values
|
||||
[input_ids[:,:tokens_processed+1]],
|
||||
past_key_values
|
||||
)
|
||||
tokens_processed += 1
|
||||
# print(logits[:,-1:,:])
|
||||
@ -225,17 +296,20 @@ class DeepSparseDecoderModel:
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: np.ndarray,
|
||||
past_key_values: Optional[DeepSparsePastKeyValues] = None,
|
||||
input_ids: List[np.ndarray],
|
||||
past_key_values: List[Optional[DeepSparsePastKeyValues]],
|
||||
):
|
||||
if past_key_values is None:
|
||||
return self.prefill(input_ids)
|
||||
assert len(past_key_values) > 0
|
||||
print(f"forward pkv {past_key_values} {past_key_values[0] is None}")
|
||||
if past_key_values[0] is None:
|
||||
assert len(input_ids) == 1
|
||||
return self.prefill(input_ids[0])
|
||||
else:
|
||||
return self.decode(input_ids, past_key_values)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: np.ndarray,
|
||||
past_key_values: Optional[DeepSparsePastKeyValues] = None,
|
||||
input_ids: List[np.ndarray],
|
||||
past_key_values: List[Optional[DeepSparsePastKeyValues]] = [],
|
||||
):
|
||||
return self.forward(input_ids, past_key_values)
|
@ -11,10 +11,10 @@ class Greedy:
|
||||
def __call__(self, logits: np.ndarray):
|
||||
# assert b=1 for now
|
||||
# shape == (batch, vocabulary_size)
|
||||
assert(logits.shape[0] == 1)
|
||||
#assert(logits.shape[0] == 1)
|
||||
assert(len(logits.shape) == 2)
|
||||
|
||||
return np.argmax(logits[0,:])
|
||||
return np.argmax(logits[0,:]) # XXXXXXXXXXXXXXX fix
|
||||
|
||||
# TODO: sample for b > 1 with vectorized code
|
||||
# https://stackoverflow.com/questions/47722005/vectorizing-numpy-random-choice-for-given-2d-array-of-probabilities-along-an-a
|
||||
|
Loading…
Reference in New Issue
Block a user