mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
black
This commit is contained in:
parent
8c8d709994
commit
51b5db77f5
@ -201,7 +201,10 @@ def get_model(
|
||||
if model_type in ["RefinedWeb", "RefinedWebModel"]:
|
||||
if sharded:
|
||||
if FLASH_ATTENTION:
|
||||
if config.alibi:
|
||||
if config.alibi or (
|
||||
config.model_type == "RefinedWebModel"
|
||||
and config.n_head_kv != config.n_head
|
||||
):
|
||||
raise NotImplementedError("sharded is not supported for this model")
|
||||
return FlashRWSharded(
|
||||
model_id,
|
||||
|
@ -107,11 +107,11 @@ class FlashRWAttention(torch.nn.Module):
|
||||
)
|
||||
self.dense = FastLinear(hidden_size, hidden_size, bias=bias)
|
||||
else:
|
||||
self.num_heads = self.num_heads // process_group.size()
|
||||
self.query_key_value = FastLinear(
|
||||
self.query_key_value = TensorParallelColumnLinear(
|
||||
hidden_size,
|
||||
self.head_size * (self.num_heads + 2 * self.num_heads_kv),
|
||||
bias=bias,
|
||||
process_group=process_group,
|
||||
)
|
||||
self.dense = TensorParallelRowLinear(
|
||||
hidden_size,
|
||||
@ -120,6 +120,7 @@ class FlashRWAttention(torch.nn.Module):
|
||||
process_group=process_group,
|
||||
reduce=reduce,
|
||||
)
|
||||
self.num_heads = self.num_heads // process_group.size()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -231,13 +232,18 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
if process_group is None:
|
||||
self.query_key_value = FastLinear(
|
||||
hidden_size,
|
||||
self.num_groups *
|
||||
self.head_size
|
||||
self.num_groups
|
||||
* self.head_size
|
||||
* (self.num_heads + 2 * self.num_heads_kv),
|
||||
bias=bias,
|
||||
)
|
||||
self.dense = FastLinear(hidden_size, hidden_size, bias=bias)
|
||||
else:
|
||||
if process_group.size() > self.num_groups:
|
||||
raise NotImplementedError(
|
||||
f"Tensor Parallelism is not implemented for world_size > n groups"
|
||||
)
|
||||
|
||||
self.query_key_value = TensorParallelColumnLinear(
|
||||
hidden_size,
|
||||
self.num_groups
|
||||
@ -269,10 +275,13 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
|
||||
|
||||
# Split on group dimension
|
||||
query, kv = qkv.split(
|
||||
[self.num_heads, 2],
|
||||
dim=2,
|
||||
)
|
||||
# Merge groups and heads
|
||||
query = query.reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
||||
|
||||
# Inplace rotary
|
||||
@ -285,8 +294,12 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
layer_past[...] = kv
|
||||
k, v = kv.split(1, dim=2)
|
||||
# Expand to query shape
|
||||
k = k.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
||||
v = v.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
||||
k = k.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(
|
||||
-1, self.num_groups * self.num_heads, self.head_size
|
||||
)
|
||||
v = v.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(
|
||||
-1, self.num_groups * self.num_heads, self.head_size
|
||||
)
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
@ -314,8 +327,12 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
layer_past[layer_past_present_indices] = kv
|
||||
k, v = layer_past.split(1, dim=2)
|
||||
# Expand to query shape
|
||||
k = k.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
||||
v = v.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
||||
k = k.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(
|
||||
-1, self.num_groups * self.num_heads, self.head_size
|
||||
)
|
||||
v = v.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(
|
||||
-1, self.num_groups * self.num_heads, self.head_size
|
||||
)
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
@ -338,7 +355,9 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
None,
|
||||
)
|
||||
|
||||
return self.dense(attn_output.view(-1, self.num_groups * self.num_heads * self.head_size))
|
||||
return self.dense(
|
||||
attn_output.view(-1, self.num_groups * self.num_heads * self.head_size)
|
||||
)
|
||||
|
||||
|
||||
class FlashMLP(nn.Module):
|
||||
@ -389,7 +408,12 @@ class FlashRWLayer(nn.Module):
|
||||
|
||||
self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
self.self_attention = FlashRWAttention(
|
||||
num_heads, num_heads_kv, hidden_size, bias, process_group=process_group, reduce=False
|
||||
num_heads,
|
||||
num_heads_kv,
|
||||
hidden_size,
|
||||
bias,
|
||||
process_group=process_group,
|
||||
reduce=False,
|
||||
)
|
||||
self.post_attention_layernorm = (
|
||||
FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
@ -397,7 +421,9 @@ class FlashRWLayer(nn.Module):
|
||||
else None
|
||||
)
|
||||
|
||||
self.mlp = FlashMLP(hidden_size, bias, process_group=process_group, reduce=False)
|
||||
self.mlp = FlashMLP(
|
||||
hidden_size, bias, process_group=process_group, reduce=False
|
||||
)
|
||||
|
||||
self.process_group = process_group
|
||||
|
||||
@ -473,10 +499,17 @@ class FlashRWLargeLayer(nn.Module):
|
||||
self.ln_mlp = FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
|
||||
self.self_attention = FlashRWLargeAttention(
|
||||
num_heads, num_heads_kv, hidden_size, bias, process_group=process_group, reduce=False
|
||||
num_heads,
|
||||
num_heads_kv,
|
||||
hidden_size,
|
||||
bias,
|
||||
process_group=process_group,
|
||||
reduce=False,
|
||||
)
|
||||
|
||||
self.mlp = FlashMLP(hidden_size, bias, process_group=process_group, reduce=False)
|
||||
self.mlp = FlashMLP(
|
||||
hidden_size, bias, process_group=process_group, reduce=False
|
||||
)
|
||||
|
||||
self.process_group = process_group
|
||||
|
||||
@ -492,8 +525,8 @@ class FlashRWLargeLayer(nn.Module):
|
||||
layer_past_present_indices,
|
||||
cu_seqlens_q,
|
||||
):
|
||||
ln_attn, _ = self.ln_attn(hidden_states)
|
||||
ln_mlp, _ = self.ln_mlp(hidden_states)
|
||||
ln_attn, residual = self.ln_attn(hidden_states, residual)
|
||||
ln_mlp, _ = self.ln_mlp(residual)
|
||||
|
||||
# Self attention.
|
||||
attn_output = self.self_attention(
|
||||
@ -516,13 +549,11 @@ class FlashRWLargeLayer(nn.Module):
|
||||
if self.process_group is not None:
|
||||
torch.distributed.all_reduce(intermediate, group=self.process_group)
|
||||
|
||||
return intermediate + hidden_states, None
|
||||
return intermediate, residual
|
||||
|
||||
|
||||
class FlashRWPreTrainedModel(PreTrainedModel):
|
||||
config_class = RWConfig
|
||||
supports_gradient_checkpointing = False
|
||||
_no_split_modules = None
|
||||
|
||||
|
||||
class FlashRWModel(FlashRWPreTrainedModel):
|
||||
@ -559,7 +590,11 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||
for _ in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.cache_size = (2, self.h[0].self_attention.num_heads_kv, self.h[0].self_attention.head_size)
|
||||
self.cache_size = (
|
||||
2,
|
||||
self.h[0].self_attention.num_heads_kv,
|
||||
self.h[0].self_attention.head_size,
|
||||
)
|
||||
elif config.model_type == "RefinedWeb":
|
||||
self.h = nn.ModuleList(
|
||||
[
|
||||
@ -574,7 +609,11 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||
for _ in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.cache_size = (self.h[0].self_attention.num_groups, 2, self.h[0].self_attention.head_size)
|
||||
self.cache_size = (
|
||||
self.h[0].self_attention.num_groups,
|
||||
2,
|
||||
self.h[0].self_attention.head_size,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"model_type {config.model_type} is not supported."
|
||||
@ -582,8 +621,6 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||
|
||||
self.ln_f = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.head_size = self.h[0].self_attention.head_size
|
||||
|
||||
def post_load_weights(self, quantize: Optional[str] = None):
|
||||
@ -629,7 +666,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||
len(hidden_states)
|
||||
if pre_allocate_past_size is None
|
||||
else pre_allocate_past_size,
|
||||
*self.cache_size
|
||||
*self.cache_size,
|
||||
)
|
||||
)
|
||||
layer_past_present_indices = None
|
||||
|
@ -113,7 +113,6 @@ class FlashRW(FlashCausalLM):
|
||||
model.post_load_weights(quantize)
|
||||
|
||||
|
||||
|
||||
class FlashRWSharded(FlashRW):
|
||||
def __init__(
|
||||
self,
|
||||
|
Loading…
Reference in New Issue
Block a user