mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
black
This commit is contained in:
parent
8c8d709994
commit
51b5db77f5
@ -201,7 +201,10 @@ def get_model(
|
|||||||
if model_type in ["RefinedWeb", "RefinedWebModel"]:
|
if model_type in ["RefinedWeb", "RefinedWebModel"]:
|
||||||
if sharded:
|
if sharded:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
if config.alibi:
|
if config.alibi or (
|
||||||
|
config.model_type == "RefinedWebModel"
|
||||||
|
and config.n_head_kv != config.n_head
|
||||||
|
):
|
||||||
raise NotImplementedError("sharded is not supported for this model")
|
raise NotImplementedError("sharded is not supported for this model")
|
||||||
return FlashRWSharded(
|
return FlashRWSharded(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -107,11 +107,11 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
self.dense = FastLinear(hidden_size, hidden_size, bias=bias)
|
self.dense = FastLinear(hidden_size, hidden_size, bias=bias)
|
||||||
else:
|
else:
|
||||||
self.num_heads = self.num_heads // process_group.size()
|
self.query_key_value = TensorParallelColumnLinear(
|
||||||
self.query_key_value = FastLinear(
|
|
||||||
hidden_size,
|
hidden_size,
|
||||||
self.head_size * (self.num_heads + 2 * self.num_heads_kv),
|
self.head_size * (self.num_heads + 2 * self.num_heads_kv),
|
||||||
bias=bias,
|
bias=bias,
|
||||||
|
process_group=process_group,
|
||||||
)
|
)
|
||||||
self.dense = TensorParallelRowLinear(
|
self.dense = TensorParallelRowLinear(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
@ -120,6 +120,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
process_group=process_group,
|
process_group=process_group,
|
||||||
reduce=reduce,
|
reduce=reduce,
|
||||||
)
|
)
|
||||||
|
self.num_heads = self.num_heads // process_group.size()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -231,13 +232,18 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
if process_group is None:
|
if process_group is None:
|
||||||
self.query_key_value = FastLinear(
|
self.query_key_value = FastLinear(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
self.num_groups *
|
self.num_groups
|
||||||
self.head_size
|
* self.head_size
|
||||||
* (self.num_heads + 2 * self.num_heads_kv),
|
* (self.num_heads + 2 * self.num_heads_kv),
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
self.dense = FastLinear(hidden_size, hidden_size, bias=bias)
|
self.dense = FastLinear(hidden_size, hidden_size, bias=bias)
|
||||||
else:
|
else:
|
||||||
|
if process_group.size() > self.num_groups:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Tensor Parallelism is not implemented for world_size > n groups"
|
||||||
|
)
|
||||||
|
|
||||||
self.query_key_value = TensorParallelColumnLinear(
|
self.query_key_value = TensorParallelColumnLinear(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
self.num_groups
|
self.num_groups
|
||||||
@ -269,10 +275,13 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
|
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
|
||||||
|
|
||||||
|
# Split on group dimension
|
||||||
query, kv = qkv.split(
|
query, kv = qkv.split(
|
||||||
[self.num_heads, 2],
|
[self.num_heads, 2],
|
||||||
dim=2,
|
dim=2,
|
||||||
)
|
)
|
||||||
|
# Merge groups and heads
|
||||||
query = query.reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
query = query.reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
||||||
|
|
||||||
# Inplace rotary
|
# Inplace rotary
|
||||||
@ -285,8 +294,12 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
layer_past[...] = kv
|
layer_past[...] = kv
|
||||||
k, v = kv.split(1, dim=2)
|
k, v = kv.split(1, dim=2)
|
||||||
# Expand to query shape
|
# Expand to query shape
|
||||||
k = k.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
k = k.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(
|
||||||
v = v.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
-1, self.num_groups * self.num_heads, self.head_size
|
||||||
|
)
|
||||||
|
v = v.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(
|
||||||
|
-1, self.num_groups * self.num_heads, self.head_size
|
||||||
|
)
|
||||||
|
|
||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
@ -314,8 +327,12 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
layer_past[layer_past_present_indices] = kv
|
layer_past[layer_past_present_indices] = kv
|
||||||
k, v = layer_past.split(1, dim=2)
|
k, v = layer_past.split(1, dim=2)
|
||||||
# Expand to query shape
|
# Expand to query shape
|
||||||
k = k.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
k = k.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(
|
||||||
v = v.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
-1, self.num_groups * self.num_heads, self.head_size
|
||||||
|
)
|
||||||
|
v = v.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(
|
||||||
|
-1, self.num_groups * self.num_heads, self.head_size
|
||||||
|
)
|
||||||
|
|
||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
@ -338,7 +355,9 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.dense(attn_output.view(-1, self.num_groups * self.num_heads * self.head_size))
|
return self.dense(
|
||||||
|
attn_output.view(-1, self.num_groups * self.num_heads * self.head_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FlashMLP(nn.Module):
|
class FlashMLP(nn.Module):
|
||||||
@ -389,7 +408,12 @@ class FlashRWLayer(nn.Module):
|
|||||||
|
|
||||||
self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||||
self.self_attention = FlashRWAttention(
|
self.self_attention = FlashRWAttention(
|
||||||
num_heads, num_heads_kv, hidden_size, bias, process_group=process_group, reduce=False
|
num_heads,
|
||||||
|
num_heads_kv,
|
||||||
|
hidden_size,
|
||||||
|
bias,
|
||||||
|
process_group=process_group,
|
||||||
|
reduce=False,
|
||||||
)
|
)
|
||||||
self.post_attention_layernorm = (
|
self.post_attention_layernorm = (
|
||||||
FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||||
@ -397,7 +421,9 @@ class FlashRWLayer(nn.Module):
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
self.mlp = FlashMLP(hidden_size, bias, process_group=process_group, reduce=False)
|
self.mlp = FlashMLP(
|
||||||
|
hidden_size, bias, process_group=process_group, reduce=False
|
||||||
|
)
|
||||||
|
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
|
|
||||||
@ -473,10 +499,17 @@ class FlashRWLargeLayer(nn.Module):
|
|||||||
self.ln_mlp = FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
self.ln_mlp = FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||||
|
|
||||||
self.self_attention = FlashRWLargeAttention(
|
self.self_attention = FlashRWLargeAttention(
|
||||||
num_heads, num_heads_kv, hidden_size, bias, process_group=process_group, reduce=False
|
num_heads,
|
||||||
|
num_heads_kv,
|
||||||
|
hidden_size,
|
||||||
|
bias,
|
||||||
|
process_group=process_group,
|
||||||
|
reduce=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.mlp = FlashMLP(hidden_size, bias, process_group=process_group, reduce=False)
|
self.mlp = FlashMLP(
|
||||||
|
hidden_size, bias, process_group=process_group, reduce=False
|
||||||
|
)
|
||||||
|
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
|
|
||||||
@ -492,8 +525,8 @@ class FlashRWLargeLayer(nn.Module):
|
|||||||
layer_past_present_indices,
|
layer_past_present_indices,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
):
|
):
|
||||||
ln_attn, _ = self.ln_attn(hidden_states)
|
ln_attn, residual = self.ln_attn(hidden_states, residual)
|
||||||
ln_mlp, _ = self.ln_mlp(hidden_states)
|
ln_mlp, _ = self.ln_mlp(residual)
|
||||||
|
|
||||||
# Self attention.
|
# Self attention.
|
||||||
attn_output = self.self_attention(
|
attn_output = self.self_attention(
|
||||||
@ -516,13 +549,11 @@ class FlashRWLargeLayer(nn.Module):
|
|||||||
if self.process_group is not None:
|
if self.process_group is not None:
|
||||||
torch.distributed.all_reduce(intermediate, group=self.process_group)
|
torch.distributed.all_reduce(intermediate, group=self.process_group)
|
||||||
|
|
||||||
return intermediate + hidden_states, None
|
return intermediate, residual
|
||||||
|
|
||||||
|
|
||||||
class FlashRWPreTrainedModel(PreTrainedModel):
|
class FlashRWPreTrainedModel(PreTrainedModel):
|
||||||
config_class = RWConfig
|
config_class = RWConfig
|
||||||
supports_gradient_checkpointing = False
|
|
||||||
_no_split_modules = None
|
|
||||||
|
|
||||||
|
|
||||||
class FlashRWModel(FlashRWPreTrainedModel):
|
class FlashRWModel(FlashRWPreTrainedModel):
|
||||||
@ -559,7 +590,11 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
for _ in range(config.num_hidden_layers)
|
for _ in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.cache_size = (2, self.h[0].self_attention.num_heads_kv, self.h[0].self_attention.head_size)
|
self.cache_size = (
|
||||||
|
2,
|
||||||
|
self.h[0].self_attention.num_heads_kv,
|
||||||
|
self.h[0].self_attention.head_size,
|
||||||
|
)
|
||||||
elif config.model_type == "RefinedWeb":
|
elif config.model_type == "RefinedWeb":
|
||||||
self.h = nn.ModuleList(
|
self.h = nn.ModuleList(
|
||||||
[
|
[
|
||||||
@ -574,7 +609,11 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
for _ in range(config.num_hidden_layers)
|
for _ in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.cache_size = (self.h[0].self_attention.num_groups, 2, self.h[0].self_attention.head_size)
|
self.cache_size = (
|
||||||
|
self.h[0].self_attention.num_groups,
|
||||||
|
2,
|
||||||
|
self.h[0].self_attention.head_size,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"model_type {config.model_type} is not supported."
|
f"model_type {config.model_type} is not supported."
|
||||||
@ -582,8 +621,6 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
|
|
||||||
self.ln_f = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
self.ln_f = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
|
||||||
|
|
||||||
self.head_size = self.h[0].self_attention.head_size
|
self.head_size = self.h[0].self_attention.head_size
|
||||||
|
|
||||||
def post_load_weights(self, quantize: Optional[str] = None):
|
def post_load_weights(self, quantize: Optional[str] = None):
|
||||||
@ -629,7 +666,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
len(hidden_states)
|
len(hidden_states)
|
||||||
if pre_allocate_past_size is None
|
if pre_allocate_past_size is None
|
||||||
else pre_allocate_past_size,
|
else pre_allocate_past_size,
|
||||||
*self.cache_size
|
*self.cache_size,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
layer_past_present_indices = None
|
layer_past_present_indices = None
|
||||||
|
@ -113,7 +113,6 @@ class FlashRW(FlashCausalLM):
|
|||||||
model.post_load_weights(quantize)
|
model.post_load_weights(quantize)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FlashRWSharded(FlashRW):
|
class FlashRWSharded(FlashRW):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user