mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
Black + ruff + T5 w0 quant.
This commit is contained in:
parent
15bf3d4944
commit
2362a80a4f
@ -1,15 +1,14 @@
|
|||||||
from setuptools import setup
|
from setuptools import setup
|
||||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name='custom_kernels',
|
name="custom_kernels",
|
||||||
ext_modules=[
|
ext_modules=[
|
||||||
CUDAExtension(
|
CUDAExtension(
|
||||||
name="custom_kernels.fused_bloom_attention_cuda",
|
name="custom_kernels.fused_bloom_attention_cuda",
|
||||||
sources=['custom_kernels/fused_bloom_attention_cuda.cu'],
|
sources=["custom_kernels/fused_bloom_attention_cuda.cu"],
|
||||||
extra_compile_args=["-arch=compute_80", "-std=c++17"],
|
extra_compile_args=["-arch=compute_80", "-std=c++17"],
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
cmdclass={
|
cmdclass={"build_ext": BuildExtension},
|
||||||
'build_ext': BuildExtension
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
@ -37,7 +37,6 @@ from text_generation_server.utils.layers import (
|
|||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelHead,
|
TensorParallelHead,
|
||||||
FastLinear
|
|
||||||
)
|
)
|
||||||
|
|
||||||
CUSTOM_KERNELS_ENABLED = False
|
CUSTOM_KERNELS_ENABLED = False
|
||||||
|
@ -60,12 +60,18 @@ def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
|
|||||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||||
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
||||||
|
|
||||||
weight = weight.view(
|
weight = (
|
||||||
num_heads, 3, head_size, hidden_size,
|
weight.view(
|
||||||
).permute(1, 0, 2, 3).reshape(-1, hidden_size)
|
num_heads,
|
||||||
|
3,
|
||||||
|
head_size,
|
||||||
|
hidden_size,
|
||||||
|
)
|
||||||
|
.permute(1, 0, 2, 3)
|
||||||
|
.reshape(-1, hidden_size)
|
||||||
|
)
|
||||||
bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1)
|
bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1)
|
||||||
|
|
||||||
|
|
||||||
linear = get_linear(weight, bias, config.quantize)
|
linear = get_linear(weight, bias, config.quantize)
|
||||||
if config.use_parallel_residual:
|
if config.use_parallel_residual:
|
||||||
return linear
|
return linear
|
||||||
@ -88,12 +94,18 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
|
|
||||||
rotary_ndims = int(self.head_size * rotary_pct)
|
rotary_ndims = int(self.head_size * rotary_pct)
|
||||||
self.rotary_emb = PositionRotaryEmbedding(rotary_ndims, base=rotary_emb_base)
|
self.rotary_emb = PositionRotaryEmbedding(rotary_ndims, base=rotary_emb_base)
|
||||||
self.rotary_emb.inv_freq = nn.Parameter(weights.get_tensor(f"{prefix}.rotary_emb.inv_freq"))
|
self.rotary_emb.inv_freq = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.rotary_emb.inv_freq")
|
||||||
|
)
|
||||||
self.softmax_scale = self.head_size ** (-0.5)
|
self.softmax_scale = self.head_size ** (-0.5)
|
||||||
|
|
||||||
self.query_key_value = load_qkv(
|
self.query_key_value = load_qkv(
|
||||||
config, prefix=f"{prefix}.query_key_value", weights=weights,
|
config,
|
||||||
num_heads = self.num_heads, head_size = self.head_size, hidden_size = self.hidden_size
|
prefix=f"{prefix}.query_key_value",
|
||||||
|
weights=weights,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
head_size=self.head_size,
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
)
|
)
|
||||||
self.dense = load_row(
|
self.dense = load_row(
|
||||||
config, prefix=f"{prefix}.dense", weights=weights, bias=True
|
config, prefix=f"{prefix}.dense", weights=weights, bias=True
|
||||||
|
@ -3,7 +3,7 @@ import torch.distributed
|
|||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from typing import Optional, List
|
from typing import Optional
|
||||||
|
|
||||||
# Flash attention imports
|
# Flash attention imports
|
||||||
import flash_attn_cuda
|
import flash_attn_cuda
|
||||||
@ -17,8 +17,9 @@ from text_generation_server.utils.layers import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_multi_mqa(
|
||||||
def load_multi_mqa(config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size):
|
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
|
||||||
|
):
|
||||||
if any("c_attn" in k for k in weights.routing.keys()):
|
if any("c_attn" in k for k in weights.routing.keys()):
|
||||||
slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
|
slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
|
||||||
shape = slice_.get_shape()
|
shape = slice_.get_shape()
|
||||||
@ -55,30 +56,35 @@ def load_multi_mqa(config, prefix: str, weights, bias: bool, head_size, num_head
|
|||||||
if config.transpose:
|
if config.transpose:
|
||||||
w = [
|
w = [
|
||||||
weights.get_sharded(f"{prefix}.q_attn.weight", dim=1).T,
|
weights.get_sharded(f"{prefix}.q_attn.weight", dim=1).T,
|
||||||
weights.get_tensor(f"{prefix}.kv_attn.weight").T
|
weights.get_tensor(f"{prefix}.kv_attn.weight").T,
|
||||||
]
|
]
|
||||||
weight = torch.cat(w, dim=0)
|
weight = torch.cat(w, dim=0)
|
||||||
else:
|
else:
|
||||||
w = [
|
w = [
|
||||||
weights.get_sharded(f"{prefix}.q_attn.weight", dim=0),
|
weights.get_sharded(f"{prefix}.q_attn.weight", dim=0),
|
||||||
weights.get_tensor(f"{prefix}.kv_attn.weight")
|
weights.get_tensor(f"{prefix}.kv_attn.weight"),
|
||||||
]
|
]
|
||||||
weight = torch.cat(w, dim=1)
|
weight = torch.cat(w, dim=1)
|
||||||
|
|
||||||
if bias:
|
if bias:
|
||||||
b = [
|
b = [
|
||||||
weights.get_sharded(f"{prefix}.q_attn.bias", dim=0),
|
weights.get_sharded(f"{prefix}.q_attn.bias", dim=0),
|
||||||
weights.get_tensor(f"{prefix}.kv_attn.bias")
|
weights.get_tensor(f"{prefix}.kv_attn.bias"),
|
||||||
]
|
]
|
||||||
bias = torch.cat(b, dim=0)
|
bias = torch.cat(b, dim=0)
|
||||||
else:
|
else:
|
||||||
bias = None
|
bias = None
|
||||||
|
|
||||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||||
assert list(weight.shape) == [(num_heads + 2) * head_size, hidden_size], f"{weight.shape} != {[(num_heads + 2) * head_size, hidden_size]}"
|
assert list(weight.shape) == [
|
||||||
|
(num_heads + 2) * head_size,
|
||||||
|
hidden_size,
|
||||||
|
], f"{weight.shape} != {[(num_heads + 2) * head_size, hidden_size]}"
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
bias = bias.to(dtype=weights.dtype).to(device=weights.device)
|
bias = bias.to(dtype=weights.dtype).to(device=weights.device)
|
||||||
assert list(bias.shape) == [(num_heads + 2) * head_size], f"{weight.shape} != {[(num_heads + 2) * head_size]}"
|
assert list(bias.shape) == [
|
||||||
|
(num_heads + 2) * head_size
|
||||||
|
], f"{weight.shape} != {[(num_heads + 2) * head_size]}"
|
||||||
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
||||||
|
|
||||||
|
|
||||||
@ -106,7 +112,9 @@ def load_row(config, prefix: str, weights, bias: bool):
|
|||||||
bias = weights.get_tensor(f"{prefix}.bias")
|
bias = weights.get_tensor(f"{prefix}.bias")
|
||||||
else:
|
else:
|
||||||
bias = None
|
bias = None
|
||||||
return TensorParallelRowLinear(get_linear(weight, bias, config.quantize), process_group=weights.process_group)
|
return TensorParallelRowLinear(
|
||||||
|
get_linear(weight, bias, config.quantize), process_group=weights.process_group
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FlashMQAttention(torch.nn.Module):
|
class FlashMQAttention(torch.nn.Module):
|
||||||
@ -131,7 +139,7 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
bias=True,
|
bias=True,
|
||||||
head_size=self.head_size,
|
head_size=self.head_size,
|
||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
num_heads=self.num_heads
|
num_heads=self.num_heads,
|
||||||
)
|
)
|
||||||
self.c_proj = load_row(
|
self.c_proj = load_row(
|
||||||
config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
|
config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
|
||||||
|
@ -109,9 +109,21 @@ class T5DenseActDense(nn.Module):
|
|||||||
self.wi = TensorParallelColumnLinear.load(
|
self.wi = TensorParallelColumnLinear.load(
|
||||||
config, prefix=f"{prefix}.wi", weights=weights, bias=False
|
config, prefix=f"{prefix}.wi", weights=weights, bias=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
### XXX: T5 models do not handle well both f16 and quantization.
|
||||||
|
### Overidding specifically this layer for that reason.
|
||||||
|
### https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L316
|
||||||
|
### https://github.com/huggingface/transformers/issues/20287
|
||||||
|
_q = config.quantize
|
||||||
|
_dtype = weights.dtype
|
||||||
|
weights.dtype = torch.float32
|
||||||
|
config.quantize = None
|
||||||
|
self.wo_cast = (torch.float32, _dtype)
|
||||||
self.wo = TensorParallelRowLinear.load(
|
self.wo = TensorParallelRowLinear.load(
|
||||||
config, prefix=f"{prefix}.wo", weights=weights, bias=False
|
config, prefix=f"{prefix}.wo", weights=weights, bias=False
|
||||||
)
|
)
|
||||||
|
weights.dtype = _dtype
|
||||||
|
config.quantize = _q
|
||||||
|
|
||||||
self.dropout = nn.Dropout(config.dropout_rate)
|
self.dropout = nn.Dropout(config.dropout_rate)
|
||||||
self.act = (
|
self.act = (
|
||||||
@ -124,7 +136,10 @@ class T5DenseActDense(nn.Module):
|
|||||||
hidden_states = self.wi(hidden_states)
|
hidden_states = self.wi(hidden_states)
|
||||||
hidden_states = self.act(hidden_states)
|
hidden_states = self.act(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.to(dtype=self.wo_cast[0])
|
||||||
hidden_states = self.wo(hidden_states)
|
hidden_states = self.wo(hidden_states)
|
||||||
|
hidden_states = hidden_states.to(dtype=self.wo_cast[1])
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@ -137,9 +152,20 @@ class T5DenseGatedActDense(nn.Module):
|
|||||||
self.wi_1 = TensorParallelColumnLinear.load(
|
self.wi_1 = TensorParallelColumnLinear.load(
|
||||||
config, prefix=f"{prefix}.wi_1", weights=weights, bias=False
|
config, prefix=f"{prefix}.wi_1", weights=weights, bias=False
|
||||||
)
|
)
|
||||||
|
### XXX: T5 models do not handle well both f16 and quantization.
|
||||||
|
### Overidding specifically this layer for that reason.
|
||||||
|
### https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L316
|
||||||
|
### https://github.com/huggingface/transformers/issues/20287
|
||||||
|
_q = config.quantize
|
||||||
|
_dtype = weights.dtype
|
||||||
|
weights.dtype = torch.float32
|
||||||
|
config.quantize = None
|
||||||
|
self.wo_cast = (torch.float32, _dtype)
|
||||||
self.wo = TensorParallelRowLinear.load(
|
self.wo = TensorParallelRowLinear.load(
|
||||||
config, prefix=f"{prefix}.wo", weights=weights, bias=False
|
config, prefix=f"{prefix}.wo", weights=weights, bias=False
|
||||||
)
|
)
|
||||||
|
weights.dtype = _dtype
|
||||||
|
config.quantize = _q
|
||||||
|
|
||||||
self.dropout = nn.Dropout(config.dropout_rate)
|
self.dropout = nn.Dropout(config.dropout_rate)
|
||||||
self.act = (
|
self.act = (
|
||||||
@ -154,18 +180,9 @@ class T5DenseGatedActDense(nn.Module):
|
|||||||
hidden_states = hidden_gelu * hidden_linear
|
hidden_states = hidden_gelu * hidden_linear
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
|
|
||||||
# TODO Support this again mayber
|
hidden_states = hidden_states.to(dtype=self.wo_cast[0])
|
||||||
# To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
|
|
||||||
# See https://github.com/huggingface/transformers/issues/20287
|
|
||||||
# we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
|
|
||||||
# if (
|
|
||||||
# isinstance(self.wo.weight, torch.Tensor)
|
|
||||||
# and hidden_states.dtype != self.wo.weight.dtype
|
|
||||||
# and self.wo.weight.dtype != torch.int8
|
|
||||||
# ):
|
|
||||||
# hidden_states = hidden_states.to(self.wo.weight.dtype)
|
|
||||||
|
|
||||||
hidden_states = self.wo(hidden_states)
|
hidden_states = self.wo(hidden_states)
|
||||||
|
hidden_states = hidden_states.to(dtype=self.wo_cast[1])
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ HAS_BITS_AND_BYTES = True
|
|||||||
try:
|
try:
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
from bitsandbytes.nn import Int8Params
|
from bitsandbytes.nn import Int8Params
|
||||||
except Exception as e:
|
except Exception:
|
||||||
HAS_BITS_AND_BYTES = False
|
HAS_BITS_AND_BYTES = False
|
||||||
|
|
||||||
|
|
||||||
|
@ -40,8 +40,9 @@ class OPTSharded(CausalLM):
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
)
|
config = AutoConfig.from_pretrained(
|
||||||
config = AutoConfig.from_pretrained(model_id, revision=revision,
|
model_id,
|
||||||
|
revision=revision,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
@ -3,6 +3,7 @@ import torch
|
|||||||
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
|
|
||||||
class FakeBarrier:
|
class FakeBarrier:
|
||||||
def wait(self):
|
def wait(self):
|
||||||
pass
|
pass
|
||||||
@ -17,7 +18,9 @@ class FakeGroup:
|
|||||||
return FakeBarrier()
|
return FakeBarrier()
|
||||||
|
|
||||||
def allgather(self, inputs, local_tensor, **kwargs):
|
def allgather(self, inputs, local_tensor, **kwargs):
|
||||||
assert len(inputs[0]) == len(local_tensor) == 1, f"{len(inputs[0])} != {len(local_tensor)} != 1, and the FakeGroup is supposed to join on simple tensors"
|
assert (
|
||||||
|
len(inputs[0]) == len(local_tensor) == 1
|
||||||
|
), f"{len(inputs[0])} != {len(local_tensor)} != 1, and the FakeGroup is supposed to join on simple tensors"
|
||||||
for input_ in inputs:
|
for input_ in inputs:
|
||||||
input_[0].data = local_tensor[0].data
|
input_[0].data = local_tensor[0].data
|
||||||
return FakeBarrier()
|
return FakeBarrier()
|
||||||
|
@ -10,8 +10,7 @@ from huggingface_hub import HfApi, hf_hub_download
|
|||||||
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
||||||
from huggingface_hub.utils import (
|
from huggingface_hub.utils import (
|
||||||
LocalEntryNotFoundError,
|
LocalEntryNotFoundError,
|
||||||
EntryNotFoundError,
|
EntryNotFoundError, # Import here to ease try/except in other part of the lib
|
||||||
RevisionNotFoundError, # Import here to ease try/except in other part of the lib
|
|
||||||
)
|
)
|
||||||
|
|
||||||
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
|
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
|
||||||
|
@ -2,14 +2,14 @@ import torch
|
|||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from typing import Optional, List
|
from typing import List
|
||||||
|
|
||||||
HAS_BITS_AND_BYTES = True
|
HAS_BITS_AND_BYTES = True
|
||||||
try:
|
try:
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
from bitsandbytes.nn import Int8Params
|
from bitsandbytes.nn import Int8Params
|
||||||
|
|
||||||
except ImportError as e:
|
except ImportError:
|
||||||
HAS_BITS_AND_BYTES = False
|
HAS_BITS_AND_BYTES = False
|
||||||
|
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
@ -27,13 +27,15 @@ def load_layer_norm(cls, prefix, weights, eps):
|
|||||||
ln.bias = nn.Parameter(bias)
|
ln.bias = nn.Parameter(bias)
|
||||||
return ln
|
return ln
|
||||||
|
|
||||||
|
|
||||||
torch.nn.LayerNorm.load = load_layer_norm
|
torch.nn.LayerNorm.load = load_layer_norm
|
||||||
|
|
||||||
|
|
||||||
class FastLinear(nn.Module):
|
class FastLinear(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
weight, bias,
|
weight,
|
||||||
|
bias,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight = nn.Parameter(weight)
|
self.weight = nn.Parameter(weight)
|
||||||
@ -56,10 +58,19 @@ class FastLinear(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Linear8bitLt(nn.Module):
|
class Linear8bitLt(nn.Module):
|
||||||
def __init__(self, weight, bias, has_fp16_weights=True,
|
def __init__(
|
||||||
memory_efficient_backward=False, threshold=0.0, index=None):
|
self,
|
||||||
|
weight,
|
||||||
|
bias,
|
||||||
|
has_fp16_weights=True,
|
||||||
|
memory_efficient_backward=False,
|
||||||
|
threshold=0.0,
|
||||||
|
index=None,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
|
assert (
|
||||||
|
not memory_efficient_backward
|
||||||
|
), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
|
||||||
self.state = bnb.MatmulLtState()
|
self.state = bnb.MatmulLtState()
|
||||||
self.index = index
|
self.index = index
|
||||||
|
|
||||||
@ -70,7 +81,11 @@ class Linear8bitLt(nn.Module):
|
|||||||
if threshold > 0.0 and not has_fp16_weights:
|
if threshold > 0.0 and not has_fp16_weights:
|
||||||
self.state.use_pool = True
|
self.state.use_pool = True
|
||||||
|
|
||||||
self.weight = Int8Params(weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights)
|
self.weight = Int8Params(
|
||||||
|
weight.data,
|
||||||
|
has_fp16_weights=has_fp16_weights,
|
||||||
|
requires_grad=has_fp16_weights,
|
||||||
|
)
|
||||||
self.weight.cuda(weight.device)
|
self.weight.cuda(weight.device)
|
||||||
self.bias = bias
|
self.bias = bias
|
||||||
|
|
||||||
@ -105,7 +120,8 @@ def get_linear(weight, bias, quantize):
|
|||||||
linear = FastLinear(weight, bias)
|
linear = FastLinear(weight, bias)
|
||||||
elif quantize == "bitsandbytes":
|
elif quantize == "bitsandbytes":
|
||||||
linear = Linear8bitLt(
|
linear = Linear8bitLt(
|
||||||
weight, bias,
|
weight,
|
||||||
|
bias,
|
||||||
has_fp16_weights=False,
|
has_fp16_weights=False,
|
||||||
threshold=6.0,
|
threshold=6.0,
|
||||||
)
|
)
|
||||||
@ -114,7 +130,9 @@ def get_linear(weight, bias, quantize):
|
|||||||
elif quantize == "gptq":
|
elif quantize == "gptq":
|
||||||
raise NotImplementedError("Soon")
|
raise NotImplementedError("Soon")
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Quantization `{config.quantize}` is not implemented yet.")
|
raise NotImplementedError(
|
||||||
|
f"Quantization `{config.quantize}` is not implemented yet."
|
||||||
|
)
|
||||||
return linear
|
return linear
|
||||||
|
|
||||||
|
|
||||||
@ -126,6 +144,7 @@ class SuperLayer(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.linear.forward(x)
|
return self.linear.forward(x)
|
||||||
|
|
||||||
|
|
||||||
class TensorParallelHead(SuperLayer):
|
class TensorParallelHead(SuperLayer):
|
||||||
def __init__(self, linear, process_group):
|
def __init__(self, linear, process_group):
|
||||||
super().__init__(linear)
|
super().__init__(linear)
|
||||||
@ -134,12 +153,17 @@ class TensorParallelHead(SuperLayer):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def load(config, prefix: str, weights):
|
def load(config, prefix: str, weights):
|
||||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||||
return TensorParallelHead(get_linear(weight, bias=None, quantize=config.quantize), process_group = weights.process_group)
|
return TensorParallelHead(
|
||||||
|
get_linear(weight, bias=None, quantize=config.quantize),
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
output = super().forward(input)
|
output = super().forward(input)
|
||||||
# Logits are sharded, so we need to gather them
|
# Logits are sharded, so we need to gather them
|
||||||
world_output = [torch.empty_like(output) for _ in range(self.process_group.size())]
|
world_output = [
|
||||||
|
torch.empty_like(output) for _ in range(self.process_group.size())
|
||||||
|
]
|
||||||
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
||||||
world_output = torch.cat(world_output, dim=-1)
|
world_output = torch.cat(world_output, dim=-1)
|
||||||
return world_output
|
return world_output
|
||||||
@ -181,13 +205,17 @@ class TensorParallelRowLinear(SuperLayer):
|
|||||||
bias = weights.get_tensor(f"{prefix}.bias")
|
bias = weights.get_tensor(f"{prefix}.bias")
|
||||||
else:
|
else:
|
||||||
bias = None
|
bias = None
|
||||||
return TensorParallelRowLinear(get_linear(weight, bias, config.quantize), process_group=weights.process_group)
|
return TensorParallelRowLinear(
|
||||||
|
get_linear(weight, bias, config.quantize),
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
out = super().forward(input)
|
out = super().forward(input)
|
||||||
torch.distributed.all_reduce(out, group=self.process_group)
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class TensorParallelEmbedding(nn.Module):
|
class TensorParallelEmbedding(nn.Module):
|
||||||
def __init__(self, prefix: str, weights, reduce=True):
|
def __init__(self, prefix: str, weights, reduce=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -222,6 +250,7 @@ class TensorParallelEmbedding(nn.Module):
|
|||||||
torch.distributed.all_reduce(out, group=self.process_group)
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import dropout_layer_norm
|
import dropout_layer_norm
|
||||||
|
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, List
|
from typing import List
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
|
|
||||||
|
|
||||||
class Weights:
|
class Weights:
|
||||||
def __init__(self, filenames: List[Path], device, dtype, process_group):
|
def __init__(self, filenames: List[Path], device, dtype, process_group):
|
||||||
routing = {}
|
routing = {}
|
||||||
@ -26,8 +27,6 @@ class Weights:
|
|||||||
|
|
||||||
return self._handles[filename]
|
return self._handles[filename]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_filename(self, tensor_name: str) -> str:
|
def get_filename(self, tensor_name: str) -> str:
|
||||||
filename = self.routing.get(tensor_name, None)
|
filename = self.routing.get(tensor_name, None)
|
||||||
if filename is None:
|
if filename is None:
|
||||||
@ -63,7 +62,9 @@ class Weights:
|
|||||||
start = rank * block_size
|
start = rank * block_size
|
||||||
stop = (rank + 1) * block_size
|
stop = (rank + 1) * block_size
|
||||||
|
|
||||||
assert size % world_size == 0, f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
assert (
|
||||||
|
size % world_size == 0
|
||||||
|
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
||||||
|
|
||||||
if dim == 0:
|
if dim == 0:
|
||||||
tensor = slice_[start:stop]
|
tensor = slice_[start:stop]
|
||||||
@ -74,5 +75,3 @@ class Weights:
|
|||||||
tensor = tensor.to(dtype=self.dtype)
|
tensor = tensor.to(dtype=self.dtype)
|
||||||
tensor = tensor.to(device=self.device)
|
tensor = tensor.to(device=self.device)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user