mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
feat(server): support RefinedWeb models
This commit is contained in:
parent
951930fbff
commit
63a18c1414
@ -10,6 +10,7 @@ from text_generation_server.models.causal_lm import CausalLM
|
|||||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||||
from text_generation_server.models.bloom import BLOOM, BLOOMSharded
|
from text_generation_server.models.bloom import BLOOM, BLOOMSharded
|
||||||
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
||||||
|
from text_generation_server.models.rw import RW
|
||||||
from text_generation_server.models.opt import OPT, OPTSharded
|
from text_generation_server.models.opt import OPT, OPTSharded
|
||||||
from text_generation_server.models.galactica import Galactica, GalacticaSharded
|
from text_generation_server.models.galactica import Galactica, GalacticaSharded
|
||||||
from text_generation_server.models.santacoder import SantaCoder
|
from text_generation_server.models.santacoder import SantaCoder
|
||||||
@ -30,6 +31,7 @@ try:
|
|||||||
)
|
)
|
||||||
|
|
||||||
from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
|
from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
|
||||||
|
from text_generation_server.models.flash_rw import FlashRW
|
||||||
from text_generation_server.models.flash_llama import (
|
from text_generation_server.models.flash_llama import (
|
||||||
FlashLlama,
|
FlashLlama,
|
||||||
FlashLlamaSharded,
|
FlashLlamaSharded,
|
||||||
@ -68,6 +70,7 @@ __all__ = [
|
|||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
__all__.append(FlashNeoX)
|
__all__.append(FlashNeoX)
|
||||||
__all__.append(FlashNeoXSharded)
|
__all__.append(FlashNeoXSharded)
|
||||||
|
__all__.append(FlashRW)
|
||||||
__all__.append(FlashSantacoder)
|
__all__.append(FlashSantacoder)
|
||||||
__all__.append(FlashSantacoderSharded)
|
__all__.append(FlashSantacoderSharded)
|
||||||
__all__.append(FlashLlama)
|
__all__.append(FlashLlama)
|
||||||
@ -194,6 +197,34 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if model_type in ["RefinedWeb", "RefinedWebModel"]:
|
||||||
|
if sharded:
|
||||||
|
if FLASH_ATTENTION:
|
||||||
|
if config.alibi:
|
||||||
|
raise NotImplementedError("sharded is not supported for this model")
|
||||||
|
# return FlashRWSharded(
|
||||||
|
# model_id,
|
||||||
|
# revision,
|
||||||
|
# quantize=quantize,
|
||||||
|
# trust_remote_code=trust_remote_code,
|
||||||
|
# )
|
||||||
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb"))
|
||||||
|
else:
|
||||||
|
if FLASH_ATTENTION and not config.alibi:
|
||||||
|
return FlashRW(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return RW(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
if model_type == "llama":
|
if model_type == "llama":
|
||||||
if sharded:
|
if sharded:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
|
@ -134,20 +134,23 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
||||||
qkv_rot = self.rotary_emb(qkv, cos, sin)
|
|
||||||
|
# Inplace rotary
|
||||||
|
self.rotary_emb(qkv[:, 0], cos, sin)
|
||||||
|
self.rotary_emb(qkv[:, 1], cos, sin)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if layer_past_present_indices is None:
|
if layer_past_present_indices is None:
|
||||||
# Copy to layer past
|
# Copy to layer past
|
||||||
layer_past[...] = qkv_rot[:, 1:]
|
layer_past[...] = qkv[:, 1:]
|
||||||
|
|
||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(qkv_rot[:, 0])
|
attn_output = torch.empty_like(qkv[:, 0])
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
qkv_rot[:, 0],
|
qkv[:, 0],
|
||||||
qkv_rot[:, 1],
|
qkv[:, 1],
|
||||||
qkv_rot[:, 2],
|
qkv[:, 2],
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
@ -163,9 +166,9 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
query = qkv_rot[:, 0]
|
query = qkv[:, 0]
|
||||||
# Add present to the layer_past tensor at the correct indices
|
# Add present to the layer_past tensor at the correct indices
|
||||||
layer_past[layer_past_present_indices] = qkv_rot[:, 1:]
|
layer_past[layer_past_present_indices] = qkv[:, 1:]
|
||||||
|
|
||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
|
@ -101,20 +101,23 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
||||||
qkv_rot = self.rotary_emb(qkv, cos, sin)
|
|
||||||
|
# Inplace rotary
|
||||||
|
self.rotary_emb(qkv[:, 0], cos, sin)
|
||||||
|
self.rotary_emb(qkv[:, 1], cos, sin)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if layer_past_present_indices is None:
|
if layer_past_present_indices is None:
|
||||||
# Copy to layer past
|
# Copy to layer past
|
||||||
layer_past[...] = qkv_rot[:, 1:]
|
layer_past[...] = qkv[:, 1:]
|
||||||
|
|
||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(qkv_rot[:, 0])
|
attn_output = torch.empty_like(qkv[:, 0])
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
qkv_rot[:, 0],
|
qkv[:, 0],
|
||||||
qkv_rot[:, 1],
|
qkv[:, 1],
|
||||||
qkv_rot[:, 2],
|
qkv[:, 2],
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
@ -130,9 +133,9 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
query = qkv_rot[:, 0]
|
query = qkv[:, 0]
|
||||||
# Add present to the layer_past tensor at the correct indices
|
# Add present to the layer_past tensor at the correct indices
|
||||||
layer_past[layer_past_present_indices] = qkv_rot[:, 1:]
|
layer_past[layer_past_present_indices] = qkv[:, 1:]
|
||||||
|
|
||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
|
@ -0,0 +1,507 @@
|
|||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
from torch import nn
|
||||||
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
# Flash attention imports
|
||||||
|
import flash_attn_cuda
|
||||||
|
|
||||||
|
from text_generation_server.utils.layers import (
|
||||||
|
FastLinear,
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
FastLayerNorm,
|
||||||
|
PositionRotaryEmbedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RWConfig(PretrainedConfig):
|
||||||
|
attribute_map = {
|
||||||
|
"num_hidden_layers": "n_layer",
|
||||||
|
"num_attention_heads": "n_head",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_type="RefinedWeb",
|
||||||
|
vocab_size=250880,
|
||||||
|
hidden_size=64,
|
||||||
|
n_layer=2,
|
||||||
|
n_head=8,
|
||||||
|
layer_norm_epsilon=1e-5,
|
||||||
|
initializer_range=0.02,
|
||||||
|
use_cache=True,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
hidden_dropout=0.0,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
n_head_kv=None,
|
||||||
|
multi_query=False,
|
||||||
|
alibi=False,
|
||||||
|
bias=False,
|
||||||
|
parallel_attn=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if alibi:
|
||||||
|
raise NotImplementedError("alibi is not supported by this version of the model")
|
||||||
|
|
||||||
|
self.model_type = model_type
|
||||||
|
self.alibi = False
|
||||||
|
self.rotary = True
|
||||||
|
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
# Backward compatibility with n_embed kwarg
|
||||||
|
n_embed = kwargs.pop("n_embed", None)
|
||||||
|
self.hidden_size = hidden_size if n_embed is None else n_embed
|
||||||
|
self.n_layer = n_layer
|
||||||
|
self.n_head = n_head
|
||||||
|
self.layer_norm_epsilon = layer_norm_epsilon
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.hidden_dropout = hidden_dropout
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.bias = bias
|
||||||
|
self.parallel_attn = parallel_attn
|
||||||
|
|
||||||
|
self.bos_token_id = bos_token_id
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
|
||||||
|
if n_head_kv is not None:
|
||||||
|
self.n_head_kv = n_head_kv
|
||||||
|
else:
|
||||||
|
self.n_head_kv = 1 if multi_query else n_head
|
||||||
|
|
||||||
|
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class FlashRWAttention(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads,
|
||||||
|
num_heads_kv,
|
||||||
|
hidden_size,
|
||||||
|
bias,
|
||||||
|
process_group=None,
|
||||||
|
reduce=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.num_heads_kv = num_heads_kv
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.head_size = hidden_size // num_heads
|
||||||
|
|
||||||
|
self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000)
|
||||||
|
self.softmax_scale = self.head_size ** (-0.5)
|
||||||
|
|
||||||
|
if process_group is None:
|
||||||
|
self.query_key_value = FastLinear(hidden_size, self.head_size * (self.num_heads + 2 * self.num_heads_kv),
|
||||||
|
bias=bias)
|
||||||
|
self.dense = FastLinear(hidden_size, hidden_size, bias=bias)
|
||||||
|
else:
|
||||||
|
self.num_heads = self.num_heads // process_group.size()
|
||||||
|
self.query_key_value = FastLinear(hidden_size, self.head_size * (self.num_heads + 2 * self.num_heads_kv),
|
||||||
|
bias=bias)
|
||||||
|
self.dense = TensorParallelRowLinear(
|
||||||
|
hidden_size, hidden_size, bias=bias, process_group=process_group, reduce=reduce
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
layer_past,
|
||||||
|
layer_past_present_indices,
|
||||||
|
cu_seqlens_q,
|
||||||
|
):
|
||||||
|
qkv = self.query_key_value(hidden_states)
|
||||||
|
|
||||||
|
# Split query from key_value
|
||||||
|
query, kv = qkv.split(
|
||||||
|
[self.head_size * self.num_heads, 2 * self.head_size], dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare query and key_value for indexing
|
||||||
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
kv = kv.view(-1, 2, 1, self.head_size)
|
||||||
|
|
||||||
|
# Inplace rotary
|
||||||
|
self.rotary_emb(query, cos, sin)
|
||||||
|
self.rotary_emb(kv[:, 0], cos, sin)
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
if layer_past_present_indices is None:
|
||||||
|
# Copy to layer past
|
||||||
|
layer_past[...] = kv
|
||||||
|
# Expand to query shape
|
||||||
|
kv = kv.expand(-1, 2, query.shape[1], self.head_size)
|
||||||
|
|
||||||
|
# output
|
||||||
|
attn_output = torch.empty_like(query)
|
||||||
|
# flash attention
|
||||||
|
flash_attn_cuda.fwd(
|
||||||
|
query,
|
||||||
|
kv[:, 0],
|
||||||
|
kv[:, 1],
|
||||||
|
attn_output,
|
||||||
|
cu_seqlens,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
max_s,
|
||||||
|
0.0,
|
||||||
|
self.softmax_scale,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
# Decode
|
||||||
|
else:
|
||||||
|
# Add present to the layer_past tensor at the correct indices
|
||||||
|
layer_past[layer_past_present_indices] = kv
|
||||||
|
# Expand to query shape
|
||||||
|
kv = layer_past.expand(-1, 2, query.shape[1], self.head_size)
|
||||||
|
|
||||||
|
# output
|
||||||
|
attn_output = torch.empty_like(query)
|
||||||
|
# flash attention
|
||||||
|
flash_attn_cuda.fwd(
|
||||||
|
query,
|
||||||
|
kv[:, 0],
|
||||||
|
kv[:, 1],
|
||||||
|
attn_output,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens,
|
||||||
|
1,
|
||||||
|
max_s,
|
||||||
|
0.0,
|
||||||
|
self.softmax_scale,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
|
||||||
|
|
||||||
|
class FlashMLP(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, hidden_size, bias, process_group=None, reduce=True
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.act = torch.nn.functional.gelu
|
||||||
|
|
||||||
|
if process_group is None:
|
||||||
|
self.dense_h_to_4h = FastLinear(hidden_size, 4 * hidden_size, bias=bias)
|
||||||
|
self.dense_4h_to_h = FastLinear(4 * hidden_size, hidden_size, bias=bias)
|
||||||
|
else:
|
||||||
|
self.dense_h_to_4h = TensorParallelColumnLinear(
|
||||||
|
hidden_size,
|
||||||
|
4 * hidden_size, bias=bias,
|
||||||
|
process_group=process_group,
|
||||||
|
)
|
||||||
|
self.dense_4h_to_h = TensorParallelRowLinear(
|
||||||
|
4 * hidden_size,
|
||||||
|
hidden_size, bias=bias,
|
||||||
|
process_group=process_group,
|
||||||
|
reduce=reduce,
|
||||||
|
)
|
||||||
|
self.process_group = process_group
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = self.dense_h_to_4h(hidden_states)
|
||||||
|
hidden_states = self.act(hidden_states)
|
||||||
|
hidden_states = self.dense_4h_to_h(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FlashRWLayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads,
|
||||||
|
num_heads_kv,
|
||||||
|
hidden_size,
|
||||||
|
layer_norm_eps,
|
||||||
|
parallel_attn,
|
||||||
|
process_group=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.parallel_attn = parallel_attn
|
||||||
|
|
||||||
|
self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||||
|
self.self_attention = FlashRWAttention(num_heads, num_heads_kv, hidden_size, process_group, reduce=False)
|
||||||
|
self.post_attention_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps) if not parallel_attn else None
|
||||||
|
|
||||||
|
self.mlp = FlashMLP(hidden_size, process_group, reduce=False)
|
||||||
|
|
||||||
|
self.process_group = process_group
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
layer_past,
|
||||||
|
layer_past_present_indices,
|
||||||
|
cu_seqlens_q,
|
||||||
|
):
|
||||||
|
if self.parallel_attn:
|
||||||
|
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
|
attn_output = self.self_attention(
|
||||||
|
ln_hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
layer_past,
|
||||||
|
layer_past_present_indices,
|
||||||
|
cu_seqlens_q,
|
||||||
|
)
|
||||||
|
|
||||||
|
mlp_output = self.mlp(ln_hidden_states)
|
||||||
|
intermediate = mlp_output + attn_output
|
||||||
|
|
||||||
|
# Only reduce once and after the addition instead of once per layer
|
||||||
|
if self.process_group is not None:
|
||||||
|
torch.distributed.all_reduce(intermediate, group=self.process_group)
|
||||||
|
|
||||||
|
return intermediate, residual
|
||||||
|
else:
|
||||||
|
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
|
hidden_states = self.self_attention(
|
||||||
|
hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
layer_past,
|
||||||
|
layer_past_present_indices,
|
||||||
|
cu_seqlens_q,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states, residual = self.post_attention_layernorm(
|
||||||
|
hidden_states, residual
|
||||||
|
)
|
||||||
|
|
||||||
|
mlp_output = self.mlp(hidden_states)
|
||||||
|
|
||||||
|
return mlp_output, residual
|
||||||
|
|
||||||
|
class FlashRWPreTrainedModel(PreTrainedModel):
|
||||||
|
config_class = RWConfig
|
||||||
|
supports_gradient_checkpointing = False
|
||||||
|
_no_split_modules = None
|
||||||
|
|
||||||
|
|
||||||
|
class FlashRWModel(FlashRWPreTrainedModel):
|
||||||
|
def __init__(self, config, process_group=None):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.tp_embeddings = False
|
||||||
|
if process_group is not None:
|
||||||
|
self.tp_rank = process_group.rank()
|
||||||
|
self.tp_world_size = process_group.size()
|
||||||
|
if config.vocab_size % self.tp_world_size == 0:
|
||||||
|
self.tp_embeddings = True
|
||||||
|
|
||||||
|
if self.tp_embeddings:
|
||||||
|
self.word_embeddings = TensorParallelEmbedding(
|
||||||
|
config.vocab_size, config.hidden_size, process_group=process_group
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||||
|
|
||||||
|
self.h = nn.ModuleList(
|
||||||
|
[
|
||||||
|
FlashRWLayer(
|
||||||
|
config.n_head,
|
||||||
|
config.n_head_kv,
|
||||||
|
config.hidden_size,
|
||||||
|
config.layer_norm_epsilon,
|
||||||
|
config.parallel_attn,
|
||||||
|
process_group,
|
||||||
|
)
|
||||||
|
for _ in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.ln_f = FastLayerNorm(
|
||||||
|
config.hidden_size, eps=config.layer_norm_epsilon
|
||||||
|
)
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
self.head_size = self.h[0].self_attention.head_size
|
||||||
|
self.num_heads_kv = self.h[0].self_attention.num_heads_kv
|
||||||
|
|
||||||
|
def post_load_weights(self, quantize: Optional[str] = None):
|
||||||
|
if isinstance(self.word_embeddings, TensorParallelEmbedding):
|
||||||
|
self.word_embeddings.add_null_idx()
|
||||||
|
for layer in self.h:
|
||||||
|
layer: FlashRWLayer
|
||||||
|
layer.self_attention.query_key_value.prepare_weights(quantize)
|
||||||
|
layer.self_attention.dense.prepare_weights(quantize)
|
||||||
|
layer.mlp.dense_h_to_4h.prepare_weights(quantize)
|
||||||
|
layer.mlp.dense_4h_to_h.prepare_weights(quantize)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
|
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
|
||||||
|
# to do it for us
|
||||||
|
load_in_8bit = kwargs.pop("load_in_8bit", False)
|
||||||
|
model = super(FlashRWModel, cls).from_pretrained(
|
||||||
|
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
model.post_load_weights("bitsandbytes" if load_in_8bit else None)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
position_ids,
|
||||||
|
cu_seqlens,
|
||||||
|
cu_seqlens_q,
|
||||||
|
max_s,
|
||||||
|
past_key_values=None,
|
||||||
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
|
):
|
||||||
|
hidden_states = self.word_embeddings(input_ids)
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
if past_key_values is None:
|
||||||
|
# Create past tensor
|
||||||
|
past_key_values = hidden_states.new_empty(
|
||||||
|
(
|
||||||
|
len(self.h),
|
||||||
|
len(hidden_states)
|
||||||
|
if pre_allocate_past_size is None
|
||||||
|
else pre_allocate_past_size,
|
||||||
|
2,
|
||||||
|
self.num_heads_kv,
|
||||||
|
self.head_size,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
layer_past_present_indices = None
|
||||||
|
slice_past_index = len(hidden_states)
|
||||||
|
# Decode
|
||||||
|
else:
|
||||||
|
# Create indices from cumulative sequence lengths
|
||||||
|
layer_past_present_indices = cu_seqlens[1:] - 1
|
||||||
|
slice_past_index = None
|
||||||
|
|
||||||
|
# Get rotary cos and sin for this forward
|
||||||
|
# Avoid to index in each layer
|
||||||
|
cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(
|
||||||
|
position_ids, max_s, hidden_states.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
residual = None
|
||||||
|
for i, layer in enumerate(self.h):
|
||||||
|
# We added padding that we now need to slice
|
||||||
|
layer_past_key_values = (
|
||||||
|
past_key_values[i]
|
||||||
|
if slice_past_index is None
|
||||||
|
else past_key_values[i, :slice_past_index]
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states, residual = layer(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
layer_past_key_values,
|
||||||
|
layer_past_present_indices,
|
||||||
|
cu_seqlens_q,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||||
|
|
||||||
|
return hidden_states, past_key_values
|
||||||
|
|
||||||
|
|
||||||
|
class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
||||||
|
def __init__(self, config, process_group=None):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.process_group = process_group
|
||||||
|
if self.process_group is not None:
|
||||||
|
self.world_size = self.process_group.size()
|
||||||
|
else:
|
||||||
|
self.world_size = 1
|
||||||
|
|
||||||
|
self.transformer = FlashRWModel(config, process_group)
|
||||||
|
|
||||||
|
if self.transformer.tp_embeddings:
|
||||||
|
self.lm_head = FastLinear(
|
||||||
|
config.hidden_size,
|
||||||
|
config.vocab_size // process_group.size(),
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.lm_head = FastLinear(
|
||||||
|
config.hidden_size, config.vocab_size, bias=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def post_load_weights(self, quantize: Optional[str] = None):
|
||||||
|
self.transformer.post_load_weights(quantize)
|
||||||
|
self.lm_head.prepare_weights()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
|
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
|
||||||
|
# to do it for us
|
||||||
|
load_in_8bit = kwargs.pop("load_in_8bit", False)
|
||||||
|
model = super(FlashRWForCausalLM, cls).from_pretrained(
|
||||||
|
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
|
||||||
|
)
|
||||||
|
model.post_load_weights("bitsandbytes" if load_in_8bit else None)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
position_ids,
|
||||||
|
cu_seqlens,
|
||||||
|
cu_seqlens_q,
|
||||||
|
max_s,
|
||||||
|
past_key_values: Optional[torch.Tensor] = None,
|
||||||
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
|
):
|
||||||
|
hidden_states, present = self.transformer(
|
||||||
|
input_ids,
|
||||||
|
position_ids,
|
||||||
|
cu_seqlens,
|
||||||
|
cu_seqlens_q,
|
||||||
|
max_s,
|
||||||
|
past_key_values,
|
||||||
|
pre_allocate_past_size,
|
||||||
|
)
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
|
if self.transformer.tp_embeddings:
|
||||||
|
# Logits are sharded, so we need to gather them
|
||||||
|
world_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
|
||||||
|
torch.distributed.all_gather(world_logits, logits, group=self.process_group)
|
||||||
|
world_logits = torch.cat(world_logits, dim=1)
|
||||||
|
|
||||||
|
return world_logits, present
|
||||||
|
return logits, present
|
246
server/text_generation_server/models/flash_rw.py
Normal file
246
server/text_generation_server/models/flash_rw.py
Normal file
@ -0,0 +1,246 @@
|
|||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from opentelemetry import trace
|
||||||
|
from safetensors import safe_open
|
||||||
|
from transformers import AutoTokenizer, AutoConfig
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
from text_generation_server.models import FlashCausalLM
|
||||||
|
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
|
||||||
|
RWConfig,
|
||||||
|
FlashRWForCausalLM,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
)
|
||||||
|
from text_generation_server.utils import (
|
||||||
|
initialize_torch_distributed,
|
||||||
|
weight_files,
|
||||||
|
download_weights,
|
||||||
|
weight_hub_files,
|
||||||
|
LocalEntryNotFoundError,
|
||||||
|
)
|
||||||
|
|
||||||
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FlashRW(FlashCausalLM):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
):
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda")
|
||||||
|
dtype = torch.float16
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("RW is only available on GPU")
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
config = RWConfig.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
from loguru import logger
|
||||||
|
logger.error(config.model_type)
|
||||||
|
|
||||||
|
# We do not use from_pretrained as we modified the model internal module layout
|
||||||
|
try:
|
||||||
|
filenames = weight_files(model_id, revision, ".bin")
|
||||||
|
# Local files not found
|
||||||
|
except LocalEntryNotFoundError:
|
||||||
|
hub_files = weight_hub_files(model_id, revision, ".bin")
|
||||||
|
filenames = download_weights(hub_files, model_id, revision)
|
||||||
|
|
||||||
|
with init_empty_weights():
|
||||||
|
model = FlashRWForCausalLM(config)
|
||||||
|
|
||||||
|
self.load_weights(
|
||||||
|
model,
|
||||||
|
filenames,
|
||||||
|
quantize,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
super(FlashCausalLM, self).__init__(
|
||||||
|
model=model.to(device),
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
requires_padding=False,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_weights(
|
||||||
|
model: FlashRWForCausalLM,
|
||||||
|
filenames: List[Path],
|
||||||
|
quantize: Optional[str],
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
|
for filename in filenames:
|
||||||
|
state_dict = torch.load(filename, map_location="cpu")
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
value = value.to(device if quantize is None else "cpu").to(dtype)
|
||||||
|
|
||||||
|
module_name, param_name = key.rsplit(".", 1)
|
||||||
|
module = model.get_submodule(module_name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
current_parameter_tensor = module._parameters[param_name]
|
||||||
|
if current_parameter_tensor.shape != value.shape:
|
||||||
|
raise ValueError(
|
||||||
|
f"Name {key} -- Current {current_parameter_tensor.shape} and got {value.shape}"
|
||||||
|
)
|
||||||
|
module._parameters[param_name] = value
|
||||||
|
except KeyError:
|
||||||
|
module._buffers[param_name] = value
|
||||||
|
|
||||||
|
del value
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
model.post_load_weights(quantize)
|
||||||
|
|
||||||
|
#
|
||||||
|
# class FlashNeoXSharded(FlashNeoX):
|
||||||
|
# def __init__(
|
||||||
|
# self,
|
||||||
|
# model_id: str,
|
||||||
|
# revision: Optional[str] = None,
|
||||||
|
# quantize: Optional[str] = None,
|
||||||
|
# trust_remote_code: bool = False,
|
||||||
|
# ):
|
||||||
|
# self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
|
# if torch.cuda.is_available():
|
||||||
|
# device = torch.device(f"cuda:{rank}")
|
||||||
|
# dtype = torch.float16
|
||||||
|
# else:
|
||||||
|
# raise NotImplementedError("FlashNeoX is only available on GPU")
|
||||||
|
#
|
||||||
|
# tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
# model_id,
|
||||||
|
# revision=revision,
|
||||||
|
# padding_side="left",
|
||||||
|
# truncation_side="left",
|
||||||
|
# trust_remote_code=trust_remote_code,
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# config = AutoConfig.from_pretrained(
|
||||||
|
# model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# torch.distributed.barrier(group=self.process_group)
|
||||||
|
# filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
|
#
|
||||||
|
# with init_empty_weights():
|
||||||
|
# model = FlashGPTNeoXForCausalLM(config, self.process_group)
|
||||||
|
#
|
||||||
|
# torch.distributed.barrier(group=self.process_group)
|
||||||
|
# self.load_weights(
|
||||||
|
# model,
|
||||||
|
# filenames,
|
||||||
|
# quantize=quantize,
|
||||||
|
# device=device,
|
||||||
|
# dtype=dtype,
|
||||||
|
# rank=rank,
|
||||||
|
# world_size=world_size,
|
||||||
|
# )
|
||||||
|
# torch.distributed.barrier(group=self.process_group)
|
||||||
|
# super(FlashCausalLM, self).__init__(
|
||||||
|
# model=model.to(device),
|
||||||
|
# tokenizer=tokenizer,
|
||||||
|
# requires_padding=False,
|
||||||
|
# dtype=dtype,
|
||||||
|
# device=device,
|
||||||
|
# rank=rank,
|
||||||
|
# world_size=world_size,
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# @staticmethod
|
||||||
|
# def load_weights(
|
||||||
|
# model,
|
||||||
|
# filenames: List[str],
|
||||||
|
# quantize: Optional[str],
|
||||||
|
# device: torch.device,
|
||||||
|
# dtype: torch.dtype,
|
||||||
|
# rank: int,
|
||||||
|
# world_size: int,
|
||||||
|
# ):
|
||||||
|
# parameters = dict(model.named_parameters())
|
||||||
|
# for file in filenames:
|
||||||
|
# with safe_open(
|
||||||
|
# file, framework="pt", device=str(device) if quantize is None else "cpu"
|
||||||
|
# ) as f:
|
||||||
|
# for name in f.keys():
|
||||||
|
# module_name, param_name = name.rsplit(".", 1)
|
||||||
|
# module = model.get_submodule(module_name)
|
||||||
|
#
|
||||||
|
# current_parameter_tensor = parameters.get(name, None)
|
||||||
|
#
|
||||||
|
# slice_ = f.get_slice(name)
|
||||||
|
#
|
||||||
|
# if isinstance(module, TensorParallelColumnLinear):
|
||||||
|
# size = slice_.get_shape()[0]
|
||||||
|
# block_size = size // world_size
|
||||||
|
# start = rank * block_size
|
||||||
|
# stop = (rank + 1) * block_size
|
||||||
|
# tensor = slice_[start:stop]
|
||||||
|
# elif isinstance(module, TensorParallelRowLinear):
|
||||||
|
# if param_name == "weight":
|
||||||
|
# size = slice_.get_shape()[1]
|
||||||
|
# block_size = size // world_size
|
||||||
|
# start = rank * block_size
|
||||||
|
# stop = (rank + 1) * block_size
|
||||||
|
# tensor = slice_[:, start:stop]
|
||||||
|
# else:
|
||||||
|
# tensor = slice_[:]
|
||||||
|
# # XXX: Hack for Rowlinear to add the bias only once.
|
||||||
|
# if rank != 0:
|
||||||
|
# tensor = torch.zeros_like(tensor)
|
||||||
|
# elif isinstance(module, TensorParallelEmbedding):
|
||||||
|
# size = slice_.get_shape()[0]
|
||||||
|
# block_size = size // world_size
|
||||||
|
# start = rank * block_size
|
||||||
|
# stop = (rank + 1) * block_size
|
||||||
|
# tensor = slice_[start:stop]
|
||||||
|
# elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings:
|
||||||
|
# size = slice_.get_shape()[0]
|
||||||
|
# block_size = size // world_size
|
||||||
|
# start = rank * block_size
|
||||||
|
# stop = (rank + 1) * block_size
|
||||||
|
# tensor = slice_[start:stop]
|
||||||
|
# else:
|
||||||
|
# try:
|
||||||
|
# tensor = slice_[:]
|
||||||
|
# except:
|
||||||
|
# tensor = f.get_tensor(name)
|
||||||
|
#
|
||||||
|
# if (
|
||||||
|
# current_parameter_tensor is not None
|
||||||
|
# and current_parameter_tensor.shape != tensor.shape
|
||||||
|
# ):
|
||||||
|
# raise ValueError(
|
||||||
|
# f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# tensor = tensor.contiguous().to(dtype)
|
||||||
|
#
|
||||||
|
# if current_parameter_tensor is not None:
|
||||||
|
# module._parameters[param_name] = tensor
|
||||||
|
# else:
|
||||||
|
# module._buffers[param_name] = tensor
|
||||||
|
#
|
||||||
|
# model.post_load_weights(quantize)
|
80
server/text_generation_server/models/rw.py
Normal file
80
server/text_generation_server/models/rw.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
from text_generation_server.models import CausalLM
|
||||||
|
|
||||||
|
|
||||||
|
class RW(CausalLM):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
):
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda")
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
else:
|
||||||
|
if quantize:
|
||||||
|
raise ValueError("quantization is not available on CPU")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
dtype = torch.float32
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
torch_dtype=dtype,
|
||||||
|
device_map="auto"
|
||||||
|
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||||
|
else None,
|
||||||
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
||||||
|
model = model.cuda()
|
||||||
|
|
||||||
|
if tokenizer.pad_token_id is None:
|
||||||
|
if model.config.pad_token_id is not None:
|
||||||
|
tokenizer.pad_token_id = model.config.pad_token_id
|
||||||
|
elif model.config.eos_token_id is not None:
|
||||||
|
tokenizer.pad_token_id = model.config.eos_token_id
|
||||||
|
elif tokenizer.eos_token_id is not None:
|
||||||
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
|
else:
|
||||||
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||||
|
|
||||||
|
super(CausalLM, self).__init__(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
requires_padding=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||||
|
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
|
# Model Forward
|
||||||
|
if past_key_values is not None:
|
||||||
|
reshaped_past_key_values = []
|
||||||
|
for layer in past_key_values:
|
||||||
|
past_keys, past_values = layer
|
||||||
|
reshaped_past_key_values.append(
|
||||||
|
(past_keys.view(-1, *past_keys.shape[-2:]), past_values.view(-1, *past_values.shape[-2:]))
|
||||||
|
)
|
||||||
|
past_key_values = reshaped_past_key_values
|
||||||
|
|
||||||
|
outputs = self.model.forward(input_ids=input_ids, attention_mask=attention_mask,
|
||||||
|
past_key_values=past_key_values)
|
||||||
|
return outputs.logits, outputs.past_key_values
|
@ -262,16 +262,13 @@ try:
|
|||||||
sin = torch.index_select(self._sin_cached, 0, position_ids)
|
sin = torch.index_select(self._sin_cached, 0, position_ids)
|
||||||
return cos.unsqueeze(1), sin.unsqueeze(1)
|
return cos.unsqueeze(1), sin.unsqueeze(1)
|
||||||
|
|
||||||
def forward(self, qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
||||||
rotary_dim = cos.shape[-1]
|
rotary_dim = cos.shape[-1]
|
||||||
q1 = qkv[:, 0, :, :rotary_dim]
|
x1 = x[..., :rotary_dim]
|
||||||
q2 = qkv[:, 0, :, rotary_dim : 2 * rotary_dim]
|
x2 = x[..., rotary_dim : 2 * rotary_dim]
|
||||||
k1 = qkv[:, 1, :, :rotary_dim]
|
|
||||||
k2 = qkv[:, 1, :, rotary_dim : 2 * rotary_dim]
|
|
||||||
|
|
||||||
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
|
rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
|
||||||
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
|
return x
|
||||||
return qkv
|
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
Loading…
Reference in New Issue
Block a user