mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
hacks to support native continuous batching
This commit is contained in:
parent
a1c9cc422a
commit
5e02f40a83
@ -57,8 +57,6 @@ class DeepSparseCausalLMBatch:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
print(r.generation_parameters)
|
|
||||||
|
|
||||||
# get next token chooser based on input
|
# get next token chooser based on input
|
||||||
next_token_chooser_list.append(
|
next_token_chooser_list.append(
|
||||||
NextTokenChooser(
|
NextTokenChooser(
|
||||||
@ -123,6 +121,8 @@ class DeepSparseCausalLMBatch:
|
|||||||
self.stopping_criteria_list = stopping_criteria_list
|
self.stopping_criteria_list = stopping_criteria_list
|
||||||
self.next_token_chooser_list = next_token_chooser_list
|
self.next_token_chooser_list = next_token_chooser_list
|
||||||
|
|
||||||
|
assert len(self.input_ids_list) == len(self.past_key_values_list)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
# combine two batches into one
|
# combine two batches into one
|
||||||
@ -144,8 +144,8 @@ class DeepSparseCausalLMBatch:
|
|||||||
# concatenate request, input_ids, and past_key_values lists
|
# concatenate request, input_ids, and past_key_values lists
|
||||||
requests.extend(batch.requests)
|
requests.extend(batch.requests)
|
||||||
input_ids_list.extend(batch.input_ids_list)
|
input_ids_list.extend(batch.input_ids_list)
|
||||||
print(f"pkv {past_key_values_list}")
|
#print(f"pkv {past_key_values_list}")
|
||||||
print(f"bpkv {batch.past_key_values_list}")
|
#print(f"bpkv {batch.past_key_values_list}")
|
||||||
past_key_values_list.extend(batch.past_key_values_list)
|
past_key_values_list.extend(batch.past_key_values_list)
|
||||||
stopping_criteria_list.extend(batch.stopping_criteria_list)
|
stopping_criteria_list.extend(batch.stopping_criteria_list)
|
||||||
next_token_chooser_list.extend(batch.next_token_chooser_list)
|
next_token_chooser_list.extend(batch.next_token_chooser_list)
|
||||||
@ -159,6 +159,8 @@ class DeepSparseCausalLMBatch:
|
|||||||
|
|
||||||
start_index += len(batch)
|
start_index += len(batch)
|
||||||
|
|
||||||
|
assert len(input_ids_list) == len(past_key_values_list)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=batches[0].batch_id,
|
batch_id=batches[0].batch_id,
|
||||||
requests=requests,
|
requests=requests,
|
||||||
@ -187,6 +189,7 @@ class DeepSparseCausalLM:
|
|||||||
onnx_file_path = model_path,
|
onnx_file_path = model_path,
|
||||||
sequence_length = DEEPSPARSE_SEQUENCE_LENGTH,
|
sequence_length = DEEPSPARSE_SEQUENCE_LENGTH,
|
||||||
multitoken_length = DEEPSPARSE_MULTITOKEN_LENGTH,
|
multitoken_length = DEEPSPARSE_MULTITOKEN_LENGTH,
|
||||||
|
batch_size=4
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_token(
|
def generate_token(
|
||||||
@ -213,16 +216,22 @@ class DeepSparseCausalLM:
|
|||||||
#assert len(input_ids.shape) == 2
|
#assert len(input_ids.shape) == 2
|
||||||
#assert input_ids.shape[0] == 1
|
#assert input_ids.shape[0] == 1
|
||||||
|
|
||||||
print(batch.past_key_values_list)
|
#print(batch.past_key_values_list)
|
||||||
print(len(batch.past_key_values_list))
|
#print(len(batch.past_key_values_list))
|
||||||
print(batch.input_ids_list)
|
#print(batch.input_ids_list)
|
||||||
print(len(batch.input_ids_list))
|
#print(len(batch.input_ids_list))
|
||||||
|
|
||||||
|
#print(f"before {len(batch.input_ids_list)} {len(batch.past_key_values_list)}")
|
||||||
|
|
||||||
# a) run inference
|
# a) run inference
|
||||||
logits, batch.past_key_values_list = self.model(batch.input_ids_list, batch.past_key_values_list)
|
logits, batch.past_key_values_list = self.model(batch.input_ids_list, batch.past_key_values_list)
|
||||||
|
|
||||||
print(logits)
|
#print(f"after {len(batch.input_ids_list)} {len(batch.past_key_values_list)} {batch.past_key_values_list}")
|
||||||
print(logits.shape)
|
|
||||||
|
assert len(batch.input_ids_list) == len(batch.past_key_values_list)
|
||||||
|
|
||||||
|
#print(logits)
|
||||||
|
#print(logits.shape)
|
||||||
|
|
||||||
for i, (
|
for i, (
|
||||||
request,
|
request,
|
||||||
|
@ -41,11 +41,20 @@ class DeepSparseDecoderEngine:
|
|||||||
input_ids_length=input_ids_length,
|
input_ids_length=input_ids_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.engine_type = DEEPSPARSE_ENGINE
|
||||||
|
#self.engine_type = "onnxruntime"
|
||||||
|
|
||||||
|
if self.engine_type == DEEPSPARSE_ENGINE:
|
||||||
|
engine_args = {"cached_outputs": cached_outputs, "batch_size": batch_size}
|
||||||
|
else:
|
||||||
|
engine_args = {"batch_size": batch_size}
|
||||||
|
|
||||||
# compile engine
|
# compile engine
|
||||||
|
print(f"compiling for batch size: {batch_size}")
|
||||||
self.engine = create_engine(
|
self.engine = create_engine(
|
||||||
onnx_file_path=onnx_file_path,
|
onnx_file_path=onnx_file_path,
|
||||||
engine_type=DEEPSPARSE_ENGINE,
|
engine_type=self.engine_type,
|
||||||
engine_args={"cached_outputs": cached_outputs, "batch_size": batch_size},
|
engine_args=engine_args,
|
||||||
context=engine_context,
|
context=engine_context,
|
||||||
)
|
)
|
||||||
print(self.engine)
|
print(self.engine)
|
||||||
@ -59,7 +68,7 @@ class DeepSparseDecoderEngine:
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
engine_inputs: Dict[str, np.ndarray],
|
engine_inputs: Dict[str, np.ndarray],
|
||||||
past_key_values: DeepSparsePastKeyValues,
|
past_key_values: DeepSparsePastKeyValues, # XXXX this can be a list
|
||||||
val_inputs: bool = True
|
val_inputs: bool = True
|
||||||
):
|
):
|
||||||
# format input into lists (we pass empty past key values)
|
# format input into lists (we pass empty past key values)
|
||||||
@ -72,11 +81,21 @@ class DeepSparseDecoderEngine:
|
|||||||
if val_inputs:
|
if val_inputs:
|
||||||
self.engine._validate_inputs(inputs)
|
self.engine._validate_inputs(inputs)
|
||||||
|
|
||||||
|
#print(f"here {past_key_values}")
|
||||||
|
|
||||||
|
if type(past_key_values) is list:
|
||||||
|
caches = [pkv.internal_past_key_values for pkv in past_key_values]
|
||||||
|
else:
|
||||||
|
caches = past_key_values.internal_past_key_values
|
||||||
|
|
||||||
# run inference, updates past_key_values internally
|
# run inference, updates past_key_values internally
|
||||||
|
if self.engine_type == DEEPSPARSE_ENGINE:
|
||||||
output = self.engine._eng_net.execute_list_out(
|
output = self.engine._eng_net.execute_list_out(
|
||||||
inputs,
|
inputs,
|
||||||
past_key_values.internal_past_key_values
|
caches
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
output = self.engine.run(inputs)
|
||||||
logits = output[0]
|
logits = output[0]
|
||||||
return logits, past_key_values
|
return logits, past_key_values
|
||||||
|
|
||||||
@ -189,13 +208,13 @@ class DeepSparseDecoderModel:
|
|||||||
|
|
||||||
engine_inputs = {}
|
engine_inputs = {}
|
||||||
|
|
||||||
print(batch_size)
|
#print(batch_size)
|
||||||
print(input_ids)
|
#print(input_ids)
|
||||||
print(len(input_ids))
|
#print(len(input_ids))
|
||||||
|
|
||||||
last_input_ids = [x[:,-1:] for x in input_ids]
|
last_input_ids = [x[:,-1:] for x in input_ids]
|
||||||
|
|
||||||
print(f"last_input_ids {last_input_ids}")
|
#print(f"last_input_ids {last_input_ids}")
|
||||||
|
|
||||||
engine_inputs["input_ids"] = np.concatenate(last_input_ids, axis=0)
|
engine_inputs["input_ids"] = np.concatenate(last_input_ids, axis=0)
|
||||||
|
|
||||||
@ -206,14 +225,14 @@ class DeepSparseDecoderModel:
|
|||||||
engine_inputs["input_ids"],
|
engine_inputs["input_ids"],
|
||||||
engine_inputs["attention_mask"]
|
engine_inputs["attention_mask"]
|
||||||
)
|
)
|
||||||
#engine_inputs["positions"] = np.ndarray([batch_size, input_ids[0].shape[1] - 1], dtype=np.int64)
|
poses = [pos.shape[1] - 1 for pos in input_ids]
|
||||||
engine_inputs["positions"] = np.array([[input_ids[0].shape[1] - 1]], dtype=np.int64)
|
#print(f"poses {poses}")
|
||||||
|
engine_inputs["positions"] = np.array(poses, dtype=np.int64)[:,None]
|
||||||
print(f"inputs {engine_inputs['input_ids']} {engine_inputs['input_ids'].shape}")
|
|
||||||
print(f"attn mask {engine_inputs['attention_mask']} {engine_inputs['attention_mask'].shape}")
|
|
||||||
print(f"causal mask {engine_inputs['causal_mask']} {engine_inputs['causal_mask'].shape}")
|
|
||||||
print(f"pos {engine_inputs['positions']} {engine_inputs['positions'].shape}")
|
|
||||||
|
|
||||||
|
#print(f"inputs {engine_inputs['input_ids']} {engine_inputs['input_ids'].shape}")
|
||||||
|
#print(f"attn mask {engine_inputs['attention_mask']} {engine_inputs['attention_mask'].shape}")
|
||||||
|
#print(f"causal mask {engine_inputs['causal_mask']} {engine_inputs['causal_mask'].shape}")
|
||||||
|
#print(f"pos {engine_inputs['positions']} {engine_inputs['positions'].shape}")
|
||||||
|
|
||||||
return engine_inputs
|
return engine_inputs
|
||||||
|
|
||||||
@ -223,6 +242,7 @@ class DeepSparseDecoderModel:
|
|||||||
batched_past_key_values: List[DeepSparsePastKeyValues]
|
batched_past_key_values: List[DeepSparsePastKeyValues]
|
||||||
) -> (np.ndarray, List[DeepSparsePastKeyValues]):
|
) -> (np.ndarray, List[DeepSparsePastKeyValues]):
|
||||||
|
|
||||||
|
#print(f"{len(batched_input_ids)} {len(batched_past_key_values)}")
|
||||||
assert len(batched_input_ids) == len(batched_past_key_values)
|
assert len(batched_input_ids) == len(batched_past_key_values)
|
||||||
|
|
||||||
batched_logits = []
|
batched_logits = []
|
||||||
@ -235,17 +255,19 @@ class DeepSparseDecoderModel:
|
|||||||
|
|
||||||
for input_ids, past_key_values in chunks:
|
for input_ids, past_key_values in chunks:
|
||||||
# assert input is of shape [1,seq_len] w/ seq_len < self.sequence_len
|
# assert input is of shape [1,seq_len] w/ seq_len < self.sequence_len
|
||||||
print(input_ids)
|
#print(input_ids)
|
||||||
assert len(input_ids[0].shape) == 2
|
assert len(input_ids[0].shape) == 2
|
||||||
assert input_ids[0].shape[1] < self.sequence_length
|
assert input_ids[0].shape[1] < self.sequence_length
|
||||||
|
|
||||||
if len(input_ids) == self.batch_size and self.batch_size != 1:
|
if len(input_ids) == self.batch_size and self.batch_size != 1:
|
||||||
engine_inputs = self.engine_inputs_for_decode(input_ids)
|
engine_inputs = self.engine_inputs_for_decode(input_ids)
|
||||||
|
print(f"GOT HERE!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! {past_key_values}")
|
||||||
logits, new_key_values = self.batched_singletoken_engine(
|
logits, new_key_values = self.batched_singletoken_engine(
|
||||||
engine_inputs,
|
engine_inputs,
|
||||||
past_key_values
|
past_key_values
|
||||||
)
|
)
|
||||||
batched_logits.append(logits)
|
batched_logits.append(logits)
|
||||||
|
# XXXXX this is bogus
|
||||||
batched_new_key_values.append(new_key_values)
|
batched_new_key_values.append(new_key_values)
|
||||||
else:
|
else:
|
||||||
for i in range(len(input_ids)):
|
for i in range(len(input_ids)):
|
||||||
@ -256,7 +278,11 @@ class DeepSparseDecoderModel:
|
|||||||
batched_logits.append(logits)
|
batched_logits.append(logits)
|
||||||
batched_new_key_values.append(new_key_values)
|
batched_new_key_values.append(new_key_values)
|
||||||
|
|
||||||
return np.concatenate(batched_logits, axis=0), batched_new_key_values
|
#print(f"decode {len(batched_input_ids)} {len(batched_new_key_values)}")
|
||||||
|
|
||||||
|
# XXXXX this is bogus
|
||||||
|
return np.concatenate(batched_logits, axis=0), batched_past_key_values
|
||||||
|
#return np.concatenate(batched_logits, axis=0), np.concatenate(batched_new_key_values)
|
||||||
|
|
||||||
def prefill(
|
def prefill(
|
||||||
self,
|
self,
|
||||||
@ -283,7 +309,7 @@ class DeepSparseDecoderModel:
|
|||||||
|
|
||||||
# if anything left over, run inference w/ singletoken engine
|
# if anything left over, run inference w/ singletoken engine
|
||||||
while tokens_processed < input_ids.shape[1]:
|
while tokens_processed < input_ids.shape[1]:
|
||||||
print(f"got here {input_ids[:,:tokens_processed+1]}")
|
#print(f"got here {input_ids[:,:tokens_processed+1]}")
|
||||||
assert len(input_ids.shape) == 2
|
assert len(input_ids.shape) == 2
|
||||||
logits, past_key_values = self.decode(
|
logits, past_key_values = self.decode(
|
||||||
[input_ids[:,:tokens_processed+1]],
|
[input_ids[:,:tokens_processed+1]],
|
||||||
@ -300,11 +326,13 @@ class DeepSparseDecoderModel:
|
|||||||
past_key_values: List[Optional[DeepSparsePastKeyValues]],
|
past_key_values: List[Optional[DeepSparsePastKeyValues]],
|
||||||
):
|
):
|
||||||
assert len(past_key_values) > 0
|
assert len(past_key_values) > 0
|
||||||
print(f"forward pkv {past_key_values} {past_key_values[0] is None}")
|
#print(f"forward pkv {past_key_values} {past_key_values[0] is None}")
|
||||||
if past_key_values[0] is None:
|
if past_key_values[0] is None:
|
||||||
assert len(input_ids) == 1
|
assert len(input_ids) == 1
|
||||||
|
print("PREFILL!!!!!!!!!!!!!!!!!!!!!")
|
||||||
return self.prefill(input_ids[0])
|
return self.prefill(input_ids[0])
|
||||||
else:
|
else:
|
||||||
|
print("DECODE!!!!!!!!!!!!!!!!!!!!!")
|
||||||
return self.decode(input_ids, past_key_values)
|
return self.decode(input_ids, past_key_values)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
|
Loading…
Reference in New Issue
Block a user