mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
wip
This commit is contained in:
parent
12ab24ae64
commit
cbffddcc06
@ -286,7 +286,9 @@ def test_batch_concatenate(
|
|||||||
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
|
next_batch = next_batch.filter(
|
||||||
|
[next_batch.requests[0].id, next_batch.requests[1].id]
|
||||||
|
)
|
||||||
|
|
||||||
for _ in range(
|
for _ in range(
|
||||||
default_bloom_batch.stopping_criterias[0].max_new_tokens
|
default_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||||
|
@ -285,7 +285,9 @@ def test_batch_concatenate(
|
|||||||
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
|
next_batch = next_batch.filter(
|
||||||
|
[next_batch.requests[0].id, next_batch.requests[1].id]
|
||||||
|
)
|
||||||
|
|
||||||
for _ in range(
|
for _ in range(
|
||||||
default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||||
|
@ -323,7 +323,9 @@ def test_batch_concatenate(
|
|||||||
)
|
)
|
||||||
assert generations[2].generated_text.generated_tokens == 5
|
assert generations[2].generated_text.generated_tokens == 5
|
||||||
|
|
||||||
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
|
next_batch = next_batch.filter(
|
||||||
|
[next_batch.requests[0].id, next_batch.requests[1].id]
|
||||||
|
)
|
||||||
|
|
||||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
@ -31,7 +31,7 @@ try:
|
|||||||
)
|
)
|
||||||
|
|
||||||
from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
|
from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
|
||||||
from text_generation_server.models.flash_rw import FlashRW
|
from text_generation_server.models.flash_rw import FlashRW, FlashRWSharded
|
||||||
from text_generation_server.models.flash_llama import (
|
from text_generation_server.models.flash_llama import (
|
||||||
FlashLlama,
|
FlashLlama,
|
||||||
FlashLlamaSharded,
|
FlashLlamaSharded,
|
||||||
@ -71,6 +71,7 @@ if FLASH_ATTENTION:
|
|||||||
__all__.append(FlashNeoX)
|
__all__.append(FlashNeoX)
|
||||||
__all__.append(FlashNeoXSharded)
|
__all__.append(FlashNeoXSharded)
|
||||||
__all__.append(FlashRW)
|
__all__.append(FlashRW)
|
||||||
|
__all__.append(FlashRWSharded)
|
||||||
__all__.append(FlashSantacoder)
|
__all__.append(FlashSantacoder)
|
||||||
__all__.append(FlashSantacoderSharded)
|
__all__.append(FlashSantacoderSharded)
|
||||||
__all__.append(FlashLlama)
|
__all__.append(FlashLlama)
|
||||||
@ -202,13 +203,15 @@ def get_model(
|
|||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
if config.alibi:
|
if config.alibi:
|
||||||
raise NotImplementedError("sharded is not supported for this model")
|
raise NotImplementedError("sharded is not supported for this model")
|
||||||
# return FlashRWSharded(
|
return FlashRWSharded(
|
||||||
# model_id,
|
model_id,
|
||||||
# revision,
|
revision,
|
||||||
# quantize=quantize,
|
quantize=quantize,
|
||||||
# trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
# )
|
)
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb"))
|
raise NotImplementedError(
|
||||||
|
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb")
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if FLASH_ATTENTION and not config.alibi:
|
if FLASH_ATTENTION and not config.alibi:
|
||||||
return FlashRW(
|
return FlashRW(
|
||||||
|
@ -27,28 +27,30 @@ class RWConfig(PretrainedConfig):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_type="RefinedWeb",
|
model_type="RefinedWeb",
|
||||||
vocab_size=250880,
|
vocab_size=250880,
|
||||||
hidden_size=64,
|
hidden_size=64,
|
||||||
n_layer=2,
|
n_layer=2,
|
||||||
n_head=8,
|
n_head=8,
|
||||||
layer_norm_epsilon=1e-5,
|
layer_norm_epsilon=1e-5,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
bos_token_id=1,
|
bos_token_id=1,
|
||||||
eos_token_id=2,
|
eos_token_id=2,
|
||||||
hidden_dropout=0.0,
|
hidden_dropout=0.0,
|
||||||
attention_dropout=0.0,
|
attention_dropout=0.0,
|
||||||
n_head_kv=None,
|
n_head_kv=None,
|
||||||
multi_query=False,
|
multi_query=False,
|
||||||
alibi=False,
|
alibi=False,
|
||||||
bias=False,
|
bias=False,
|
||||||
parallel_attn=False,
|
parallel_attn=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if alibi:
|
if alibi:
|
||||||
raise NotImplementedError("alibi is not supported by this version of the model")
|
raise NotImplementedError(
|
||||||
|
"alibi is not supported by this version of the model"
|
||||||
|
)
|
||||||
|
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
self.alibi = False
|
self.alibi = False
|
||||||
@ -81,13 +83,13 @@ class RWConfig(PretrainedConfig):
|
|||||||
|
|
||||||
class FlashRWAttention(torch.nn.Module):
|
class FlashRWAttention(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_heads,
|
num_heads,
|
||||||
num_heads_kv,
|
num_heads_kv,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
bias,
|
bias,
|
||||||
process_group=None,
|
process_group=None,
|
||||||
reduce=True,
|
reduce=True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
@ -99,33 +101,44 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
self.softmax_scale = self.head_size ** (-0.5)
|
self.softmax_scale = self.head_size ** (-0.5)
|
||||||
|
|
||||||
if process_group is None:
|
if process_group is None:
|
||||||
self.query_key_value = FastLinear(hidden_size, self.head_size * (self.num_heads + 2 * self.num_heads_kv),
|
self.query_key_value = FastLinear(
|
||||||
bias=bias)
|
hidden_size,
|
||||||
|
self.head_size * (self.num_heads + 2 * self.num_heads_kv),
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
self.dense = FastLinear(hidden_size, hidden_size, bias=bias)
|
self.dense = FastLinear(hidden_size, hidden_size, bias=bias)
|
||||||
else:
|
else:
|
||||||
self.num_heads = self.num_heads // process_group.size()
|
self.num_heads = self.num_heads // process_group.size()
|
||||||
self.query_key_value = FastLinear(hidden_size, self.head_size * (self.num_heads + 2 * self.num_heads_kv),
|
self.query_key_value = FastLinear(
|
||||||
bias=bias)
|
hidden_size,
|
||||||
|
self.head_size * (self.num_heads + 2 * self.num_heads_kv),
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
self.dense = TensorParallelRowLinear(
|
self.dense = TensorParallelRowLinear(
|
||||||
hidden_size, hidden_size, bias=bias, process_group=process_group, reduce=reduce
|
hidden_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=bias,
|
||||||
|
process_group=process_group,
|
||||||
|
reduce=reduce,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
layer_past_present_indices,
|
layer_past_present_indices,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
|
|
||||||
# Split query from key_value
|
# Split query from key_value
|
||||||
query, kv = qkv.split(
|
query, kv = qkv.split(
|
||||||
[self.head_size * self.num_heads, 2 * self.head_size * self.num_heads_kv], dim=1
|
[self.head_size * self.num_heads, 2 * self.head_size * self.num_heads_kv],
|
||||||
|
dim=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prepare query and key_value for indexing
|
# Prepare query and key_value for indexing
|
||||||
@ -194,11 +207,149 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
|
||||||
|
|
||||||
class FlashMLP(nn.Module):
|
class FlashRWLargeAttention(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, hidden_size, bias, process_group=None, reduce=True
|
self,
|
||||||
|
num_heads,
|
||||||
|
num_heads_kv,
|
||||||
|
hidden_size,
|
||||||
|
bias,
|
||||||
|
process_group=None,
|
||||||
|
reduce=True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.head_size = hidden_size // num_heads
|
||||||
|
|
||||||
|
self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000)
|
||||||
|
self.softmax_scale = self.head_size ** (-0.5)
|
||||||
|
|
||||||
|
self.num_groups = num_heads // (num_heads_kv * 2)
|
||||||
|
self.num_heads = num_heads // self.num_groups
|
||||||
|
self.num_heads_kv = num_heads_kv // self.num_groups
|
||||||
|
|
||||||
|
if process_group is None:
|
||||||
|
self.query_key_value = FastLinear(
|
||||||
|
hidden_size,
|
||||||
|
self.num_groups *
|
||||||
|
self.head_size
|
||||||
|
* (self.num_heads + 2 * self.num_heads_kv),
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
self.dense = FastLinear(hidden_size, hidden_size, bias=bias)
|
||||||
|
else:
|
||||||
|
self.query_key_value = TensorParallelColumnLinear(
|
||||||
|
hidden_size,
|
||||||
|
self.num_groups
|
||||||
|
* self.head_size
|
||||||
|
* (self.num_heads + 2 * self.num_heads_kv),
|
||||||
|
bias=bias,
|
||||||
|
process_group=process_group,
|
||||||
|
)
|
||||||
|
self.dense = TensorParallelRowLinear(
|
||||||
|
hidden_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=bias,
|
||||||
|
process_group=process_group,
|
||||||
|
reduce=reduce,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_groups = self.num_groups // process_group.size()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
layer_past,
|
||||||
|
layer_past_present_indices,
|
||||||
|
cu_seqlens_q,
|
||||||
|
):
|
||||||
|
qkv = self.query_key_value(hidden_states)
|
||||||
|
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
|
||||||
|
|
||||||
|
# Split query from key_value
|
||||||
|
query, kv = qkv.split(
|
||||||
|
[self.num_heads, 2],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare query and key_value for indexing
|
||||||
|
query = query.reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
||||||
|
kv = kv.transpose(1, 2)
|
||||||
|
|
||||||
|
# Inplace rotary
|
||||||
|
self.rotary_emb(query, cos, sin)
|
||||||
|
self.rotary_emb(kv[:, 0], cos, sin)
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
if layer_past_present_indices is None:
|
||||||
|
# Copy to layer past
|
||||||
|
layer_past[...] = kv
|
||||||
|
k, v = kv.split(1, dim=1)
|
||||||
|
# Expand to query shape
|
||||||
|
k = k.expand(-1, self.num_heads, self.num_groups, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
||||||
|
v = v.expand(-1, self.num_heads, self.num_groups, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
||||||
|
|
||||||
|
# output
|
||||||
|
attn_output = torch.empty_like(query)
|
||||||
|
# flash attention
|
||||||
|
flash_attn_cuda.fwd(
|
||||||
|
query,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
attn_output,
|
||||||
|
cu_seqlens,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
max_s,
|
||||||
|
0.0,
|
||||||
|
self.softmax_scale,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
# Decode
|
||||||
|
else:
|
||||||
|
# Add present to the layer_past tensor at the correct indices
|
||||||
|
layer_past[layer_past_present_indices] = kv
|
||||||
|
k, v = layer_past.split(1, dim=1)
|
||||||
|
# Expand to query shape
|
||||||
|
k = k.expand(-1, self.num_heads, self.num_groups, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
||||||
|
v = v.expand(-1, self.num_heads, self.num_groups, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
||||||
|
|
||||||
|
# output
|
||||||
|
attn_output = torch.empty_like(query)
|
||||||
|
# flash attention
|
||||||
|
flash_attn_cuda.fwd(
|
||||||
|
query,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
attn_output,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens,
|
||||||
|
1,
|
||||||
|
max_s,
|
||||||
|
0.0,
|
||||||
|
self.softmax_scale,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.dense(attn_output.view(-1, self.num_heads * self.num_groups * self.head_size))
|
||||||
|
|
||||||
|
|
||||||
|
class FlashMLP(nn.Module):
|
||||||
|
def __init__(self, hidden_size, bias, process_group=None, reduce=True):
|
||||||
|
super().__init__()
|
||||||
self.act = torch.nn.functional.gelu
|
self.act = torch.nn.functional.gelu
|
||||||
|
|
||||||
if process_group is None:
|
if process_group is None:
|
||||||
@ -207,12 +358,14 @@ class FlashMLP(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.dense_h_to_4h = TensorParallelColumnLinear(
|
self.dense_h_to_4h = TensorParallelColumnLinear(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
4 * hidden_size, bias=bias,
|
4 * hidden_size,
|
||||||
|
bias=bias,
|
||||||
process_group=process_group,
|
process_group=process_group,
|
||||||
)
|
)
|
||||||
self.dense_4h_to_h = TensorParallelRowLinear(
|
self.dense_4h_to_h = TensorParallelRowLinear(
|
||||||
4 * hidden_size,
|
4 * hidden_size,
|
||||||
hidden_size, bias=bias,
|
hidden_size,
|
||||||
|
bias=bias,
|
||||||
process_group=process_group,
|
process_group=process_group,
|
||||||
reduce=reduce,
|
reduce=reduce,
|
||||||
)
|
)
|
||||||
@ -227,37 +380,44 @@ class FlashMLP(nn.Module):
|
|||||||
|
|
||||||
class FlashRWLayer(nn.Module):
|
class FlashRWLayer(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_heads,
|
num_heads,
|
||||||
num_heads_kv,
|
num_heads_kv,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
layer_norm_eps,
|
bias,
|
||||||
parallel_attn,
|
layer_norm_eps,
|
||||||
process_group=None,
|
parallel_attn,
|
||||||
|
process_group=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.parallel_attn = parallel_attn
|
self.parallel_attn = parallel_attn
|
||||||
|
|
||||||
self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||||
self.self_attention = FlashRWAttention(num_heads, num_heads_kv, hidden_size, process_group, reduce=False)
|
self.self_attention = FlashRWAttention(
|
||||||
self.post_attention_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps) if not parallel_attn else None
|
num_heads, num_heads_kv, hidden_size, bias, process_group=process_group, reduce=False
|
||||||
|
)
|
||||||
|
self.post_attention_layernorm = (
|
||||||
|
FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||||
|
if not parallel_attn
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
self.mlp = FlashMLP(hidden_size, process_group, reduce=False)
|
self.mlp = FlashMLP(hidden_size, bias, process_group=process_group, reduce=False)
|
||||||
|
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
residual,
|
residual,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
layer_past_present_indices,
|
layer_past_present_indices,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
):
|
):
|
||||||
if self.parallel_attn:
|
if self.parallel_attn:
|
||||||
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
@ -303,6 +463,68 @@ class FlashRWLayer(nn.Module):
|
|||||||
|
|
||||||
return mlp_output, residual
|
return mlp_output, residual
|
||||||
|
|
||||||
|
|
||||||
|
class FlashRWLargeLayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads,
|
||||||
|
num_heads_kv,
|
||||||
|
hidden_size,
|
||||||
|
bias,
|
||||||
|
layer_norm_eps,
|
||||||
|
process_group=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.ln_attn = FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||||
|
self.ln_mlp = FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||||
|
|
||||||
|
self.self_attention = FlashRWLargeAttention(
|
||||||
|
num_heads, num_heads_kv, hidden_size, bias, process_group=process_group, reduce=False
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mlp = FlashMLP(hidden_size, bias, process_group=process_group, reduce=False)
|
||||||
|
|
||||||
|
self.process_group = process_group
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
layer_past,
|
||||||
|
layer_past_present_indices,
|
||||||
|
cu_seqlens_q,
|
||||||
|
):
|
||||||
|
ln_attn, residual = self.ln_attn(hidden_states, residual)
|
||||||
|
ln_mlp, _ = self.ln_mlp(hidden_states, residual)
|
||||||
|
|
||||||
|
# Self attention.
|
||||||
|
attn_output = self.self_attention(
|
||||||
|
ln_attn,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
layer_past,
|
||||||
|
layer_past_present_indices,
|
||||||
|
cu_seqlens_q,
|
||||||
|
)
|
||||||
|
|
||||||
|
# MLP.
|
||||||
|
mlp_output = self.mlp(ln_mlp)
|
||||||
|
|
||||||
|
intermediate = attn_output + mlp_output
|
||||||
|
|
||||||
|
# Only reduce once and after the addition instead of once per layer
|
||||||
|
if self.process_group is not None:
|
||||||
|
torch.distributed.all_reduce(intermediate, group=self.process_group)
|
||||||
|
|
||||||
|
return intermediate, residual
|
||||||
|
|
||||||
|
|
||||||
class FlashRWPreTrainedModel(PreTrainedModel):
|
class FlashRWPreTrainedModel(PreTrainedModel):
|
||||||
config_class = RWConfig
|
config_class = RWConfig
|
||||||
supports_gradient_checkpointing = False
|
supports_gradient_checkpointing = False
|
||||||
@ -328,27 +550,47 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||||
|
|
||||||
self.h = nn.ModuleList(
|
if config.model_type == "RefinedWebModel":
|
||||||
[
|
self.h = nn.ModuleList(
|
||||||
FlashRWLayer(
|
[
|
||||||
config.n_head,
|
FlashRWLayer(
|
||||||
config.n_head_kv,
|
config.n_head,
|
||||||
config.hidden_size,
|
config.n_head_kv,
|
||||||
config.layer_norm_epsilon,
|
config.hidden_size,
|
||||||
config.parallel_attn,
|
config.bias,
|
||||||
process_group,
|
config.layer_norm_epsilon,
|
||||||
)
|
config.parallel_attn,
|
||||||
for _ in range(config.num_hidden_layers)
|
process_group,
|
||||||
]
|
)
|
||||||
)
|
for _ in range(config.num_hidden_layers)
|
||||||
self.ln_f = FastLayerNorm(
|
]
|
||||||
config.hidden_size, eps=config.layer_norm_epsilon
|
)
|
||||||
)
|
self.kv_size = self.h[0].self_attention.num_heads_kv
|
||||||
|
elif config.model_type == "RefinedWeb":
|
||||||
|
self.h = nn.ModuleList(
|
||||||
|
[
|
||||||
|
FlashRWLargeLayer(
|
||||||
|
config.n_head,
|
||||||
|
config.n_head_kv,
|
||||||
|
config.hidden_size,
|
||||||
|
config.bias,
|
||||||
|
config.layer_norm_epsilon,
|
||||||
|
process_group,
|
||||||
|
)
|
||||||
|
for _ in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.kv_size = self.h[0].self_attention.num_groups
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"model_type {config.model_type} is not supported."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.ln_f = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
self.head_size = self.h[0].self_attention.head_size
|
self.head_size = self.h[0].self_attention.head_size
|
||||||
self.num_heads_kv = self.h[0].self_attention.num_heads_kv
|
|
||||||
|
|
||||||
def post_load_weights(self, quantize: Optional[str] = None):
|
def post_load_weights(self, quantize: Optional[str] = None):
|
||||||
if isinstance(self.word_embeddings, TensorParallelEmbedding):
|
if isinstance(self.word_embeddings, TensorParallelEmbedding):
|
||||||
@ -373,14 +615,14 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
max_s,
|
max_s,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
hidden_states = self.word_embeddings(input_ids)
|
hidden_states = self.word_embeddings(input_ids)
|
||||||
|
|
||||||
@ -394,7 +636,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
if pre_allocate_past_size is None
|
if pre_allocate_past_size is None
|
||||||
else pre_allocate_past_size,
|
else pre_allocate_past_size,
|
||||||
2,
|
2,
|
||||||
self.num_heads_kv,
|
self.kv_size,
|
||||||
self.head_size,
|
self.head_size,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -457,9 +699,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
|||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.lm_head = FastLinear(
|
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
config.hidden_size, config.vocab_size, bias=False
|
|
||||||
)
|
|
||||||
|
|
||||||
def post_load_weights(self, quantize: Optional[str] = None):
|
def post_load_weights(self, quantize: Optional[str] = None):
|
||||||
self.transformer.post_load_weights(quantize)
|
self.transformer.post_load_weights(quantize)
|
||||||
@ -477,14 +717,14 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
max_s,
|
max_s,
|
||||||
past_key_values: Optional[torch.Tensor] = None,
|
past_key_values: Optional[torch.Tensor] = None,
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
hidden_states, present = self.transformer(
|
hidden_states, present = self.transformer(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
@ -53,8 +53,6 @@ class FlashRW(FlashCausalLM):
|
|||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
)
|
)
|
||||||
from loguru import logger
|
|
||||||
logger.error(config.model_type)
|
|
||||||
|
|
||||||
# We do not use from_pretrained as we modified the model internal module layout
|
# We do not use from_pretrained as we modified the model internal module layout
|
||||||
try:
|
try:
|
||||||
@ -114,133 +112,134 @@ class FlashRW(FlashCausalLM):
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
model.post_load_weights(quantize)
|
model.post_load_weights(quantize)
|
||||||
|
|
||||||
#
|
|
||||||
# class FlashNeoXSharded(FlashNeoX):
|
|
||||||
# def __init__(
|
class FlashRWSharded(FlashRW):
|
||||||
# self,
|
def __init__(
|
||||||
# model_id: str,
|
self,
|
||||||
# revision: Optional[str] = None,
|
model_id: str,
|
||||||
# quantize: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
# trust_remote_code: bool = False,
|
quantize: Optional[str] = None,
|
||||||
# ):
|
trust_remote_code: bool = False,
|
||||||
# self.process_group, rank, world_size = initialize_torch_distributed()
|
):
|
||||||
# if torch.cuda.is_available():
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
# device = torch.device(f"cuda:{rank}")
|
if torch.cuda.is_available():
|
||||||
# dtype = torch.float16
|
device = torch.device(f"cuda:{rank}")
|
||||||
# else:
|
dtype = torch.float16
|
||||||
# raise NotImplementedError("FlashNeoX is only available on GPU")
|
else:
|
||||||
#
|
raise NotImplementedError("FlashRW is only available on GPU")
|
||||||
# tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
# model_id,
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
# revision=revision,
|
model_id,
|
||||||
# padding_side="left",
|
revision=revision,
|
||||||
# truncation_side="left",
|
padding_side="left",
|
||||||
# trust_remote_code=trust_remote_code,
|
truncation_side="left",
|
||||||
# )
|
trust_remote_code=trust_remote_code,
|
||||||
#
|
)
|
||||||
# config = AutoConfig.from_pretrained(
|
|
||||||
# model_id, revision=revision, trust_remote_code=trust_remote_code
|
config = RWConfig.from_pretrained(
|
||||||
# )
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
#
|
)
|
||||||
# torch.distributed.barrier(group=self.process_group)
|
|
||||||
# filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
torch.distributed.barrier(group=self.process_group)
|
||||||
#
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
# with init_empty_weights():
|
|
||||||
# model = FlashGPTNeoXForCausalLM(config, self.process_group)
|
with init_empty_weights():
|
||||||
#
|
model = FlashRWForCausalLM(config, self.process_group)
|
||||||
# torch.distributed.barrier(group=self.process_group)
|
|
||||||
# self.load_weights(
|
torch.distributed.barrier(group=self.process_group)
|
||||||
# model,
|
self.load_weights(
|
||||||
# filenames,
|
model,
|
||||||
# quantize=quantize,
|
filenames,
|
||||||
# device=device,
|
quantize=quantize,
|
||||||
# dtype=dtype,
|
device=device,
|
||||||
# rank=rank,
|
dtype=dtype,
|
||||||
# world_size=world_size,
|
rank=rank,
|
||||||
# )
|
world_size=world_size,
|
||||||
# torch.distributed.barrier(group=self.process_group)
|
)
|
||||||
# super(FlashCausalLM, self).__init__(
|
torch.distributed.barrier(group=self.process_group)
|
||||||
# model=model.to(device),
|
super(FlashCausalLM, self).__init__(
|
||||||
# tokenizer=tokenizer,
|
model=model.to(device),
|
||||||
# requires_padding=False,
|
tokenizer=tokenizer,
|
||||||
# dtype=dtype,
|
requires_padding=False,
|
||||||
# device=device,
|
dtype=dtype,
|
||||||
# rank=rank,
|
device=device,
|
||||||
# world_size=world_size,
|
rank=rank,
|
||||||
# )
|
world_size=world_size,
|
||||||
#
|
)
|
||||||
# @staticmethod
|
|
||||||
# def load_weights(
|
@staticmethod
|
||||||
# model,
|
def load_weights(
|
||||||
# filenames: List[str],
|
model,
|
||||||
# quantize: Optional[str],
|
filenames: List[str],
|
||||||
# device: torch.device,
|
quantize: Optional[str],
|
||||||
# dtype: torch.dtype,
|
device: torch.device,
|
||||||
# rank: int,
|
dtype: torch.dtype,
|
||||||
# world_size: int,
|
rank: int,
|
||||||
# ):
|
world_size: int,
|
||||||
# parameters = dict(model.named_parameters())
|
):
|
||||||
# for file in filenames:
|
parameters = dict(model.named_parameters())
|
||||||
# with safe_open(
|
for file in filenames:
|
||||||
# file, framework="pt", device=str(device) if quantize is None else "cpu"
|
with safe_open(
|
||||||
# ) as f:
|
file, framework="pt", device=str(device) if quantize is None else "cpu"
|
||||||
# for name in f.keys():
|
) as f:
|
||||||
# module_name, param_name = name.rsplit(".", 1)
|
for name in f.keys():
|
||||||
# module = model.get_submodule(module_name)
|
module_name, param_name = name.rsplit(".", 1)
|
||||||
#
|
module = model.get_submodule(module_name)
|
||||||
# current_parameter_tensor = parameters.get(name, None)
|
|
||||||
#
|
current_parameter_tensor = parameters.get(name, None)
|
||||||
# slice_ = f.get_slice(name)
|
|
||||||
#
|
slice_ = f.get_slice(name)
|
||||||
# if isinstance(module, TensorParallelColumnLinear):
|
|
||||||
# size = slice_.get_shape()[0]
|
if isinstance(module, TensorParallelColumnLinear):
|
||||||
# block_size = size // world_size
|
size = slice_.get_shape()[0]
|
||||||
# start = rank * block_size
|
block_size = size // world_size
|
||||||
# stop = (rank + 1) * block_size
|
start = rank * block_size
|
||||||
# tensor = slice_[start:stop]
|
stop = (rank + 1) * block_size
|
||||||
# elif isinstance(module, TensorParallelRowLinear):
|
tensor = slice_[start:stop]
|
||||||
# if param_name == "weight":
|
elif isinstance(module, TensorParallelRowLinear):
|
||||||
# size = slice_.get_shape()[1]
|
if param_name == "weight":
|
||||||
# block_size = size // world_size
|
size = slice_.get_shape()[1]
|
||||||
# start = rank * block_size
|
block_size = size // world_size
|
||||||
# stop = (rank + 1) * block_size
|
start = rank * block_size
|
||||||
# tensor = slice_[:, start:stop]
|
stop = (rank + 1) * block_size
|
||||||
# else:
|
tensor = slice_[:, start:stop]
|
||||||
# tensor = slice_[:]
|
else:
|
||||||
# # XXX: Hack for Rowlinear to add the bias only once.
|
tensor = slice_[:]
|
||||||
# if rank != 0:
|
# XXX: Hack for Rowlinear to add the bias only once.
|
||||||
# tensor = torch.zeros_like(tensor)
|
if rank != 0:
|
||||||
# elif isinstance(module, TensorParallelEmbedding):
|
tensor = torch.zeros_like(tensor)
|
||||||
# size = slice_.get_shape()[0]
|
elif isinstance(module, TensorParallelEmbedding):
|
||||||
# block_size = size // world_size
|
size = slice_.get_shape()[0]
|
||||||
# start = rank * block_size
|
block_size = size // world_size
|
||||||
# stop = (rank + 1) * block_size
|
start = rank * block_size
|
||||||
# tensor = slice_[start:stop]
|
stop = (rank + 1) * block_size
|
||||||
# elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings:
|
tensor = slice_[start:stop]
|
||||||
# size = slice_.get_shape()[0]
|
elif name == "lm_head.weight" and model.transformer.tp_embeddings:
|
||||||
# block_size = size // world_size
|
size = slice_.get_shape()[0]
|
||||||
# start = rank * block_size
|
block_size = size // world_size
|
||||||
# stop = (rank + 1) * block_size
|
start = rank * block_size
|
||||||
# tensor = slice_[start:stop]
|
stop = (rank + 1) * block_size
|
||||||
# else:
|
tensor = slice_[start:stop]
|
||||||
# try:
|
else:
|
||||||
# tensor = slice_[:]
|
try:
|
||||||
# except:
|
tensor = slice_[:]
|
||||||
# tensor = f.get_tensor(name)
|
except:
|
||||||
#
|
tensor = f.get_tensor(name)
|
||||||
# if (
|
|
||||||
# current_parameter_tensor is not None
|
if (
|
||||||
# and current_parameter_tensor.shape != tensor.shape
|
current_parameter_tensor is not None
|
||||||
# ):
|
and current_parameter_tensor.shape != tensor.shape
|
||||||
# raise ValueError(
|
):
|
||||||
# f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
|
raise ValueError(
|
||||||
# )
|
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
|
||||||
#
|
)
|
||||||
# tensor = tensor.contiguous().to(dtype)
|
|
||||||
#
|
tensor = tensor.contiguous().to(dtype)
|
||||||
# if current_parameter_tensor is not None:
|
|
||||||
# module._parameters[param_name] = tensor
|
if current_parameter_tensor is not None:
|
||||||
# else:
|
module._parameters[param_name] = tensor
|
||||||
# module._buffers[param_name] = tensor
|
else:
|
||||||
#
|
module._buffers[param_name] = tensor
|
||||||
# model.post_load_weights(quantize)
|
|
||||||
|
model.post_load_weights(quantize)
|
||||||
|
@ -8,11 +8,11 @@ from text_generation_server.models import CausalLM
|
|||||||
|
|
||||||
class RW(CausalLM):
|
class RW(CausalLM):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
@ -63,7 +63,7 @@ class RW(CausalLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
if past_key_values is not None:
|
if past_key_values is not None:
|
||||||
@ -71,10 +71,16 @@ class RW(CausalLM):
|
|||||||
for layer in past_key_values:
|
for layer in past_key_values:
|
||||||
past_keys, past_values = layer
|
past_keys, past_values = layer
|
||||||
reshaped_past_key_values.append(
|
reshaped_past_key_values.append(
|
||||||
(past_keys.view(-1, *past_keys.shape[-2:]), past_values.view(-1, *past_values.shape[-2:]))
|
(
|
||||||
|
past_keys.view(-1, *past_keys.shape[-2:]),
|
||||||
|
past_values.view(-1, *past_values.shape[-2:]),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
past_key_values = reshaped_past_key_values
|
past_key_values = reshaped_past_key_values
|
||||||
|
|
||||||
outputs = self.model.forward(input_ids=input_ids, attention_mask=attention_mask,
|
outputs = self.model.forward(
|
||||||
past_key_values=past_key_values)
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
)
|
||||||
return outputs.logits, outputs.past_key_values
|
return outputs.logits, outputs.past_key_values
|
||||||
|
Loading…
Reference in New Issue
Block a user