mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
baseline implementation complete, working on sample server for example
This commit is contained in:
parent
f8565cd915
commit
fec0b1dce5
@ -1,60 +1,36 @@
|
||||
import numpy, torch
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, List, Type, Dict
|
||||
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||
|
||||
from text_generation_server.models import Model
|
||||
from text_generation_server.models.types import (
|
||||
Batch,
|
||||
PrefillTokens,
|
||||
Generation,
|
||||
GeneratedText,
|
||||
)
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Dict, Optional, Type
|
||||
|
||||
from text_generation_server.models.deepsparse_model import (
|
||||
DeepSparsePastKeyValues,
|
||||
DeepSparseDecoderModel
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
|
||||
DEEPSPARSE_SEQUENCE_LENGTH = 128
|
||||
DEEPSPARSE_MULTITOKEN_LENGTH = 4
|
||||
|
||||
@dataclass
|
||||
class DeepSparseCausalLMBatch(Batch):
|
||||
class DeepSparseCausalLMBatch:
|
||||
batch_id: int
|
||||
requests: List[generate_pb2.Request]
|
||||
requests_idx_mapping: Dict[int,int]
|
||||
|
||||
# TODO: update to handle calculating max_tokens --- needed for CachedBatch
|
||||
|
||||
# Decoder values
|
||||
input_ids_list: List[numpy.ndarray]
|
||||
past_key_values_list: Optional[List[DeepSparsePastKeyValues]]
|
||||
|
||||
# Generation helpers
|
||||
next_token_choosers: List[NextTokenChooser]
|
||||
stopping_criterias: List[StoppingCriteria]
|
||||
|
||||
def to_pb(self) -> generate_pb2.CachedBatch:
|
||||
return generate_pb2.CachedBatch(
|
||||
id=self.batch_id,
|
||||
request_ids=[r.id for r in self.requests],
|
||||
size=len(self),
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
input_ids_list: List[np.ndarray]
|
||||
past_key_values_list: List[Optional[DeepSparsePastKeyValues]]
|
||||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "DeepSparseCausalLMBatch":
|
||||
|
||||
# parse batch
|
||||
input_ids_list = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
requests_idx_mapping = {}
|
||||
input_ids_list = []
|
||||
|
||||
# setup tokenizer for deepsparse left padding
|
||||
tokenizer.padding_side = "left"
|
||||
@ -63,11 +39,10 @@ class DeepSparseCausalLMBatch(Batch):
|
||||
padding, truncation = "longest", False
|
||||
|
||||
# loop through items in the batch
|
||||
for i, r in enumerate(pb.requests):
|
||||
# get mapping
|
||||
requests_idx_mapping[r.id] = i
|
||||
for idx, r in enumerate(pb.requests):
|
||||
requests_idx_mapping[r.id] = idx
|
||||
|
||||
# setup inputs
|
||||
# setup inputs_ids, past_key_values
|
||||
tokenized_inputs = tokenizer(
|
||||
r.inputs,
|
||||
return_tensors="np",
|
||||
@ -78,50 +53,188 @@ class DeepSparseCausalLMBatch(Batch):
|
||||
)
|
||||
input_ids_list.append(tokenized_inputs["input_ids"])
|
||||
|
||||
# setup sequence generation helpers, capping at DEEPSPARSE_SEQUENCE_LENGTH
|
||||
# cap stopping parameters to DeepSparse sequence length
|
||||
input_len = tokenized_inputs["input_ids"].shape[1]
|
||||
assert DEEPSPARSE_SEQUENCE_LENGTH - input_len > 0
|
||||
r.stopping_parameters.max_new_tokens = min(
|
||||
r.stopping_parameters.max_new_tokens,
|
||||
DEEPSPARSE_SEQUENCE_LENGTH - input_len
|
||||
)
|
||||
stopping_criterias.append(StoppingCriteria.from_pb(r.stopping_parameters, tokenizer))
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||
|
||||
return cls(
|
||||
batch_id=pb.id,
|
||||
requests=pb.requests,
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
input_ids_list=input_ids_list,
|
||||
past_key_values_list=None
|
||||
past_key_values_list=[None] * len(pb.requests),
|
||||
)
|
||||
|
||||
# length of the batch
|
||||
def __len__(self):
|
||||
return len(self.requests)
|
||||
|
||||
# pass list of request ids, returns batch with only those request ids
|
||||
def filter(self, request_ids: List[int]) -> Optional["DeepSparseCausalLMBatch"]:
|
||||
pass
|
||||
assert(len(request_ids) > 0)
|
||||
|
||||
requests_idx_mapping = {}
|
||||
requests = []
|
||||
input_ids_list = []
|
||||
past_key_values_list = []
|
||||
|
||||
# loop through requests, keep ones that should remain
|
||||
for new_idx, request_id in enumerate(request_ids):
|
||||
assert request_id in self.requests_idx_mapping.keys(), "all request ids must be in the batch"
|
||||
|
||||
requests_idx_mapping[request_id] = new_idx
|
||||
|
||||
old_idx = self.requests_idx_mapping[request_id]
|
||||
requests.append(self.requests[old_idx])
|
||||
input_ids_list.append(self.input_ids_list[old_idx])
|
||||
past_key_values_list.append(self.past_key_values[old_idx])
|
||||
|
||||
# update batch state
|
||||
self.requests = requests
|
||||
self.requests_idx_mapping = requests_idx_mapping
|
||||
self.input_ids_list = input_ids_list
|
||||
self.past_key_values_list = past_key_values_list
|
||||
|
||||
return self
|
||||
|
||||
# combine two batches into one
|
||||
@classmethod
|
||||
def concatenate(cls, batches: List["DeepSparseCausalLMBatch"]) -> "DeepSparseCausalLMBatch":
|
||||
pass
|
||||
assert len(batches) > 1, "must have more than 1 batch to concatenate"
|
||||
|
||||
requests_idx_mapping = {}
|
||||
requests = []
|
||||
input_ids_list = []
|
||||
past_key_values_list = []
|
||||
|
||||
start_index = 0
|
||||
for i, batch in enumerate(batches):
|
||||
assert batch.past_key_values_list is None, "only concatenate prefilled batches"
|
||||
|
||||
# concatenate request, input_ids, and past_key_values lists
|
||||
requests.extend(batch.requests)
|
||||
input_ids_list.extend(batch.input_ids_list)
|
||||
past_key_values_list.extend(batch.past_key_values_list)
|
||||
|
||||
# merge the request_id to index mapping
|
||||
if i == 0:
|
||||
requests_idx_mapping = batch.requests_idx_mapping
|
||||
else:
|
||||
for k, v in batch.requests_idx_mapping.items():
|
||||
requests_idx_mapping[k] = v + start_index
|
||||
|
||||
start_index += len(batch)
|
||||
|
||||
return cls(
|
||||
batch_id= batches[0].id,
|
||||
requests=requests,
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
input_ids_list=input_ids_list,
|
||||
past_key_values_list=past_key_values_list
|
||||
)
|
||||
|
||||
class DeepSparseCausalLM:
|
||||
def __init__(
|
||||
self,
|
||||
deployment_path: str
|
||||
model_path: str,
|
||||
tokenizer_path: str,
|
||||
):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(deployment_path)
|
||||
# setup tokenizer
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
self.tokenizer.padding_side = "left"
|
||||
if not self.tokenizer.pad_token:
|
||||
assert self.tokenizer.eos_token
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
|
||||
# setup model
|
||||
self.model = DeepSparseDecoderModel(
|
||||
onnx_file_path = model_path,
|
||||
sequence_length = DEEPSPARSE_SEQUENCE_LENGTH,
|
||||
multitoken_length = DEEPSPARSE_MULTITOKEN_LENGTH,
|
||||
)
|
||||
|
||||
# TODO (@rsnm2): switch to NextTokenChooser
|
||||
def sample_token(
|
||||
self,
|
||||
logits: np.ndarray
|
||||
):
|
||||
assert(logits.shape[0] == 1) # assert b=1 for now
|
||||
return np.argmax(logits[0,-1,:]) # grab logits for the last item in the sequence
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[DeepSparseCausalLMBatch]:
|
||||
return DeepSparseCausalLMBatch
|
||||
# TODO (@rsnm2): switch to StoppingCriteria
|
||||
def should_stop(
|
||||
self,
|
||||
num_tokens_processed: int,
|
||||
generated_token_id: int
|
||||
):
|
||||
if num_tokens_processed >= self.model.sequence_length:
|
||||
return True
|
||||
if generated_token_id == self.tokenizer.eos_token_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
def generate_token(
|
||||
self,
|
||||
batch: DeepSparseCausalLMBatch,
|
||||
) -> (Dict[int,str], Optional[DeepSparseCausalLMBatch]):
|
||||
|
||||
generations: Dict[int, str] = {}
|
||||
all_stopped = True
|
||||
|
||||
# if we supported continuous batching, we would do batched inference here
|
||||
# logits, past_key_values = self.model(batch)
|
||||
|
||||
# for each member of the batch:
|
||||
# a) run inference
|
||||
# b) sample and check stopping criteria
|
||||
# c) create generation + update batch
|
||||
for i, (
|
||||
request,
|
||||
input_ids,
|
||||
past_key_values,
|
||||
) in enumerate(zip(
|
||||
batch.requests,
|
||||
batch.input_ids_list,
|
||||
batch.past_key_values_list
|
||||
)):
|
||||
|
||||
# run inference
|
||||
logits, past_key_values = self.model(input_ids, past_key_values)
|
||||
|
||||
# sample token
|
||||
# simple for now --- should use NextTokenChooser
|
||||
generated_token_id = self.sample_token(logits)
|
||||
|
||||
# check stopping criteria
|
||||
# simple for now --- should use StoppingCriteria
|
||||
stop = self.should_stop(
|
||||
num_tokens_processed=len(input_ids) + 1,
|
||||
generated_token_id = generated_token_id
|
||||
)
|
||||
|
||||
# if not stopped, convert token id to text
|
||||
generated_text = None
|
||||
if not stop:
|
||||
all_stopped = False
|
||||
generated_text = self.tokenizer.decode(
|
||||
generated_token_id,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False
|
||||
)
|
||||
generations[request.id] = generated_text
|
||||
|
||||
# update values in the batch
|
||||
assert len(batch.input_ids_list[i].shape) == 2
|
||||
assert batch.input_ids_list[i].shape[0] == 1
|
||||
|
||||
# bad --- this does not occur in place
|
||||
# print(batch.input_ids_list[i])
|
||||
batch.input_ids_list[i] = np.append(
|
||||
batch.input_ids_list[i],
|
||||
np.array([[generated_token_id]]),
|
||||
axis=1
|
||||
)
|
||||
batch.past_key_values_list[i] = past_key_values
|
||||
|
||||
# if all elements of the batch are done, return generation + null for batch
|
||||
if all_stopped:
|
||||
return generations, None
|
||||
|
||||
# return generation + updated batch
|
||||
return generations, batch
|
@ -1,3 +1,7 @@
|
||||
import os
|
||||
|
||||
os.environ["WAND_OPT_FLAGS"] = "default,~pyramids"
|
||||
|
||||
import numpy as np
|
||||
from typing import Optional, List, Dict
|
||||
|
||||
@ -53,8 +57,10 @@ class DeepSparseDecoderEngine:
|
||||
val_inputs: bool = True
|
||||
):
|
||||
# format input into lists (we pass empty past key values)
|
||||
inputs = [self.empty_past_key_values[name] if name.startswith(PAST_KEY_VALUES_NAME)
|
||||
else engine_inputs[name] for name in self.engine.input_names]
|
||||
inputs = [
|
||||
self.empty_past_key_values[name] if name.startswith(PAST_KEY_VALUES_NAME)
|
||||
else engine_inputs[name] for name in self.engine.input_names
|
||||
]
|
||||
|
||||
# validate inputs formatted correctly
|
||||
if val_inputs:
|
||||
@ -87,16 +93,11 @@ class DeepSparseDecoderModel:
|
||||
sequence_length: int = 1024,
|
||||
multitoken_length: int = 16,
|
||||
engine_context: Optional[Context] = None,
|
||||
singletoken_engine = None,
|
||||
multitoken_engine = None,
|
||||
):
|
||||
self.sequence_length = sequence_length
|
||||
self.multitoken_length = multitoken_length
|
||||
|
||||
if singletoken_engine is not None and multitoken_engine is not None:
|
||||
self.singletoken_engine = singletoken_engine
|
||||
self.multitoken_engine = multitoken_engine
|
||||
else:
|
||||
# compile decode engine
|
||||
self.singletoken_engine = DeepSparseDecoderEngine(
|
||||
onnx_file_path=onnx_file_path,
|
||||
engine_context=engine_context,
|
||||
@ -104,6 +105,7 @@ class DeepSparseDecoderModel:
|
||||
input_ids_length=1,
|
||||
)
|
||||
|
||||
# compile prefill engine
|
||||
self.multitoken_engine = DeepSparseDecoderEngine(
|
||||
onnx_file_path=onnx_file_path,
|
||||
engine_context=engine_context,
|
||||
@ -118,14 +120,12 @@ class DeepSparseDecoderModel:
|
||||
|
||||
def engine_inputs_for_prefill(
|
||||
self,
|
||||
tokens: List[int],
|
||||
input_ids: np.ndarray,
|
||||
):
|
||||
assert len(tokens) < self.sequence_length
|
||||
|
||||
# split batch into N token_batches
|
||||
num_batches = len(tokens) // self.multitoken_length
|
||||
num_batches = input_ids.shape[1] // self.multitoken_length
|
||||
token_batches = [
|
||||
tokens[i * self.multitoken_length : (i+1) * self.multitoken_length]
|
||||
input_ids[:, i*self.multitoken_length : (i+1)*self.multitoken_length]
|
||||
for i in range(0, num_batches)
|
||||
]
|
||||
|
||||
@ -134,7 +134,7 @@ class DeepSparseDecoderModel:
|
||||
num_processed_tokens = self.multitoken_length * idx
|
||||
|
||||
engine_inputs = {}
|
||||
engine_inputs["input_ids"] = np.array([token_batch])
|
||||
engine_inputs["input_ids"] = token_batch
|
||||
|
||||
# make attention mask from the right
|
||||
engine_inputs["attention_mask"] = np.zeros((1, self.sequence_length), dtype=np.int64)
|
||||
@ -156,30 +156,33 @@ class DeepSparseDecoderModel:
|
||||
|
||||
def engine_inputs_for_decode(
|
||||
self,
|
||||
tokens: List[int],
|
||||
input_ids: np.ndarray,
|
||||
):
|
||||
assert len(tokens) < self.sequence_length
|
||||
|
||||
engine_inputs = {}
|
||||
engine_inputs["input_ids"] = np.array([[tokens[-1]]])
|
||||
engine_inputs["input_ids"] = input_ids[:,-1:]
|
||||
engine_inputs["attention_mask"] = np.zeros((1, self.sequence_length), dtype=np.int64)
|
||||
engine_inputs["attention_mask"][:, -len(tokens):] = 1
|
||||
engine_inputs["attention_mask"][:, -input_ids.shape[1]:] = 1
|
||||
|
||||
engine_inputs["causal_mask"] = create_causal_mask(
|
||||
engine_inputs["input_ids"],
|
||||
engine_inputs["attention_mask"]
|
||||
)
|
||||
engine_inputs["positions"] = np.array([[len(tokens) - 1]], dtype=np.int64)
|
||||
engine_inputs["positions"] = np.array([[input_ids.shape[1] - 1]], dtype=np.int64)
|
||||
|
||||
return engine_inputs
|
||||
|
||||
def decode(
|
||||
self,
|
||||
tokens: List[int],
|
||||
input_ids: np.ndarray,
|
||||
past_key_values: DeepSparsePastKeyValues
|
||||
) -> (np.ndarray, DeepSparsePastKeyValues):
|
||||
|
||||
engine_inputs = self.engine_inputs_for_decode(tokens)
|
||||
# assert input is of shape [1,seq_len] w/ seq_len < self.sequence_len
|
||||
assert len(input_ids.shape) == 2
|
||||
assert input_ids.shape[0] == 1
|
||||
assert input_ids.shape[1] < self.sequence_length
|
||||
|
||||
engine_inputs = self.engine_inputs_for_decode(input_ids)
|
||||
logits, past_key_values = self.singletoken_engine(
|
||||
engine_inputs,
|
||||
past_key_values
|
||||
@ -189,24 +192,51 @@ class DeepSparseDecoderModel:
|
||||
|
||||
def prefill(
|
||||
self,
|
||||
tokens: List[int],
|
||||
past_key_values: DeepSparsePastKeyValues
|
||||
input_ids: np.ndarray,
|
||||
) -> (np.ndarray, DeepSparsePastKeyValues):
|
||||
|
||||
# assert input is of shape [1,seq_len] w/ seq_len < self.sequence_len
|
||||
assert len(input_ids.shape) == 2
|
||||
assert input_ids.shape[0] == 1
|
||||
assert input_ids.shape[1] < self.sequence_length
|
||||
|
||||
tokens_processed = 0
|
||||
|
||||
for engine_inputs in self.engine_inputs_for_prefill(tokens):
|
||||
_, past_key_values = self.multitoken_engine(
|
||||
# setup empty past key values
|
||||
past_key_values = DeepSparsePastKeyValues()
|
||||
|
||||
# loop through chunks, run inference w/ multitoken engine
|
||||
for engine_inputs in self.engine_inputs_for_prefill(input_ids):
|
||||
logits, past_key_values = self.multitoken_engine(
|
||||
engine_inputs,
|
||||
past_key_values
|
||||
)
|
||||
tokens_processed += self.multitoken_length
|
||||
|
||||
while tokens_processed < len(tokens):
|
||||
# if anything left over, run inference w/ singletoken engine
|
||||
while tokens_processed < input_ids.shape[1]:
|
||||
logits, past_key_values = self.decode(
|
||||
tokens= tokens[:tokens_processed + 1],
|
||||
input_ids=input_ids[:,:tokens_processed+1],
|
||||
past_key_values=past_key_values
|
||||
)
|
||||
tokens_processed += 1
|
||||
# print(logits[:,-1:,:])
|
||||
|
||||
return logits, past_key_values
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: np.ndarray,
|
||||
past_key_values: Optional[DeepSparsePastKeyValues] = None,
|
||||
):
|
||||
if past_key_values is None:
|
||||
return self.prefill(input_ids)
|
||||
else:
|
||||
return self.decode(input_ids, past_key_values)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: np.ndarray,
|
||||
past_key_values: Optional[DeepSparsePastKeyValues] = None,
|
||||
):
|
||||
return self.forward(input_ids, past_key_values)
|
Loading…
Reference in New Issue
Block a user