This commit is contained in:
OlivierDehaene 2023-05-30 10:41:10 +02:00
parent 12ab24ae64
commit cbffddcc06
7 changed files with 509 additions and 255 deletions

View File

@ -286,7 +286,9 @@ def test_batch_concatenate(
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
) )
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id]) next_batch = next_batch.filter(
[next_batch.requests[0].id, next_batch.requests[1].id]
)
for _ in range( for _ in range(
default_bloom_batch.stopping_criterias[0].max_new_tokens default_bloom_batch.stopping_criterias[0].max_new_tokens

View File

@ -285,7 +285,9 @@ def test_batch_concatenate(
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
) )
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id]) next_batch = next_batch.filter(
[next_batch.requests[0].id, next_batch.requests[1].id]
)
for _ in range( for _ in range(
default_causal_lm_batch.stopping_criterias[0].max_new_tokens default_causal_lm_batch.stopping_criterias[0].max_new_tokens

View File

@ -323,7 +323,9 @@ def test_batch_concatenate(
) )
assert generations[2].generated_text.generated_tokens == 5 assert generations[2].generated_text.generated_tokens == 5
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id]) next_batch = next_batch.filter(
[next_batch.requests[0].id, next_batch.requests[1].id]
)
generations, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None assert next_batch is not None

View File

@ -31,7 +31,7 @@ try:
) )
from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
from text_generation_server.models.flash_rw import FlashRW from text_generation_server.models.flash_rw import FlashRW, FlashRWSharded
from text_generation_server.models.flash_llama import ( from text_generation_server.models.flash_llama import (
FlashLlama, FlashLlama,
FlashLlamaSharded, FlashLlamaSharded,
@ -71,6 +71,7 @@ if FLASH_ATTENTION:
__all__.append(FlashNeoX) __all__.append(FlashNeoX)
__all__.append(FlashNeoXSharded) __all__.append(FlashNeoXSharded)
__all__.append(FlashRW) __all__.append(FlashRW)
__all__.append(FlashRWSharded)
__all__.append(FlashSantacoder) __all__.append(FlashSantacoder)
__all__.append(FlashSantacoderSharded) __all__.append(FlashSantacoderSharded)
__all__.append(FlashLlama) __all__.append(FlashLlama)
@ -202,13 +203,15 @@ def get_model(
if FLASH_ATTENTION: if FLASH_ATTENTION:
if config.alibi: if config.alibi:
raise NotImplementedError("sharded is not supported for this model") raise NotImplementedError("sharded is not supported for this model")
# return FlashRWSharded( return FlashRWSharded(
# model_id, model_id,
# revision, revision,
# quantize=quantize, quantize=quantize,
# trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
# ) )
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb")) raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb")
)
else: else:
if FLASH_ATTENTION and not config.alibi: if FLASH_ATTENTION and not config.alibi:
return FlashRW( return FlashRW(

View File

@ -27,28 +27,30 @@ class RWConfig(PretrainedConfig):
} }
def __init__( def __init__(
self, self,
model_type="RefinedWeb", model_type="RefinedWeb",
vocab_size=250880, vocab_size=250880,
hidden_size=64, hidden_size=64,
n_layer=2, n_layer=2,
n_head=8, n_head=8,
layer_norm_epsilon=1e-5, layer_norm_epsilon=1e-5,
initializer_range=0.02, initializer_range=0.02,
use_cache=True, use_cache=True,
bos_token_id=1, bos_token_id=1,
eos_token_id=2, eos_token_id=2,
hidden_dropout=0.0, hidden_dropout=0.0,
attention_dropout=0.0, attention_dropout=0.0,
n_head_kv=None, n_head_kv=None,
multi_query=False, multi_query=False,
alibi=False, alibi=False,
bias=False, bias=False,
parallel_attn=False, parallel_attn=False,
**kwargs, **kwargs,
): ):
if alibi: if alibi:
raise NotImplementedError("alibi is not supported by this version of the model") raise NotImplementedError(
"alibi is not supported by this version of the model"
)
self.model_type = model_type self.model_type = model_type
self.alibi = False self.alibi = False
@ -81,13 +83,13 @@ class RWConfig(PretrainedConfig):
class FlashRWAttention(torch.nn.Module): class FlashRWAttention(torch.nn.Module):
def __init__( def __init__(
self, self,
num_heads, num_heads,
num_heads_kv, num_heads_kv,
hidden_size, hidden_size,
bias, bias,
process_group=None, process_group=None,
reduce=True, reduce=True,
): ):
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
@ -99,33 +101,44 @@ class FlashRWAttention(torch.nn.Module):
self.softmax_scale = self.head_size ** (-0.5) self.softmax_scale = self.head_size ** (-0.5)
if process_group is None: if process_group is None:
self.query_key_value = FastLinear(hidden_size, self.head_size * (self.num_heads + 2 * self.num_heads_kv), self.query_key_value = FastLinear(
bias=bias) hidden_size,
self.head_size * (self.num_heads + 2 * self.num_heads_kv),
bias=bias,
)
self.dense = FastLinear(hidden_size, hidden_size, bias=bias) self.dense = FastLinear(hidden_size, hidden_size, bias=bias)
else: else:
self.num_heads = self.num_heads // process_group.size() self.num_heads = self.num_heads // process_group.size()
self.query_key_value = FastLinear(hidden_size, self.head_size * (self.num_heads + 2 * self.num_heads_kv), self.query_key_value = FastLinear(
bias=bias) hidden_size,
self.head_size * (self.num_heads + 2 * self.num_heads_kv),
bias=bias,
)
self.dense = TensorParallelRowLinear( self.dense = TensorParallelRowLinear(
hidden_size, hidden_size, bias=bias, process_group=process_group, reduce=reduce hidden_size,
hidden_size,
bias=bias,
process_group=process_group,
reduce=reduce,
) )
def forward( def forward(
self, self,
hidden_states, hidden_states,
cos, cos,
sin, sin,
cu_seqlens, cu_seqlens,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, layer_past_present_indices,
cu_seqlens_q, cu_seqlens_q,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
# Split query from key_value # Split query from key_value
query, kv = qkv.split( query, kv = qkv.split(
[self.head_size * self.num_heads, 2 * self.head_size * self.num_heads_kv], dim=1 [self.head_size * self.num_heads, 2 * self.head_size * self.num_heads_kv],
dim=1,
) )
# Prepare query and key_value for indexing # Prepare query and key_value for indexing
@ -194,11 +207,149 @@ class FlashRWAttention(torch.nn.Module):
return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
class FlashMLP(nn.Module): class FlashRWLargeAttention(torch.nn.Module):
def __init__( def __init__(
self, hidden_size, bias, process_group=None, reduce=True self,
num_heads,
num_heads_kv,
hidden_size,
bias,
process_group=None,
reduce=True,
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads
self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000)
self.softmax_scale = self.head_size ** (-0.5)
self.num_groups = num_heads // (num_heads_kv * 2)
self.num_heads = num_heads // self.num_groups
self.num_heads_kv = num_heads_kv // self.num_groups
if process_group is None:
self.query_key_value = FastLinear(
hidden_size,
self.num_groups *
self.head_size
* (self.num_heads + 2 * self.num_heads_kv),
bias=bias,
)
self.dense = FastLinear(hidden_size, hidden_size, bias=bias)
else:
self.query_key_value = TensorParallelColumnLinear(
hidden_size,
self.num_groups
* self.head_size
* (self.num_heads + 2 * self.num_heads_kv),
bias=bias,
process_group=process_group,
)
self.dense = TensorParallelRowLinear(
hidden_size,
hidden_size,
bias=bias,
process_group=process_group,
reduce=reduce,
)
self.num_groups = self.num_groups // process_group.size()
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlens,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
):
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
# Split query from key_value
query, kv = qkv.split(
[self.num_heads, 2],
dim=2,
)
# Prepare query and key_value for indexing
query = query.reshape(-1, self.num_groups * self.num_heads, self.head_size)
kv = kv.transpose(1, 2)
# Inplace rotary
self.rotary_emb(query, cos, sin)
self.rotary_emb(kv[:, 0], cos, sin)
# Prefill
if layer_past_present_indices is None:
# Copy to layer past
layer_past[...] = kv
k, v = kv.split(1, dim=1)
# Expand to query shape
k = k.expand(-1, self.num_heads, self.num_groups, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
v = v.expand(-1, self.num_heads, self.num_groups, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
k,
v,
attn_output,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
self.softmax_scale,
False,
True,
False,
0,
None,
)
# Decode
else:
# Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = kv
k, v = layer_past.split(1, dim=1)
# Expand to query shape
k = k.expand(-1, self.num_heads, self.num_groups, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
v = v.expand(-1, self.num_heads, self.num_groups, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
k,
v,
attn_output,
cu_seqlens_q,
cu_seqlens,
1,
max_s,
0.0,
self.softmax_scale,
False,
False,
False,
0,
None,
)
return self.dense(attn_output.view(-1, self.num_heads * self.num_groups * self.head_size))
class FlashMLP(nn.Module):
def __init__(self, hidden_size, bias, process_group=None, reduce=True):
super().__init__()
self.act = torch.nn.functional.gelu self.act = torch.nn.functional.gelu
if process_group is None: if process_group is None:
@ -207,12 +358,14 @@ class FlashMLP(nn.Module):
else: else:
self.dense_h_to_4h = TensorParallelColumnLinear( self.dense_h_to_4h = TensorParallelColumnLinear(
hidden_size, hidden_size,
4 * hidden_size, bias=bias, 4 * hidden_size,
bias=bias,
process_group=process_group, process_group=process_group,
) )
self.dense_4h_to_h = TensorParallelRowLinear( self.dense_4h_to_h = TensorParallelRowLinear(
4 * hidden_size, 4 * hidden_size,
hidden_size, bias=bias, hidden_size,
bias=bias,
process_group=process_group, process_group=process_group,
reduce=reduce, reduce=reduce,
) )
@ -227,37 +380,44 @@ class FlashMLP(nn.Module):
class FlashRWLayer(nn.Module): class FlashRWLayer(nn.Module):
def __init__( def __init__(
self, self,
num_heads, num_heads,
num_heads_kv, num_heads_kv,
hidden_size, hidden_size,
layer_norm_eps, bias,
parallel_attn, layer_norm_eps,
process_group=None, parallel_attn,
process_group=None,
): ):
super().__init__() super().__init__()
self.parallel_attn = parallel_attn self.parallel_attn = parallel_attn
self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps) self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
self.self_attention = FlashRWAttention(num_heads, num_heads_kv, hidden_size, process_group, reduce=False) self.self_attention = FlashRWAttention(
self.post_attention_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps) if not parallel_attn else None num_heads, num_heads_kv, hidden_size, bias, process_group=process_group, reduce=False
)
self.post_attention_layernorm = (
FastLayerNorm(hidden_size, eps=layer_norm_eps)
if not parallel_attn
else None
)
self.mlp = FlashMLP(hidden_size, process_group, reduce=False) self.mlp = FlashMLP(hidden_size, bias, process_group=process_group, reduce=False)
self.process_group = process_group self.process_group = process_group
def forward( def forward(
self, self,
hidden_states, hidden_states,
residual, residual,
cos, cos,
sin, sin,
cu_seqlens, cu_seqlens,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, layer_past_present_indices,
cu_seqlens_q, cu_seqlens_q,
): ):
if self.parallel_attn: if self.parallel_attn:
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual) ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
@ -303,6 +463,68 @@ class FlashRWLayer(nn.Module):
return mlp_output, residual return mlp_output, residual
class FlashRWLargeLayer(nn.Module):
def __init__(
self,
num_heads,
num_heads_kv,
hidden_size,
bias,
layer_norm_eps,
process_group=None,
):
super().__init__()
self.ln_attn = FastLayerNorm(hidden_size, eps=layer_norm_eps)
self.ln_mlp = FastLayerNorm(hidden_size, eps=layer_norm_eps)
self.self_attention = FlashRWLargeAttention(
num_heads, num_heads_kv, hidden_size, bias, process_group=process_group, reduce=False
)
self.mlp = FlashMLP(hidden_size, bias, process_group=process_group, reduce=False)
self.process_group = process_group
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlens,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
):
ln_attn, residual = self.ln_attn(hidden_states, residual)
ln_mlp, _ = self.ln_mlp(hidden_states, residual)
# Self attention.
attn_output = self.self_attention(
ln_attn,
cos,
sin,
cu_seqlens,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
)
# MLP.
mlp_output = self.mlp(ln_mlp)
intermediate = attn_output + mlp_output
# Only reduce once and after the addition instead of once per layer
if self.process_group is not None:
torch.distributed.all_reduce(intermediate, group=self.process_group)
return intermediate, residual
class FlashRWPreTrainedModel(PreTrainedModel): class FlashRWPreTrainedModel(PreTrainedModel):
config_class = RWConfig config_class = RWConfig
supports_gradient_checkpointing = False supports_gradient_checkpointing = False
@ -328,27 +550,47 @@ class FlashRWModel(FlashRWPreTrainedModel):
else: else:
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.h = nn.ModuleList( if config.model_type == "RefinedWebModel":
[ self.h = nn.ModuleList(
FlashRWLayer( [
config.n_head, FlashRWLayer(
config.n_head_kv, config.n_head,
config.hidden_size, config.n_head_kv,
config.layer_norm_epsilon, config.hidden_size,
config.parallel_attn, config.bias,
process_group, config.layer_norm_epsilon,
) config.parallel_attn,
for _ in range(config.num_hidden_layers) process_group,
] )
) for _ in range(config.num_hidden_layers)
self.ln_f = FastLayerNorm( ]
config.hidden_size, eps=config.layer_norm_epsilon )
) self.kv_size = self.h[0].self_attention.num_heads_kv
elif config.model_type == "RefinedWeb":
self.h = nn.ModuleList(
[
FlashRWLargeLayer(
config.n_head,
config.n_head_kv,
config.hidden_size,
config.bias,
config.layer_norm_epsilon,
process_group,
)
for _ in range(config.num_hidden_layers)
]
)
self.kv_size = self.h[0].self_attention.num_groups
else:
raise NotImplementedError(
f"model_type {config.model_type} is not supported."
)
self.ln_f = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.gradient_checkpointing = False self.gradient_checkpointing = False
self.head_size = self.h[0].self_attention.head_size self.head_size = self.h[0].self_attention.head_size
self.num_heads_kv = self.h[0].self_attention.num_heads_kv
def post_load_weights(self, quantize: Optional[str] = None): def post_load_weights(self, quantize: Optional[str] = None):
if isinstance(self.word_embeddings, TensorParallelEmbedding): if isinstance(self.word_embeddings, TensorParallelEmbedding):
@ -373,14 +615,14 @@ class FlashRWModel(FlashRWPreTrainedModel):
return model return model
def forward( def forward(
self, self,
input_ids, input_ids,
position_ids, position_ids,
cu_seqlens, cu_seqlens,
cu_seqlens_q, cu_seqlens_q,
max_s, max_s,
past_key_values=None, past_key_values=None,
pre_allocate_past_size: Optional[int] = None, pre_allocate_past_size: Optional[int] = None,
): ):
hidden_states = self.word_embeddings(input_ids) hidden_states = self.word_embeddings(input_ids)
@ -394,7 +636,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
if pre_allocate_past_size is None if pre_allocate_past_size is None
else pre_allocate_past_size, else pre_allocate_past_size,
2, 2,
self.num_heads_kv, self.kv_size,
self.head_size, self.head_size,
) )
) )
@ -457,9 +699,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
bias=False, bias=False,
) )
else: else:
self.lm_head = FastLinear( self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
config.hidden_size, config.vocab_size, bias=False
)
def post_load_weights(self, quantize: Optional[str] = None): def post_load_weights(self, quantize: Optional[str] = None):
self.transformer.post_load_weights(quantize) self.transformer.post_load_weights(quantize)
@ -477,14 +717,14 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
return model return model
def forward( def forward(
self, self,
input_ids, input_ids,
position_ids, position_ids,
cu_seqlens, cu_seqlens,
cu_seqlens_q, cu_seqlens_q,
max_s, max_s,
past_key_values: Optional[torch.Tensor] = None, past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None, pre_allocate_past_size: Optional[int] = None,
): ):
hidden_states, present = self.transformer( hidden_states, present = self.transformer(
input_ids, input_ids,

View File

@ -53,8 +53,6 @@ class FlashRW(FlashCausalLM):
model_id, model_id,
revision=revision, revision=revision,
) )
from loguru import logger
logger.error(config.model_type)
# We do not use from_pretrained as we modified the model internal module layout # We do not use from_pretrained as we modified the model internal module layout
try: try:
@ -114,133 +112,134 @@ class FlashRW(FlashCausalLM):
torch.cuda.empty_cache() torch.cuda.empty_cache()
model.post_load_weights(quantize) model.post_load_weights(quantize)
#
# class FlashNeoXSharded(FlashNeoX):
# def __init__( class FlashRWSharded(FlashRW):
# self, def __init__(
# model_id: str, self,
# revision: Optional[str] = None, model_id: str,
# quantize: Optional[str] = None, revision: Optional[str] = None,
# trust_remote_code: bool = False, quantize: Optional[str] = None,
# ): trust_remote_code: bool = False,
# self.process_group, rank, world_size = initialize_torch_distributed() ):
# if torch.cuda.is_available(): self.process_group, rank, world_size = initialize_torch_distributed()
# device = torch.device(f"cuda:{rank}") if torch.cuda.is_available():
# dtype = torch.float16 device = torch.device(f"cuda:{rank}")
# else: dtype = torch.float16
# raise NotImplementedError("FlashNeoX is only available on GPU") else:
# raise NotImplementedError("FlashRW is only available on GPU")
# tokenizer = AutoTokenizer.from_pretrained(
# model_id, tokenizer = AutoTokenizer.from_pretrained(
# revision=revision, model_id,
# padding_side="left", revision=revision,
# truncation_side="left", padding_side="left",
# trust_remote_code=trust_remote_code, truncation_side="left",
# ) trust_remote_code=trust_remote_code,
# )
# config = AutoConfig.from_pretrained(
# model_id, revision=revision, trust_remote_code=trust_remote_code config = RWConfig.from_pretrained(
# ) model_id, revision=revision, trust_remote_code=trust_remote_code
# )
# torch.distributed.barrier(group=self.process_group)
# filenames = weight_files(model_id, revision=revision, extension=".safetensors") torch.distributed.barrier(group=self.process_group)
# filenames = weight_files(model_id, revision=revision, extension=".safetensors")
# with init_empty_weights():
# model = FlashGPTNeoXForCausalLM(config, self.process_group) with init_empty_weights():
# model = FlashRWForCausalLM(config, self.process_group)
# torch.distributed.barrier(group=self.process_group)
# self.load_weights( torch.distributed.barrier(group=self.process_group)
# model, self.load_weights(
# filenames, model,
# quantize=quantize, filenames,
# device=device, quantize=quantize,
# dtype=dtype, device=device,
# rank=rank, dtype=dtype,
# world_size=world_size, rank=rank,
# ) world_size=world_size,
# torch.distributed.barrier(group=self.process_group) )
# super(FlashCausalLM, self).__init__( torch.distributed.barrier(group=self.process_group)
# model=model.to(device), super(FlashCausalLM, self).__init__(
# tokenizer=tokenizer, model=model.to(device),
# requires_padding=False, tokenizer=tokenizer,
# dtype=dtype, requires_padding=False,
# device=device, dtype=dtype,
# rank=rank, device=device,
# world_size=world_size, rank=rank,
# ) world_size=world_size,
# )
# @staticmethod
# def load_weights( @staticmethod
# model, def load_weights(
# filenames: List[str], model,
# quantize: Optional[str], filenames: List[str],
# device: torch.device, quantize: Optional[str],
# dtype: torch.dtype, device: torch.device,
# rank: int, dtype: torch.dtype,
# world_size: int, rank: int,
# ): world_size: int,
# parameters = dict(model.named_parameters()) ):
# for file in filenames: parameters = dict(model.named_parameters())
# with safe_open( for file in filenames:
# file, framework="pt", device=str(device) if quantize is None else "cpu" with safe_open(
# ) as f: file, framework="pt", device=str(device) if quantize is None else "cpu"
# for name in f.keys(): ) as f:
# module_name, param_name = name.rsplit(".", 1) for name in f.keys():
# module = model.get_submodule(module_name) module_name, param_name = name.rsplit(".", 1)
# module = model.get_submodule(module_name)
# current_parameter_tensor = parameters.get(name, None)
# current_parameter_tensor = parameters.get(name, None)
# slice_ = f.get_slice(name)
# slice_ = f.get_slice(name)
# if isinstance(module, TensorParallelColumnLinear):
# size = slice_.get_shape()[0] if isinstance(module, TensorParallelColumnLinear):
# block_size = size // world_size size = slice_.get_shape()[0]
# start = rank * block_size block_size = size // world_size
# stop = (rank + 1) * block_size start = rank * block_size
# tensor = slice_[start:stop] stop = (rank + 1) * block_size
# elif isinstance(module, TensorParallelRowLinear): tensor = slice_[start:stop]
# if param_name == "weight": elif isinstance(module, TensorParallelRowLinear):
# size = slice_.get_shape()[1] if param_name == "weight":
# block_size = size // world_size size = slice_.get_shape()[1]
# start = rank * block_size block_size = size // world_size
# stop = (rank + 1) * block_size start = rank * block_size
# tensor = slice_[:, start:stop] stop = (rank + 1) * block_size
# else: tensor = slice_[:, start:stop]
# tensor = slice_[:] else:
# # XXX: Hack for Rowlinear to add the bias only once. tensor = slice_[:]
# if rank != 0: # XXX: Hack for Rowlinear to add the bias only once.
# tensor = torch.zeros_like(tensor) if rank != 0:
# elif isinstance(module, TensorParallelEmbedding): tensor = torch.zeros_like(tensor)
# size = slice_.get_shape()[0] elif isinstance(module, TensorParallelEmbedding):
# block_size = size // world_size size = slice_.get_shape()[0]
# start = rank * block_size block_size = size // world_size
# stop = (rank + 1) * block_size start = rank * block_size
# tensor = slice_[start:stop] stop = (rank + 1) * block_size
# elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings: tensor = slice_[start:stop]
# size = slice_.get_shape()[0] elif name == "lm_head.weight" and model.transformer.tp_embeddings:
# block_size = size // world_size size = slice_.get_shape()[0]
# start = rank * block_size block_size = size // world_size
# stop = (rank + 1) * block_size start = rank * block_size
# tensor = slice_[start:stop] stop = (rank + 1) * block_size
# else: tensor = slice_[start:stop]
# try: else:
# tensor = slice_[:] try:
# except: tensor = slice_[:]
# tensor = f.get_tensor(name) except:
# tensor = f.get_tensor(name)
# if (
# current_parameter_tensor is not None if (
# and current_parameter_tensor.shape != tensor.shape current_parameter_tensor is not None
# ): and current_parameter_tensor.shape != tensor.shape
# raise ValueError( ):
# f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" raise ValueError(
# ) f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
# )
# tensor = tensor.contiguous().to(dtype)
# tensor = tensor.contiguous().to(dtype)
# if current_parameter_tensor is not None:
# module._parameters[param_name] = tensor if current_parameter_tensor is not None:
# else: module._parameters[param_name] = tensor
# module._buffers[param_name] = tensor else:
# module._buffers[param_name] = tensor
# model.post_load_weights(quantize)
model.post_load_weights(quantize)

View File

@ -8,11 +8,11 @@ from text_generation_server.models import CausalLM
class RW(CausalLM): class RW(CausalLM):
def __init__( def __init__(
self, self,
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
@ -63,7 +63,7 @@ class RW(CausalLM):
) )
def forward( def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
# Model Forward # Model Forward
if past_key_values is not None: if past_key_values is not None:
@ -71,10 +71,16 @@ class RW(CausalLM):
for layer in past_key_values: for layer in past_key_values:
past_keys, past_values = layer past_keys, past_values = layer
reshaped_past_key_values.append( reshaped_past_key_values.append(
(past_keys.view(-1, *past_keys.shape[-2:]), past_values.view(-1, *past_values.shape[-2:])) (
past_keys.view(-1, *past_keys.shape[-2:]),
past_values.view(-1, *past_values.shape[-2:]),
)
) )
past_key_values = reshaped_past_key_values past_key_values = reshaped_past_key_values
outputs = self.model.forward(input_ids=input_ids, attention_mask=attention_mask, outputs = self.model.forward(
past_key_values=past_key_values) input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
)
return outputs.logits, outputs.past_key_values return outputs.logits, outputs.past_key_values