feat(server): support RefinedWeb models

This commit is contained in:
OlivierDehaene 2023-05-29 11:56:19 +02:00
parent 951930fbff
commit 63a18c1414
7 changed files with 891 additions and 24 deletions

View File

@ -10,6 +10,7 @@ from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.bloom import BLOOM, BLOOMSharded from text_generation_server.models.bloom import BLOOM, BLOOMSharded
from text_generation_server.models.seq2seq_lm import Seq2SeqLM from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation_server.models.rw import RW
from text_generation_server.models.opt import OPT, OPTSharded from text_generation_server.models.opt import OPT, OPTSharded
from text_generation_server.models.galactica import Galactica, GalacticaSharded from text_generation_server.models.galactica import Galactica, GalacticaSharded
from text_generation_server.models.santacoder import SantaCoder from text_generation_server.models.santacoder import SantaCoder
@ -30,6 +31,7 @@ try:
) )
from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
from text_generation_server.models.flash_rw import FlashRW
from text_generation_server.models.flash_llama import ( from text_generation_server.models.flash_llama import (
FlashLlama, FlashLlama,
FlashLlamaSharded, FlashLlamaSharded,
@ -68,6 +70,7 @@ __all__ = [
if FLASH_ATTENTION: if FLASH_ATTENTION:
__all__.append(FlashNeoX) __all__.append(FlashNeoX)
__all__.append(FlashNeoXSharded) __all__.append(FlashNeoXSharded)
__all__.append(FlashRW)
__all__.append(FlashSantacoder) __all__.append(FlashSantacoder)
__all__.append(FlashSantacoderSharded) __all__.append(FlashSantacoderSharded)
__all__.append(FlashLlama) __all__.append(FlashLlama)
@ -194,6 +197,34 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if model_type in ["RefinedWeb", "RefinedWebModel"]:
if sharded:
if FLASH_ATTENTION:
if config.alibi:
raise NotImplementedError("sharded is not supported for this model")
# return FlashRWSharded(
# model_id,
# revision,
# quantize=quantize,
# trust_remote_code=trust_remote_code,
# )
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb"))
else:
if FLASH_ATTENTION and not config.alibi:
return FlashRW(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
else:
return RW(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
if model_type == "llama": if model_type == "llama":
if sharded: if sharded:
if FLASH_ATTENTION: if FLASH_ATTENTION:

View File

@ -134,20 +134,23 @@ class FlashLlamaAttention(torch.nn.Module):
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size) qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
qkv_rot = self.rotary_emb(qkv, cos, sin)
# Inplace rotary
self.rotary_emb(qkv[:, 0], cos, sin)
self.rotary_emb(qkv[:, 1], cos, sin)
# Prefill # Prefill
if layer_past_present_indices is None: if layer_past_present_indices is None:
# Copy to layer past # Copy to layer past
layer_past[...] = qkv_rot[:, 1:] layer_past[...] = qkv[:, 1:]
# output # output
attn_output = torch.empty_like(qkv_rot[:, 0]) attn_output = torch.empty_like(qkv[:, 0])
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
qkv_rot[:, 0], qkv[:, 0],
qkv_rot[:, 1], qkv[:, 1],
qkv_rot[:, 2], qkv[:, 2],
attn_output, attn_output,
cu_seqlens, cu_seqlens,
cu_seqlens, cu_seqlens,
@ -163,9 +166,9 @@ class FlashLlamaAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
query = qkv_rot[:, 0] query = qkv[:, 0]
# Add present to the layer_past tensor at the correct indices # Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = qkv_rot[:, 1:] layer_past[layer_past_present_indices] = qkv[:, 1:]
# output # output
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)

View File

@ -101,20 +101,23 @@ class FlashNeoxAttention(torch.nn.Module):
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size) qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
qkv_rot = self.rotary_emb(qkv, cos, sin)
# Inplace rotary
self.rotary_emb(qkv[:, 0], cos, sin)
self.rotary_emb(qkv[:, 1], cos, sin)
# Prefill # Prefill
if layer_past_present_indices is None: if layer_past_present_indices is None:
# Copy to layer past # Copy to layer past
layer_past[...] = qkv_rot[:, 1:] layer_past[...] = qkv[:, 1:]
# output # output
attn_output = torch.empty_like(qkv_rot[:, 0]) attn_output = torch.empty_like(qkv[:, 0])
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
qkv_rot[:, 0], qkv[:, 0],
qkv_rot[:, 1], qkv[:, 1],
qkv_rot[:, 2], qkv[:, 2],
attn_output, attn_output,
cu_seqlens, cu_seqlens,
cu_seqlens, cu_seqlens,
@ -130,9 +133,9 @@ class FlashNeoxAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
query = qkv_rot[:, 0] query = qkv[:, 0]
# Add present to the layer_past tensor at the correct indices # Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = qkv_rot[:, 1:] layer_past[layer_past_present_indices] = qkv[:, 1:]
# output # output
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)

View File

@ -0,0 +1,507 @@
import torch
import torch.distributed
from loguru import logger
from torch import nn
from transformers.modeling_utils import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig
from typing import Optional
# Flash attention imports
import flash_attn_cuda
from text_generation_server.utils.layers import (
FastLinear,
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
FastLayerNorm,
PositionRotaryEmbedding,
)
class RWConfig(PretrainedConfig):
attribute_map = {
"num_hidden_layers": "n_layer",
"num_attention_heads": "n_head",
}
def __init__(
self,
model_type="RefinedWeb",
vocab_size=250880,
hidden_size=64,
n_layer=2,
n_head=8,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
use_cache=True,
bos_token_id=1,
eos_token_id=2,
hidden_dropout=0.0,
attention_dropout=0.0,
n_head_kv=None,
multi_query=False,
alibi=False,
bias=False,
parallel_attn=False,
**kwargs,
):
if alibi:
raise NotImplementedError("alibi is not supported by this version of the model")
self.model_type = model_type
self.alibi = False
self.rotary = True
self.vocab_size = vocab_size
# Backward compatibility with n_embed kwarg
n_embed = kwargs.pop("n_embed", None)
self.hidden_size = hidden_size if n_embed is None else n_embed
self.n_layer = n_layer
self.n_head = n_head
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.use_cache = use_cache
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.bias = bias
self.parallel_attn = parallel_attn
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
if n_head_kv is not None:
self.n_head_kv = n_head_kv
else:
self.n_head_kv = 1 if multi_query else n_head
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
class FlashRWAttention(torch.nn.Module):
def __init__(
self,
num_heads,
num_heads_kv,
hidden_size,
bias,
process_group=None,
reduce=True,
):
super().__init__()
self.num_heads = num_heads
self.num_heads_kv = num_heads_kv
self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads
self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000)
self.softmax_scale = self.head_size ** (-0.5)
if process_group is None:
self.query_key_value = FastLinear(hidden_size, self.head_size * (self.num_heads + 2 * self.num_heads_kv),
bias=bias)
self.dense = FastLinear(hidden_size, hidden_size, bias=bias)
else:
self.num_heads = self.num_heads // process_group.size()
self.query_key_value = FastLinear(hidden_size, self.head_size * (self.num_heads + 2 * self.num_heads_kv),
bias=bias)
self.dense = TensorParallelRowLinear(
hidden_size, hidden_size, bias=bias, process_group=process_group, reduce=reduce
)
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlens,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
):
qkv = self.query_key_value(hidden_states)
# Split query from key_value
query, kv = qkv.split(
[self.head_size * self.num_heads, 2 * self.head_size], dim=1
)
# Prepare query and key_value for indexing
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, 1, self.head_size)
# Inplace rotary
self.rotary_emb(query, cos, sin)
self.rotary_emb(kv[:, 0], cos, sin)
# Prefill
if layer_past_present_indices is None:
# Copy to layer past
layer_past[...] = kv
# Expand to query shape
kv = kv.expand(-1, 2, query.shape[1], self.head_size)
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
kv[:, 0],
kv[:, 1],
attn_output,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
self.softmax_scale,
False,
True,
False,
0,
None,
)
# Decode
else:
# Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = kv
# Expand to query shape
kv = layer_past.expand(-1, 2, query.shape[1], self.head_size)
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
kv[:, 0],
kv[:, 1],
attn_output,
cu_seqlens_q,
cu_seqlens,
1,
max_s,
0.0,
self.softmax_scale,
False,
False,
False,
0,
None,
)
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
class FlashMLP(nn.Module):
def __init__(
self, hidden_size, bias, process_group=None, reduce=True
):
super().__init__()
self.act = torch.nn.functional.gelu
if process_group is None:
self.dense_h_to_4h = FastLinear(hidden_size, 4 * hidden_size, bias=bias)
self.dense_4h_to_h = FastLinear(4 * hidden_size, hidden_size, bias=bias)
else:
self.dense_h_to_4h = TensorParallelColumnLinear(
hidden_size,
4 * hidden_size, bias=bias,
process_group=process_group,
)
self.dense_4h_to_h = TensorParallelRowLinear(
4 * hidden_size,
hidden_size, bias=bias,
process_group=process_group,
reduce=reduce,
)
self.process_group = process_group
def forward(self, hidden_states):
hidden_states = self.dense_h_to_4h(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.dense_4h_to_h(hidden_states)
return hidden_states
class FlashRWLayer(nn.Module):
def __init__(
self,
num_heads,
num_heads_kv,
hidden_size,
layer_norm_eps,
parallel_attn,
process_group=None,
):
super().__init__()
self.parallel_attn = parallel_attn
self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
self.self_attention = FlashRWAttention(num_heads, num_heads_kv, hidden_size, process_group, reduce=False)
self.post_attention_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps) if not parallel_attn else None
self.mlp = FlashMLP(hidden_size, process_group, reduce=False)
self.process_group = process_group
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlens,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
):
if self.parallel_attn:
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
attn_output = self.self_attention(
ln_hidden_states,
cos,
sin,
cu_seqlens,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
)
mlp_output = self.mlp(ln_hidden_states)
intermediate = mlp_output + attn_output
# Only reduce once and after the addition instead of once per layer
if self.process_group is not None:
torch.distributed.all_reduce(intermediate, group=self.process_group)
return intermediate, residual
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attention(
hidden_states,
cos,
sin,
cu_seqlens,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
)
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
mlp_output = self.mlp(hidden_states)
return mlp_output, residual
class FlashRWPreTrainedModel(PreTrainedModel):
config_class = RWConfig
supports_gradient_checkpointing = False
_no_split_modules = None
class FlashRWModel(FlashRWPreTrainedModel):
def __init__(self, config, process_group=None):
super().__init__(config)
self.config = config
self.tp_embeddings = False
if process_group is not None:
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
if config.vocab_size % self.tp_world_size == 0:
self.tp_embeddings = True
if self.tp_embeddings:
self.word_embeddings = TensorParallelEmbedding(
config.vocab_size, config.hidden_size, process_group=process_group
)
else:
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.h = nn.ModuleList(
[
FlashRWLayer(
config.n_head,
config.n_head_kv,
config.hidden_size,
config.layer_norm_epsilon,
config.parallel_attn,
process_group,
)
for _ in range(config.num_hidden_layers)
]
)
self.ln_f = FastLayerNorm(
config.hidden_size, eps=config.layer_norm_epsilon
)
self.gradient_checkpointing = False
self.head_size = self.h[0].self_attention.head_size
self.num_heads_kv = self.h[0].self_attention.num_heads_kv
def post_load_weights(self, quantize: Optional[str] = None):
if isinstance(self.word_embeddings, TensorParallelEmbedding):
self.word_embeddings.add_null_idx()
for layer in self.h:
layer: FlashRWLayer
layer.self_attention.query_key_value.prepare_weights(quantize)
layer.self_attention.dense.prepare_weights(quantize)
layer.mlp.dense_h_to_4h.prepare_weights(quantize)
layer.mlp.dense_4h_to_h.prepare_weights(quantize)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
# to do it for us
load_in_8bit = kwargs.pop("load_in_8bit", False)
model = super(FlashRWModel, cls).from_pretrained(
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
)
model.post_load_weights("bitsandbytes" if load_in_8bit else None)
return model
def forward(
self,
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
max_s,
past_key_values=None,
pre_allocate_past_size: Optional[int] = None,
):
hidden_states = self.word_embeddings(input_ids)
# Prefill
if past_key_values is None:
# Create past tensor
past_key_values = hidden_states.new_empty(
(
len(self.h),
len(hidden_states)
if pre_allocate_past_size is None
else pre_allocate_past_size,
2,
self.num_heads_kv,
self.head_size,
)
)
layer_past_present_indices = None
slice_past_index = len(hidden_states)
# Decode
else:
# Create indices from cumulative sequence lengths
layer_past_present_indices = cu_seqlens[1:] - 1
slice_past_index = None
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None
for i, layer in enumerate(self.h):
# We added padding that we now need to slice
layer_past_key_values = (
past_key_values[i]
if slice_past_index is None
else past_key_values[i, :slice_past_index]
)
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlens,
max_s,
layer_past_key_values,
layer_past_present_indices,
cu_seqlens_q,
)
hidden_states, _ = self.ln_f(hidden_states, residual)
return hidden_states, past_key_values
class FlashRWForCausalLM(FlashRWPreTrainedModel):
def __init__(self, config, process_group=None):
super().__init__(config)
self.process_group = process_group
if self.process_group is not None:
self.world_size = self.process_group.size()
else:
self.world_size = 1
self.transformer = FlashRWModel(config, process_group)
if self.transformer.tp_embeddings:
self.lm_head = FastLinear(
config.hidden_size,
config.vocab_size // process_group.size(),
bias=False,
)
else:
self.lm_head = FastLinear(
config.hidden_size, config.vocab_size, bias=False
)
def post_load_weights(self, quantize: Optional[str] = None):
self.transformer.post_load_weights(quantize)
self.lm_head.prepare_weights()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
# to do it for us
load_in_8bit = kwargs.pop("load_in_8bit", False)
model = super(FlashRWForCausalLM, cls).from_pretrained(
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
)
model.post_load_weights("bitsandbytes" if load_in_8bit else None)
return model
def forward(
self,
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
max_s,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
):
hidden_states, present = self.transformer(
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
max_s,
past_key_values,
pre_allocate_past_size,
)
logits = self.lm_head(hidden_states)
if self.transformer.tp_embeddings:
# Logits are sharded, so we need to gather them
world_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
torch.distributed.all_gather(world_logits, logits, group=self.process_group)
world_logits = torch.cat(world_logits, dim=1)
return world_logits, present
return logits, present

View File

@ -0,0 +1,246 @@
import torch
import torch.distributed
from pathlib import Path
from accelerate import init_empty_weights
from opentelemetry import trace
from safetensors import safe_open
from transformers import AutoTokenizer, AutoConfig
from typing import Optional, List
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
RWConfig,
FlashRWForCausalLM,
TensorParallelEmbedding,
TensorParallelRowLinear,
TensorParallelColumnLinear,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
download_weights,
weight_hub_files,
LocalEntryNotFoundError,
)
tracer = trace.get_tracer(__name__)
class FlashRW(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16
else:
raise NotImplementedError("RW is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = RWConfig.from_pretrained(
model_id,
revision=revision,
)
from loguru import logger
logger.error(config.model_type)
# We do not use from_pretrained as we modified the model internal module layout
try:
filenames = weight_files(model_id, revision, ".bin")
# Local files not found
except LocalEntryNotFoundError:
hub_files = weight_hub_files(model_id, revision, ".bin")
filenames = download_weights(hub_files, model_id, revision)
with init_empty_weights():
model = FlashRWForCausalLM(config)
self.load_weights(
model,
filenames,
quantize,
device,
dtype,
)
super(FlashCausalLM, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
)
@staticmethod
def load_weights(
model: FlashRWForCausalLM,
filenames: List[Path],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
):
for filename in filenames:
state_dict = torch.load(filename, map_location="cpu")
for key, value in state_dict.items():
value = value.to(device if quantize is None else "cpu").to(dtype)
module_name, param_name = key.rsplit(".", 1)
module = model.get_submodule(module_name)
try:
current_parameter_tensor = module._parameters[param_name]
if current_parameter_tensor.shape != value.shape:
raise ValueError(
f"Name {key} -- Current {current_parameter_tensor.shape} and got {value.shape}"
)
module._parameters[param_name] = value
except KeyError:
module._buffers[param_name] = value
del value
torch.cuda.empty_cache()
model.post_load_weights(quantize)
#
# class FlashNeoXSharded(FlashNeoX):
# def __init__(
# self,
# model_id: str,
# revision: Optional[str] = None,
# quantize: Optional[str] = None,
# trust_remote_code: bool = False,
# ):
# self.process_group, rank, world_size = initialize_torch_distributed()
# if torch.cuda.is_available():
# device = torch.device(f"cuda:{rank}")
# dtype = torch.float16
# else:
# raise NotImplementedError("FlashNeoX is only available on GPU")
#
# tokenizer = AutoTokenizer.from_pretrained(
# model_id,
# revision=revision,
# padding_side="left",
# truncation_side="left",
# trust_remote_code=trust_remote_code,
# )
#
# config = AutoConfig.from_pretrained(
# model_id, revision=revision, trust_remote_code=trust_remote_code
# )
#
# torch.distributed.barrier(group=self.process_group)
# filenames = weight_files(model_id, revision=revision, extension=".safetensors")
#
# with init_empty_weights():
# model = FlashGPTNeoXForCausalLM(config, self.process_group)
#
# torch.distributed.barrier(group=self.process_group)
# self.load_weights(
# model,
# filenames,
# quantize=quantize,
# device=device,
# dtype=dtype,
# rank=rank,
# world_size=world_size,
# )
# torch.distributed.barrier(group=self.process_group)
# super(FlashCausalLM, self).__init__(
# model=model.to(device),
# tokenizer=tokenizer,
# requires_padding=False,
# dtype=dtype,
# device=device,
# rank=rank,
# world_size=world_size,
# )
#
# @staticmethod
# def load_weights(
# model,
# filenames: List[str],
# quantize: Optional[str],
# device: torch.device,
# dtype: torch.dtype,
# rank: int,
# world_size: int,
# ):
# parameters = dict(model.named_parameters())
# for file in filenames:
# with safe_open(
# file, framework="pt", device=str(device) if quantize is None else "cpu"
# ) as f:
# for name in f.keys():
# module_name, param_name = name.rsplit(".", 1)
# module = model.get_submodule(module_name)
#
# current_parameter_tensor = parameters.get(name, None)
#
# slice_ = f.get_slice(name)
#
# if isinstance(module, TensorParallelColumnLinear):
# size = slice_.get_shape()[0]
# block_size = size // world_size
# start = rank * block_size
# stop = (rank + 1) * block_size
# tensor = slice_[start:stop]
# elif isinstance(module, TensorParallelRowLinear):
# if param_name == "weight":
# size = slice_.get_shape()[1]
# block_size = size // world_size
# start = rank * block_size
# stop = (rank + 1) * block_size
# tensor = slice_[:, start:stop]
# else:
# tensor = slice_[:]
# # XXX: Hack for Rowlinear to add the bias only once.
# if rank != 0:
# tensor = torch.zeros_like(tensor)
# elif isinstance(module, TensorParallelEmbedding):
# size = slice_.get_shape()[0]
# block_size = size // world_size
# start = rank * block_size
# stop = (rank + 1) * block_size
# tensor = slice_[start:stop]
# elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings:
# size = slice_.get_shape()[0]
# block_size = size // world_size
# start = rank * block_size
# stop = (rank + 1) * block_size
# tensor = slice_[start:stop]
# else:
# try:
# tensor = slice_[:]
# except:
# tensor = f.get_tensor(name)
#
# if (
# current_parameter_tensor is not None
# and current_parameter_tensor.shape != tensor.shape
# ):
# raise ValueError(
# f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
# )
#
# tensor = tensor.contiguous().to(dtype)
#
# if current_parameter_tensor is not None:
# module._parameters[param_name] = tensor
# else:
# module._buffers[param_name] = tensor
#
# model.post_load_weights(quantize)

View File

@ -0,0 +1,80 @@
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Optional, Tuple
from text_generation_server.models import CausalLM
class RW(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16
else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu")
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
device_map="auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None,
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
model = model.cuda()
if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None:
tokenizer.pad_token_id = model.config.pad_token_id
elif model.config.eos_token_id is not None:
tokenizer.pad_token_id = model.config.eos_token_id
elif tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
)
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
# Model Forward
if past_key_values is not None:
reshaped_past_key_values = []
for layer in past_key_values:
past_keys, past_values = layer
reshaped_past_key_values.append(
(past_keys.view(-1, *past_keys.shape[-2:]), past_values.view(-1, *past_values.shape[-2:]))
)
past_key_values = reshaped_past_key_values
outputs = self.model.forward(input_ids=input_ids, attention_mask=attention_mask,
past_key_values=past_key_values)
return outputs.logits, outputs.past_key_values

View File

@ -262,16 +262,13 @@ try:
sin = torch.index_select(self._sin_cached, 0, position_ids) sin = torch.index_select(self._sin_cached, 0, position_ids)
return cos.unsqueeze(1), sin.unsqueeze(1) return cos.unsqueeze(1), sin.unsqueeze(1)
def forward(self, qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
rotary_dim = cos.shape[-1] rotary_dim = cos.shape[-1]
q1 = qkv[:, 0, :, :rotary_dim] x1 = x[..., :rotary_dim]
q2 = qkv[:, 0, :, rotary_dim : 2 * rotary_dim] x2 = x[..., rotary_dim : 2 * rotary_dim]
k1 = qkv[:, 1, :, :rotary_dim]
k2 = qkv[:, 1, :, rotary_dim : 2 * rotary_dim]
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) return x
return qkv
except ImportError: except ImportError:
pass pass