mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
black + cleanup
This commit is contained in:
parent
5e0a6ea1b7
commit
b027f5f129
@ -138,7 +138,7 @@ COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.lin
|
||||
# Copy build artifacts from transformers builder
|
||||
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels
|
||||
|
||||
# Install transformers dependencies
|
||||
# Install flash-attention dependencies
|
||||
RUN pip install einops --no-cache-dir
|
||||
|
||||
# Install server
|
||||
|
@ -249,7 +249,6 @@ def launcher(event_loop):
|
||||
) as process:
|
||||
yield ProcessLauncherHandle(process, port)
|
||||
|
||||
|
||||
process.terminate()
|
||||
process.wait(60)
|
||||
|
||||
@ -261,6 +260,7 @@ def launcher(event_loop):
|
||||
|
||||
if not use_flash_attention:
|
||||
del env["USE_FLASH_ATTENTION"]
|
||||
|
||||
@contextlib.contextmanager
|
||||
def docker_launcher(
|
||||
model_id: str,
|
||||
|
@ -3,7 +3,9 @@ import pytest
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def neox_handle(launcher):
|
||||
with launcher("stabilityai/stablelm-tuned-alpha-3b", num_shard=1, use_flash_attention=False) as handle:
|
||||
with launcher(
|
||||
"stabilityai/stablelm-tuned-alpha-3b", num_shard=1, use_flash_attention=False
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
|
@ -3,7 +3,9 @@ import pytest
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def neox_sharded_handle(launcher):
|
||||
with launcher("OpenAssistant/oasst-sft-1-pythia-12b", num_shard=2, use_flash_attention=False) as handle:
|
||||
with launcher(
|
||||
"OpenAssistant/oasst-sft-1-pythia-12b", num_shard=2, use_flash_attention=False
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
|
@ -13,7 +13,7 @@ setup(
|
||||
name="custom_kernels.fused_attention_cuda",
|
||||
sources=["custom_kernels/fused_attention_cuda.cu"],
|
||||
extra_compile_args=["-arch=compute_80", "-std=c++17"],
|
||||
)
|
||||
),
|
||||
],
|
||||
cmdclass={"build_ext": BuildExtension},
|
||||
)
|
||||
|
@ -19,7 +19,10 @@ from text_generation_server.models.t5 import T5Sharded
|
||||
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
||||
|
||||
try:
|
||||
if torch.cuda.is_available() and not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||
if (
|
||||
torch.cuda.is_available()
|
||||
and not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false"
|
||||
):
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
is_sm75 = major == 7 and minor == 5
|
||||
is_sm8x = major == 8 and minor >= 0
|
||||
|
@ -46,7 +46,6 @@ class LlamaRMSNorm(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
# assert weight.shape == (hidden_size,)
|
||||
self.weight = nn.Parameter(weight)
|
||||
self.variance_epsilon = eps
|
||||
|
||||
@ -103,7 +102,9 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.load(prefix=f"{prefix}.rotary_emb", weights=weights)
|
||||
self.rotary_emb = PositionRotaryEmbedding.load(
|
||||
prefix=f"{prefix}.rotary_emb", weights=weights
|
||||
)
|
||||
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
||||
|
@ -90,10 +90,9 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
self.head_size = hidden_size // num_heads
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
|
||||
rotary_pct = config.rotary_pct
|
||||
|
||||
rotary_ndims = int(self.head_size * rotary_pct)
|
||||
self.rotary_emb = PositionRotaryEmbedding.load(prefix=f"{prefix}.rotary_emb", weights=weights)
|
||||
self.rotary_emb = PositionRotaryEmbedding.load(
|
||||
prefix=f"{prefix}.rotary_emb", weights=weights
|
||||
)
|
||||
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
||||
|
@ -1,5 +1,3 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
@ -104,7 +102,6 @@ class FlashRWAttention(torch.nn.Module):
|
||||
config,
|
||||
prefix,
|
||||
weights,
|
||||
reduce=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = config.n_head
|
||||
@ -395,7 +392,6 @@ class FlashRWLayer(nn.Module):
|
||||
config,
|
||||
prefix=f"{prefix}.self_attention",
|
||||
weights=weights,
|
||||
reduce=False,
|
||||
)
|
||||
self.post_attention_layernorm = (
|
||||
FastLayerNorm.load(
|
||||
@ -548,18 +544,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||
if config.model_type == "RefinedWebModel":
|
||||
self.h = nn.ModuleList(
|
||||
[
|
||||
FlashRWLayer(
|
||||
layer_id,
|
||||
config,
|
||||
weights
|
||||
# config.n_head,
|
||||
# config.n_head_kv,
|
||||
# config.hidden_size,
|
||||
# config.bias,
|
||||
# config.layer_norm_epsilon,
|
||||
# config.parallel_attn,
|
||||
# process_group,
|
||||
)
|
||||
FlashRWLayer(layer_id, config, weights)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
@ -48,7 +48,6 @@ from text_generation_server.utils.layers import (
|
||||
)
|
||||
|
||||
|
||||
|
||||
CUSTOM_KERNELS_ENABLED = False
|
||||
if not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True":
|
||||
try:
|
||||
@ -62,7 +61,6 @@ if not CUSTOM_KERNELS_ENABLED:
|
||||
logger.warning("We're not using custom kernels.")
|
||||
|
||||
|
||||
|
||||
def make_causal_mask(
|
||||
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
|
||||
) -> torch.BoolTensor:
|
||||
@ -70,10 +68,16 @@ def make_causal_mask(
|
||||
Make causal mask used for self-attention.
|
||||
"""
|
||||
batch_size, target_length = input_ids_shape
|
||||
mask = torch.ones((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
|
||||
mask = torch.ones(
|
||||
(target_length, target_length + past_key_values_length),
|
||||
dtype=torch.bool,
|
||||
device=device,
|
||||
)
|
||||
mask = mask.triu(1 + past_key_values_length)
|
||||
|
||||
expanded_mask = mask.unsqueeze(0).expand(batch_size, target_length, target_length + past_key_values_length)
|
||||
expanded_mask = mask.unsqueeze(0).expand(
|
||||
batch_size, target_length, target_length + past_key_values_length
|
||||
)
|
||||
return expanded_mask
|
||||
|
||||
|
||||
@ -89,7 +93,9 @@ def expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
|
||||
|
||||
|
||||
def prepare_attn_mask(
|
||||
attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
|
||||
attention_mask: torch.Tensor,
|
||||
input_shape: Tuple[int, int],
|
||||
past_key_values_length: int,
|
||||
) -> torch.BoolTensor:
|
||||
# create causal mask
|
||||
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
|
||||
@ -105,7 +111,9 @@ def prepare_attn_mask(
|
||||
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
|
||||
expanded_attn_mask = expand_mask(attention_mask, tgt_length=src_length)
|
||||
combined_attention_mask = (
|
||||
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
|
||||
expanded_attn_mask
|
||||
if combined_attention_mask is None
|
||||
else expanded_attn_mask | combined_attention_mask
|
||||
)
|
||||
|
||||
return combined_attention_mask
|
||||
@ -118,7 +126,6 @@ class GPTNeoXPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class GPTNeoXAttention(nn.Module):
|
||||
def __init__(self, config, prefix, weights):
|
||||
super().__init__()
|
||||
@ -136,17 +143,21 @@ class GPTNeoXAttention(nn.Module):
|
||||
# )
|
||||
# self.register_buffer("masked_bias", torch.tensor(-1e9))
|
||||
self.rotary_emb = RotaryEmbedding(
|
||||
self.rotary_ndims, config.max_position_embeddings, base=config.rotary_emb_base
|
||||
self.rotary_ndims,
|
||||
config.max_position_embeddings,
|
||||
base=config.rotary_emb_base,
|
||||
)
|
||||
self.rotary_emb.inv_freq = nn.Parameter(
|
||||
weights.get_tensor(f"{prefix}.rotary_emb.inv_freq")
|
||||
)
|
||||
self.inv_norm_factor = 1.0 / torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(
|
||||
torch.get_default_dtype()
|
||||
)
|
||||
self.inv_norm_factor = 1.0 / torch.sqrt(
|
||||
torch.tensor(self.head_size, dtype=torch.float32)
|
||||
).to(torch.get_default_dtype())
|
||||
|
||||
assert self.num_attention_heads % weights.process_group.size() == 0
|
||||
self.num_attention_heads = self.num_attention_heads // weights.process_group.size()
|
||||
self.num_attention_heads = (
|
||||
self.num_attention_heads // weights.process_group.size()
|
||||
)
|
||||
self.query_key_value = TensorParallelColumnLinear.load(
|
||||
config, prefix=f"{prefix}.query_key_value", weights=weights, bias=True
|
||||
)
|
||||
@ -214,10 +225,14 @@ class GPTNeoXAttention(nn.Module):
|
||||
present = (key, value) if use_cache else None
|
||||
|
||||
# Compute attention
|
||||
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
||||
attn_output, attn_weights = self._attn(
|
||||
query, key, value, attention_mask, head_mask
|
||||
)
|
||||
|
||||
# Reshape outputs
|
||||
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size)
|
||||
attn_output = self._merge_heads(
|
||||
attn_output, self.num_attention_heads, self.head_size
|
||||
)
|
||||
|
||||
attn_output = self.dense(attn_output)
|
||||
|
||||
@ -248,7 +263,9 @@ class GPTNeoXAttention(nn.Module):
|
||||
# tensor [bs, num_attention_heads, seq_len, attn_head_size]
|
||||
tensor = tensor.permute(0, 2, 1, 3).contiguous()
|
||||
# -> [bs, seq_len, num_attention_heads, attn_head_size]
|
||||
tensor = tensor.view(tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size)
|
||||
tensor = tensor.view(
|
||||
tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size
|
||||
)
|
||||
# -> [bs, seq_len, hidden_size]
|
||||
return tensor
|
||||
|
||||
@ -258,7 +275,9 @@ class GPTNeoXAttention(nn.Module):
|
||||
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
|
||||
key_length = key.size(-2)
|
||||
|
||||
query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
|
||||
query = query.view(
|
||||
batch_size * num_attention_heads, query_length, attn_head_size
|
||||
)
|
||||
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
|
||||
attn_scores = torch.zeros(
|
||||
1,
|
||||
@ -277,8 +296,12 @@ class GPTNeoXAttention(nn.Module):
|
||||
input_dtype = attn_scores.dtype
|
||||
if input_dtype in [torch.float16, torch.bfloat16]:
|
||||
attn_scores = attn_scores.to(torch.float)
|
||||
attn_scores = torch.where(attention_mask, torch.finfo(attn_scores.dtype).min, attn_scores)
|
||||
attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
|
||||
attn_scores = torch.where(
|
||||
attention_mask, torch.finfo(attn_scores.dtype).min, attn_scores
|
||||
)
|
||||
attn_scores = attn_scores.view(
|
||||
batch_size, num_attention_heads, query_length, key_length
|
||||
)
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_scores, dim=-1)
|
||||
attn_weights = attn_weights.to(value.dtype)
|
||||
@ -294,7 +317,9 @@ class GPTNeoXAttention(nn.Module):
|
||||
class RotaryEmbedding(torch.nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings, base=10000, device=None):
|
||||
super().__init__()
|
||||
self.true_inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
||||
self.true_inv_freq = 1.0 / (
|
||||
base ** (torch.arange(0, dim, 2).float().to(device) / dim)
|
||||
)
|
||||
self.register_buffer("inv_freq", self.true_inv_freq)
|
||||
|
||||
# Build here to make `torch.jit.trace` work.
|
||||
@ -311,7 +336,9 @@ class RotaryEmbedding(torch.nn.Module):
|
||||
|
||||
@staticmethod
|
||||
def _create_cos_sin(inv_freq, max_position_embeddings, dtype, device):
|
||||
t = torch.arange(max_position_embeddings, device=inv_freq.device, dtype=inv_freq.dtype)
|
||||
t = torch.arange(
|
||||
max_position_embeddings, device=inv_freq.device, dtype=inv_freq.dtype
|
||||
)
|
||||
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
@ -319,7 +346,11 @@ class RotaryEmbedding(torch.nn.Module):
|
||||
|
||||
def forward(self, q, k, position_ids, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
if seq_len > self.max_seq_len_cached or self.cos_cached is None or self.sin_cached is None:
|
||||
if (
|
||||
seq_len > self.max_seq_len_cached
|
||||
or self.cos_cached is None
|
||||
or self.sin_cached is None
|
||||
):
|
||||
if seq_len > self.max_seq_len_cached:
|
||||
self.max_seq_len_cached = seq_len
|
||||
self.cos_cached, self.sin_cached = self._create_cos_sin(
|
||||
@ -371,11 +402,22 @@ class GPTNeoXLayer(nn.Module):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
super().__init__()
|
||||
self.use_parallel_residual = config.use_parallel_residual
|
||||
self.input_layernorm = nn.LayerNorm.load(prefix=f"gpt_neox.layers.{layer_id}.input_layernorm", weights=weights, eps=config.layer_norm_eps)
|
||||
self.post_attention_layernorm = nn.LayerNorm.load(prefix=f"gpt_neox.layers.{layer_id}.post_attention_layernorm", weights=weights, eps=config.layer_norm_eps)
|
||||
self.attention = GPTNeoXAttention(config, prefix=f"gpt_neox.layers.{layer_id}.attention", weights=weights)
|
||||
self.mlp = GPTNeoXMLP(config, prefix=f"gpt_neox.layers.{layer_id}.mlp", weights=weights)
|
||||
|
||||
self.input_layernorm = nn.LayerNorm.load(
|
||||
prefix=f"gpt_neox.layers.{layer_id}.input_layernorm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
self.post_attention_layernorm = nn.LayerNorm.load(
|
||||
prefix=f"gpt_neox.layers.{layer_id}.post_attention_layernorm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
self.attention = GPTNeoXAttention(
|
||||
config, prefix=f"gpt_neox.layers.{layer_id}.attention", weights=weights
|
||||
)
|
||||
self.mlp = GPTNeoXMLP(
|
||||
config, prefix=f"gpt_neox.layers.{layer_id}.mlp", weights=weights
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -396,7 +438,9 @@ class GPTNeoXLayer(nn.Module):
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights)
|
||||
attn_output = attention_layer_outputs[
|
||||
0
|
||||
] # output_attn: attn_output, present, (attn_weights)
|
||||
outputs = attention_layer_outputs[1:]
|
||||
|
||||
if self.use_parallel_residual:
|
||||
@ -413,7 +457,9 @@ class GPTNeoXLayer(nn.Module):
|
||||
hidden_states = mlp_output + attn_output
|
||||
|
||||
if use_cache:
|
||||
outputs = (hidden_states,) + outputs # hidden_states, present, (attn_weights)
|
||||
outputs = (
|
||||
hidden_states,
|
||||
) + outputs # hidden_states, present, (attn_weights)
|
||||
else:
|
||||
outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights)
|
||||
|
||||
@ -427,12 +473,22 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
||||
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
|
||||
self.embed_in = TensorParallelEmbedding(prefix="gpt_neox.embed_in", weights=weights)
|
||||
self.layers = nn.ModuleList([GPTNeoXLayer(layer_id, config, weights) for layer_id in range(config.num_hidden_layers)])
|
||||
self.final_layer_norm = nn.LayerNorm.load(prefix="gpt_neox.final_layer_norm", weights=weights, eps=config.layer_norm_eps)
|
||||
self.embed_in = TensorParallelEmbedding(
|
||||
prefix="gpt_neox.embed_in", weights=weights
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
GPTNeoXLayer(layer_id, config, weights)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.final_layer_norm = nn.LayerNorm.load(
|
||||
prefix="gpt_neox.final_layer_norm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
self.tp_world_size = weights.process_group.size()
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
@ -456,15 +512,25 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||
`past_key_values`).
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
raise ValueError(
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time"
|
||||
)
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
elif inputs_embeds is not None:
|
||||
@ -482,7 +548,9 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
||||
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device)
|
||||
position_ids = torch.arange(
|
||||
past_length, seq_length + past_length, dtype=torch.long, device=device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
@ -499,7 +567,9 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
||||
past_key_values_length = past_key_values[0][0].shape[-1]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past), device=hidden_states.device
|
||||
)
|
||||
else:
|
||||
attention_mask = attention_mask.to(hidden_states.device)
|
||||
|
||||
@ -548,7 +618,11 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, presents, all_hidden_states, all_attentions]
|
||||
if v is not None
|
||||
)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
@ -564,7 +638,9 @@ class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__(config)
|
||||
self.gpt_neox = GPTNeoXModel(config, weights)
|
||||
self.embed_out = TensorParallelHead.load(config, prefix="embed_out", weights=weights)
|
||||
self.embed_out = TensorParallelHead.load(
|
||||
config, prefix="embed_out", weights=weights
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -619,7 +695,9 @@ class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
|
||||
|
||||
>>> prediction_logits = outputs.logits
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
outputs = self.gpt_neox(
|
||||
input_ids,
|
||||
@ -645,7 +723,9 @@ class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
|
||||
shift_logits = lm_logits[:, :-1, :].contiguous()
|
||||
labels = labels[:, 1:].contiguous()
|
||||
loss_fct = CrossEntropyLoss()
|
||||
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
|
||||
lm_loss = loss_fct(
|
||||
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + outputs[1:]
|
||||
@ -660,7 +740,12 @@ class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
**kwargs,
|
||||
):
|
||||
input_shape = input_ids.shape
|
||||
|
||||
@ -700,6 +785,10 @@ class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
|
||||
tuple(
|
||||
past_state.index_select(0, beam_idx)
|
||||
for past_state in layer_past[:2]
|
||||
)
|
||||
+ layer_past[2:],
|
||||
)
|
||||
return reordered_past
|
||||
|
@ -845,7 +845,6 @@ class T5Stack(T5PreTrainedModel):
|
||||
), "You have to initialize the model with valid token embeddings"
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
|
||||
batch_size, seq_length = input_shape
|
||||
|
||||
# required mask seq length can be calculated via length of past
|
||||
@ -1026,7 +1025,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
embed_tokens=self.shared,
|
||||
)
|
||||
|
||||
self.lm_head = TensorParallelHead.load(config, prefix="lm_head", weights=weights)
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
config, prefix="lm_head", weights=weights
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -1,28 +1,19 @@
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from pathlib import Path
|
||||
from accelerate import init_empty_weights
|
||||
from opentelemetry import trace
|
||||
from safetensors import safe_open
|
||||
from transformers import AutoTokenizer, AutoConfig
|
||||
from typing import Optional, List
|
||||
from transformers import AutoTokenizer
|
||||
from typing import Optional
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
|
||||
RWConfig,
|
||||
FlashRWForCausalLM,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
download_weights,
|
||||
weight_hub_files,
|
||||
Weights,
|
||||
LocalEntryNotFoundError,
|
||||
)
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
@ -73,79 +64,3 @@ class FlashRWSharded(FlashCausalLM):
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
# @staticmethod
|
||||
# def load_weights(
|
||||
# model,
|
||||
# filenames: List[str],
|
||||
# quantize: Optional[str],
|
||||
# device: torch.device,
|
||||
# dtype: torch.dtype,
|
||||
# rank: int,
|
||||
# world_size: int,
|
||||
# ):
|
||||
# parameters = dict(model.named_parameters())
|
||||
# for file in filenames:
|
||||
# with safe_open(
|
||||
# file, framework="pt", device=str(device) if quantize is None else "cpu"
|
||||
# ) as f:
|
||||
# for name in f.keys():
|
||||
# module_name, param_name = name.rsplit(".", 1)
|
||||
# module = model.get_submodule(module_name)
|
||||
|
||||
# current_parameter_tensor = parameters.get(name, None)
|
||||
|
||||
# slice_ = f.get_slice(name)
|
||||
|
||||
# if isinstance(module, TensorParallelColumnLinear):
|
||||
# size = slice_.get_shape()[0]
|
||||
# block_size = size // world_size
|
||||
# start = rank * block_size
|
||||
# stop = (rank + 1) * block_size
|
||||
# tensor = slice_[start:stop]
|
||||
# elif isinstance(module, TensorParallelRowLinear):
|
||||
# if param_name == "weight":
|
||||
# size = slice_.get_shape()[1]
|
||||
# block_size = size // world_size
|
||||
# start = rank * block_size
|
||||
# stop = (rank + 1) * block_size
|
||||
# tensor = slice_[:, start:stop]
|
||||
# else:
|
||||
# tensor = slice_[:]
|
||||
# # XXX: Hack for Rowlinear to add the bias only once.
|
||||
# if rank != 0:
|
||||
# tensor = torch.zeros_like(tensor)
|
||||
# elif isinstance(module, TensorParallelEmbedding):
|
||||
# size = slice_.get_shape()[0]
|
||||
# block_size = size // world_size
|
||||
# start = rank * block_size
|
||||
# stop = (rank + 1) * block_size
|
||||
# tensor = slice_[start:stop]
|
||||
# elif name == "lm_head.weight" and model.transformer.tp_embeddings:
|
||||
# size = slice_.get_shape()[0]
|
||||
# block_size = size // world_size
|
||||
# start = rank * block_size
|
||||
# stop = (rank + 1) * block_size
|
||||
# tensor = slice_[start:stop]
|
||||
# else:
|
||||
# try:
|
||||
# tensor = slice_[:]
|
||||
# except:
|
||||
# tensor = f.get_tensor(name)
|
||||
|
||||
# if (
|
||||
# current_parameter_tensor is not None
|
||||
# and current_parameter_tensor.shape != tensor.shape
|
||||
# ):
|
||||
# raise ValueError(
|
||||
# f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
|
||||
# )
|
||||
|
||||
# tensor = tensor.contiguous().to(dtype)
|
||||
|
||||
# if current_parameter_tensor is not None:
|
||||
# module._parameters[param_name] = tensor
|
||||
# else:
|
||||
# module._buffers[param_name] = tensor
|
||||
|
||||
# model.post_load_weights(quantize)
|
||||
|
@ -182,6 +182,7 @@ class GalacticaSharded(CausalLM):
|
||||
tp_parallel=True,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config.quantize = quantize
|
||||
tokenizer.pad_token_id = config.pad_token_id
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
@ -1,13 +1,10 @@
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from accelerate import init_empty_weights
|
||||
from safetensors import safe_open
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForCausalLM,
|
||||
AutoConfig,
|
||||
)
|
||||
from text_generation_server.models import CausalLM
|
||||
|
@ -35,7 +35,9 @@ class T5Sharded(Seq2SeqLM):
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
||||
config = AutoConfig.from_pretrained(model_id, revision=revision,
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config.quantize = quantize
|
||||
|
@ -10,8 +10,8 @@ from huggingface_hub import HfApi, hf_hub_download
|
||||
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
||||
from huggingface_hub.utils import (
|
||||
LocalEntryNotFoundError,
|
||||
EntryNotFoundError, # Import here to ease try/except in other part of the lib
|
||||
RevisionNotFoundError
|
||||
EntryNotFoundError,
|
||||
RevisionNotFoundError, # Import here to ease try/except in other part of the lib
|
||||
)
|
||||
|
||||
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
|
||||
|
@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
@ -44,14 +45,14 @@ class FastLinear(nn.Module):
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
@staticmethod
|
||||
def load(config, prefix: str, weights, bias: bool):
|
||||
@classmethod
|
||||
def load(cls, config, prefix: str, weights, bias: bool):
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
if bias:
|
||||
bias = weights.get_tensor(f"{prefix}.bias")
|
||||
else:
|
||||
bias = None
|
||||
return FastLinear(weight, bias)
|
||||
return cls(weight, bias)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return F.linear(input, self.weight, self.bias)
|
||||
@ -130,9 +131,7 @@ def get_linear(weight, bias, quantize):
|
||||
elif quantize == "gptq":
|
||||
raise NotImplementedError("Soon")
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Quantization `{config.quantize}` is not implemented yet."
|
||||
)
|
||||
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
|
||||
return linear
|
||||
|
||||
|
||||
@ -170,17 +169,17 @@ class TensorParallelHead(SuperLayer):
|
||||
|
||||
|
||||
class TensorParallelColumnLinear(SuperLayer):
|
||||
@staticmethod
|
||||
def load(config, prefix: str, weights, bias: bool):
|
||||
@classmethod
|
||||
def load(cls, config, prefix: str, weights, bias: bool):
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||
if bias:
|
||||
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
||||
else:
|
||||
bias = None
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
||||
return cls(get_linear(weight, bias, config.quantize))
|
||||
|
||||
@staticmethod
|
||||
def load_multi(config, prefixes: List[str], weights, bias: bool, dim: int):
|
||||
@classmethod
|
||||
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
|
||||
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||
weight = torch.cat(w, dim=dim)
|
||||
|
||||
@ -189,7 +188,7 @@ class TensorParallelColumnLinear(SuperLayer):
|
||||
bias = torch.cat(b, dim=0)
|
||||
else:
|
||||
bias = None
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
||||
return cls(get_linear(weight, bias, config.quantize))
|
||||
|
||||
|
||||
class TensorParallelRowLinear(SuperLayer):
|
||||
@ -197,15 +196,15 @@ class TensorParallelRowLinear(SuperLayer):
|
||||
super().__init__(linear)
|
||||
self.process_group = process_group
|
||||
|
||||
@staticmethod
|
||||
def load(config, prefix: str, weights, bias: bool):
|
||||
@classmethod
|
||||
def load(cls, config, prefix: str, weights, bias: bool):
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||
if bias and weights.process_group.rank() == 0:
|
||||
# Rank is only on the first rank process
|
||||
bias = weights.get_tensor(f"{prefix}.bias")
|
||||
else:
|
||||
bias = None
|
||||
return TensorParallelRowLinear(
|
||||
return cls(
|
||||
get_linear(weight, bias, config.quantize),
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
@ -308,22 +307,22 @@ try:
|
||||
self._cos_k_cached = None
|
||||
self._sin_k_cached = None
|
||||
|
||||
@staticmethod
|
||||
def static(dim, base, device):
|
||||
@classmethod
|
||||
def static(cls, dim, base, device):
|
||||
inv_freq = 1.0 / (
|
||||
base
|
||||
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
|
||||
)
|
||||
return PositionRotaryEmbedding(inv_freq)
|
||||
return cls(inv_freq)
|
||||
|
||||
@staticmethod
|
||||
def load(prefix, weights):
|
||||
@classmethod
|
||||
def load(cls, prefix, weights):
|
||||
# XXX: Always load this in float32 !
|
||||
dtype = weights.dtype
|
||||
weights.dtype = torch.float32
|
||||
inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
|
||||
weights.dtype = dtype
|
||||
return PositionRotaryEmbedding(inv_freq)
|
||||
return cls(inv_freq)
|
||||
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
# Reset the tables if the sequence length has changed,
|
||||
|
Loading…
Reference in New Issue
Block a user