made service for deepsparse

This commit is contained in:
rsnm2 2023-08-22 02:59:18 +00:00
parent 7c394a9214
commit 6a739e5142
3 changed files with 34 additions and 482 deletions

View File

@ -1,240 +0,0 @@
import numpy as np
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from dataclasses import dataclass
from typing import List, Dict, Optional, Type
from text_generation_server.models.deepsparse_model import (
DeepSparsePastKeyValues,
DeepSparseDecoderModel
)
from text_generation_server.pb import generate_pb2
DEEPSPARSE_SEQUENCE_LENGTH = 128
DEEPSPARSE_MULTITOKEN_LENGTH = 4
@dataclass
class DeepSparseCausalLMBatch:
batch_id: int
requests: List[generate_pb2.Request]
requests_idx_mapping: Dict[int,int]
input_ids_list: List[np.ndarray]
past_key_values_list: List[Optional[DeepSparsePastKeyValues]]
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
) -> "DeepSparseCausalLMBatch":
# parse batch
requests_idx_mapping = {}
input_ids_list = []
# setup tokenizer for deepsparse left padding
tokenizer.padding_side = "left"
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token
padding, truncation = "longest", False
# loop through items in the batch
for idx, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = idx
# setup inputs_ids, past_key_values
tokenized_inputs = tokenizer(
r.inputs,
return_tensors="np",
padding=padding,
truncation=truncation,
return_token_type_ids=False,
max_length=DEEPSPARSE_SEQUENCE_LENGTH
)
input_ids_list.append(tokenized_inputs["input_ids"])
return cls(
batch_id=pb.id,
requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids_list=input_ids_list,
past_key_values_list=[None] * len(pb.requests),
)
# length of the batch
def __len__(self):
return len(self.requests)
# pass list of request ids, returns batch with only those request ids
def filter(self, request_ids: List[int]) -> Optional["DeepSparseCausalLMBatch"]:
assert(len(request_ids) > 0)
requests_idx_mapping = {}
requests = []
input_ids_list = []
past_key_values_list = []
# loop through requests, keep ones that should remain
for new_idx, request_id in enumerate(request_ids):
assert request_id in self.requests_idx_mapping.keys(), "all request ids must be in the batch"
requests_idx_mapping[request_id] = new_idx
old_idx = self.requests_idx_mapping[request_id]
requests.append(self.requests[old_idx])
input_ids_list.append(self.input_ids_list[old_idx])
past_key_values_list.append(self.past_key_values[old_idx])
# update batch state
self.requests = requests
self.requests_idx_mapping = requests_idx_mapping
self.input_ids_list = input_ids_list
self.past_key_values_list = past_key_values_list
return self
# combine two batches into one
@classmethod
def concatenate(cls, batches: List["DeepSparseCausalLMBatch"]) -> "DeepSparseCausalLMBatch":
assert len(batches) > 1, "must have more than 1 batch to concatenate"
requests_idx_mapping = {}
requests = []
input_ids_list = []
past_key_values_list = []
start_index = 0
for i, batch in enumerate(batches):
assert batch.past_key_values_list is None, "only concatenate prefilled batches"
# concatenate request, input_ids, and past_key_values lists
requests.extend(batch.requests)
input_ids_list.extend(batch.input_ids_list)
past_key_values_list.extend(batch.past_key_values_list)
# merge the request_id to index mapping
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
else:
for k, v in batch.requests_idx_mapping.items():
requests_idx_mapping[k] = v + start_index
start_index += len(batch)
return cls(
batch_id= batches[0].id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids_list=input_ids_list,
past_key_values_list=past_key_values_list
)
class DeepSparseCausalLM:
def __init__(
self,
model_path: str,
tokenizer_path: str,
):
# setup tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
self.tokenizer.padding_side = "left"
if not self.tokenizer.pad_token:
assert self.tokenizer.eos_token
self.tokenizer.pad_token = self.tokenizer.eos_token
# setup model
self.model = DeepSparseDecoderModel(
onnx_file_path = model_path,
sequence_length = DEEPSPARSE_SEQUENCE_LENGTH,
multitoken_length = DEEPSPARSE_MULTITOKEN_LENGTH,
)
# TODO (@rsnm2): switch to NextTokenChooser
def sample_token(
self,
logits: np.ndarray
):
assert(logits.shape[0] == 1) # assert b=1 for now
return np.argmax(logits[0,-1,:]) # grab logits for the last item in the sequence
# TODO (@rsnm2): switch to StoppingCriteria
def should_stop(
self,
num_tokens_processed: int,
generated_token_id: int
):
if num_tokens_processed >= self.model.sequence_length:
return True
if generated_token_id == self.tokenizer.eos_token_id:
return True
return False
def generate_token(
self,
batch: DeepSparseCausalLMBatch,
) -> (Dict[int,str], Optional[DeepSparseCausalLMBatch]):
generations: Dict[int, str] = {}
all_stopped = True
# if we supported continuous batching, we would do batched inference here
# logits, past_key_values = self.model(batch)
# for each member of the batch:
# a) run inference
# b) sample and check stopping criteria
# c) create generation + update batch
for i, (
request,
input_ids,
past_key_values,
) in enumerate(zip(
batch.requests,
batch.input_ids_list,
batch.past_key_values_list
)):
# run inference
logits, past_key_values = self.model(input_ids, past_key_values)
# sample token
# simple for now --- should use NextTokenChooser
generated_token_id = self.sample_token(logits)
# check stopping criteria
# simple for now --- should use StoppingCriteria
stop = self.should_stop(
num_tokens_processed=len(input_ids) + 1,
generated_token_id = generated_token_id
)
# if not stopped, convert token id to text
generated_text = None
if not stop:
all_stopped = False
generated_text = self.tokenizer.decode(
generated_token_id,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
generations[request.id] = generated_text
# update values in the batch
assert len(batch.input_ids_list[i].shape) == 2
assert batch.input_ids_list[i].shape[0] == 1
# bad --- this does not occur in place
# print(batch.input_ids_list[i])
batch.input_ids_list[i] = np.append(
batch.input_ids_list[i],
np.array([[generated_token_id]]),
axis=1
)
batch.past_key_values_list[i] = past_key_values
# if all elements of the batch are done, return generation + null for batch
if all_stopped:
return generations, None
# return generation + updated batch
return generations, batch

View File

@ -1,242 +0,0 @@
import os
os.environ["WAND_OPT_FLAGS"] = "default,~pyramids"
import numpy as np
from typing import Optional, List, Dict
from deepsparse import Context
from deepsparse.engine import LIB
from deepsparse.pipeline import DEEPSPARSE_ENGINE, create_engine
from deepsparse.transformers.utils.helpers import overwrite_onnx_model_inputs, create_causal_mask
PAST_KEY_VALUES_NAME = "past_key_values"
class DeepSparsePastKeyValues:
def __init__(self):
prev_num_tokens = 0
num_frozen_tokens = 1
self.internal_past_key_values = LIB.kv_cache(prev_num_tokens, num_frozen_tokens)
class DeepSparseDecoderEngine:
def __init__ (
self,
onnx_file_path: str,
sequence_length: int = 1024,
input_ids_length: int = 1,
engine_context: Optional[Context] = None,
):
# update ONNX graph
onnx_file_path, cached_outputs, data_type = overwrite_onnx_model_inputs(
onnx_file_path=onnx_file_path,
batch_size=1,
sequence_length=sequence_length,
input_ids_length=input_ids_length,
)
# compile engine
self.engine = create_engine(
onnx_file_path=onnx_file_path,
engine_type=DEEPSPARSE_ENGINE,
engine_args={"cached_outputs": cached_outputs},
context=engine_context,
)
print(self.engine)
# save utilties
self.past_key_value_dtype = data_type
self.onnx_inputs = self.engine.input_names
self.empty_past_key_values = self.make_empty_past_key_values()
# forward function
def __call__(
self,
engine_inputs: Dict[str, np.ndarray],
past_key_values: DeepSparsePastKeyValues,
val_inputs: bool = True
):
# format input into lists (we pass empty past key values)
inputs = [
self.empty_past_key_values[name] if name.startswith(PAST_KEY_VALUES_NAME)
else engine_inputs[name] for name in self.engine.input_names
]
# validate inputs formatted correctly
if val_inputs:
self.engine._validate_inputs(inputs)
# run inference, updates past_key_values internally
output = self.engine._eng_net.execute_list_out(
inputs,
past_key_values.internal_past_key_values
)
logits = output[0]
return logits, past_key_values
# empty past kvs (dummy values to be passed around)
def make_empty_past_key_values(self):
past_key_values = {}
for idx, name in enumerate(self.onnx_inputs):
if name.startswith(PAST_KEY_VALUES_NAME):
past_key_values[name] = np.zeros(
self.engine.input_shapes[idx],
dtype=self.past_key_value_dtype
)
return past_key_values
class DeepSparseDecoderModel:
def __init__(
self,
onnx_file_path: str,
sequence_length: int = 1024,
multitoken_length: int = 16,
engine_context: Optional[Context] = None,
):
self.sequence_length = sequence_length
self.multitoken_length = multitoken_length
# compile decode engine
self.singletoken_engine = DeepSparseDecoderEngine(
onnx_file_path=onnx_file_path,
engine_context=engine_context,
sequence_length=sequence_length,
input_ids_length=1,
)
# compile prefill engine
self.multitoken_engine = DeepSparseDecoderEngine(
onnx_file_path=onnx_file_path,
engine_context=engine_context,
sequence_length=sequence_length,
input_ids_length=self.multitoken_length,
)
assert "input_ids" in self.singletoken_engine.onnx_inputs
assert "attention_mask" in self.singletoken_engine.onnx_inputs
assert "causal_mask" in self.singletoken_engine.onnx_inputs
assert "positions" in self.singletoken_engine.onnx_inputs
def engine_inputs_for_prefill(
self,
input_ids: np.ndarray,
):
# split batch into N token_batches
num_batches = input_ids.shape[1] // self.multitoken_length
token_batches = [
input_ids[:, i*self.multitoken_length : (i+1)*self.multitoken_length]
for i in range(0, num_batches)
]
# format inputs for each of the N token_batches
for idx, token_batch in enumerate(token_batches):
num_processed_tokens = self.multitoken_length * idx
engine_inputs = {}
engine_inputs["input_ids"] = token_batch
# make attention mask from the right
engine_inputs["attention_mask"] = np.zeros((1, self.sequence_length), dtype=np.int64)
engine_inputs["attention_mask"][:, -(self.multitoken_length + num_processed_tokens):] = 1
# make positions (building from the right)
# TODO: handle case when multitoken engine is 1
assert self.multitoken_length > 1
engine_inputs["positions"] = np.arange(
num_processed_tokens, num_processed_tokens + self.multitoken_length
).reshape(1, -1).astype(np.int64)
# make causal mask (building from the right)
engine_inputs["causal_mask"] = create_causal_mask(
input_ids=engine_inputs["input_ids"],
attention_mask=engine_inputs["attention_mask"]
)
yield engine_inputs
def engine_inputs_for_decode(
self,
input_ids: np.ndarray,
):
engine_inputs = {}
engine_inputs["input_ids"] = input_ids[:,-1:]
engine_inputs["attention_mask"] = np.zeros((1, self.sequence_length), dtype=np.int64)
engine_inputs["attention_mask"][:, -input_ids.shape[1]:] = 1
engine_inputs["causal_mask"] = create_causal_mask(
engine_inputs["input_ids"],
engine_inputs["attention_mask"]
)
engine_inputs["positions"] = np.array([[input_ids.shape[1] - 1]], dtype=np.int64)
return engine_inputs
def decode(
self,
input_ids: np.ndarray,
past_key_values: DeepSparsePastKeyValues
) -> (np.ndarray, DeepSparsePastKeyValues):
# assert input is of shape [1,seq_len] w/ seq_len < self.sequence_len
assert len(input_ids.shape) == 2
assert input_ids.shape[0] == 1
assert input_ids.shape[1] < self.sequence_length
engine_inputs = self.engine_inputs_for_decode(input_ids)
logits, past_key_values = self.singletoken_engine(
engine_inputs,
past_key_values
)
return logits, past_key_values
def prefill(
self,
input_ids: np.ndarray,
) -> (np.ndarray, DeepSparsePastKeyValues):
# assert input is of shape [1,seq_len] w/ seq_len < self.sequence_len
assert len(input_ids.shape) == 2
assert input_ids.shape[0] == 1
assert input_ids.shape[1] < self.sequence_length
tokens_processed = 0
# setup empty past key values
past_key_values = DeepSparsePastKeyValues()
# loop through chunks, run inference w/ multitoken engine
for engine_inputs in self.engine_inputs_for_prefill(input_ids):
logits, past_key_values = self.multitoken_engine(
engine_inputs,
past_key_values
)
tokens_processed += self.multitoken_length
# if anything left over, run inference w/ singletoken engine
while tokens_processed < input_ids.shape[1]:
logits, past_key_values = self.decode(
input_ids=input_ids[:,:tokens_processed+1],
past_key_values=past_key_values
)
tokens_processed += 1
# print(logits[:,-1:,:])
return logits, past_key_values
def forward(
self,
input_ids: np.ndarray,
past_key_values: Optional[DeepSparsePastKeyValues] = None,
):
if past_key_values is None:
return self.prefill(input_ids)
else:
return self.decode(input_ids, past_key_values)
def __call__(
self,
input_ids: np.ndarray,
past_key_values: Optional[DeepSparsePastKeyValues] = None,
):
return self.forward(input_ids, past_key_values)

View File

@ -0,0 +1,34 @@
from dataclasses import dataclass
from typing import List, Optional
@dataclass
class Request:
id: int
prompt: str
@dataclass
class Batch:
id: int
requests: List[Request]
@dataclass
class CachedBatch:
batch_id: int
request_ids: List[int]
@dataclass
class Generation:
request_id: int
generated_text: Optional[str]
@dataclass
class PrefillRequest:
batch: Batch
@dataclass
class DecodeRequest:
batches: List[CachedBatch]
@dataclass
class FilterBatchRequest:
batch_id: int