mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
wip
This commit is contained in:
parent
12ab24ae64
commit
cbffddcc06
@ -286,7 +286,9 @@ def test_batch_concatenate(
|
||||
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||
)
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
|
||||
next_batch = next_batch.filter(
|
||||
[next_batch.requests[0].id, next_batch.requests[1].id]
|
||||
)
|
||||
|
||||
for _ in range(
|
||||
default_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||
|
@ -285,7 +285,9 @@ def test_batch_concatenate(
|
||||
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||
)
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
|
||||
next_batch = next_batch.filter(
|
||||
[next_batch.requests[0].id, next_batch.requests[1].id]
|
||||
)
|
||||
|
||||
for _ in range(
|
||||
default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||
|
@ -323,7 +323,9 @@ def test_batch_concatenate(
|
||||
)
|
||||
assert generations[2].generated_text.generated_tokens == 5
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
|
||||
next_batch = next_batch.filter(
|
||||
[next_batch.requests[0].id, next_batch.requests[1].id]
|
||||
)
|
||||
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert next_batch is not None
|
||||
|
@ -31,7 +31,7 @@ try:
|
||||
)
|
||||
|
||||
from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
|
||||
from text_generation_server.models.flash_rw import FlashRW
|
||||
from text_generation_server.models.flash_rw import FlashRW, FlashRWSharded
|
||||
from text_generation_server.models.flash_llama import (
|
||||
FlashLlama,
|
||||
FlashLlamaSharded,
|
||||
@ -71,6 +71,7 @@ if FLASH_ATTENTION:
|
||||
__all__.append(FlashNeoX)
|
||||
__all__.append(FlashNeoXSharded)
|
||||
__all__.append(FlashRW)
|
||||
__all__.append(FlashRWSharded)
|
||||
__all__.append(FlashSantacoder)
|
||||
__all__.append(FlashSantacoderSharded)
|
||||
__all__.append(FlashLlama)
|
||||
@ -202,13 +203,15 @@ def get_model(
|
||||
if FLASH_ATTENTION:
|
||||
if config.alibi:
|
||||
raise NotImplementedError("sharded is not supported for this model")
|
||||
# return FlashRWSharded(
|
||||
# model_id,
|
||||
# revision,
|
||||
# quantize=quantize,
|
||||
# trust_remote_code=trust_remote_code,
|
||||
# )
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb"))
|
||||
return FlashRWSharded(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
raise NotImplementedError(
|
||||
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb")
|
||||
)
|
||||
else:
|
||||
if FLASH_ATTENTION and not config.alibi:
|
||||
return FlashRW(
|
||||
|
@ -27,28 +27,30 @@ class RWConfig(PretrainedConfig):
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_type="RefinedWeb",
|
||||
vocab_size=250880,
|
||||
hidden_size=64,
|
||||
n_layer=2,
|
||||
n_head=8,
|
||||
layer_norm_epsilon=1e-5,
|
||||
initializer_range=0.02,
|
||||
use_cache=True,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
hidden_dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
n_head_kv=None,
|
||||
multi_query=False,
|
||||
alibi=False,
|
||||
bias=False,
|
||||
parallel_attn=False,
|
||||
**kwargs,
|
||||
self,
|
||||
model_type="RefinedWeb",
|
||||
vocab_size=250880,
|
||||
hidden_size=64,
|
||||
n_layer=2,
|
||||
n_head=8,
|
||||
layer_norm_epsilon=1e-5,
|
||||
initializer_range=0.02,
|
||||
use_cache=True,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
hidden_dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
n_head_kv=None,
|
||||
multi_query=False,
|
||||
alibi=False,
|
||||
bias=False,
|
||||
parallel_attn=False,
|
||||
**kwargs,
|
||||
):
|
||||
if alibi:
|
||||
raise NotImplementedError("alibi is not supported by this version of the model")
|
||||
raise NotImplementedError(
|
||||
"alibi is not supported by this version of the model"
|
||||
)
|
||||
|
||||
self.model_type = model_type
|
||||
self.alibi = False
|
||||
@ -81,13 +83,13 @@ class RWConfig(PretrainedConfig):
|
||||
|
||||
class FlashRWAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads,
|
||||
num_heads_kv,
|
||||
hidden_size,
|
||||
bias,
|
||||
process_group=None,
|
||||
reduce=True,
|
||||
self,
|
||||
num_heads,
|
||||
num_heads_kv,
|
||||
hidden_size,
|
||||
bias,
|
||||
process_group=None,
|
||||
reduce=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
@ -99,33 +101,44 @@ class FlashRWAttention(torch.nn.Module):
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
||||
if process_group is None:
|
||||
self.query_key_value = FastLinear(hidden_size, self.head_size * (self.num_heads + 2 * self.num_heads_kv),
|
||||
bias=bias)
|
||||
self.query_key_value = FastLinear(
|
||||
hidden_size,
|
||||
self.head_size * (self.num_heads + 2 * self.num_heads_kv),
|
||||
bias=bias,
|
||||
)
|
||||
self.dense = FastLinear(hidden_size, hidden_size, bias=bias)
|
||||
else:
|
||||
self.num_heads = self.num_heads // process_group.size()
|
||||
self.query_key_value = FastLinear(hidden_size, self.head_size * (self.num_heads + 2 * self.num_heads_kv),
|
||||
bias=bias)
|
||||
self.query_key_value = FastLinear(
|
||||
hidden_size,
|
||||
self.head_size * (self.num_heads + 2 * self.num_heads_kv),
|
||||
bias=bias,
|
||||
)
|
||||
self.dense = TensorParallelRowLinear(
|
||||
hidden_size, hidden_size, bias=bias, process_group=process_group, reduce=reduce
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
bias=bias,
|
||||
process_group=process_group,
|
||||
reduce=reduce,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
layer_past,
|
||||
layer_past_present_indices,
|
||||
cu_seqlens_q,
|
||||
self,
|
||||
hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
layer_past,
|
||||
layer_past_present_indices,
|
||||
cu_seqlens_q,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
|
||||
# Split query from key_value
|
||||
query, kv = qkv.split(
|
||||
[self.head_size * self.num_heads, 2 * self.head_size * self.num_heads_kv], dim=1
|
||||
[self.head_size * self.num_heads, 2 * self.head_size * self.num_heads_kv],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# Prepare query and key_value for indexing
|
||||
@ -194,11 +207,149 @@ class FlashRWAttention(torch.nn.Module):
|
||||
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
||||
|
||||
class FlashMLP(nn.Module):
|
||||
class FlashRWLargeAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self, hidden_size, bias, process_group=None, reduce=True
|
||||
self,
|
||||
num_heads,
|
||||
num_heads_kv,
|
||||
hidden_size,
|
||||
bias,
|
||||
process_group=None,
|
||||
reduce=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = hidden_size // num_heads
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000)
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
||||
self.num_groups = num_heads // (num_heads_kv * 2)
|
||||
self.num_heads = num_heads // self.num_groups
|
||||
self.num_heads_kv = num_heads_kv // self.num_groups
|
||||
|
||||
if process_group is None:
|
||||
self.query_key_value = FastLinear(
|
||||
hidden_size,
|
||||
self.num_groups *
|
||||
self.head_size
|
||||
* (self.num_heads + 2 * self.num_heads_kv),
|
||||
bias=bias,
|
||||
)
|
||||
self.dense = FastLinear(hidden_size, hidden_size, bias=bias)
|
||||
else:
|
||||
self.query_key_value = TensorParallelColumnLinear(
|
||||
hidden_size,
|
||||
self.num_groups
|
||||
* self.head_size
|
||||
* (self.num_heads + 2 * self.num_heads_kv),
|
||||
bias=bias,
|
||||
process_group=process_group,
|
||||
)
|
||||
self.dense = TensorParallelRowLinear(
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
bias=bias,
|
||||
process_group=process_group,
|
||||
reduce=reduce,
|
||||
)
|
||||
|
||||
self.num_groups = self.num_groups // process_group.size()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
layer_past,
|
||||
layer_past_present_indices,
|
||||
cu_seqlens_q,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
|
||||
|
||||
# Split query from key_value
|
||||
query, kv = qkv.split(
|
||||
[self.num_heads, 2],
|
||||
dim=2,
|
||||
)
|
||||
|
||||
# Prepare query and key_value for indexing
|
||||
query = query.reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
||||
kv = kv.transpose(1, 2)
|
||||
|
||||
# Inplace rotary
|
||||
self.rotary_emb(query, cos, sin)
|
||||
self.rotary_emb(kv[:, 0], cos, sin)
|
||||
|
||||
# Prefill
|
||||
if layer_past_present_indices is None:
|
||||
# Copy to layer past
|
||||
layer_past[...] = kv
|
||||
k, v = kv.split(1, dim=1)
|
||||
# Expand to query shape
|
||||
k = k.expand(-1, self.num_heads, self.num_groups, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
||||
v = v.expand(-1, self.num_heads, self.num_groups, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
query,
|
||||
k,
|
||||
v,
|
||||
attn_output,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
max_s,
|
||||
0.0,
|
||||
self.softmax_scale,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
0,
|
||||
None,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
# Add present to the layer_past tensor at the correct indices
|
||||
layer_past[layer_past_present_indices] = kv
|
||||
k, v = layer_past.split(1, dim=1)
|
||||
# Expand to query shape
|
||||
k = k.expand(-1, self.num_heads, self.num_groups, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
||||
v = v.expand(-1, self.num_heads, self.num_groups, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
query,
|
||||
k,
|
||||
v,
|
||||
attn_output,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens,
|
||||
1,
|
||||
max_s,
|
||||
0.0,
|
||||
self.softmax_scale,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
0,
|
||||
None,
|
||||
)
|
||||
|
||||
return self.dense(attn_output.view(-1, self.num_heads * self.num_groups * self.head_size))
|
||||
|
||||
|
||||
class FlashMLP(nn.Module):
|
||||
def __init__(self, hidden_size, bias, process_group=None, reduce=True):
|
||||
super().__init__()
|
||||
self.act = torch.nn.functional.gelu
|
||||
|
||||
if process_group is None:
|
||||
@ -207,12 +358,14 @@ class FlashMLP(nn.Module):
|
||||
else:
|
||||
self.dense_h_to_4h = TensorParallelColumnLinear(
|
||||
hidden_size,
|
||||
4 * hidden_size, bias=bias,
|
||||
4 * hidden_size,
|
||||
bias=bias,
|
||||
process_group=process_group,
|
||||
)
|
||||
self.dense_4h_to_h = TensorParallelRowLinear(
|
||||
4 * hidden_size,
|
||||
hidden_size, bias=bias,
|
||||
hidden_size,
|
||||
bias=bias,
|
||||
process_group=process_group,
|
||||
reduce=reduce,
|
||||
)
|
||||
@ -227,37 +380,44 @@ class FlashMLP(nn.Module):
|
||||
|
||||
class FlashRWLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads,
|
||||
num_heads_kv,
|
||||
hidden_size,
|
||||
layer_norm_eps,
|
||||
parallel_attn,
|
||||
process_group=None,
|
||||
self,
|
||||
num_heads,
|
||||
num_heads_kv,
|
||||
hidden_size,
|
||||
bias,
|
||||
layer_norm_eps,
|
||||
parallel_attn,
|
||||
process_group=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.parallel_attn = parallel_attn
|
||||
|
||||
self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
self.self_attention = FlashRWAttention(num_heads, num_heads_kv, hidden_size, process_group, reduce=False)
|
||||
self.post_attention_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps) if not parallel_attn else None
|
||||
self.self_attention = FlashRWAttention(
|
||||
num_heads, num_heads_kv, hidden_size, bias, process_group=process_group, reduce=False
|
||||
)
|
||||
self.post_attention_layernorm = (
|
||||
FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
if not parallel_attn
|
||||
else None
|
||||
)
|
||||
|
||||
self.mlp = FlashMLP(hidden_size, process_group, reduce=False)
|
||||
self.mlp = FlashMLP(hidden_size, bias, process_group=process_group, reduce=False)
|
||||
|
||||
self.process_group = process_group
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
residual,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
layer_past,
|
||||
layer_past_present_indices,
|
||||
cu_seqlens_q,
|
||||
self,
|
||||
hidden_states,
|
||||
residual,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
layer_past,
|
||||
layer_past_present_indices,
|
||||
cu_seqlens_q,
|
||||
):
|
||||
if self.parallel_attn:
|
||||
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
@ -303,6 +463,68 @@ class FlashRWLayer(nn.Module):
|
||||
|
||||
return mlp_output, residual
|
||||
|
||||
|
||||
class FlashRWLargeLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads,
|
||||
num_heads_kv,
|
||||
hidden_size,
|
||||
bias,
|
||||
layer_norm_eps,
|
||||
process_group=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.ln_attn = FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
self.ln_mlp = FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
|
||||
self.self_attention = FlashRWLargeAttention(
|
||||
num_heads, num_heads_kv, hidden_size, bias, process_group=process_group, reduce=False
|
||||
)
|
||||
|
||||
self.mlp = FlashMLP(hidden_size, bias, process_group=process_group, reduce=False)
|
||||
|
||||
self.process_group = process_group
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
residual,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
layer_past,
|
||||
layer_past_present_indices,
|
||||
cu_seqlens_q,
|
||||
):
|
||||
ln_attn, residual = self.ln_attn(hidden_states, residual)
|
||||
ln_mlp, _ = self.ln_mlp(hidden_states, residual)
|
||||
|
||||
# Self attention.
|
||||
attn_output = self.self_attention(
|
||||
ln_attn,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
layer_past,
|
||||
layer_past_present_indices,
|
||||
cu_seqlens_q,
|
||||
)
|
||||
|
||||
# MLP.
|
||||
mlp_output = self.mlp(ln_mlp)
|
||||
|
||||
intermediate = attn_output + mlp_output
|
||||
|
||||
# Only reduce once and after the addition instead of once per layer
|
||||
if self.process_group is not None:
|
||||
torch.distributed.all_reduce(intermediate, group=self.process_group)
|
||||
|
||||
return intermediate, residual
|
||||
|
||||
|
||||
class FlashRWPreTrainedModel(PreTrainedModel):
|
||||
config_class = RWConfig
|
||||
supports_gradient_checkpointing = False
|
||||
@ -328,27 +550,47 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||
else:
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
|
||||
self.h = nn.ModuleList(
|
||||
[
|
||||
FlashRWLayer(
|
||||
config.n_head,
|
||||
config.n_head_kv,
|
||||
config.hidden_size,
|
||||
config.layer_norm_epsilon,
|
||||
config.parallel_attn,
|
||||
process_group,
|
||||
)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.ln_f = FastLayerNorm(
|
||||
config.hidden_size, eps=config.layer_norm_epsilon
|
||||
)
|
||||
if config.model_type == "RefinedWebModel":
|
||||
self.h = nn.ModuleList(
|
||||
[
|
||||
FlashRWLayer(
|
||||
config.n_head,
|
||||
config.n_head_kv,
|
||||
config.hidden_size,
|
||||
config.bias,
|
||||
config.layer_norm_epsilon,
|
||||
config.parallel_attn,
|
||||
process_group,
|
||||
)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.kv_size = self.h[0].self_attention.num_heads_kv
|
||||
elif config.model_type == "RefinedWeb":
|
||||
self.h = nn.ModuleList(
|
||||
[
|
||||
FlashRWLargeLayer(
|
||||
config.n_head,
|
||||
config.n_head_kv,
|
||||
config.hidden_size,
|
||||
config.bias,
|
||||
config.layer_norm_epsilon,
|
||||
process_group,
|
||||
)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.kv_size = self.h[0].self_attention.num_groups
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"model_type {config.model_type} is not supported."
|
||||
)
|
||||
|
||||
self.ln_f = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.head_size = self.h[0].self_attention.head_size
|
||||
self.num_heads_kv = self.h[0].self_attention.num_heads_kv
|
||||
|
||||
def post_load_weights(self, quantize: Optional[str] = None):
|
||||
if isinstance(self.word_embeddings, TensorParallelEmbedding):
|
||||
@ -373,14 +615,14 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||
return model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
position_ids,
|
||||
cu_seqlens,
|
||||
cu_seqlens_q,
|
||||
max_s,
|
||||
past_key_values=None,
|
||||
pre_allocate_past_size: Optional[int] = None,
|
||||
self,
|
||||
input_ids,
|
||||
position_ids,
|
||||
cu_seqlens,
|
||||
cu_seqlens_q,
|
||||
max_s,
|
||||
past_key_values=None,
|
||||
pre_allocate_past_size: Optional[int] = None,
|
||||
):
|
||||
hidden_states = self.word_embeddings(input_ids)
|
||||
|
||||
@ -394,7 +636,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||
if pre_allocate_past_size is None
|
||||
else pre_allocate_past_size,
|
||||
2,
|
||||
self.num_heads_kv,
|
||||
self.kv_size,
|
||||
self.head_size,
|
||||
)
|
||||
)
|
||||
@ -457,9 +699,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
||||
bias=False,
|
||||
)
|
||||
else:
|
||||
self.lm_head = FastLinear(
|
||||
config.hidden_size, config.vocab_size, bias=False
|
||||
)
|
||||
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
def post_load_weights(self, quantize: Optional[str] = None):
|
||||
self.transformer.post_load_weights(quantize)
|
||||
@ -477,14 +717,14 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
||||
return model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
position_ids,
|
||||
cu_seqlens,
|
||||
cu_seqlens_q,
|
||||
max_s,
|
||||
past_key_values: Optional[torch.Tensor] = None,
|
||||
pre_allocate_past_size: Optional[int] = None,
|
||||
self,
|
||||
input_ids,
|
||||
position_ids,
|
||||
cu_seqlens,
|
||||
cu_seqlens_q,
|
||||
max_s,
|
||||
past_key_values: Optional[torch.Tensor] = None,
|
||||
pre_allocate_past_size: Optional[int] = None,
|
||||
):
|
||||
hidden_states, present = self.transformer(
|
||||
input_ids,
|
||||
|
@ -53,8 +53,6 @@ class FlashRW(FlashCausalLM):
|
||||
model_id,
|
||||
revision=revision,
|
||||
)
|
||||
from loguru import logger
|
||||
logger.error(config.model_type)
|
||||
|
||||
# We do not use from_pretrained as we modified the model internal module layout
|
||||
try:
|
||||
@ -114,133 +112,134 @@ class FlashRW(FlashCausalLM):
|
||||
torch.cuda.empty_cache()
|
||||
model.post_load_weights(quantize)
|
||||
|
||||
#
|
||||
# class FlashNeoXSharded(FlashNeoX):
|
||||
# def __init__(
|
||||
# self,
|
||||
# model_id: str,
|
||||
# revision: Optional[str] = None,
|
||||
# quantize: Optional[str] = None,
|
||||
# trust_remote_code: bool = False,
|
||||
# ):
|
||||
# self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
# if torch.cuda.is_available():
|
||||
# device = torch.device(f"cuda:{rank}")
|
||||
# dtype = torch.float16
|
||||
# else:
|
||||
# raise NotImplementedError("FlashNeoX is only available on GPU")
|
||||
#
|
||||
# tokenizer = AutoTokenizer.from_pretrained(
|
||||
# model_id,
|
||||
# revision=revision,
|
||||
# padding_side="left",
|
||||
# truncation_side="left",
|
||||
# trust_remote_code=trust_remote_code,
|
||||
# )
|
||||
#
|
||||
# config = AutoConfig.from_pretrained(
|
||||
# model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
# )
|
||||
#
|
||||
# torch.distributed.barrier(group=self.process_group)
|
||||
# filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
#
|
||||
# with init_empty_weights():
|
||||
# model = FlashGPTNeoXForCausalLM(config, self.process_group)
|
||||
#
|
||||
# torch.distributed.barrier(group=self.process_group)
|
||||
# self.load_weights(
|
||||
# model,
|
||||
# filenames,
|
||||
# quantize=quantize,
|
||||
# device=device,
|
||||
# dtype=dtype,
|
||||
# rank=rank,
|
||||
# world_size=world_size,
|
||||
# )
|
||||
# torch.distributed.barrier(group=self.process_group)
|
||||
# super(FlashCausalLM, self).__init__(
|
||||
# model=model.to(device),
|
||||
# tokenizer=tokenizer,
|
||||
# requires_padding=False,
|
||||
# dtype=dtype,
|
||||
# device=device,
|
||||
# rank=rank,
|
||||
# world_size=world_size,
|
||||
# )
|
||||
#
|
||||
# @staticmethod
|
||||
# def load_weights(
|
||||
# model,
|
||||
# filenames: List[str],
|
||||
# quantize: Optional[str],
|
||||
# device: torch.device,
|
||||
# dtype: torch.dtype,
|
||||
# rank: int,
|
||||
# world_size: int,
|
||||
# ):
|
||||
# parameters = dict(model.named_parameters())
|
||||
# for file in filenames:
|
||||
# with safe_open(
|
||||
# file, framework="pt", device=str(device) if quantize is None else "cpu"
|
||||
# ) as f:
|
||||
# for name in f.keys():
|
||||
# module_name, param_name = name.rsplit(".", 1)
|
||||
# module = model.get_submodule(module_name)
|
||||
#
|
||||
# current_parameter_tensor = parameters.get(name, None)
|
||||
#
|
||||
# slice_ = f.get_slice(name)
|
||||
#
|
||||
# if isinstance(module, TensorParallelColumnLinear):
|
||||
# size = slice_.get_shape()[0]
|
||||
# block_size = size // world_size
|
||||
# start = rank * block_size
|
||||
# stop = (rank + 1) * block_size
|
||||
# tensor = slice_[start:stop]
|
||||
# elif isinstance(module, TensorParallelRowLinear):
|
||||
# if param_name == "weight":
|
||||
# size = slice_.get_shape()[1]
|
||||
# block_size = size // world_size
|
||||
# start = rank * block_size
|
||||
# stop = (rank + 1) * block_size
|
||||
# tensor = slice_[:, start:stop]
|
||||
# else:
|
||||
# tensor = slice_[:]
|
||||
# # XXX: Hack for Rowlinear to add the bias only once.
|
||||
# if rank != 0:
|
||||
# tensor = torch.zeros_like(tensor)
|
||||
# elif isinstance(module, TensorParallelEmbedding):
|
||||
# size = slice_.get_shape()[0]
|
||||
# block_size = size // world_size
|
||||
# start = rank * block_size
|
||||
# stop = (rank + 1) * block_size
|
||||
# tensor = slice_[start:stop]
|
||||
# elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings:
|
||||
# size = slice_.get_shape()[0]
|
||||
# block_size = size // world_size
|
||||
# start = rank * block_size
|
||||
# stop = (rank + 1) * block_size
|
||||
# tensor = slice_[start:stop]
|
||||
# else:
|
||||
# try:
|
||||
# tensor = slice_[:]
|
||||
# except:
|
||||
# tensor = f.get_tensor(name)
|
||||
#
|
||||
# if (
|
||||
# current_parameter_tensor is not None
|
||||
# and current_parameter_tensor.shape != tensor.shape
|
||||
# ):
|
||||
# raise ValueError(
|
||||
# f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
|
||||
# )
|
||||
#
|
||||
# tensor = tensor.contiguous().to(dtype)
|
||||
#
|
||||
# if current_parameter_tensor is not None:
|
||||
# module._parameters[param_name] = tensor
|
||||
# else:
|
||||
# module._buffers[param_name] = tensor
|
||||
#
|
||||
# model.post_load_weights(quantize)
|
||||
|
||||
|
||||
class FlashRWSharded(FlashRW):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16
|
||||
else:
|
||||
raise NotImplementedError("FlashRW is only available on GPU")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = RWConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
|
||||
with init_empty_weights():
|
||||
model = FlashRWForCausalLM(config, self.process_group)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
self.load_weights(
|
||||
model,
|
||||
filenames,
|
||||
quantize=quantize,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashCausalLM, self).__init__(
|
||||
model=model.to(device),
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_weights(
|
||||
model,
|
||||
filenames: List[str],
|
||||
quantize: Optional[str],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
):
|
||||
parameters = dict(model.named_parameters())
|
||||
for file in filenames:
|
||||
with safe_open(
|
||||
file, framework="pt", device=str(device) if quantize is None else "cpu"
|
||||
) as f:
|
||||
for name in f.keys():
|
||||
module_name, param_name = name.rsplit(".", 1)
|
||||
module = model.get_submodule(module_name)
|
||||
|
||||
current_parameter_tensor = parameters.get(name, None)
|
||||
|
||||
slice_ = f.get_slice(name)
|
||||
|
||||
if isinstance(module, TensorParallelColumnLinear):
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
elif isinstance(module, TensorParallelRowLinear):
|
||||
if param_name == "weight":
|
||||
size = slice_.get_shape()[1]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[:, start:stop]
|
||||
else:
|
||||
tensor = slice_[:]
|
||||
# XXX: Hack for Rowlinear to add the bias only once.
|
||||
if rank != 0:
|
||||
tensor = torch.zeros_like(tensor)
|
||||
elif isinstance(module, TensorParallelEmbedding):
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
elif name == "lm_head.weight" and model.transformer.tp_embeddings:
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
else:
|
||||
try:
|
||||
tensor = slice_[:]
|
||||
except:
|
||||
tensor = f.get_tensor(name)
|
||||
|
||||
if (
|
||||
current_parameter_tensor is not None
|
||||
and current_parameter_tensor.shape != tensor.shape
|
||||
):
|
||||
raise ValueError(
|
||||
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
|
||||
)
|
||||
|
||||
tensor = tensor.contiguous().to(dtype)
|
||||
|
||||
if current_parameter_tensor is not None:
|
||||
module._parameters[param_name] = tensor
|
||||
else:
|
||||
module._buffers[param_name] = tensor
|
||||
|
||||
model.post_load_weights(quantize)
|
||||
|
@ -8,11 +8,11 @@ from text_generation_server.models import CausalLM
|
||||
|
||||
class RW(CausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
@ -63,7 +63,7 @@ class RW(CausalLM):
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
# Model Forward
|
||||
if past_key_values is not None:
|
||||
@ -71,10 +71,16 @@ class RW(CausalLM):
|
||||
for layer in past_key_values:
|
||||
past_keys, past_values = layer
|
||||
reshaped_past_key_values.append(
|
||||
(past_keys.view(-1, *past_keys.shape[-2:]), past_values.view(-1, *past_values.shape[-2:]))
|
||||
(
|
||||
past_keys.view(-1, *past_keys.shape[-2:]),
|
||||
past_values.view(-1, *past_values.shape[-2:]),
|
||||
)
|
||||
)
|
||||
past_key_values = reshaped_past_key_values
|
||||
|
||||
outputs = self.model.forward(input_ids=input_ids, attention_mask=attention_mask,
|
||||
past_key_values=past_key_values)
|
||||
outputs = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
return outputs.logits, outputs.past_key_values
|
||||
|
Loading…
Reference in New Issue
Block a user