baseline implementation complete, working on sample server for example

This commit is contained in:
rsnm2 2023-08-21 19:27:07 +00:00
parent f8565cd915
commit fec0b1dce5
2 changed files with 255 additions and 112 deletions

View File

@ -1,61 +1,37 @@
import numpy, torch
from dataclasses import dataclass
from typing import Optional, Tuple, List, Type, Dict
import numpy as np
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from text_generation_server.models import Model
from text_generation_server.models.types import (
Batch,
PrefillTokens,
Generation,
GeneratedText,
)
from dataclasses import dataclass
from typing import List, Dict, Optional, Type
from text_generation_server.models.deepsparse_model import (
DeepSparsePastKeyValues,
DeepSparseDecoderModel
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
DEEPSPARSE_SEQUENCE_LENGTH = 128
DEEPSPARSE_MULTITOKEN_LENGTH = 4
@dataclass
class DeepSparseCausalLMBatch(Batch):
class DeepSparseCausalLMBatch:
batch_id: int
requests: List[generate_pb2.Request]
requests_idx_mapping: Dict[int, int]
# TODO: update to handle calculating max_tokens --- needed for CachedBatch
# Decoder values
input_ids_list: List[numpy.ndarray]
past_key_values_list: Optional[List[DeepSparsePastKeyValues]]
# Generation helpers
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
def to_pb(self) -> generate_pb2.CachedBatch:
return generate_pb2.CachedBatch(
id=self.batch_id,
request_ids=[r.id for r in self.requests],
size=len(self),
max_tokens=self.max_tokens,
)
requests_idx_mapping: Dict[int,int]
input_ids_list: List[np.ndarray]
past_key_values_list: List[Optional[DeepSparsePastKeyValues]]
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "DeepSparseCausalLMBatch":
# parse batch
input_ids_list = []
next_token_choosers = []
stopping_criterias = []
requests_idx_mapping = {}
input_ids_list = []
# setup tokenizer for deepsparse left padding
tokenizer.padding_side = "left"
if not tokenizer.pad_token:
@ -63,11 +39,10 @@ class DeepSparseCausalLMBatch(Batch):
padding, truncation = "longest", False
# loop through items in the batch
for i, r in enumerate(pb.requests):
# get mapping
requests_idx_mapping[r.id] = i
for idx, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = idx
# setup inputs
# setup inputs_ids, past_key_values
tokenized_inputs = tokenizer(
r.inputs,
return_tensors="np",
@ -77,51 +52,189 @@ class DeepSparseCausalLMBatch(Batch):
max_length=DEEPSPARSE_SEQUENCE_LENGTH
)
input_ids_list.append(tokenized_inputs["input_ids"])
# setup sequence generation helpers, capping at DEEPSPARSE_SEQUENCE_LENGTH
# cap stopping parameters to DeepSparse sequence length
input_len = tokenized_inputs["input_ids"].shape[1]
assert DEEPSPARSE_SEQUENCE_LENGTH - input_len > 0
r.stopping_parameters.max_new_tokens = min(
r.stopping_parameters.max_new_tokens,
DEEPSPARSE_SEQUENCE_LENGTH - input_len
)
stopping_criterias.append(StoppingCriteria.from_pb(r.stopping_parameters, tokenizer))
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
return cls(
batch_id=pb.id,
requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
input_ids_list=input_ids_list,
past_key_values_list=None
past_key_values_list=[None] * len(pb.requests),
)
# length of the batch
def __len__(self):
return len(self.requests)
# pass list of request ids, returns batch with only those request ids
def filter(self, request_ids: List[int]) -> Optional["DeepSparseCausalLMBatch"]:
pass
assert(len(request_ids) > 0)
requests_idx_mapping = {}
requests = []
input_ids_list = []
past_key_values_list = []
# loop through requests, keep ones that should remain
for new_idx, request_id in enumerate(request_ids):
assert request_id in self.requests_idx_mapping.keys(), "all request ids must be in the batch"
requests_idx_mapping[request_id] = new_idx
old_idx = self.requests_idx_mapping[request_id]
requests.append(self.requests[old_idx])
input_ids_list.append(self.input_ids_list[old_idx])
past_key_values_list.append(self.past_key_values[old_idx])
# update batch state
self.requests = requests
self.requests_idx_mapping = requests_idx_mapping
self.input_ids_list = input_ids_list
self.past_key_values_list = past_key_values_list
return self
# combine two batches into one
@classmethod
def concatenate(cls, batches: List["DeepSparseCausalLMBatch"]) -> "DeepSparseCausalLMBatch":
pass
assert len(batches) > 1, "must have more than 1 batch to concatenate"
requests_idx_mapping = {}
requests = []
input_ids_list = []
past_key_values_list = []
start_index = 0
for i, batch in enumerate(batches):
assert batch.past_key_values_list is None, "only concatenate prefilled batches"
# concatenate request, input_ids, and past_key_values lists
requests.extend(batch.requests)
input_ids_list.extend(batch.input_ids_list)
past_key_values_list.extend(batch.past_key_values_list)
# merge the request_id to index mapping
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
else:
for k, v in batch.requests_idx_mapping.items():
requests_idx_mapping[k] = v + start_index
start_index += len(batch)
return cls(
batch_id= batches[0].id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids_list=input_ids_list,
past_key_values_list=past_key_values_list
)
class DeepSparseCausalLM:
def __init__(
self,
deployment_path: str
):
self.tokenizer = AutoTokenizer.from_pretrained(deployment_path)
model_path: str,
tokenizer_path: str,
):
# setup tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
self.tokenizer.padding_side = "left"
if not self.tokenizer.pad_token:
assert self.tokenizer.eos_token
self.tokenizer.pad_token = self.tokenizer.eos_token
# setup model
self.model = DeepSparseDecoderModel(
onnx_file_path = model_path,
sequence_length = DEEPSPARSE_SEQUENCE_LENGTH,
multitoken_length = DEEPSPARSE_MULTITOKEN_LENGTH,
)
@property
def batch_type(self) -> Type[DeepSparseCausalLMBatch]:
return DeepSparseCausalLMBatch
# TODO (@rsnm2): switch to NextTokenChooser
def sample_token(
self,
logits: np.ndarray
):
assert(logits.shape[0] == 1) # assert b=1 for now
return np.argmax(logits[0,-1,:]) # grab logits for the last item in the sequence
# TODO (@rsnm2): switch to StoppingCriteria
def should_stop(
self,
num_tokens_processed: int,
generated_token_id: int
):
if num_tokens_processed >= self.model.sequence_length:
return True
if generated_token_id == self.tokenizer.eos_token_id:
return True
return False
def generate_token(
self,
batch: DeepSparseCausalLMBatch,
) -> (Dict[int,str], Optional[DeepSparseCausalLMBatch]):
generations: Dict[int, str] = {}
all_stopped = True
# if we supported continuous batching, we would do batched inference here
# logits, past_key_values = self.model(batch)
# for each member of the batch:
# a) run inference
# b) sample and check stopping criteria
# c) create generation + update batch
for i, (
request,
input_ids,
past_key_values,
) in enumerate(zip(
batch.requests,
batch.input_ids_list,
batch.past_key_values_list
)):
# run inference
logits, past_key_values = self.model(input_ids, past_key_values)
# sample token
# simple for now --- should use NextTokenChooser
generated_token_id = self.sample_token(logits)
# check stopping criteria
# simple for now --- should use StoppingCriteria
stop = self.should_stop(
num_tokens_processed=len(input_ids) + 1,
generated_token_id = generated_token_id
)
# if not stopped, convert token id to text
generated_text = None
if not stop:
all_stopped = False
generated_text = self.tokenizer.decode(
generated_token_id,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
generations[request.id] = generated_text
# update values in the batch
assert len(batch.input_ids_list[i].shape) == 2
assert batch.input_ids_list[i].shape[0] == 1
# bad --- this does not occur in place
# print(batch.input_ids_list[i])
batch.input_ids_list[i] = np.append(
batch.input_ids_list[i],
np.array([[generated_token_id]]),
axis=1
)
batch.past_key_values_list[i] = past_key_values
# if all elements of the batch are done, return generation + null for batch
if all_stopped:
return generations, None
# return generation + updated batch
return generations, batch

View File

@ -1,3 +1,7 @@
import os
os.environ["WAND_OPT_FLAGS"] = "default,~pyramids"
import numpy as np
from typing import Optional, List, Dict
@ -53,8 +57,10 @@ class DeepSparseDecoderEngine:
val_inputs: bool = True
):
# format input into lists (we pass empty past key values)
inputs = [self.empty_past_key_values[name] if name.startswith(PAST_KEY_VALUES_NAME)
else engine_inputs[name] for name in self.engine.input_names]
inputs = [
self.empty_past_key_values[name] if name.startswith(PAST_KEY_VALUES_NAME)
else engine_inputs[name] for name in self.engine.input_names
]
# validate inputs formatted correctly
if val_inputs:
@ -87,29 +93,25 @@ class DeepSparseDecoderModel:
sequence_length: int = 1024,
multitoken_length: int = 16,
engine_context: Optional[Context] = None,
singletoken_engine = None,
multitoken_engine = None,
):
self.sequence_length = sequence_length
self.multitoken_length = multitoken_length
if singletoken_engine is not None and multitoken_engine is not None:
self.singletoken_engine = singletoken_engine
self.multitoken_engine = multitoken_engine
else:
self.singletoken_engine = DeepSparseDecoderEngine(
onnx_file_path=onnx_file_path,
engine_context=engine_context,
sequence_length=sequence_length,
input_ids_length=1,
)
self.multitoken_engine = DeepSparseDecoderEngine(
onnx_file_path=onnx_file_path,
engine_context=engine_context,
sequence_length=sequence_length,
input_ids_length=self.multitoken_length,
)
# compile decode engine
self.singletoken_engine = DeepSparseDecoderEngine(
onnx_file_path=onnx_file_path,
engine_context=engine_context,
sequence_length=sequence_length,
input_ids_length=1,
)
# compile prefill engine
self.multitoken_engine = DeepSparseDecoderEngine(
onnx_file_path=onnx_file_path,
engine_context=engine_context,
sequence_length=sequence_length,
input_ids_length=self.multitoken_length,
)
assert "input_ids" in self.singletoken_engine.onnx_inputs
assert "attention_mask" in self.singletoken_engine.onnx_inputs
@ -118,14 +120,12 @@ class DeepSparseDecoderModel:
def engine_inputs_for_prefill(
self,
tokens: List[int],
input_ids: np.ndarray,
):
assert len(tokens) < self.sequence_length
# split batch into N token_batches
num_batches = len(tokens) // self.multitoken_length
num_batches = input_ids.shape[1] // self.multitoken_length
token_batches = [
tokens[i * self.multitoken_length : (i+1) * self.multitoken_length]
input_ids[:, i*self.multitoken_length : (i+1)*self.multitoken_length]
for i in range(0, num_batches)
]
@ -133,9 +133,9 @@ class DeepSparseDecoderModel:
for idx, token_batch in enumerate(token_batches):
num_processed_tokens = self.multitoken_length * idx
engine_inputs = {}
engine_inputs["input_ids"] = np.array([token_batch])
engine_inputs = {}
engine_inputs["input_ids"] = token_batch
# make attention mask from the right
engine_inputs["attention_mask"] = np.zeros((1, self.sequence_length), dtype=np.int64)
engine_inputs["attention_mask"][:, -(self.multitoken_length + num_processed_tokens):] = 1
@ -156,30 +156,33 @@ class DeepSparseDecoderModel:
def engine_inputs_for_decode(
self,
tokens: List[int],
input_ids: np.ndarray,
):
assert len(tokens) < self.sequence_length
engine_inputs = {}
engine_inputs["input_ids"] = np.array([[tokens[-1]]])
engine_inputs["input_ids"] = input_ids[:,-1:]
engine_inputs["attention_mask"] = np.zeros((1, self.sequence_length), dtype=np.int64)
engine_inputs["attention_mask"][:, -len(tokens):] = 1
engine_inputs["attention_mask"][:, -input_ids.shape[1]:] = 1
engine_inputs["causal_mask"] = create_causal_mask(
engine_inputs["input_ids"],
engine_inputs["attention_mask"]
)
engine_inputs["positions"] = np.array([[len(tokens) - 1]], dtype=np.int64)
engine_inputs["positions"] = np.array([[input_ids.shape[1] - 1]], dtype=np.int64)
return engine_inputs
def decode(
self,
tokens: List[int],
input_ids: np.ndarray,
past_key_values: DeepSparsePastKeyValues
) -> (np.ndarray, DeepSparsePastKeyValues):
engine_inputs = self.engine_inputs_for_decode(tokens)
# assert input is of shape [1,seq_len] w/ seq_len < self.sequence_len
assert len(input_ids.shape) == 2
assert input_ids.shape[0] == 1
assert input_ids.shape[1] < self.sequence_length
engine_inputs = self.engine_inputs_for_decode(input_ids)
logits, past_key_values = self.singletoken_engine(
engine_inputs,
past_key_values
@ -189,24 +192,51 @@ class DeepSparseDecoderModel:
def prefill(
self,
tokens: List[int],
past_key_values: DeepSparsePastKeyValues
input_ids: np.ndarray,
) -> (np.ndarray, DeepSparsePastKeyValues):
# assert input is of shape [1,seq_len] w/ seq_len < self.sequence_len
assert len(input_ids.shape) == 2
assert input_ids.shape[0] == 1
assert input_ids.shape[1] < self.sequence_length
tokens_processed = 0
# setup empty past key values
past_key_values = DeepSparsePastKeyValues()
for engine_inputs in self.engine_inputs_for_prefill(tokens):
_, past_key_values = self.multitoken_engine(
# loop through chunks, run inference w/ multitoken engine
for engine_inputs in self.engine_inputs_for_prefill(input_ids):
logits, past_key_values = self.multitoken_engine(
engine_inputs,
past_key_values
)
tokens_processed += self.multitoken_length
while tokens_processed < len(tokens):
# if anything left over, run inference w/ singletoken engine
while tokens_processed < input_ids.shape[1]:
logits, past_key_values = self.decode(
tokens= tokens[:tokens_processed + 1],
input_ids=input_ids[:,:tokens_processed+1],
past_key_values=past_key_values
)
tokens_processed += 1
# print(logits[:,-1:,:])
return logits, past_key_values
return logits, past_key_values
def forward(
self,
input_ids: np.ndarray,
past_key_values: Optional[DeepSparsePastKeyValues] = None,
):
if past_key_values is None:
return self.prefill(input_ids)
else:
return self.decode(input_ids, past_key_values)
def __call__(
self,
input_ids: np.ndarray,
past_key_values: Optional[DeepSparsePastKeyValues] = None,
):
return self.forward(input_ids, past_key_values)