black + cleanup

This commit is contained in:
OlivierDehaene 2023-06-08 11:47:59 +02:00
parent 5e0a6ea1b7
commit b027f5f129
17 changed files with 186 additions and 190 deletions

View File

@ -138,7 +138,7 @@ COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.lin
# Copy build artifacts from transformers builder # Copy build artifacts from transformers builder
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels
# Install transformers dependencies # Install flash-attention dependencies
RUN pip install einops --no-cache-dir RUN pip install einops --no-cache-dir
# Install server # Install server

View File

@ -249,7 +249,6 @@ def launcher(event_loop):
) as process: ) as process:
yield ProcessLauncherHandle(process, port) yield ProcessLauncherHandle(process, port)
process.terminate() process.terminate()
process.wait(60) process.wait(60)
@ -261,6 +260,7 @@ def launcher(event_loop):
if not use_flash_attention: if not use_flash_attention:
del env["USE_FLASH_ATTENTION"] del env["USE_FLASH_ATTENTION"]
@contextlib.contextmanager @contextlib.contextmanager
def docker_launcher( def docker_launcher(
model_id: str, model_id: str,

View File

@ -3,7 +3,9 @@ import pytest
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def neox_handle(launcher): def neox_handle(launcher):
with launcher("stabilityai/stablelm-tuned-alpha-3b", num_shard=1, use_flash_attention=False) as handle: with launcher(
"stabilityai/stablelm-tuned-alpha-3b", num_shard=1, use_flash_attention=False
) as handle:
yield handle yield handle

View File

@ -3,7 +3,9 @@ import pytest
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def neox_sharded_handle(launcher): def neox_sharded_handle(launcher):
with launcher("OpenAssistant/oasst-sft-1-pythia-12b", num_shard=2, use_flash_attention=False) as handle: with launcher(
"OpenAssistant/oasst-sft-1-pythia-12b", num_shard=2, use_flash_attention=False
) as handle:
yield handle yield handle

View File

@ -13,7 +13,7 @@ setup(
name="custom_kernels.fused_attention_cuda", name="custom_kernels.fused_attention_cuda",
sources=["custom_kernels/fused_attention_cuda.cu"], sources=["custom_kernels/fused_attention_cuda.cu"],
extra_compile_args=["-arch=compute_80", "-std=c++17"], extra_compile_args=["-arch=compute_80", "-std=c++17"],
) ),
], ],
cmdclass={"build_ext": BuildExtension}, cmdclass={"build_ext": BuildExtension},
) )

View File

@ -19,7 +19,10 @@ from text_generation_server.models.t5 import T5Sharded
from text_generation_server.models.gpt_neox import GPTNeoxSharded from text_generation_server.models.gpt_neox import GPTNeoxSharded
try: try:
if torch.cuda.is_available() and not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": if (
torch.cuda.is_available()
and not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false"
):
major, minor = torch.cuda.get_device_capability() major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5 is_sm75 = major == 7 and minor == 5
is_sm8x = major == 8 and minor >= 0 is_sm8x = major == 8 and minor >= 0

View File

@ -46,7 +46,6 @@ class LlamaRMSNorm(nn.Module):
super().__init__() super().__init__()
weight = weights.get_tensor(f"{prefix}.weight") weight = weights.get_tensor(f"{prefix}.weight")
# assert weight.shape == (hidden_size,)
self.weight = nn.Parameter(weight) self.weight = nn.Parameter(weight)
self.variance_epsilon = eps self.variance_epsilon = eps
@ -103,7 +102,9 @@ class FlashLlamaAttention(torch.nn.Module):
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = PositionRotaryEmbedding.load(prefix=f"{prefix}.rotary_emb", weights=weights) self.rotary_emb = PositionRotaryEmbedding.load(
prefix=f"{prefix}.rotary_emb", weights=weights
)
self.softmax_scale = self.head_size ** (-0.5) self.softmax_scale = self.head_size ** (-0.5)

View File

@ -90,10 +90,9 @@ class FlashNeoxAttention(torch.nn.Module):
self.head_size = hidden_size // num_heads self.head_size = hidden_size // num_heads
self.num_heads = self.num_heads // weights.process_group.size() self.num_heads = self.num_heads // weights.process_group.size()
rotary_pct = config.rotary_pct self.rotary_emb = PositionRotaryEmbedding.load(
prefix=f"{prefix}.rotary_emb", weights=weights
rotary_ndims = int(self.head_size * rotary_pct) )
self.rotary_emb = PositionRotaryEmbedding.load(prefix=f"{prefix}.rotary_emb", weights=weights)
self.softmax_scale = self.head_size ** (-0.5) self.softmax_scale = self.head_size ** (-0.5)

View File

@ -1,5 +1,3 @@
import os
import torch import torch
import torch.distributed import torch.distributed
@ -104,7 +102,6 @@ class FlashRWAttention(torch.nn.Module):
config, config,
prefix, prefix,
weights, weights,
reduce=True,
): ):
super().__init__() super().__init__()
self.num_heads = config.n_head self.num_heads = config.n_head
@ -395,7 +392,6 @@ class FlashRWLayer(nn.Module):
config, config,
prefix=f"{prefix}.self_attention", prefix=f"{prefix}.self_attention",
weights=weights, weights=weights,
reduce=False,
) )
self.post_attention_layernorm = ( self.post_attention_layernorm = (
FastLayerNorm.load( FastLayerNorm.load(
@ -548,18 +544,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
if config.model_type == "RefinedWebModel": if config.model_type == "RefinedWebModel":
self.h = nn.ModuleList( self.h = nn.ModuleList(
[ [
FlashRWLayer( FlashRWLayer(layer_id, config, weights)
layer_id,
config,
weights
# config.n_head,
# config.n_head_kv,
# config.hidden_size,
# config.bias,
# config.layer_norm_epsilon,
# config.parallel_attn,
# process_group,
)
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )

View File

@ -48,7 +48,6 @@ from text_generation_server.utils.layers import (
) )
CUSTOM_KERNELS_ENABLED = False CUSTOM_KERNELS_ENABLED = False
if not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True": if not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True":
try: try:
@ -62,7 +61,6 @@ if not CUSTOM_KERNELS_ENABLED:
logger.warning("We're not using custom kernels.") logger.warning("We're not using custom kernels.")
def make_causal_mask( def make_causal_mask(
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
) -> torch.BoolTensor: ) -> torch.BoolTensor:
@ -70,10 +68,16 @@ def make_causal_mask(
Make causal mask used for self-attention. Make causal mask used for self-attention.
""" """
batch_size, target_length = input_ids_shape batch_size, target_length = input_ids_shape
mask = torch.ones((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device) mask = torch.ones(
(target_length, target_length + past_key_values_length),
dtype=torch.bool,
device=device,
)
mask = mask.triu(1 + past_key_values_length) mask = mask.triu(1 + past_key_values_length)
expanded_mask = mask.unsqueeze(0).expand(batch_size, target_length, target_length + past_key_values_length) expanded_mask = mask.unsqueeze(0).expand(
batch_size, target_length, target_length + past_key_values_length
)
return expanded_mask return expanded_mask
@ -89,7 +93,9 @@ def expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
def prepare_attn_mask( def prepare_attn_mask(
attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int attention_mask: torch.Tensor,
input_shape: Tuple[int, int],
past_key_values_length: int,
) -> torch.BoolTensor: ) -> torch.BoolTensor:
# create causal mask # create causal mask
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length] # [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
@ -105,7 +111,9 @@ def prepare_attn_mask(
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length] # [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
expanded_attn_mask = expand_mask(attention_mask, tgt_length=src_length) expanded_attn_mask = expand_mask(attention_mask, tgt_length=src_length)
combined_attention_mask = ( combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask expanded_attn_mask
if combined_attention_mask is None
else expanded_attn_mask | combined_attention_mask
) )
return combined_attention_mask return combined_attention_mask
@ -118,7 +126,6 @@ class GPTNeoXPreTrainedModel(PreTrainedModel):
""" """
class GPTNeoXAttention(nn.Module): class GPTNeoXAttention(nn.Module):
def __init__(self, config, prefix, weights): def __init__(self, config, prefix, weights):
super().__init__() super().__init__()
@ -136,17 +143,21 @@ class GPTNeoXAttention(nn.Module):
# ) # )
# self.register_buffer("masked_bias", torch.tensor(-1e9)) # self.register_buffer("masked_bias", torch.tensor(-1e9))
self.rotary_emb = RotaryEmbedding( self.rotary_emb = RotaryEmbedding(
self.rotary_ndims, config.max_position_embeddings, base=config.rotary_emb_base self.rotary_ndims,
config.max_position_embeddings,
base=config.rotary_emb_base,
) )
self.rotary_emb.inv_freq = nn.Parameter( self.rotary_emb.inv_freq = nn.Parameter(
weights.get_tensor(f"{prefix}.rotary_emb.inv_freq") weights.get_tensor(f"{prefix}.rotary_emb.inv_freq")
) )
self.inv_norm_factor = 1.0 / torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to( self.inv_norm_factor = 1.0 / torch.sqrt(
torch.get_default_dtype() torch.tensor(self.head_size, dtype=torch.float32)
) ).to(torch.get_default_dtype())
assert self.num_attention_heads % weights.process_group.size() == 0 assert self.num_attention_heads % weights.process_group.size() == 0
self.num_attention_heads = self.num_attention_heads // weights.process_group.size() self.num_attention_heads = (
self.num_attention_heads // weights.process_group.size()
)
self.query_key_value = TensorParallelColumnLinear.load( self.query_key_value = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.query_key_value", weights=weights, bias=True config, prefix=f"{prefix}.query_key_value", weights=weights, bias=True
) )
@ -214,10 +225,14 @@ class GPTNeoXAttention(nn.Module):
present = (key, value) if use_cache else None present = (key, value) if use_cache else None
# Compute attention # Compute attention
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) attn_output, attn_weights = self._attn(
query, key, value, attention_mask, head_mask
)
# Reshape outputs # Reshape outputs
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size) attn_output = self._merge_heads(
attn_output, self.num_attention_heads, self.head_size
)
attn_output = self.dense(attn_output) attn_output = self.dense(attn_output)
@ -248,7 +263,9 @@ class GPTNeoXAttention(nn.Module):
# tensor [bs, num_attention_heads, seq_len, attn_head_size] # tensor [bs, num_attention_heads, seq_len, attn_head_size]
tensor = tensor.permute(0, 2, 1, 3).contiguous() tensor = tensor.permute(0, 2, 1, 3).contiguous()
# -> [bs, seq_len, num_attention_heads, attn_head_size] # -> [bs, seq_len, num_attention_heads, attn_head_size]
tensor = tensor.view(tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size) tensor = tensor.view(
tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size
)
# -> [bs, seq_len, hidden_size] # -> [bs, seq_len, hidden_size]
return tensor return tensor
@ -258,7 +275,9 @@ class GPTNeoXAttention(nn.Module):
batch_size, num_attention_heads, query_length, attn_head_size = query.size() batch_size, num_attention_heads, query_length, attn_head_size = query.size()
key_length = key.size(-2) key_length = key.size(-2)
query = query.view(batch_size * num_attention_heads, query_length, attn_head_size) query = query.view(
batch_size * num_attention_heads, query_length, attn_head_size
)
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size) key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
attn_scores = torch.zeros( attn_scores = torch.zeros(
1, 1,
@ -277,8 +296,12 @@ class GPTNeoXAttention(nn.Module):
input_dtype = attn_scores.dtype input_dtype = attn_scores.dtype
if input_dtype in [torch.float16, torch.bfloat16]: if input_dtype in [torch.float16, torch.bfloat16]:
attn_scores = attn_scores.to(torch.float) attn_scores = attn_scores.to(torch.float)
attn_scores = torch.where(attention_mask, torch.finfo(attn_scores.dtype).min, attn_scores) attn_scores = torch.where(
attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length) attention_mask, torch.finfo(attn_scores.dtype).min, attn_scores
)
attn_scores = attn_scores.view(
batch_size, num_attention_heads, query_length, key_length
)
attn_weights = nn.functional.softmax(attn_scores, dim=-1) attn_weights = nn.functional.softmax(attn_scores, dim=-1)
attn_weights = attn_weights.to(value.dtype) attn_weights = attn_weights.to(value.dtype)
@ -294,7 +317,9 @@ class GPTNeoXAttention(nn.Module):
class RotaryEmbedding(torch.nn.Module): class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings, base=10000, device=None): def __init__(self, dim, max_position_embeddings, base=10000, device=None):
super().__init__() super().__init__()
self.true_inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) self.true_inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2).float().to(device) / dim)
)
self.register_buffer("inv_freq", self.true_inv_freq) self.register_buffer("inv_freq", self.true_inv_freq)
# Build here to make `torch.jit.trace` work. # Build here to make `torch.jit.trace` work.
@ -311,7 +336,9 @@ class RotaryEmbedding(torch.nn.Module):
@staticmethod @staticmethod
def _create_cos_sin(inv_freq, max_position_embeddings, dtype, device): def _create_cos_sin(inv_freq, max_position_embeddings, dtype, device):
t = torch.arange(max_position_embeddings, device=inv_freq.device, dtype=inv_freq.dtype) t = torch.arange(
max_position_embeddings, device=inv_freq.device, dtype=inv_freq.dtype
)
freqs = torch.einsum("i,j->ij", t, inv_freq) freqs = torch.einsum("i,j->ij", t, inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
@ -319,7 +346,11 @@ class RotaryEmbedding(torch.nn.Module):
def forward(self, q, k, position_ids, seq_len=None): def forward(self, q, k, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size] # x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached or self.cos_cached is None or self.sin_cached is None: if (
seq_len > self.max_seq_len_cached
or self.cos_cached is None
or self.sin_cached is None
):
if seq_len > self.max_seq_len_cached: if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len
self.cos_cached, self.sin_cached = self._create_cos_sin( self.cos_cached, self.sin_cached = self._create_cos_sin(
@ -371,11 +402,22 @@ class GPTNeoXLayer(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, layer_id, config, weights):
super().__init__() super().__init__()
self.use_parallel_residual = config.use_parallel_residual self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = nn.LayerNorm.load(prefix=f"gpt_neox.layers.{layer_id}.input_layernorm", weights=weights, eps=config.layer_norm_eps) self.input_layernorm = nn.LayerNorm.load(
self.post_attention_layernorm = nn.LayerNorm.load(prefix=f"gpt_neox.layers.{layer_id}.post_attention_layernorm", weights=weights, eps=config.layer_norm_eps) prefix=f"gpt_neox.layers.{layer_id}.input_layernorm",
self.attention = GPTNeoXAttention(config, prefix=f"gpt_neox.layers.{layer_id}.attention", weights=weights) weights=weights,
self.mlp = GPTNeoXMLP(config, prefix=f"gpt_neox.layers.{layer_id}.mlp", weights=weights) eps=config.layer_norm_eps,
)
self.post_attention_layernorm = nn.LayerNorm.load(
prefix=f"gpt_neox.layers.{layer_id}.post_attention_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
self.attention = GPTNeoXAttention(
config, prefix=f"gpt_neox.layers.{layer_id}.attention", weights=weights
)
self.mlp = GPTNeoXMLP(
config, prefix=f"gpt_neox.layers.{layer_id}.mlp", weights=weights
)
def forward( def forward(
self, self,
@ -396,7 +438,9 @@ class GPTNeoXLayer(nn.Module):
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights) attn_output = attention_layer_outputs[
0
] # output_attn: attn_output, present, (attn_weights)
outputs = attention_layer_outputs[1:] outputs = attention_layer_outputs[1:]
if self.use_parallel_residual: if self.use_parallel_residual:
@ -413,7 +457,9 @@ class GPTNeoXLayer(nn.Module):
hidden_states = mlp_output + attn_output hidden_states = mlp_output + attn_output
if use_cache: if use_cache:
outputs = (hidden_states,) + outputs # hidden_states, present, (attn_weights) outputs = (
hidden_states,
) + outputs # hidden_states, present, (attn_weights)
else: else:
outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights) outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights)
@ -427,12 +473,22 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
self.embed_in = TensorParallelEmbedding(prefix="gpt_neox.embed_in", weights=weights) self.embed_in = TensorParallelEmbedding(
self.layers = nn.ModuleList([GPTNeoXLayer(layer_id, config, weights) for layer_id in range(config.num_hidden_layers)]) prefix="gpt_neox.embed_in", weights=weights
self.final_layer_norm = nn.LayerNorm.load(prefix="gpt_neox.final_layer_norm", weights=weights, eps=config.layer_norm_eps) )
self.layers = nn.ModuleList(
[
GPTNeoXLayer(layer_id, config, weights)
for layer_id in range(config.num_hidden_layers)
]
)
self.final_layer_norm = nn.LayerNorm.load(
prefix="gpt_neox.final_layer_norm",
weights=weights,
eps=config.layer_norm_eps,
)
self.tp_world_size = weights.process_group.size() self.tp_world_size = weights.process_group.size()
def forward( def forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
@ -456,15 +512,25 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`). `past_key_values`).
""" """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = (
output_hidden_states = ( output_attentions
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif input_ids is not None: elif input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
elif inputs_embeds is not None: elif inputs_embeds is not None:
@ -482,7 +548,9 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
if position_ids is None: if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device) position_ids = torch.arange(
past_length, seq_length + past_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length) position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else: else:
position_ids = position_ids.view(-1, seq_length).long() position_ids = position_ids.view(-1, seq_length).long()
@ -499,7 +567,9 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
past_key_values_length = past_key_values[0][0].shape[-1] past_key_values_length = past_key_values[0][0].shape[-1]
seq_length_with_past = seq_length_with_past + past_key_values_length seq_length_with_past = seq_length_with_past + past_key_values_length
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) attention_mask = torch.ones(
(batch_size, seq_length_with_past), device=hidden_states.device
)
else: else:
attention_mask = attention_mask.to(hidden_states.device) attention_mask = attention_mask.to(hidden_states.device)
@ -548,7 +618,11 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None) return tuple(
v
for v in [hidden_states, presents, all_hidden_states, all_attentions]
if v is not None
)
return BaseModelOutputWithPast( return BaseModelOutputWithPast(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
@ -564,7 +638,9 @@ class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
def __init__(self, config, weights): def __init__(self, config, weights):
super().__init__(config) super().__init__(config)
self.gpt_neox = GPTNeoXModel(config, weights) self.gpt_neox = GPTNeoXModel(config, weights)
self.embed_out = TensorParallelHead.load(config, prefix="embed_out", weights=weights) self.embed_out = TensorParallelHead.load(
config, prefix="embed_out", weights=weights
)
def forward( def forward(
self, self,
@ -619,7 +695,9 @@ class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
>>> prediction_logits = outputs.logits >>> prediction_logits = outputs.logits
```""" ```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.gpt_neox( outputs = self.gpt_neox(
input_ids, input_ids,
@ -645,7 +723,9 @@ class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
shift_logits = lm_logits[:, :-1, :].contiguous() shift_logits = lm_logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous() labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)) lm_loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
)
if not return_dict: if not return_dict:
output = (lm_logits,) + outputs[1:] output = (lm_logits,) + outputs[1:]
@ -660,7 +740,12 @@ class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
) )
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
**kwargs,
): ):
input_shape = input_ids.shape input_shape = input_ids.shape
@ -700,6 +785,10 @@ class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
reordered_past = () reordered_past = ()
for layer_past in past_key_values: for layer_past in past_key_values:
reordered_past += ( reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], tuple(
past_state.index_select(0, beam_idx)
for past_state in layer_past[:2]
)
+ layer_past[2:],
) )
return reordered_past return reordered_past

View File

@ -845,7 +845,6 @@ class T5Stack(T5PreTrainedModel):
), "You have to initialize the model with valid token embeddings" ), "You have to initialize the model with valid token embeddings"
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
batch_size, seq_length = input_shape batch_size, seq_length = input_shape
# required mask seq length can be calculated via length of past # required mask seq length can be calculated via length of past
@ -1026,7 +1025,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
embed_tokens=self.shared, embed_tokens=self.shared,
) )
self.lm_head = TensorParallelHead.load(config, prefix="lm_head", weights=weights) self.lm_head = TensorParallelHead.load(
config, prefix="lm_head", weights=weights
)
def forward( def forward(
self, self,

View File

@ -1,28 +1,19 @@
import torch import torch
import torch.distributed import torch.distributed
from pathlib import Path
from accelerate import init_empty_weights
from opentelemetry import trace from opentelemetry import trace
from safetensors import safe_open from transformers import AutoTokenizer
from transformers import AutoTokenizer, AutoConfig from typing import Optional
from typing import Optional, List
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_rw_modeling import ( from text_generation_server.models.custom_modeling.flash_rw_modeling import (
RWConfig, RWConfig,
FlashRWForCausalLM, FlashRWForCausalLM,
TensorParallelEmbedding,
TensorParallelRowLinear,
TensorParallelColumnLinear,
) )
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
download_weights,
weight_hub_files,
Weights, Weights,
LocalEntryNotFoundError,
) )
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -73,79 +64,3 @@ class FlashRWSharded(FlashCausalLM):
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
) )
# @staticmethod
# def load_weights(
# model,
# filenames: List[str],
# quantize: Optional[str],
# device: torch.device,
# dtype: torch.dtype,
# rank: int,
# world_size: int,
# ):
# parameters = dict(model.named_parameters())
# for file in filenames:
# with safe_open(
# file, framework="pt", device=str(device) if quantize is None else "cpu"
# ) as f:
# for name in f.keys():
# module_name, param_name = name.rsplit(".", 1)
# module = model.get_submodule(module_name)
# current_parameter_tensor = parameters.get(name, None)
# slice_ = f.get_slice(name)
# if isinstance(module, TensorParallelColumnLinear):
# size = slice_.get_shape()[0]
# block_size = size // world_size
# start = rank * block_size
# stop = (rank + 1) * block_size
# tensor = slice_[start:stop]
# elif isinstance(module, TensorParallelRowLinear):
# if param_name == "weight":
# size = slice_.get_shape()[1]
# block_size = size // world_size
# start = rank * block_size
# stop = (rank + 1) * block_size
# tensor = slice_[:, start:stop]
# else:
# tensor = slice_[:]
# # XXX: Hack for Rowlinear to add the bias only once.
# if rank != 0:
# tensor = torch.zeros_like(tensor)
# elif isinstance(module, TensorParallelEmbedding):
# size = slice_.get_shape()[0]
# block_size = size // world_size
# start = rank * block_size
# stop = (rank + 1) * block_size
# tensor = slice_[start:stop]
# elif name == "lm_head.weight" and model.transformer.tp_embeddings:
# size = slice_.get_shape()[0]
# block_size = size // world_size
# start = rank * block_size
# stop = (rank + 1) * block_size
# tensor = slice_[start:stop]
# else:
# try:
# tensor = slice_[:]
# except:
# tensor = f.get_tensor(name)
# if (
# current_parameter_tensor is not None
# and current_parameter_tensor.shape != tensor.shape
# ):
# raise ValueError(
# f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
# )
# tensor = tensor.contiguous().to(dtype)
# if current_parameter_tensor is not None:
# module._parameters[param_name] = tensor
# else:
# module._buffers[param_name] = tensor
# model.post_load_weights(quantize)

View File

@ -182,6 +182,7 @@ class GalacticaSharded(CausalLM):
tp_parallel=True, tp_parallel=True,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
config.quantize = quantize
tokenizer.pad_token_id = config.pad_token_id tokenizer.pad_token_id = config.pad_token_id
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)

View File

@ -1,13 +1,10 @@
import torch import torch
import torch.distributed import torch.distributed
from typing import List, Optional from typing import Optional
from accelerate import init_empty_weights
from safetensors import safe_open
from transformers import ( from transformers import (
AutoTokenizer, AutoTokenizer,
AutoModelForCausalLM,
AutoConfig, AutoConfig,
) )
from text_generation_server.models import CausalLM from text_generation_server.models import CausalLM

View File

@ -35,7 +35,9 @@ class T5Sharded(Seq2SeqLM):
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 dtype = torch.float32
config = AutoConfig.from_pretrained(model_id, revision=revision, config = AutoConfig.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
config.quantize = quantize config.quantize = quantize

View File

@ -10,8 +10,8 @@ from huggingface_hub import HfApi, hf_hub_download
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from huggingface_hub.utils import ( from huggingface_hub.utils import (
LocalEntryNotFoundError, LocalEntryNotFoundError,
EntryNotFoundError, # Import here to ease try/except in other part of the lib EntryNotFoundError,
RevisionNotFoundError RevisionNotFoundError, # Import here to ease try/except in other part of the lib
) )
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None) WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)

View File

@ -1,4 +1,5 @@
import torch import torch
import torch.distributed
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
@ -44,14 +45,14 @@ class FastLinear(nn.Module):
else: else:
self.bias = None self.bias = None
@staticmethod @classmethod
def load(config, prefix: str, weights, bias: bool): def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_tensor(f"{prefix}.weight") weight = weights.get_tensor(f"{prefix}.weight")
if bias: if bias:
bias = weights.get_tensor(f"{prefix}.bias") bias = weights.get_tensor(f"{prefix}.bias")
else: else:
bias = None bias = None
return FastLinear(weight, bias) return cls(weight, bias)
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.weight, self.bias) return F.linear(input, self.weight, self.bias)
@ -130,9 +131,7 @@ def get_linear(weight, bias, quantize):
elif quantize == "gptq": elif quantize == "gptq":
raise NotImplementedError("Soon") raise NotImplementedError("Soon")
else: else:
raise NotImplementedError( raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
f"Quantization `{config.quantize}` is not implemented yet."
)
return linear return linear
@ -170,17 +169,17 @@ class TensorParallelHead(SuperLayer):
class TensorParallelColumnLinear(SuperLayer): class TensorParallelColumnLinear(SuperLayer):
@staticmethod @classmethod
def load(config, prefix: str, weights, bias: bool): def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_sharded(f"{prefix}.weight", dim=0) weight = weights.get_sharded(f"{prefix}.weight", dim=0)
if bias: if bias:
bias = weights.get_sharded(f"{prefix}.bias", dim=0) bias = weights.get_sharded(f"{prefix}.bias", dim=0)
else: else:
bias = None bias = None
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) return cls(get_linear(weight, bias, config.quantize))
@staticmethod @classmethod
def load_multi(config, prefixes: List[str], weights, bias: bool, dim: int): def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes] w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
weight = torch.cat(w, dim=dim) weight = torch.cat(w, dim=dim)
@ -189,7 +188,7 @@ class TensorParallelColumnLinear(SuperLayer):
bias = torch.cat(b, dim=0) bias = torch.cat(b, dim=0)
else: else:
bias = None bias = None
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) return cls(get_linear(weight, bias, config.quantize))
class TensorParallelRowLinear(SuperLayer): class TensorParallelRowLinear(SuperLayer):
@ -197,15 +196,15 @@ class TensorParallelRowLinear(SuperLayer):
super().__init__(linear) super().__init__(linear)
self.process_group = process_group self.process_group = process_group
@staticmethod @classmethod
def load(config, prefix: str, weights, bias: bool): def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_sharded(f"{prefix}.weight", dim=1) weight = weights.get_sharded(f"{prefix}.weight", dim=1)
if bias and weights.process_group.rank() == 0: if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process # Rank is only on the first rank process
bias = weights.get_tensor(f"{prefix}.bias") bias = weights.get_tensor(f"{prefix}.bias")
else: else:
bias = None bias = None
return TensorParallelRowLinear( return cls(
get_linear(weight, bias, config.quantize), get_linear(weight, bias, config.quantize),
process_group=weights.process_group, process_group=weights.process_group,
) )
@ -308,22 +307,22 @@ try:
self._cos_k_cached = None self._cos_k_cached = None
self._sin_k_cached = None self._sin_k_cached = None
@staticmethod @classmethod
def static(dim, base, device): def static(cls, dim, base, device):
inv_freq = 1.0 / ( inv_freq = 1.0 / (
base base
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
) )
return PositionRotaryEmbedding(inv_freq) return cls(inv_freq)
@staticmethod @classmethod
def load(prefix, weights): def load(cls, prefix, weights):
# XXX: Always load this in float32 ! # XXX: Always load this in float32 !
dtype = weights.dtype dtype = weights.dtype
weights.dtype = torch.float32 weights.dtype = torch.float32
inv_freq = weights.get_tensor(f"{prefix}.inv_freq") inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
weights.dtype = dtype weights.dtype = dtype
return PositionRotaryEmbedding(inv_freq) return cls(inv_freq)
def _update_cos_sin_cache(self, dtype, device, seqlen): def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed, # Reset the tables if the sequence length has changed,