mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
black + cleanup
This commit is contained in:
parent
5e0a6ea1b7
commit
b027f5f129
@ -138,7 +138,7 @@ COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.lin
|
|||||||
# Copy build artifacts from transformers builder
|
# Copy build artifacts from transformers builder
|
||||||
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels
|
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels
|
||||||
|
|
||||||
# Install transformers dependencies
|
# Install flash-attention dependencies
|
||||||
RUN pip install einops --no-cache-dir
|
RUN pip install einops --no-cache-dir
|
||||||
|
|
||||||
# Install server
|
# Install server
|
||||||
|
@ -249,7 +249,6 @@ def launcher(event_loop):
|
|||||||
) as process:
|
) as process:
|
||||||
yield ProcessLauncherHandle(process, port)
|
yield ProcessLauncherHandle(process, port)
|
||||||
|
|
||||||
|
|
||||||
process.terminate()
|
process.terminate()
|
||||||
process.wait(60)
|
process.wait(60)
|
||||||
|
|
||||||
@ -261,6 +260,7 @@ def launcher(event_loop):
|
|||||||
|
|
||||||
if not use_flash_attention:
|
if not use_flash_attention:
|
||||||
del env["USE_FLASH_ATTENTION"]
|
del env["USE_FLASH_ATTENTION"]
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def docker_launcher(
|
def docker_launcher(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -3,7 +3,9 @@ import pytest
|
|||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def neox_handle(launcher):
|
def neox_handle(launcher):
|
||||||
with launcher("stabilityai/stablelm-tuned-alpha-3b", num_shard=1, use_flash_attention=False) as handle:
|
with launcher(
|
||||||
|
"stabilityai/stablelm-tuned-alpha-3b", num_shard=1, use_flash_attention=False
|
||||||
|
) as handle:
|
||||||
yield handle
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,7 +3,9 @@ import pytest
|
|||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def neox_sharded_handle(launcher):
|
def neox_sharded_handle(launcher):
|
||||||
with launcher("OpenAssistant/oasst-sft-1-pythia-12b", num_shard=2, use_flash_attention=False) as handle:
|
with launcher(
|
||||||
|
"OpenAssistant/oasst-sft-1-pythia-12b", num_shard=2, use_flash_attention=False
|
||||||
|
) as handle:
|
||||||
yield handle
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ setup(
|
|||||||
name="custom_kernels.fused_attention_cuda",
|
name="custom_kernels.fused_attention_cuda",
|
||||||
sources=["custom_kernels/fused_attention_cuda.cu"],
|
sources=["custom_kernels/fused_attention_cuda.cu"],
|
||||||
extra_compile_args=["-arch=compute_80", "-std=c++17"],
|
extra_compile_args=["-arch=compute_80", "-std=c++17"],
|
||||||
)
|
),
|
||||||
],
|
],
|
||||||
cmdclass={"build_ext": BuildExtension},
|
cmdclass={"build_ext": BuildExtension},
|
||||||
)
|
)
|
||||||
|
@ -19,7 +19,10 @@ from text_generation_server.models.t5 import T5Sharded
|
|||||||
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if torch.cuda.is_available() and not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
if (
|
||||||
|
torch.cuda.is_available()
|
||||||
|
and not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false"
|
||||||
|
):
|
||||||
major, minor = torch.cuda.get_device_capability()
|
major, minor = torch.cuda.get_device_capability()
|
||||||
is_sm75 = major == 7 and minor == 5
|
is_sm75 = major == 7 and minor == 5
|
||||||
is_sm8x = major == 8 and minor >= 0
|
is_sm8x = major == 8 and minor >= 0
|
||||||
|
@ -46,7 +46,6 @@ class LlamaRMSNorm(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
weight = weights.get_tensor(f"{prefix}.weight")
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
# assert weight.shape == (hidden_size,)
|
|
||||||
self.weight = nn.Parameter(weight)
|
self.weight = nn.Parameter(weight)
|
||||||
self.variance_epsilon = eps
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
@ -103,7 +102,9 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.head_size = self.hidden_size // self.num_heads
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
|
|
||||||
self.rotary_emb = PositionRotaryEmbedding.load(prefix=f"{prefix}.rotary_emb", weights=weights)
|
self.rotary_emb = PositionRotaryEmbedding.load(
|
||||||
|
prefix=f"{prefix}.rotary_emb", weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
self.softmax_scale = self.head_size ** (-0.5)
|
self.softmax_scale = self.head_size ** (-0.5)
|
||||||
|
|
||||||
|
@ -90,10 +90,9 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
self.head_size = hidden_size // num_heads
|
self.head_size = hidden_size // num_heads
|
||||||
self.num_heads = self.num_heads // weights.process_group.size()
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
|
||||||
rotary_pct = config.rotary_pct
|
self.rotary_emb = PositionRotaryEmbedding.load(
|
||||||
|
prefix=f"{prefix}.rotary_emb", weights=weights
|
||||||
rotary_ndims = int(self.head_size * rotary_pct)
|
)
|
||||||
self.rotary_emb = PositionRotaryEmbedding.load(prefix=f"{prefix}.rotary_emb", weights=weights)
|
|
||||||
|
|
||||||
self.softmax_scale = self.head_size ** (-0.5)
|
self.softmax_scale = self.head_size ** (-0.5)
|
||||||
|
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
@ -104,7 +102,6 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
config,
|
config,
|
||||||
prefix,
|
prefix,
|
||||||
weights,
|
weights,
|
||||||
reduce=True,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_heads = config.n_head
|
self.num_heads = config.n_head
|
||||||
@ -395,7 +392,6 @@ class FlashRWLayer(nn.Module):
|
|||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.self_attention",
|
prefix=f"{prefix}.self_attention",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
reduce=False,
|
|
||||||
)
|
)
|
||||||
self.post_attention_layernorm = (
|
self.post_attention_layernorm = (
|
||||||
FastLayerNorm.load(
|
FastLayerNorm.load(
|
||||||
@ -548,18 +544,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
if config.model_type == "RefinedWebModel":
|
if config.model_type == "RefinedWebModel":
|
||||||
self.h = nn.ModuleList(
|
self.h = nn.ModuleList(
|
||||||
[
|
[
|
||||||
FlashRWLayer(
|
FlashRWLayer(layer_id, config, weights)
|
||||||
layer_id,
|
|
||||||
config,
|
|
||||||
weights
|
|
||||||
# config.n_head,
|
|
||||||
# config.n_head_kv,
|
|
||||||
# config.hidden_size,
|
|
||||||
# config.bias,
|
|
||||||
# config.layer_norm_epsilon,
|
|
||||||
# config.parallel_attn,
|
|
||||||
# process_group,
|
|
||||||
)
|
|
||||||
for layer_id in range(config.num_hidden_layers)
|
for layer_id in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -48,7 +48,6 @@ from text_generation_server.utils.layers import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
CUSTOM_KERNELS_ENABLED = False
|
CUSTOM_KERNELS_ENABLED = False
|
||||||
if not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True":
|
if not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True":
|
||||||
try:
|
try:
|
||||||
@ -62,7 +61,6 @@ if not CUSTOM_KERNELS_ENABLED:
|
|||||||
logger.warning("We're not using custom kernels.")
|
logger.warning("We're not using custom kernels.")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def make_causal_mask(
|
def make_causal_mask(
|
||||||
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
|
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
|
||||||
) -> torch.BoolTensor:
|
) -> torch.BoolTensor:
|
||||||
@ -70,10 +68,16 @@ def make_causal_mask(
|
|||||||
Make causal mask used for self-attention.
|
Make causal mask used for self-attention.
|
||||||
"""
|
"""
|
||||||
batch_size, target_length = input_ids_shape
|
batch_size, target_length = input_ids_shape
|
||||||
mask = torch.ones((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
|
mask = torch.ones(
|
||||||
|
(target_length, target_length + past_key_values_length),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
mask = mask.triu(1 + past_key_values_length)
|
mask = mask.triu(1 + past_key_values_length)
|
||||||
|
|
||||||
expanded_mask = mask.unsqueeze(0).expand(batch_size, target_length, target_length + past_key_values_length)
|
expanded_mask = mask.unsqueeze(0).expand(
|
||||||
|
batch_size, target_length, target_length + past_key_values_length
|
||||||
|
)
|
||||||
return expanded_mask
|
return expanded_mask
|
||||||
|
|
||||||
|
|
||||||
@ -89,7 +93,9 @@ def expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
|
|||||||
|
|
||||||
|
|
||||||
def prepare_attn_mask(
|
def prepare_attn_mask(
|
||||||
attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
|
attention_mask: torch.Tensor,
|
||||||
|
input_shape: Tuple[int, int],
|
||||||
|
past_key_values_length: int,
|
||||||
) -> torch.BoolTensor:
|
) -> torch.BoolTensor:
|
||||||
# create causal mask
|
# create causal mask
|
||||||
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
|
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
|
||||||
@ -105,7 +111,9 @@ def prepare_attn_mask(
|
|||||||
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
|
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
|
||||||
expanded_attn_mask = expand_mask(attention_mask, tgt_length=src_length)
|
expanded_attn_mask = expand_mask(attention_mask, tgt_length=src_length)
|
||||||
combined_attention_mask = (
|
combined_attention_mask = (
|
||||||
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
|
expanded_attn_mask
|
||||||
|
if combined_attention_mask is None
|
||||||
|
else expanded_attn_mask | combined_attention_mask
|
||||||
)
|
)
|
||||||
|
|
||||||
return combined_attention_mask
|
return combined_attention_mask
|
||||||
@ -118,7 +126,6 @@ class GPTNeoXPreTrainedModel(PreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class GPTNeoXAttention(nn.Module):
|
class GPTNeoXAttention(nn.Module):
|
||||||
def __init__(self, config, prefix, weights):
|
def __init__(self, config, prefix, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -136,17 +143,21 @@ class GPTNeoXAttention(nn.Module):
|
|||||||
# )
|
# )
|
||||||
# self.register_buffer("masked_bias", torch.tensor(-1e9))
|
# self.register_buffer("masked_bias", torch.tensor(-1e9))
|
||||||
self.rotary_emb = RotaryEmbedding(
|
self.rotary_emb = RotaryEmbedding(
|
||||||
self.rotary_ndims, config.max_position_embeddings, base=config.rotary_emb_base
|
self.rotary_ndims,
|
||||||
|
config.max_position_embeddings,
|
||||||
|
base=config.rotary_emb_base,
|
||||||
)
|
)
|
||||||
self.rotary_emb.inv_freq = nn.Parameter(
|
self.rotary_emb.inv_freq = nn.Parameter(
|
||||||
weights.get_tensor(f"{prefix}.rotary_emb.inv_freq")
|
weights.get_tensor(f"{prefix}.rotary_emb.inv_freq")
|
||||||
)
|
)
|
||||||
self.inv_norm_factor = 1.0 / torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(
|
self.inv_norm_factor = 1.0 / torch.sqrt(
|
||||||
torch.get_default_dtype()
|
torch.tensor(self.head_size, dtype=torch.float32)
|
||||||
)
|
).to(torch.get_default_dtype())
|
||||||
|
|
||||||
assert self.num_attention_heads % weights.process_group.size() == 0
|
assert self.num_attention_heads % weights.process_group.size() == 0
|
||||||
self.num_attention_heads = self.num_attention_heads // weights.process_group.size()
|
self.num_attention_heads = (
|
||||||
|
self.num_attention_heads // weights.process_group.size()
|
||||||
|
)
|
||||||
self.query_key_value = TensorParallelColumnLinear.load(
|
self.query_key_value = TensorParallelColumnLinear.load(
|
||||||
config, prefix=f"{prefix}.query_key_value", weights=weights, bias=True
|
config, prefix=f"{prefix}.query_key_value", weights=weights, bias=True
|
||||||
)
|
)
|
||||||
@ -214,10 +225,14 @@ class GPTNeoXAttention(nn.Module):
|
|||||||
present = (key, value) if use_cache else None
|
present = (key, value) if use_cache else None
|
||||||
|
|
||||||
# Compute attention
|
# Compute attention
|
||||||
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
attn_output, attn_weights = self._attn(
|
||||||
|
query, key, value, attention_mask, head_mask
|
||||||
|
)
|
||||||
|
|
||||||
# Reshape outputs
|
# Reshape outputs
|
||||||
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size)
|
attn_output = self._merge_heads(
|
||||||
|
attn_output, self.num_attention_heads, self.head_size
|
||||||
|
)
|
||||||
|
|
||||||
attn_output = self.dense(attn_output)
|
attn_output = self.dense(attn_output)
|
||||||
|
|
||||||
@ -248,7 +263,9 @@ class GPTNeoXAttention(nn.Module):
|
|||||||
# tensor [bs, num_attention_heads, seq_len, attn_head_size]
|
# tensor [bs, num_attention_heads, seq_len, attn_head_size]
|
||||||
tensor = tensor.permute(0, 2, 1, 3).contiguous()
|
tensor = tensor.permute(0, 2, 1, 3).contiguous()
|
||||||
# -> [bs, seq_len, num_attention_heads, attn_head_size]
|
# -> [bs, seq_len, num_attention_heads, attn_head_size]
|
||||||
tensor = tensor.view(tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size)
|
tensor = tensor.view(
|
||||||
|
tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size
|
||||||
|
)
|
||||||
# -> [bs, seq_len, hidden_size]
|
# -> [bs, seq_len, hidden_size]
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
@ -258,7 +275,9 @@ class GPTNeoXAttention(nn.Module):
|
|||||||
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
|
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
|
||||||
key_length = key.size(-2)
|
key_length = key.size(-2)
|
||||||
|
|
||||||
query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
|
query = query.view(
|
||||||
|
batch_size * num_attention_heads, query_length, attn_head_size
|
||||||
|
)
|
||||||
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
|
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
|
||||||
attn_scores = torch.zeros(
|
attn_scores = torch.zeros(
|
||||||
1,
|
1,
|
||||||
@ -277,8 +296,12 @@ class GPTNeoXAttention(nn.Module):
|
|||||||
input_dtype = attn_scores.dtype
|
input_dtype = attn_scores.dtype
|
||||||
if input_dtype in [torch.float16, torch.bfloat16]:
|
if input_dtype in [torch.float16, torch.bfloat16]:
|
||||||
attn_scores = attn_scores.to(torch.float)
|
attn_scores = attn_scores.to(torch.float)
|
||||||
attn_scores = torch.where(attention_mask, torch.finfo(attn_scores.dtype).min, attn_scores)
|
attn_scores = torch.where(
|
||||||
attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
|
attention_mask, torch.finfo(attn_scores.dtype).min, attn_scores
|
||||||
|
)
|
||||||
|
attn_scores = attn_scores.view(
|
||||||
|
batch_size, num_attention_heads, query_length, key_length
|
||||||
|
)
|
||||||
|
|
||||||
attn_weights = nn.functional.softmax(attn_scores, dim=-1)
|
attn_weights = nn.functional.softmax(attn_scores, dim=-1)
|
||||||
attn_weights = attn_weights.to(value.dtype)
|
attn_weights = attn_weights.to(value.dtype)
|
||||||
@ -294,7 +317,9 @@ class GPTNeoXAttention(nn.Module):
|
|||||||
class RotaryEmbedding(torch.nn.Module):
|
class RotaryEmbedding(torch.nn.Module):
|
||||||
def __init__(self, dim, max_position_embeddings, base=10000, device=None):
|
def __init__(self, dim, max_position_embeddings, base=10000, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.true_inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
self.true_inv_freq = 1.0 / (
|
||||||
|
base ** (torch.arange(0, dim, 2).float().to(device) / dim)
|
||||||
|
)
|
||||||
self.register_buffer("inv_freq", self.true_inv_freq)
|
self.register_buffer("inv_freq", self.true_inv_freq)
|
||||||
|
|
||||||
# Build here to make `torch.jit.trace` work.
|
# Build here to make `torch.jit.trace` work.
|
||||||
@ -311,7 +336,9 @@ class RotaryEmbedding(torch.nn.Module):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_cos_sin(inv_freq, max_position_embeddings, dtype, device):
|
def _create_cos_sin(inv_freq, max_position_embeddings, dtype, device):
|
||||||
t = torch.arange(max_position_embeddings, device=inv_freq.device, dtype=inv_freq.dtype)
|
t = torch.arange(
|
||||||
|
max_position_embeddings, device=inv_freq.device, dtype=inv_freq.dtype
|
||||||
|
)
|
||||||
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
||||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||||
emb = torch.cat((freqs, freqs), dim=-1)
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
@ -319,7 +346,11 @@ class RotaryEmbedding(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(self, q, k, position_ids, seq_len=None):
|
def forward(self, q, k, position_ids, seq_len=None):
|
||||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||||
if seq_len > self.max_seq_len_cached or self.cos_cached is None or self.sin_cached is None:
|
if (
|
||||||
|
seq_len > self.max_seq_len_cached
|
||||||
|
or self.cos_cached is None
|
||||||
|
or self.sin_cached is None
|
||||||
|
):
|
||||||
if seq_len > self.max_seq_len_cached:
|
if seq_len > self.max_seq_len_cached:
|
||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
self.cos_cached, self.sin_cached = self._create_cos_sin(
|
self.cos_cached, self.sin_cached = self._create_cos_sin(
|
||||||
@ -371,11 +402,22 @@ class GPTNeoXLayer(nn.Module):
|
|||||||
def __init__(self, layer_id, config, weights):
|
def __init__(self, layer_id, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.use_parallel_residual = config.use_parallel_residual
|
self.use_parallel_residual = config.use_parallel_residual
|
||||||
self.input_layernorm = nn.LayerNorm.load(prefix=f"gpt_neox.layers.{layer_id}.input_layernorm", weights=weights, eps=config.layer_norm_eps)
|
self.input_layernorm = nn.LayerNorm.load(
|
||||||
self.post_attention_layernorm = nn.LayerNorm.load(prefix=f"gpt_neox.layers.{layer_id}.post_attention_layernorm", weights=weights, eps=config.layer_norm_eps)
|
prefix=f"gpt_neox.layers.{layer_id}.input_layernorm",
|
||||||
self.attention = GPTNeoXAttention(config, prefix=f"gpt_neox.layers.{layer_id}.attention", weights=weights)
|
weights=weights,
|
||||||
self.mlp = GPTNeoXMLP(config, prefix=f"gpt_neox.layers.{layer_id}.mlp", weights=weights)
|
eps=config.layer_norm_eps,
|
||||||
|
)
|
||||||
|
self.post_attention_layernorm = nn.LayerNorm.load(
|
||||||
|
prefix=f"gpt_neox.layers.{layer_id}.post_attention_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.layer_norm_eps,
|
||||||
|
)
|
||||||
|
self.attention = GPTNeoXAttention(
|
||||||
|
config, prefix=f"gpt_neox.layers.{layer_id}.attention", weights=weights
|
||||||
|
)
|
||||||
|
self.mlp = GPTNeoXMLP(
|
||||||
|
config, prefix=f"gpt_neox.layers.{layer_id}.mlp", weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -396,7 +438,9 @@ class GPTNeoXLayer(nn.Module):
|
|||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights)
|
attn_output = attention_layer_outputs[
|
||||||
|
0
|
||||||
|
] # output_attn: attn_output, present, (attn_weights)
|
||||||
outputs = attention_layer_outputs[1:]
|
outputs = attention_layer_outputs[1:]
|
||||||
|
|
||||||
if self.use_parallel_residual:
|
if self.use_parallel_residual:
|
||||||
@ -413,7 +457,9 @@ class GPTNeoXLayer(nn.Module):
|
|||||||
hidden_states = mlp_output + attn_output
|
hidden_states = mlp_output + attn_output
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
outputs = (hidden_states,) + outputs # hidden_states, present, (attn_weights)
|
outputs = (
|
||||||
|
hidden_states,
|
||||||
|
) + outputs # hidden_states, present, (attn_weights)
|
||||||
else:
|
else:
|
||||||
outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights)
|
outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights)
|
||||||
|
|
||||||
@ -427,12 +473,22 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
|||||||
|
|
||||||
self.num_attention_heads = config.num_attention_heads
|
self.num_attention_heads = config.num_attention_heads
|
||||||
|
|
||||||
self.embed_in = TensorParallelEmbedding(prefix="gpt_neox.embed_in", weights=weights)
|
self.embed_in = TensorParallelEmbedding(
|
||||||
self.layers = nn.ModuleList([GPTNeoXLayer(layer_id, config, weights) for layer_id in range(config.num_hidden_layers)])
|
prefix="gpt_neox.embed_in", weights=weights
|
||||||
self.final_layer_norm = nn.LayerNorm.load(prefix="gpt_neox.final_layer_norm", weights=weights, eps=config.layer_norm_eps)
|
)
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
GPTNeoXLayer(layer_id, config, weights)
|
||||||
|
for layer_id in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.final_layer_norm = nn.LayerNorm.load(
|
||||||
|
prefix="gpt_neox.final_layer_norm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.layer_norm_eps,
|
||||||
|
)
|
||||||
self.tp_world_size = weights.process_group.size()
|
self.tp_world_size = weights.process_group.size()
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
@ -456,15 +512,25 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
|||||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||||
`past_key_values`).
|
`past_key_values`).
|
||||||
"""
|
"""
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = (
|
||||||
output_hidden_states = (
|
output_attentions
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
)
|
)
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
raise ValueError(
|
||||||
|
"You cannot specify both input_ids and inputs_embeds at the same time"
|
||||||
|
)
|
||||||
elif input_ids is not None:
|
elif input_ids is not None:
|
||||||
input_shape = input_ids.size()
|
input_shape = input_ids.size()
|
||||||
elif inputs_embeds is not None:
|
elif inputs_embeds is not None:
|
||||||
@ -482,7 +548,9 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
|||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device)
|
position_ids = torch.arange(
|
||||||
|
past_length, seq_length + past_length, dtype=torch.long, device=device
|
||||||
|
)
|
||||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||||
else:
|
else:
|
||||||
position_ids = position_ids.view(-1, seq_length).long()
|
position_ids = position_ids.view(-1, seq_length).long()
|
||||||
@ -499,7 +567,9 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
|||||||
past_key_values_length = past_key_values[0][0].shape[-1]
|
past_key_values_length = past_key_values[0][0].shape[-1]
|
||||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
attention_mask = torch.ones(
|
||||||
|
(batch_size, seq_length_with_past), device=hidden_states.device
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
attention_mask = attention_mask.to(hidden_states.device)
|
attention_mask = attention_mask.to(hidden_states.device)
|
||||||
|
|
||||||
@ -548,7 +618,11 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
|||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)
|
return tuple(
|
||||||
|
v
|
||||||
|
for v in [hidden_states, presents, all_hidden_states, all_attentions]
|
||||||
|
if v is not None
|
||||||
|
)
|
||||||
|
|
||||||
return BaseModelOutputWithPast(
|
return BaseModelOutputWithPast(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
@ -564,7 +638,9 @@ class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
|
|||||||
def __init__(self, config, weights):
|
def __init__(self, config, weights):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.gpt_neox = GPTNeoXModel(config, weights)
|
self.gpt_neox = GPTNeoXModel(config, weights)
|
||||||
self.embed_out = TensorParallelHead.load(config, prefix="embed_out", weights=weights)
|
self.embed_out = TensorParallelHead.load(
|
||||||
|
config, prefix="embed_out", weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -619,7 +695,9 @@ class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
|
|||||||
|
|
||||||
>>> prediction_logits = outputs.logits
|
>>> prediction_logits = outputs.logits
|
||||||
```"""
|
```"""
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
|
||||||
outputs = self.gpt_neox(
|
outputs = self.gpt_neox(
|
||||||
input_ids,
|
input_ids,
|
||||||
@ -645,7 +723,9 @@ class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
|
|||||||
shift_logits = lm_logits[:, :-1, :].contiguous()
|
shift_logits = lm_logits[:, :-1, :].contiguous()
|
||||||
labels = labels[:, 1:].contiguous()
|
labels = labels[:, 1:].contiguous()
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
|
lm_loss = loss_fct(
|
||||||
|
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
|
||||||
|
)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (lm_logits,) + outputs[1:]
|
output = (lm_logits,) + outputs[1:]
|
||||||
@ -660,7 +740,12 @@ class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
self,
|
||||||
|
input_ids,
|
||||||
|
past_key_values=None,
|
||||||
|
attention_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
input_shape = input_ids.shape
|
input_shape = input_ids.shape
|
||||||
|
|
||||||
@ -700,6 +785,10 @@ class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
|
|||||||
reordered_past = ()
|
reordered_past = ()
|
||||||
for layer_past in past_key_values:
|
for layer_past in past_key_values:
|
||||||
reordered_past += (
|
reordered_past += (
|
||||||
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
|
tuple(
|
||||||
|
past_state.index_select(0, beam_idx)
|
||||||
|
for past_state in layer_past[:2]
|
||||||
|
)
|
||||||
|
+ layer_past[2:],
|
||||||
)
|
)
|
||||||
return reordered_past
|
return reordered_past
|
||||||
|
@ -845,7 +845,6 @@ class T5Stack(T5PreTrainedModel):
|
|||||||
), "You have to initialize the model with valid token embeddings"
|
), "You have to initialize the model with valid token embeddings"
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
|
||||||
batch_size, seq_length = input_shape
|
batch_size, seq_length = input_shape
|
||||||
|
|
||||||
# required mask seq length can be calculated via length of past
|
# required mask seq length can be calculated via length of past
|
||||||
@ -1026,7 +1025,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||||||
embed_tokens=self.shared,
|
embed_tokens=self.shared,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.lm_head = TensorParallelHead.load(config, prefix="lm_head", weights=weights)
|
self.lm_head = TensorParallelHead.load(
|
||||||
|
config, prefix="lm_head", weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -1,28 +1,19 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from accelerate import init_empty_weights
|
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from safetensors import safe_open
|
from transformers import AutoTokenizer
|
||||||
from transformers import AutoTokenizer, AutoConfig
|
from typing import Optional
|
||||||
from typing import Optional, List
|
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
from text_generation_server.models import FlashCausalLM
|
||||||
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
|
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
|
||||||
RWConfig,
|
RWConfig,
|
||||||
FlashRWForCausalLM,
|
FlashRWForCausalLM,
|
||||||
TensorParallelEmbedding,
|
|
||||||
TensorParallelRowLinear,
|
|
||||||
TensorParallelColumnLinear,
|
|
||||||
)
|
)
|
||||||
from text_generation_server.utils import (
|
from text_generation_server.utils import (
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
download_weights,
|
|
||||||
weight_hub_files,
|
|
||||||
Weights,
|
Weights,
|
||||||
LocalEntryNotFoundError,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
@ -73,79 +64,3 @@ class FlashRWSharded(FlashCausalLM):
|
|||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# @staticmethod
|
|
||||||
# def load_weights(
|
|
||||||
# model,
|
|
||||||
# filenames: List[str],
|
|
||||||
# quantize: Optional[str],
|
|
||||||
# device: torch.device,
|
|
||||||
# dtype: torch.dtype,
|
|
||||||
# rank: int,
|
|
||||||
# world_size: int,
|
|
||||||
# ):
|
|
||||||
# parameters = dict(model.named_parameters())
|
|
||||||
# for file in filenames:
|
|
||||||
# with safe_open(
|
|
||||||
# file, framework="pt", device=str(device) if quantize is None else "cpu"
|
|
||||||
# ) as f:
|
|
||||||
# for name in f.keys():
|
|
||||||
# module_name, param_name = name.rsplit(".", 1)
|
|
||||||
# module = model.get_submodule(module_name)
|
|
||||||
|
|
||||||
# current_parameter_tensor = parameters.get(name, None)
|
|
||||||
|
|
||||||
# slice_ = f.get_slice(name)
|
|
||||||
|
|
||||||
# if isinstance(module, TensorParallelColumnLinear):
|
|
||||||
# size = slice_.get_shape()[0]
|
|
||||||
# block_size = size // world_size
|
|
||||||
# start = rank * block_size
|
|
||||||
# stop = (rank + 1) * block_size
|
|
||||||
# tensor = slice_[start:stop]
|
|
||||||
# elif isinstance(module, TensorParallelRowLinear):
|
|
||||||
# if param_name == "weight":
|
|
||||||
# size = slice_.get_shape()[1]
|
|
||||||
# block_size = size // world_size
|
|
||||||
# start = rank * block_size
|
|
||||||
# stop = (rank + 1) * block_size
|
|
||||||
# tensor = slice_[:, start:stop]
|
|
||||||
# else:
|
|
||||||
# tensor = slice_[:]
|
|
||||||
# # XXX: Hack for Rowlinear to add the bias only once.
|
|
||||||
# if rank != 0:
|
|
||||||
# tensor = torch.zeros_like(tensor)
|
|
||||||
# elif isinstance(module, TensorParallelEmbedding):
|
|
||||||
# size = slice_.get_shape()[0]
|
|
||||||
# block_size = size // world_size
|
|
||||||
# start = rank * block_size
|
|
||||||
# stop = (rank + 1) * block_size
|
|
||||||
# tensor = slice_[start:stop]
|
|
||||||
# elif name == "lm_head.weight" and model.transformer.tp_embeddings:
|
|
||||||
# size = slice_.get_shape()[0]
|
|
||||||
# block_size = size // world_size
|
|
||||||
# start = rank * block_size
|
|
||||||
# stop = (rank + 1) * block_size
|
|
||||||
# tensor = slice_[start:stop]
|
|
||||||
# else:
|
|
||||||
# try:
|
|
||||||
# tensor = slice_[:]
|
|
||||||
# except:
|
|
||||||
# tensor = f.get_tensor(name)
|
|
||||||
|
|
||||||
# if (
|
|
||||||
# current_parameter_tensor is not None
|
|
||||||
# and current_parameter_tensor.shape != tensor.shape
|
|
||||||
# ):
|
|
||||||
# raise ValueError(
|
|
||||||
# f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
|
|
||||||
# )
|
|
||||||
|
|
||||||
# tensor = tensor.contiguous().to(dtype)
|
|
||||||
|
|
||||||
# if current_parameter_tensor is not None:
|
|
||||||
# module._parameters[param_name] = tensor
|
|
||||||
# else:
|
|
||||||
# module._buffers[param_name] = tensor
|
|
||||||
|
|
||||||
# model.post_load_weights(quantize)
|
|
||||||
|
@ -182,6 +182,7 @@ class GalacticaSharded(CausalLM):
|
|||||||
tp_parallel=True,
|
tp_parallel=True,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
config.quantize = quantize
|
||||||
tokenizer.pad_token_id = config.pad_token_id
|
tokenizer.pad_token_id = config.pad_token_id
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
@ -1,13 +1,10 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from accelerate import init_empty_weights
|
|
||||||
from safetensors import safe_open
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
AutoModelForCausalLM,
|
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
)
|
)
|
||||||
from text_generation_server.models import CausalLM
|
from text_generation_server.models import CausalLM
|
||||||
|
@ -35,7 +35,9 @@ class T5Sharded(Seq2SeqLM):
|
|||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(model_id, revision=revision,
|
config = AutoConfig.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
@ -10,8 +10,8 @@ from huggingface_hub import HfApi, hf_hub_download
|
|||||||
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
||||||
from huggingface_hub.utils import (
|
from huggingface_hub.utils import (
|
||||||
LocalEntryNotFoundError,
|
LocalEntryNotFoundError,
|
||||||
EntryNotFoundError, # Import here to ease try/except in other part of the lib
|
EntryNotFoundError,
|
||||||
RevisionNotFoundError
|
RevisionNotFoundError, # Import here to ease try/except in other part of the lib
|
||||||
)
|
)
|
||||||
|
|
||||||
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
|
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
@ -44,14 +45,14 @@ class FastLinear(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.bias = None
|
self.bias = None
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def load(config, prefix: str, weights, bias: bool):
|
def load(cls, config, prefix: str, weights, bias: bool):
|
||||||
weight = weights.get_tensor(f"{prefix}.weight")
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
if bias:
|
if bias:
|
||||||
bias = weights.get_tensor(f"{prefix}.bias")
|
bias = weights.get_tensor(f"{prefix}.bias")
|
||||||
else:
|
else:
|
||||||
bias = None
|
bias = None
|
||||||
return FastLinear(weight, bias)
|
return cls(weight, bias)
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
return F.linear(input, self.weight, self.bias)
|
return F.linear(input, self.weight, self.bias)
|
||||||
@ -130,9 +131,7 @@ def get_linear(weight, bias, quantize):
|
|||||||
elif quantize == "gptq":
|
elif quantize == "gptq":
|
||||||
raise NotImplementedError("Soon")
|
raise NotImplementedError("Soon")
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
|
||||||
f"Quantization `{config.quantize}` is not implemented yet."
|
|
||||||
)
|
|
||||||
return linear
|
return linear
|
||||||
|
|
||||||
|
|
||||||
@ -170,17 +169,17 @@ class TensorParallelHead(SuperLayer):
|
|||||||
|
|
||||||
|
|
||||||
class TensorParallelColumnLinear(SuperLayer):
|
class TensorParallelColumnLinear(SuperLayer):
|
||||||
@staticmethod
|
@classmethod
|
||||||
def load(config, prefix: str, weights, bias: bool):
|
def load(cls, config, prefix: str, weights, bias: bool):
|
||||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||||
if bias:
|
if bias:
|
||||||
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
||||||
else:
|
else:
|
||||||
bias = None
|
bias = None
|
||||||
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
return cls(get_linear(weight, bias, config.quantize))
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def load_multi(config, prefixes: List[str], weights, bias: bool, dim: int):
|
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
|
||||||
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||||
weight = torch.cat(w, dim=dim)
|
weight = torch.cat(w, dim=dim)
|
||||||
|
|
||||||
@ -189,7 +188,7 @@ class TensorParallelColumnLinear(SuperLayer):
|
|||||||
bias = torch.cat(b, dim=0)
|
bias = torch.cat(b, dim=0)
|
||||||
else:
|
else:
|
||||||
bias = None
|
bias = None
|
||||||
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
return cls(get_linear(weight, bias, config.quantize))
|
||||||
|
|
||||||
|
|
||||||
class TensorParallelRowLinear(SuperLayer):
|
class TensorParallelRowLinear(SuperLayer):
|
||||||
@ -197,15 +196,15 @@ class TensorParallelRowLinear(SuperLayer):
|
|||||||
super().__init__(linear)
|
super().__init__(linear)
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def load(config, prefix: str, weights, bias: bool):
|
def load(cls, config, prefix: str, weights, bias: bool):
|
||||||
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
|
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
if bias and weights.process_group.rank() == 0:
|
if bias and weights.process_group.rank() == 0:
|
||||||
# Rank is only on the first rank process
|
# Rank is only on the first rank process
|
||||||
bias = weights.get_tensor(f"{prefix}.bias")
|
bias = weights.get_tensor(f"{prefix}.bias")
|
||||||
else:
|
else:
|
||||||
bias = None
|
bias = None
|
||||||
return TensorParallelRowLinear(
|
return cls(
|
||||||
get_linear(weight, bias, config.quantize),
|
get_linear(weight, bias, config.quantize),
|
||||||
process_group=weights.process_group,
|
process_group=weights.process_group,
|
||||||
)
|
)
|
||||||
@ -308,22 +307,22 @@ try:
|
|||||||
self._cos_k_cached = None
|
self._cos_k_cached = None
|
||||||
self._sin_k_cached = None
|
self._sin_k_cached = None
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def static(dim, base, device):
|
def static(cls, dim, base, device):
|
||||||
inv_freq = 1.0 / (
|
inv_freq = 1.0 / (
|
||||||
base
|
base
|
||||||
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
|
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
|
||||||
)
|
)
|
||||||
return PositionRotaryEmbedding(inv_freq)
|
return cls(inv_freq)
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def load(prefix, weights):
|
def load(cls, prefix, weights):
|
||||||
# XXX: Always load this in float32 !
|
# XXX: Always load this in float32 !
|
||||||
dtype = weights.dtype
|
dtype = weights.dtype
|
||||||
weights.dtype = torch.float32
|
weights.dtype = torch.float32
|
||||||
inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
|
inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
|
||||||
weights.dtype = dtype
|
weights.dtype = dtype
|
||||||
return PositionRotaryEmbedding(inv_freq)
|
return cls(inv_freq)
|
||||||
|
|
||||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||||
# Reset the tables if the sequence length has changed,
|
# Reset the tables if the sequence length has changed,
|
||||||
|
Loading…
Reference in New Issue
Block a user