baseline implementation complete, working on sample server for example

This commit is contained in:
rsnm2 2023-08-21 19:27:07 +00:00
parent f8565cd915
commit fec0b1dce5
2 changed files with 255 additions and 112 deletions

View File

@ -1,61 +1,37 @@
import numpy, torch import numpy as np
from dataclasses import dataclass
from typing import Optional, Tuple, List, Type, Dict
from transformers import AutoTokenizer, PreTrainedTokenizerBase from transformers import AutoTokenizer, PreTrainedTokenizerBase
from text_generation_server.models import Model from dataclasses import dataclass
from text_generation_server.models.types import ( from typing import List, Dict, Optional, Type
Batch,
PrefillTokens,
Generation,
GeneratedText,
)
from text_generation_server.models.deepsparse_model import (
DeepSparsePastKeyValues,
DeepSparseDecoderModel
)
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
DEEPSPARSE_SEQUENCE_LENGTH = 128 DEEPSPARSE_SEQUENCE_LENGTH = 128
DEEPSPARSE_MULTITOKEN_LENGTH = 4
@dataclass @dataclass
class DeepSparseCausalLMBatch(Batch): class DeepSparseCausalLMBatch:
batch_id: int batch_id: int
requests: List[generate_pb2.Request] requests: List[generate_pb2.Request]
requests_idx_mapping: Dict[int, int] requests_idx_mapping: Dict[int,int]
input_ids_list: List[np.ndarray]
# TODO: update to handle calculating max_tokens --- needed for CachedBatch past_key_values_list: List[Optional[DeepSparsePastKeyValues]]
# Decoder values
input_ids_list: List[numpy.ndarray]
past_key_values_list: Optional[List[DeepSparsePastKeyValues]]
# Generation helpers
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
def to_pb(self) -> generate_pb2.CachedBatch:
return generate_pb2.CachedBatch(
id=self.batch_id,
request_ids=[r.id for r in self.requests],
size=len(self),
max_tokens=self.max_tokens,
)
@classmethod @classmethod
def from_pb( def from_pb(
cls, cls,
pb: generate_pb2.Batch, pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "DeepSparseCausalLMBatch": ) -> "DeepSparseCausalLMBatch":
# parse batch # parse batch
input_ids_list = []
next_token_choosers = []
stopping_criterias = []
requests_idx_mapping = {} requests_idx_mapping = {}
input_ids_list = []
# setup tokenizer for deepsparse left padding # setup tokenizer for deepsparse left padding
tokenizer.padding_side = "left" tokenizer.padding_side = "left"
if not tokenizer.pad_token: if not tokenizer.pad_token:
@ -63,11 +39,10 @@ class DeepSparseCausalLMBatch(Batch):
padding, truncation = "longest", False padding, truncation = "longest", False
# loop through items in the batch # loop through items in the batch
for i, r in enumerate(pb.requests): for idx, r in enumerate(pb.requests):
# get mapping requests_idx_mapping[r.id] = idx
requests_idx_mapping[r.id] = i
# setup inputs # setup inputs_ids, past_key_values
tokenized_inputs = tokenizer( tokenized_inputs = tokenizer(
r.inputs, r.inputs,
return_tensors="np", return_tensors="np",
@ -77,51 +52,189 @@ class DeepSparseCausalLMBatch(Batch):
max_length=DEEPSPARSE_SEQUENCE_LENGTH max_length=DEEPSPARSE_SEQUENCE_LENGTH
) )
input_ids_list.append(tokenized_inputs["input_ids"]) input_ids_list.append(tokenized_inputs["input_ids"])
# setup sequence generation helpers, capping at DEEPSPARSE_SEQUENCE_LENGTH
# cap stopping parameters to DeepSparse sequence length
input_len = tokenized_inputs["input_ids"].shape[1]
assert DEEPSPARSE_SEQUENCE_LENGTH - input_len > 0
r.stopping_parameters.max_new_tokens = min(
r.stopping_parameters.max_new_tokens,
DEEPSPARSE_SEQUENCE_LENGTH - input_len
)
stopping_criterias.append(StoppingCriteria.from_pb(r.stopping_parameters, tokenizer))
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
requests=pb.requests, requests=pb.requests,
requests_idx_mapping=requests_idx_mapping, requests_idx_mapping=requests_idx_mapping,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
input_ids_list=input_ids_list, input_ids_list=input_ids_list,
past_key_values_list=None past_key_values_list=[None] * len(pb.requests),
) )
# length of the batch
def __len__(self): def __len__(self):
return len(self.requests) return len(self.requests)
# pass list of request ids, returns batch with only those request ids
def filter(self, request_ids: List[int]) -> Optional["DeepSparseCausalLMBatch"]: def filter(self, request_ids: List[int]) -> Optional["DeepSparseCausalLMBatch"]:
pass assert(len(request_ids) > 0)
requests_idx_mapping = {}
requests = []
input_ids_list = []
past_key_values_list = []
# loop through requests, keep ones that should remain
for new_idx, request_id in enumerate(request_ids):
assert request_id in self.requests_idx_mapping.keys(), "all request ids must be in the batch"
requests_idx_mapping[request_id] = new_idx
old_idx = self.requests_idx_mapping[request_id]
requests.append(self.requests[old_idx])
input_ids_list.append(self.input_ids_list[old_idx])
past_key_values_list.append(self.past_key_values[old_idx])
# update batch state
self.requests = requests
self.requests_idx_mapping = requests_idx_mapping
self.input_ids_list = input_ids_list
self.past_key_values_list = past_key_values_list
return self
# combine two batches into one
@classmethod
def concatenate(cls, batches: List["DeepSparseCausalLMBatch"]) -> "DeepSparseCausalLMBatch": def concatenate(cls, batches: List["DeepSparseCausalLMBatch"]) -> "DeepSparseCausalLMBatch":
pass assert len(batches) > 1, "must have more than 1 batch to concatenate"
requests_idx_mapping = {}
requests = []
input_ids_list = []
past_key_values_list = []
start_index = 0
for i, batch in enumerate(batches):
assert batch.past_key_values_list is None, "only concatenate prefilled batches"
# concatenate request, input_ids, and past_key_values lists
requests.extend(batch.requests)
input_ids_list.extend(batch.input_ids_list)
past_key_values_list.extend(batch.past_key_values_list)
# merge the request_id to index mapping
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
else:
for k, v in batch.requests_idx_mapping.items():
requests_idx_mapping[k] = v + start_index
start_index += len(batch)
return cls(
batch_id= batches[0].id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids_list=input_ids_list,
past_key_values_list=past_key_values_list
)
class DeepSparseCausalLM: class DeepSparseCausalLM:
def __init__( def __init__(
self, self,
deployment_path: str model_path: str,
): tokenizer_path: str,
self.tokenizer = AutoTokenizer.from_pretrained(deployment_path) ):
# setup tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
self.tokenizer.padding_side = "left" self.tokenizer.padding_side = "left"
if not self.tokenizer.pad_token: if not self.tokenizer.pad_token:
assert self.tokenizer.eos_token assert self.tokenizer.eos_token
self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.pad_token = self.tokenizer.eos_token
# setup model
self.model = DeepSparseDecoderModel(
onnx_file_path = model_path,
sequence_length = DEEPSPARSE_SEQUENCE_LENGTH,
multitoken_length = DEEPSPARSE_MULTITOKEN_LENGTH,
)
@property # TODO (@rsnm2): switch to NextTokenChooser
def batch_type(self) -> Type[DeepSparseCausalLMBatch]: def sample_token(
return DeepSparseCausalLMBatch self,
logits: np.ndarray
):
assert(logits.shape[0] == 1) # assert b=1 for now
return np.argmax(logits[0,-1,:]) # grab logits for the last item in the sequence
# TODO (@rsnm2): switch to StoppingCriteria
def should_stop(
self,
num_tokens_processed: int,
generated_token_id: int
):
if num_tokens_processed >= self.model.sequence_length:
return True
if generated_token_id == self.tokenizer.eos_token_id:
return True
return False
def generate_token(
self,
batch: DeepSparseCausalLMBatch,
) -> (Dict[int,str], Optional[DeepSparseCausalLMBatch]):
generations: Dict[int, str] = {}
all_stopped = True
# if we supported continuous batching, we would do batched inference here
# logits, past_key_values = self.model(batch)
# for each member of the batch:
# a) run inference
# b) sample and check stopping criteria
# c) create generation + update batch
for i, (
request,
input_ids,
past_key_values,
) in enumerate(zip(
batch.requests,
batch.input_ids_list,
batch.past_key_values_list
)):
# run inference
logits, past_key_values = self.model(input_ids, past_key_values)
# sample token
# simple for now --- should use NextTokenChooser
generated_token_id = self.sample_token(logits)
# check stopping criteria
# simple for now --- should use StoppingCriteria
stop = self.should_stop(
num_tokens_processed=len(input_ids) + 1,
generated_token_id = generated_token_id
)
# if not stopped, convert token id to text
generated_text = None
if not stop:
all_stopped = False
generated_text = self.tokenizer.decode(
generated_token_id,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
generations[request.id] = generated_text
# update values in the batch
assert len(batch.input_ids_list[i].shape) == 2
assert batch.input_ids_list[i].shape[0] == 1
# bad --- this does not occur in place
# print(batch.input_ids_list[i])
batch.input_ids_list[i] = np.append(
batch.input_ids_list[i],
np.array([[generated_token_id]]),
axis=1
)
batch.past_key_values_list[i] = past_key_values
# if all elements of the batch are done, return generation + null for batch
if all_stopped:
return generations, None
# return generation + updated batch
return generations, batch

View File

@ -1,3 +1,7 @@
import os
os.environ["WAND_OPT_FLAGS"] = "default,~pyramids"
import numpy as np import numpy as np
from typing import Optional, List, Dict from typing import Optional, List, Dict
@ -53,8 +57,10 @@ class DeepSparseDecoderEngine:
val_inputs: bool = True val_inputs: bool = True
): ):
# format input into lists (we pass empty past key values) # format input into lists (we pass empty past key values)
inputs = [self.empty_past_key_values[name] if name.startswith(PAST_KEY_VALUES_NAME) inputs = [
else engine_inputs[name] for name in self.engine.input_names] self.empty_past_key_values[name] if name.startswith(PAST_KEY_VALUES_NAME)
else engine_inputs[name] for name in self.engine.input_names
]
# validate inputs formatted correctly # validate inputs formatted correctly
if val_inputs: if val_inputs:
@ -87,29 +93,25 @@ class DeepSparseDecoderModel:
sequence_length: int = 1024, sequence_length: int = 1024,
multitoken_length: int = 16, multitoken_length: int = 16,
engine_context: Optional[Context] = None, engine_context: Optional[Context] = None,
singletoken_engine = None,
multitoken_engine = None,
): ):
self.sequence_length = sequence_length self.sequence_length = sequence_length
self.multitoken_length = multitoken_length self.multitoken_length = multitoken_length
if singletoken_engine is not None and multitoken_engine is not None: # compile decode engine
self.singletoken_engine = singletoken_engine self.singletoken_engine = DeepSparseDecoderEngine(
self.multitoken_engine = multitoken_engine onnx_file_path=onnx_file_path,
else: engine_context=engine_context,
self.singletoken_engine = DeepSparseDecoderEngine( sequence_length=sequence_length,
onnx_file_path=onnx_file_path, input_ids_length=1,
engine_context=engine_context, )
sequence_length=sequence_length,
input_ids_length=1, # compile prefill engine
) self.multitoken_engine = DeepSparseDecoderEngine(
onnx_file_path=onnx_file_path,
self.multitoken_engine = DeepSparseDecoderEngine( engine_context=engine_context,
onnx_file_path=onnx_file_path, sequence_length=sequence_length,
engine_context=engine_context, input_ids_length=self.multitoken_length,
sequence_length=sequence_length, )
input_ids_length=self.multitoken_length,
)
assert "input_ids" in self.singletoken_engine.onnx_inputs assert "input_ids" in self.singletoken_engine.onnx_inputs
assert "attention_mask" in self.singletoken_engine.onnx_inputs assert "attention_mask" in self.singletoken_engine.onnx_inputs
@ -118,14 +120,12 @@ class DeepSparseDecoderModel:
def engine_inputs_for_prefill( def engine_inputs_for_prefill(
self, self,
tokens: List[int], input_ids: np.ndarray,
): ):
assert len(tokens) < self.sequence_length
# split batch into N token_batches # split batch into N token_batches
num_batches = len(tokens) // self.multitoken_length num_batches = input_ids.shape[1] // self.multitoken_length
token_batches = [ token_batches = [
tokens[i * self.multitoken_length : (i+1) * self.multitoken_length] input_ids[:, i*self.multitoken_length : (i+1)*self.multitoken_length]
for i in range(0, num_batches) for i in range(0, num_batches)
] ]
@ -133,9 +133,9 @@ class DeepSparseDecoderModel:
for idx, token_batch in enumerate(token_batches): for idx, token_batch in enumerate(token_batches):
num_processed_tokens = self.multitoken_length * idx num_processed_tokens = self.multitoken_length * idx
engine_inputs = {} engine_inputs = {}
engine_inputs["input_ids"] = np.array([token_batch]) engine_inputs["input_ids"] = token_batch
# make attention mask from the right # make attention mask from the right
engine_inputs["attention_mask"] = np.zeros((1, self.sequence_length), dtype=np.int64) engine_inputs["attention_mask"] = np.zeros((1, self.sequence_length), dtype=np.int64)
engine_inputs["attention_mask"][:, -(self.multitoken_length + num_processed_tokens):] = 1 engine_inputs["attention_mask"][:, -(self.multitoken_length + num_processed_tokens):] = 1
@ -156,30 +156,33 @@ class DeepSparseDecoderModel:
def engine_inputs_for_decode( def engine_inputs_for_decode(
self, self,
tokens: List[int], input_ids: np.ndarray,
): ):
assert len(tokens) < self.sequence_length
engine_inputs = {} engine_inputs = {}
engine_inputs["input_ids"] = np.array([[tokens[-1]]]) engine_inputs["input_ids"] = input_ids[:,-1:]
engine_inputs["attention_mask"] = np.zeros((1, self.sequence_length), dtype=np.int64) engine_inputs["attention_mask"] = np.zeros((1, self.sequence_length), dtype=np.int64)
engine_inputs["attention_mask"][:, -len(tokens):] = 1 engine_inputs["attention_mask"][:, -input_ids.shape[1]:] = 1
engine_inputs["causal_mask"] = create_causal_mask( engine_inputs["causal_mask"] = create_causal_mask(
engine_inputs["input_ids"], engine_inputs["input_ids"],
engine_inputs["attention_mask"] engine_inputs["attention_mask"]
) )
engine_inputs["positions"] = np.array([[len(tokens) - 1]], dtype=np.int64) engine_inputs["positions"] = np.array([[input_ids.shape[1] - 1]], dtype=np.int64)
return engine_inputs return engine_inputs
def decode( def decode(
self, self,
tokens: List[int], input_ids: np.ndarray,
past_key_values: DeepSparsePastKeyValues past_key_values: DeepSparsePastKeyValues
) -> (np.ndarray, DeepSparsePastKeyValues): ) -> (np.ndarray, DeepSparsePastKeyValues):
engine_inputs = self.engine_inputs_for_decode(tokens) # assert input is of shape [1,seq_len] w/ seq_len < self.sequence_len
assert len(input_ids.shape) == 2
assert input_ids.shape[0] == 1
assert input_ids.shape[1] < self.sequence_length
engine_inputs = self.engine_inputs_for_decode(input_ids)
logits, past_key_values = self.singletoken_engine( logits, past_key_values = self.singletoken_engine(
engine_inputs, engine_inputs,
past_key_values past_key_values
@ -189,24 +192,51 @@ class DeepSparseDecoderModel:
def prefill( def prefill(
self, self,
tokens: List[int], input_ids: np.ndarray,
past_key_values: DeepSparsePastKeyValues
) -> (np.ndarray, DeepSparsePastKeyValues): ) -> (np.ndarray, DeepSparsePastKeyValues):
# assert input is of shape [1,seq_len] w/ seq_len < self.sequence_len
assert len(input_ids.shape) == 2
assert input_ids.shape[0] == 1
assert input_ids.shape[1] < self.sequence_length
tokens_processed = 0 tokens_processed = 0
# setup empty past key values
past_key_values = DeepSparsePastKeyValues()
for engine_inputs in self.engine_inputs_for_prefill(tokens): # loop through chunks, run inference w/ multitoken engine
_, past_key_values = self.multitoken_engine( for engine_inputs in self.engine_inputs_for_prefill(input_ids):
logits, past_key_values = self.multitoken_engine(
engine_inputs, engine_inputs,
past_key_values past_key_values
) )
tokens_processed += self.multitoken_length tokens_processed += self.multitoken_length
while tokens_processed < len(tokens): # if anything left over, run inference w/ singletoken engine
while tokens_processed < input_ids.shape[1]:
logits, past_key_values = self.decode( logits, past_key_values = self.decode(
tokens= tokens[:tokens_processed + 1], input_ids=input_ids[:,:tokens_processed+1],
past_key_values=past_key_values past_key_values=past_key_values
) )
tokens_processed += 1 tokens_processed += 1
# print(logits[:,-1:,:])
return logits, past_key_values return logits, past_key_values
def forward(
self,
input_ids: np.ndarray,
past_key_values: Optional[DeepSparsePastKeyValues] = None,
):
if past_key_values is None:
return self.prefill(input_ids)
else:
return self.decode(input_ids, past_key_values)
def __call__(
self,
input_ids: np.ndarray,
past_key_values: Optional[DeepSparsePastKeyValues] = None,
):
return self.forward(input_ids, past_key_values)