text-generation-inference/server/bloom_inference/model.py
Olivier Dehaene 295831a481 Init
2022-10-08 12:30:12 +02:00

429 lines
16 KiB
Python

import torch
import torch.distributed
from dataclasses import dataclass
from pathlib import Path
from typing import List, Tuple, Optional, Dict
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from transformers.modeling_utils import no_init_weights
from bloom_inference.cache import CacheEntry
from bloom_inference.pb import generate_pb2
from bloom_inference.shard_model import shard_model, match_suffix
from bloom_inference.utils import (
StoppingCriteria,
NextTokenChooser,
initialize_torch_distributed,
set_default_dtype,
)
torch.manual_seed(0)
@dataclass
class Batch:
batch_id: int
request_ids: List[int]
input_ids: Dict[str, torch.Tensor]
all_input_ids: List[torch.Tensor]
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
@classmethod
def from_batch_pb(
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
) -> "Batch":
request_ids = []
inputs = []
next_token_choosers = []
stopping_criterias = []
# Parse batch
for r in pb.requests:
request_ids.append(r.id)
inputs.append(r.inputs)
next_token_choosers.append(
NextTokenChooser(
temperature=r.parameters.temperature,
top_k=r.parameters.top_k,
top_p=r.parameters.top_p,
do_sample=r.parameters.do_sample,
)
)
stopping_criterias.append(StoppingCriteria(max_new_tokens=r.max_new_tokens))
input_ids = tokenizer(inputs, return_tensors="pt", padding=True).to(device)
all_input_ids = input_ids["input_ids"].unsqueeze(-1)
return cls(
pb.id,
request_ids,
input_ids,
all_input_ids,
next_token_choosers,
stopping_criterias,
)
@classmethod
def from_cache_entry(cls, cache_entry: CacheEntry) -> "Batch":
return cls(
cache_entry.batch_id,
cache_entry.request_ids,
cache_entry.input_ids,
cache_entry.all_input_ids,
cache_entry.next_token_choosers,
cache_entry.stopping_criterias,
)
@classmethod
def from_batch_cached_pb(cls, pb: generate_pb2.BatchCached, cache) -> "Batch":
if len(pb.batch_cached_ids) == 1:
cache_entry = cache.pop(pb.batch_cached_ids[0])
if cache_entry is None:
raise ValueError(f"Batch ID {pb.batch_id} not found in cache")
return cls.from_cache_entry(cache_entry)
total_batch_size = pb.total_batch_size
max_sequence_length = pb.max_sequence_length
input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []}
request_ids = []
all_input_ids = []
next_token_choosers = []
stopping_criterias = []
start_index = 0
for i, batch_id in enumerate(pb.batch_cached_ids):
cache_entry = cache.pop(batch_id)
if cache_entry is None:
raise ValueError(f"Batch ID {batch_id} not found in cache")
request_ids.extend(cache_entry.request_ids)
all_input_ids.extend(cache_entry.all_input_ids)
next_token_choosers.extend(cache_entry.next_token_choosers)
stopping_criterias.extend(cache_entry.stopping_criterias)
batch_size = len(cache_entry.request_ids)
end_index = start_index + batch_size
sequence_length = max(len(entry) for entry in cache_entry.all_input_ids)
if input_ids["input_ids"] is None:
input_ids["input_ids"] = torch.empty(
(total_batch_size, 1),
dtype=cache_entry.input_ids["input_ids"].dtype,
device=cache_entry.input_ids["input_ids"].device,
)
input_ids["input_ids"][start_index:end_index] = cache_entry.input_ids[
"input_ids"
]
if input_ids["attention_mask"] is None:
input_ids["attention_mask"] = torch.zeros(
(total_batch_size, max_sequence_length),
dtype=cache_entry.input_ids["attention_mask"].dtype,
device=cache_entry.input_ids["attention_mask"].device,
)
input_ids["attention_mask"][
start_index:end_index, -sequence_length:
] = cache_entry.input_ids["attention_mask"][:, -sequence_length:]
for j, past in enumerate(cache_entry.input_ids["past_key_values"]):
# TODO: this could be done without the views by using indices
past_keys = past[0]
past_values = past[1]
_, head_dim, padded_sequence_length = past_keys.shape
past_keys = past_keys.view(
batch_size, -1, head_dim, padded_sequence_length
)
past_values = past_values.view(
batch_size, -1, padded_sequence_length, head_dim
)
num_heads = past_keys.shape[1]
if j == len(input_ids["past_key_values"]):
padded_past_keys = torch.zeros(
(
total_batch_size,
num_heads,
head_dim,
max_sequence_length - 1,
),
dtype=past_keys.dtype,
device=past_keys.device,
)
padded_past_values = torch.zeros(
(
total_batch_size,
num_heads,
max_sequence_length - 1,
head_dim,
),
dtype=past_values.dtype,
device=past_values.device,
)
input_ids["past_key_values"].append(
[padded_past_keys, padded_past_values]
)
input_ids["past_key_values"][j][0][
start_index:end_index, :, :, -(sequence_length - 1):
] = past_keys[:, :, :, -(sequence_length - 1):]
input_ids["past_key_values"][j][1][
start_index:end_index, :, -(sequence_length - 1):, :
] = past_values[:, :, -(sequence_length - 1):, :]
if (i + 1) == len(pb.batch_cached_ids):
input_ids["past_key_values"][j][0] = input_ids["past_key_values"][
j
][0].view(total_batch_size * num_heads, head_dim, -1)
input_ids["past_key_values"][j][1] = input_ids["past_key_values"][
j
][1].view(total_batch_size * num_heads, -1, head_dim)
start_index += batch_size
assert pb.request_ids == request_ids
return cls(
pb.id,
request_ids,
input_ids,
all_input_ids,
next_token_choosers,
stopping_criterias,
)
@dataclass
class FinishedGeneration:
request_id: str
output: str
def to_pb(self) -> generate_pb2.FinishedGeneration:
return generate_pb2.FinishedGeneration(id=self.request_id, output=self.output)
class BLOOM:
def __init__(self, model_name: str):
if torch.cuda.is_available():
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
self.model = (
AutoModelForCausalLM.from_pretrained(model_name).eval().to(self.device)
)
self.num_heads = self.model.base_model.num_heads
def forward(self, input_ids, attention_mask, past_key_values: Optional = None):
# Model Forward
return self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
def generate_token(
self, batch: Batch
) -> Tuple[List[FinishedGeneration], Optional[CacheEntry]]:
with torch.no_grad():
outputs = self.forward(**batch.input_ids)
# List of indices to cache
cache_indices = []
cache_past_indices = []
# New input_ids for next forward; keep in cache
cache_next_input_ids = []
cache_all_input_ids = []
# Finished requests
finished_generations: List[FinishedGeneration] = []
# Zipped iterator
iterator = zip(
batch.request_ids,
outputs.logits,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
)
# For each member of the batch
for i, (
request_id,
logits,
next_token_chooser,
stopping_criteria,
all_tokens,
) in enumerate(iterator):
# Select next token
next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1])
# Append next token to all tokens
all_tokens = torch.cat([all_tokens, next_token])
# Evaluate stopping criteria
if stopping_criteria(all_tokens):
# Decode all tokens
output = self.tokenizer.decode(
all_tokens.squeeze(-1), skip_special_tokens=True
)
# Add to the list of finished generations with the original request id
finished_generations.append(FinishedGeneration(request_id, output))
# must be added to the cache
else:
cache_indices.append(i)
cache_past_indices.extend([j for j in range(i * self.num_heads, (i + 1) * self.num_heads)])
cache_next_input_ids.append(next_token)
cache_all_input_ids.append(all_tokens)
# No cache is needed, we finished all generations in the batch
if not cache_indices:
return finished_generations, None
# If we finished at least one generation
cache_input_ids = {"input_ids": torch.cat(cache_next_input_ids, dim=0)}
if finished_generations:
# Apply indices to attention mask, past key values and other items that need to be cached
cache_input_ids["attention_mask"] = batch.input_ids["attention_mask"][
cache_indices
]
cache_input_ids["past_key_values"] = [
(keys[cache_past_indices], values[cache_past_indices])
for keys, values in outputs["past_key_values"]
]
cache_request_ids = [batch.request_ids[i] for i in cache_indices]
cache_next_token_choosers = [
batch.next_token_choosers[i] for i in cache_indices
]
cache_stopping_criterias = [
batch.stopping_criterias[i] for i in cache_indices
]
else:
cache_input_ids["attention_mask"] = batch.input_ids["attention_mask"]
cache_input_ids["past_key_values"] = outputs["past_key_values"]
cache_request_ids = batch.request_ids
cache_next_token_choosers = batch.next_token_choosers
cache_stopping_criterias = batch.stopping_criterias
# Update attention_mask with padding as we added a new token to input_ids
cache_input_ids["attention_mask"] = torch.cat(
[
cache_input_ids["attention_mask"],
torch.ones((cache_input_ids["attention_mask"].shape[0], 1)).to(
cache_input_ids["attention_mask"].device
),
],
dim=1,
)
cache_entry = CacheEntry(
batch.batch_id,
cache_request_ids,
cache_input_ids,
cache_all_input_ids,
cache_next_token_choosers,
cache_stopping_criterias,
)
return finished_generations, cache_entry
class BLOOMSharded(BLOOM):
def __init__(self, model_name: str, shard_directory: Path):
super(BLOOM, self).__init__()
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0
if torch.cuda.is_available():
self.device = torch.device(f"cuda:{self.rank}")
dtype = torch.bfloat16
else:
self.device = torch.device("cpu")
dtype = torch.float32
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
# shard state_dict
if self.master:
# TODO @thomasw21 do some caching
shard_state_dict_paths = shard_model(
model_name, shard_directory, tp_world_size=self.world_size, dtype=dtype
)
shard_state_dict_paths = [
str(path.absolute()) for path in shard_state_dict_paths
]
else:
shard_state_dict_paths = [None] * self.world_size
torch.distributed.broadcast_object_list(
shard_state_dict_paths, src=0, group=self.process_group
)
shard_state_dict_path = shard_state_dict_paths[self.rank]
config = AutoConfig.from_pretrained(
model_name, slow_but_exact=False, tp_parallel=True
)
config.pad_token_id = 3
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True
with set_default_dtype(dtype):
with no_init_weights():
# we can probably set the device to `meta` here?
model = AutoModelForCausalLM.from_config(config).to(dtype)
torch.distributed.barrier(group=self.process_group)
# print_rank_0(f"Initialized model")
state_dict = torch.load(shard_state_dict_path)
# TODO @thomasw21: HACK in order to transpose all weight prior
for key in state_dict.keys():
do_transpose = False
if not match_suffix(key, "weight"):
continue
for potential_suffix in [
"self_attention.query_key_value.weight",
"self_attention.dense.weight",
"dense_h_to_4h.weight",
"dense_4h_to_h.weight",
]:
if match_suffix(key, potential_suffix):
do_transpose = True
if do_transpose:
state_dict[key] = state_dict[key].transpose(1, 0).contiguous()
model.load_state_dict(state_dict)
self.model = model.to(self.device).eval()
self.num_heads = config.n_head // self.process_group.size()
torch.distributed.barrier(group=self.process_group)
def forward(self, input_ids, attention_mask, past_key_values: Optional = None):
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
logits_shard = outputs.logits[:, -1, :].contiguous()
batch_size, vocab_shard_size = logits_shard.shape
vocab_size = self.world_size * vocab_shard_size
logits = [torch.empty_like(logits_shard) for _ in range(self.world_size)]
torch.distributed.all_gather(logits, logits_shard, group=self.process_group)
logits = torch.cat(logits, dim=1).view(batch_size, 1, vocab_size)
outputs.logits = logits
return outputs