mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
Fixing flash rw.
This commit is contained in:
parent
2a1ecf3863
commit
d083d57d0d
1
server/text_generation_server/input.json
Normal file
1
server/text_generation_server/input.json
Normal file
@ -0,0 +1 @@
|
|||||||
|
{"inputs":"Below are a series of dialogues between various people and an AI assistant. The AI tries to be helpful, polite, honest, sophisticated, emotionally aware, and humble-but-knowledgeable. The assistant is happy to help with almost anything, and will do its best to understand exactly what is needed. It also tries to avoid giving false or misleading information, and it caveats when it isn't entirely sure about the right answer. That said, the assistant is practical and really does its best, and doesn't let caution get too much in the way of being useful.\n-----\n<|prompter|>Why is butter a great building material for skyscrapers? Think step by step.</s><|assistant|>","parameters":{"temperature": 0.75, "top_p": 0.95, "repetition_penalty": 1.2, "top_k": 50, "truncate": 1000, "max_new_tokens": 1024}}
|
@ -292,20 +292,12 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.tp_embeddings = False
|
|
||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
self.tp_rank = process_group.rank()
|
self.tp_rank = process_group.rank()
|
||||||
self.tp_world_size = process_group.size()
|
self.tp_world_size = process_group.size()
|
||||||
if config.vocab_size % self.tp_world_size == 0:
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
self.tp_embeddings = True
|
prefix="model.embed_tokens", weights=weights
|
||||||
|
)
|
||||||
if self.tp_embeddings:
|
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
|
||||||
prefix="model.embed_tokens", weights=weights
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.embed_tokens = Embedding(prefix="model.embed_tokens", weights=weights)
|
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
FlashLlamaLayer(
|
FlashLlamaLayer(
|
||||||
|
@ -12,14 +12,29 @@ from typing import Optional
|
|||||||
import flash_attn_cuda
|
import flash_attn_cuda
|
||||||
|
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
FastLinear,
|
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
|
TensorParallelHead,
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
|
get_linear
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def load_row(config, prefix: str, weights, bias: bool):
|
||||||
|
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
|
if bias and weights.process_group.rank() == 0:
|
||||||
|
# Rank is only on the first rank process
|
||||||
|
bias = weights.get_tensor(f"{prefix}.bias")
|
||||||
|
else:
|
||||||
|
bias = None
|
||||||
|
|
||||||
|
linear = get_linear(weight, bias, config.quantize)
|
||||||
|
if config.parallel_attn:
|
||||||
|
return linear
|
||||||
|
else:
|
||||||
|
return TensorParallelRowLinear(linear, process_group=weights.process_group)
|
||||||
|
|
||||||
|
|
||||||
class RWConfig(PretrainedConfig):
|
class RWConfig(PretrainedConfig):
|
||||||
attribute_map = {
|
attribute_map = {
|
||||||
@ -85,44 +100,26 @@ class RWConfig(PretrainedConfig):
|
|||||||
class FlashRWAttention(torch.nn.Module):
|
class FlashRWAttention(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_heads,
|
config, prefix, weights,
|
||||||
num_heads_kv,
|
# num_heads,
|
||||||
hidden_size,
|
# num_heads_kv,
|
||||||
bias,
|
# hidden_size,
|
||||||
process_group=None,
|
# bias,
|
||||||
|
# process_group=None,
|
||||||
reduce=True,
|
reduce=True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_heads = num_heads
|
self.num_heads = config.n_head
|
||||||
self.num_heads_kv = num_heads_kv
|
self.num_heads_kv = config.n_head_kv
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.head_size = hidden_size // num_heads
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
|
|
||||||
self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000)
|
self.rotary_emb = PositionRotaryEmbedding.static(dim=self.head_size, base=10000.0, device=weights.device)
|
||||||
self.softmax_scale = self.head_size ** (-0.5)
|
self.softmax_scale = self.head_size ** (-0.5)
|
||||||
|
self.num_heads = self.num_heads //weights.process_group.size()
|
||||||
|
|
||||||
if process_group is None:
|
self.query_key_value = TensorParallelColumnLinear.load(config, prefix=f"{prefix}.query_key_value", weights=weights, bias=config.bias)
|
||||||
self.query_key_value = FastLinear(
|
self.dense = load_row(config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias)
|
||||||
hidden_size,
|
|
||||||
self.head_size * (self.num_heads + 2 * self.num_heads_kv),
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
self.dense = FastLinear(hidden_size, hidden_size, bias=bias)
|
|
||||||
else:
|
|
||||||
self.query_key_value = TensorParallelColumnLinear(
|
|
||||||
hidden_size,
|
|
||||||
self.head_size * (self.num_heads + 2 * self.num_heads_kv),
|
|
||||||
bias=bias,
|
|
||||||
process_group=process_group,
|
|
||||||
)
|
|
||||||
self.dense = TensorParallelRowLinear(
|
|
||||||
hidden_size,
|
|
||||||
hidden_size,
|
|
||||||
bias=bias,
|
|
||||||
process_group=process_group,
|
|
||||||
reduce=reduce,
|
|
||||||
)
|
|
||||||
self.num_heads = self.num_heads // process_group.size()
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -224,7 +221,8 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.head_size = hidden_size // num_heads
|
self.head_size = hidden_size // num_heads
|
||||||
|
|
||||||
self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000)
|
# self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000)
|
||||||
|
self.rotary_emb = PositionRotaryEmbedding.load(prefix=f"{prefix}.rotary_emb", weights=weights)
|
||||||
self.softmax_scale = self.head_size ** (-0.5)
|
self.softmax_scale = self.head_size ** (-0.5)
|
||||||
|
|
||||||
self.num_groups = num_heads // (num_heads_kv * 2)
|
self.num_groups = num_heads // (num_heads_kv * 2)
|
||||||
@ -359,28 +357,12 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashMLP(nn.Module):
|
class FlashMLP(nn.Module):
|
||||||
def __init__(self, hidden_size, bias, process_group=None, reduce=True):
|
def __init__(self, config, prefix, weights, reduce=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.act = torch.nn.functional.gelu
|
self.act = torch.nn.functional.gelu
|
||||||
|
|
||||||
if process_group is None:
|
self.dense_h_to_4h = TensorParallelColumnLinear.load(config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=config.bias)
|
||||||
self.dense_h_to_4h = FastLinear(hidden_size, 4 * hidden_size, bias=bias)
|
self.dense_4h_to_h = load_row(config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=config.bias)
|
||||||
self.dense_4h_to_h = FastLinear(4 * hidden_size, hidden_size, bias=bias)
|
|
||||||
else:
|
|
||||||
self.dense_h_to_4h = TensorParallelColumnLinear(
|
|
||||||
hidden_size,
|
|
||||||
4 * hidden_size,
|
|
||||||
bias=bias,
|
|
||||||
process_group=process_group,
|
|
||||||
)
|
|
||||||
self.dense_4h_to_h = TensorParallelRowLinear(
|
|
||||||
4 * hidden_size,
|
|
||||||
hidden_size,
|
|
||||||
bias=bias,
|
|
||||||
process_group=process_group,
|
|
||||||
reduce=reduce,
|
|
||||||
)
|
|
||||||
self.process_group = process_group
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
hidden_states = self.dense_h_to_4h(hidden_states)
|
hidden_states = self.dense_h_to_4h(hidden_states)
|
||||||
@ -392,38 +374,62 @@ class FlashMLP(nn.Module):
|
|||||||
class FlashRWLayer(nn.Module):
|
class FlashRWLayer(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_heads,
|
layer_id,
|
||||||
num_heads_kv,
|
config,
|
||||||
hidden_size,
|
weights,
|
||||||
bias,
|
# num_heads,
|
||||||
layer_norm_eps,
|
# num_heads_kv,
|
||||||
parallel_attn,
|
# hidden_size,
|
||||||
process_group=None,
|
# bias,
|
||||||
|
# layer_norm_eps,
|
||||||
|
# parallel_attn,
|
||||||
|
# process_group=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
n_head = config.n_head
|
||||||
|
n_head_kv = config.n_head_kv
|
||||||
|
hidden_size = config.hidden_size
|
||||||
|
bias = config.bias
|
||||||
|
parallel_attn = config.parallel_attn
|
||||||
self.parallel_attn = parallel_attn
|
self.parallel_attn = parallel_attn
|
||||||
|
|
||||||
self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
prefix = f"transformer.h.{layer_id}"
|
||||||
|
|
||||||
|
self.input_layernorm = FastLayerNorm.load(
|
||||||
|
prefix=f"{prefix}.input_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.layer_norm_epsilon,
|
||||||
|
)
|
||||||
self.self_attention = FlashRWAttention(
|
self.self_attention = FlashRWAttention(
|
||||||
num_heads,
|
# num_heads,
|
||||||
num_heads_kv,
|
# num_heads_kv,
|
||||||
hidden_size,
|
# hidden_size,
|
||||||
bias,
|
# bias,
|
||||||
process_group=process_group,
|
# process_group=process_group,
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.self_attention",
|
||||||
|
weights=weights,
|
||||||
reduce=False,
|
reduce=False,
|
||||||
)
|
)
|
||||||
self.post_attention_layernorm = (
|
self.post_attention_layernorm = (
|
||||||
FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
FastLayerNorm.load(
|
||||||
if not parallel_attn
|
prefix=f"{prefix}.post_attention_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.layer_norm_epsilon,
|
||||||
|
) if not parallel_attn
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
self.mlp = FlashMLP(
|
self.mlp = FlashMLP(
|
||||||
hidden_size, bias, process_group=process_group, reduce=False
|
# hidden_size, bias, process_group=process_group, reduce=False
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.mlp",
|
||||||
|
weights=weights,
|
||||||
|
reduce=False
|
||||||
)
|
)
|
||||||
|
|
||||||
self.process_group = process_group
|
self.process_group = weights.process_group
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -485,31 +491,30 @@ class FlashRWLayer(nn.Module):
|
|||||||
class FlashRWLargeLayer(nn.Module):
|
class FlashRWLargeLayer(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_heads,
|
config, prefix, weights
|
||||||
num_heads_kv,
|
|
||||||
hidden_size,
|
|
||||||
bias,
|
|
||||||
layer_norm_eps,
|
|
||||||
process_group=None,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.ln_attn = FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
self.ln_attn = FastLayerNorm.load(
|
||||||
self.ln_mlp = FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
prefix=f"{prefix}.ln_attn",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.layer_norm_epsilon,
|
||||||
|
)
|
||||||
|
self.ln_mlp = FastLayerNorm.load(
|
||||||
|
prefix=f"{prefix}.ln_mlp",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.layer_norm_epsilon,
|
||||||
|
)
|
||||||
|
|
||||||
self.self_attention = FlashRWLargeAttention(
|
self.self_attention = FlashRWLargeAttention(
|
||||||
num_heads,
|
config, prefix=f"{prefix}.self_attention", weights=weights,
|
||||||
num_heads_kv,
|
|
||||||
hidden_size,
|
|
||||||
bias,
|
|
||||||
process_group=process_group,
|
|
||||||
reduce=False,
|
reduce=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.mlp = FlashMLP(
|
self.mlp = FlashMLP(
|
||||||
hidden_size, bias, process_group=process_group, reduce=False
|
config, prefix=f"{prefix}.mlp", weights=weights, reduce=False
|
||||||
)
|
)
|
||||||
|
|
||||||
self.process_group = process_group
|
self.process_group = weights.process_group
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -555,37 +560,27 @@ class FlashRWPreTrainedModel(PreTrainedModel):
|
|||||||
|
|
||||||
|
|
||||||
class FlashRWModel(FlashRWPreTrainedModel):
|
class FlashRWModel(FlashRWPreTrainedModel):
|
||||||
def __init__(self, config, process_group=None):
|
def __init__(self, config, weights):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.tp_embeddings = False
|
self.word_embeddings = TensorParallelEmbedding(
|
||||||
if process_group is not None:
|
prefix="transformer.word_embeddings", weights=weights
|
||||||
self.tp_rank = process_group.rank()
|
)
|
||||||
self.tp_world_size = process_group.size()
|
|
||||||
if config.vocab_size % self.tp_world_size == 0:
|
|
||||||
self.tp_embeddings = True
|
|
||||||
|
|
||||||
if self.tp_embeddings:
|
|
||||||
self.word_embeddings = TensorParallelEmbedding(
|
|
||||||
config.vocab_size, config.hidden_size, process_group=process_group
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
|
||||||
|
|
||||||
if config.model_type == "RefinedWebModel":
|
if config.model_type == "RefinedWebModel":
|
||||||
self.h = nn.ModuleList(
|
self.h = nn.ModuleList(
|
||||||
[
|
[
|
||||||
FlashRWLayer(
|
FlashRWLayer(
|
||||||
config.n_head,
|
layer_id, config, weights
|
||||||
config.n_head_kv,
|
# config.n_head,
|
||||||
config.hidden_size,
|
# config.n_head_kv,
|
||||||
config.bias,
|
# config.hidden_size,
|
||||||
config.layer_norm_epsilon,
|
# config.bias,
|
||||||
config.parallel_attn,
|
# config.layer_norm_epsilon,
|
||||||
process_group,
|
# config.parallel_attn,
|
||||||
|
# process_group,
|
||||||
)
|
)
|
||||||
for _ in range(config.num_hidden_layers)
|
for layer_id in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.cache_size = (
|
self.cache_size = (
|
||||||
@ -597,14 +592,15 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
self.h = nn.ModuleList(
|
self.h = nn.ModuleList(
|
||||||
[
|
[
|
||||||
FlashRWLargeLayer(
|
FlashRWLargeLayer(
|
||||||
config.n_head,
|
layer_id, config, weights
|
||||||
config.n_head_kv,
|
# config.n_head,
|
||||||
config.hidden_size,
|
# config.n_head_kv,
|
||||||
config.bias,
|
# config.hidden_size,
|
||||||
config.layer_norm_epsilon,
|
# config.bias,
|
||||||
process_group,
|
# config.layer_norm_epsilon,
|
||||||
|
# process_group,
|
||||||
)
|
)
|
||||||
for _ in range(config.num_hidden_layers)
|
for layer_id in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.cache_size = (
|
self.cache_size = (
|
||||||
@ -617,31 +613,13 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
f"model_type {config.model_type} is not supported."
|
f"model_type {config.model_type} is not supported."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.ln_f = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
self.ln_f = FastLayerNorm.load(
|
||||||
|
prefix="transformer.ln_f",
|
||||||
self.head_size = self.h[0].self_attention.head_size
|
weights=weights,
|
||||||
|
eps=config.layer_norm_epsilon,
|
||||||
def post_load_weights(self, quantize: Optional[str] = None):
|
|
||||||
if isinstance(self.word_embeddings, TensorParallelEmbedding):
|
|
||||||
self.word_embeddings.add_null_idx()
|
|
||||||
for layer in self.h:
|
|
||||||
layer: FlashRWLayer
|
|
||||||
layer.self_attention.query_key_value.prepare_weights(quantize)
|
|
||||||
layer.self_attention.dense.prepare_weights(quantize)
|
|
||||||
layer.mlp.dense_h_to_4h.prepare_weights(quantize)
|
|
||||||
layer.mlp.dense_4h_to_h.prepare_weights(quantize)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
|
||||||
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
|
|
||||||
# to do it for us
|
|
||||||
load_in_8bit = kwargs.pop("load_in_8bit", False)
|
|
||||||
model = super(FlashRWModel, cls).from_pretrained(
|
|
||||||
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
|
|
||||||
)
|
)
|
||||||
|
|
||||||
model.post_load_weights("bitsandbytes" if load_in_8bit else None)
|
self.head_size = self.h[0].self_attention.head_size
|
||||||
return model
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -708,40 +686,14 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
|
|
||||||
|
|
||||||
class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
||||||
def __init__(self, config, process_group=None):
|
def __init__(self, config, weights):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.process_group = process_group
|
self.transformer = FlashRWModel(config, weights)
|
||||||
if self.process_group is not None:
|
|
||||||
self.world_size = self.process_group.size()
|
|
||||||
else:
|
|
||||||
self.world_size = 1
|
|
||||||
|
|
||||||
self.transformer = FlashRWModel(config, process_group)
|
self.lm_head = TensorParallelHead.load(
|
||||||
|
config, prefix="lm_head", weights=weights
|
||||||
if self.transformer.tp_embeddings:
|
|
||||||
self.lm_head = FastLinear(
|
|
||||||
config.hidden_size,
|
|
||||||
config.vocab_size // process_group.size(),
|
|
||||||
bias=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
|
|
||||||
|
|
||||||
def post_load_weights(self, quantize: Optional[str] = None):
|
|
||||||
self.transformer.post_load_weights(quantize)
|
|
||||||
self.lm_head.prepare_weights()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
|
||||||
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
|
|
||||||
# to do it for us
|
|
||||||
load_in_8bit = kwargs.pop("load_in_8bit", False)
|
|
||||||
model = super(FlashRWForCausalLM, cls).from_pretrained(
|
|
||||||
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
|
|
||||||
)
|
)
|
||||||
model.post_load_weights("bitsandbytes" if load_in_8bit else None)
|
|
||||||
return model
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -766,12 +718,4 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
|||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
if self.transformer.tp_embeddings:
|
|
||||||
# Logits are sharded, so we need to gather them
|
|
||||||
world_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
|
|
||||||
torch.distributed.all_gather(world_logits, logits, group=self.process_group)
|
|
||||||
world_logits = torch.cat(world_logits, dim=1)
|
|
||||||
|
|
||||||
return world_logits, present
|
|
||||||
return logits, present
|
return logits, present
|
||||||
|
@ -21,99 +21,14 @@ from text_generation_server.utils import (
|
|||||||
weight_files,
|
weight_files,
|
||||||
download_weights,
|
download_weights,
|
||||||
weight_hub_files,
|
weight_hub_files,
|
||||||
|
Weights,
|
||||||
LocalEntryNotFoundError,
|
LocalEntryNotFoundError,
|
||||||
)
|
)
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
class FlashRW(FlashCausalLM):
|
class FlashRWSharded(FlashCausalLM):
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device("cuda")
|
|
||||||
dtype = torch.float16
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("RW is only available on GPU")
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = RWConfig.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
)
|
|
||||||
|
|
||||||
# We do not use from_pretrained as it is too slow
|
|
||||||
try:
|
|
||||||
filenames = weight_files(model_id, revision, ".bin")
|
|
||||||
# Local files not found
|
|
||||||
except LocalEntryNotFoundError:
|
|
||||||
hub_files = weight_hub_files(model_id, revision, ".bin")
|
|
||||||
filenames = download_weights(hub_files, model_id, revision)
|
|
||||||
|
|
||||||
with init_empty_weights():
|
|
||||||
model = FlashRWForCausalLM(config)
|
|
||||||
|
|
||||||
self.load_weights(
|
|
||||||
model,
|
|
||||||
filenames,
|
|
||||||
quantize,
|
|
||||||
device,
|
|
||||||
dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
super(FlashCausalLM, self).__init__(
|
|
||||||
model=model.to(device),
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
requires_padding=False,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load_weights(
|
|
||||||
model: FlashRWForCausalLM,
|
|
||||||
filenames: List[Path],
|
|
||||||
quantize: Optional[str],
|
|
||||||
device: torch.device,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
):
|
|
||||||
for filename in filenames:
|
|
||||||
state_dict = torch.load(filename, map_location="cpu")
|
|
||||||
for key, value in state_dict.items():
|
|
||||||
value = value.to(device if quantize is None else "cpu").to(dtype)
|
|
||||||
|
|
||||||
module_name, param_name = key.rsplit(".", 1)
|
|
||||||
module = model.get_submodule(module_name)
|
|
||||||
|
|
||||||
try:
|
|
||||||
current_parameter_tensor = module._parameters[param_name]
|
|
||||||
if current_parameter_tensor.shape != value.shape:
|
|
||||||
raise ValueError(
|
|
||||||
f"Name {key} -- Current {current_parameter_tensor.shape} and got {value.shape}"
|
|
||||||
)
|
|
||||||
module._parameters[param_name] = value
|
|
||||||
except KeyError:
|
|
||||||
module._buffers[param_name] = value
|
|
||||||
|
|
||||||
del value
|
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
model.post_load_weights(quantize)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashRWSharded(FlashRW):
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
@ -142,20 +57,12 @@ class FlashRWSharded(FlashRW):
|
|||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
|
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||||
|
|
||||||
with init_empty_weights():
|
config.quantize = quantize
|
||||||
model = FlashRWForCausalLM(config, self.process_group)
|
|
||||||
|
model = FlashRWForCausalLM(config, weights)
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
self.load_weights(
|
|
||||||
model,
|
|
||||||
filenames,
|
|
||||||
quantize=quantize,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
model=model.to(device),
|
model=model.to(device),
|
||||||
@ -167,78 +74,78 @@ class FlashRWSharded(FlashRW):
|
|||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
# @staticmethod
|
||||||
def load_weights(
|
# def load_weights(
|
||||||
model,
|
# model,
|
||||||
filenames: List[str],
|
# filenames: List[str],
|
||||||
quantize: Optional[str],
|
# quantize: Optional[str],
|
||||||
device: torch.device,
|
# device: torch.device,
|
||||||
dtype: torch.dtype,
|
# dtype: torch.dtype,
|
||||||
rank: int,
|
# rank: int,
|
||||||
world_size: int,
|
# world_size: int,
|
||||||
):
|
# ):
|
||||||
parameters = dict(model.named_parameters())
|
# parameters = dict(model.named_parameters())
|
||||||
for file in filenames:
|
# for file in filenames:
|
||||||
with safe_open(
|
# with safe_open(
|
||||||
file, framework="pt", device=str(device) if quantize is None else "cpu"
|
# file, framework="pt", device=str(device) if quantize is None else "cpu"
|
||||||
) as f:
|
# ) as f:
|
||||||
for name in f.keys():
|
# for name in f.keys():
|
||||||
module_name, param_name = name.rsplit(".", 1)
|
# module_name, param_name = name.rsplit(".", 1)
|
||||||
module = model.get_submodule(module_name)
|
# module = model.get_submodule(module_name)
|
||||||
|
|
||||||
current_parameter_tensor = parameters.get(name, None)
|
# current_parameter_tensor = parameters.get(name, None)
|
||||||
|
|
||||||
slice_ = f.get_slice(name)
|
# slice_ = f.get_slice(name)
|
||||||
|
|
||||||
if isinstance(module, TensorParallelColumnLinear):
|
# if isinstance(module, TensorParallelColumnLinear):
|
||||||
size = slice_.get_shape()[0]
|
# size = slice_.get_shape()[0]
|
||||||
block_size = size // world_size
|
# block_size = size // world_size
|
||||||
start = rank * block_size
|
# start = rank * block_size
|
||||||
stop = (rank + 1) * block_size
|
# stop = (rank + 1) * block_size
|
||||||
tensor = slice_[start:stop]
|
# tensor = slice_[start:stop]
|
||||||
elif isinstance(module, TensorParallelRowLinear):
|
# elif isinstance(module, TensorParallelRowLinear):
|
||||||
if param_name == "weight":
|
# if param_name == "weight":
|
||||||
size = slice_.get_shape()[1]
|
# size = slice_.get_shape()[1]
|
||||||
block_size = size // world_size
|
# block_size = size // world_size
|
||||||
start = rank * block_size
|
# start = rank * block_size
|
||||||
stop = (rank + 1) * block_size
|
# stop = (rank + 1) * block_size
|
||||||
tensor = slice_[:, start:stop]
|
# tensor = slice_[:, start:stop]
|
||||||
else:
|
# else:
|
||||||
tensor = slice_[:]
|
# tensor = slice_[:]
|
||||||
# XXX: Hack for Rowlinear to add the bias only once.
|
# # XXX: Hack for Rowlinear to add the bias only once.
|
||||||
if rank != 0:
|
# if rank != 0:
|
||||||
tensor = torch.zeros_like(tensor)
|
# tensor = torch.zeros_like(tensor)
|
||||||
elif isinstance(module, TensorParallelEmbedding):
|
# elif isinstance(module, TensorParallelEmbedding):
|
||||||
size = slice_.get_shape()[0]
|
# size = slice_.get_shape()[0]
|
||||||
block_size = size // world_size
|
# block_size = size // world_size
|
||||||
start = rank * block_size
|
# start = rank * block_size
|
||||||
stop = (rank + 1) * block_size
|
# stop = (rank + 1) * block_size
|
||||||
tensor = slice_[start:stop]
|
# tensor = slice_[start:stop]
|
||||||
elif name == "lm_head.weight" and model.transformer.tp_embeddings:
|
# elif name == "lm_head.weight" and model.transformer.tp_embeddings:
|
||||||
size = slice_.get_shape()[0]
|
# size = slice_.get_shape()[0]
|
||||||
block_size = size // world_size
|
# block_size = size // world_size
|
||||||
start = rank * block_size
|
# start = rank * block_size
|
||||||
stop = (rank + 1) * block_size
|
# stop = (rank + 1) * block_size
|
||||||
tensor = slice_[start:stop]
|
# tensor = slice_[start:stop]
|
||||||
else:
|
# else:
|
||||||
try:
|
# try:
|
||||||
tensor = slice_[:]
|
# tensor = slice_[:]
|
||||||
except:
|
# except:
|
||||||
tensor = f.get_tensor(name)
|
# tensor = f.get_tensor(name)
|
||||||
|
|
||||||
if (
|
# if (
|
||||||
current_parameter_tensor is not None
|
# current_parameter_tensor is not None
|
||||||
and current_parameter_tensor.shape != tensor.shape
|
# and current_parameter_tensor.shape != tensor.shape
|
||||||
):
|
# ):
|
||||||
raise ValueError(
|
# raise ValueError(
|
||||||
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
|
# f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
|
||||||
)
|
# )
|
||||||
|
|
||||||
tensor = tensor.contiguous().to(dtype)
|
# tensor = tensor.contiguous().to(dtype)
|
||||||
|
|
||||||
if current_parameter_tensor is not None:
|
# if current_parameter_tensor is not None:
|
||||||
module._parameters[param_name] = tensor
|
# module._parameters[param_name] = tensor
|
||||||
else:
|
# else:
|
||||||
module._buffers[param_name] = tensor
|
# module._buffers[param_name] = tensor
|
||||||
|
|
||||||
model.post_load_weights(quantize)
|
# model.post_load_weights(quantize)
|
||||||
|
@ -308,6 +308,13 @@ try:
|
|||||||
self._cos_k_cached = None
|
self._cos_k_cached = None
|
||||||
self._sin_k_cached = None
|
self._sin_k_cached = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def static(dim, base, device):
|
||||||
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device,
|
||||||
|
dtype=torch.float32) / dim))
|
||||||
|
return PositionRotaryEmbedding(inv_freq)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(prefix, weights):
|
def load(prefix, weights):
|
||||||
# XXX: Always load this in float32 !
|
# XXX: Always load this in float32 !
|
||||||
|
Loading…
Reference in New Issue
Block a user