mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 12:54:52 +00:00
Exclude non-MLP layers when using FP8 quantization with Llama
This commit is contained in:
parent
a93b2b5083
commit
cf16172a85
@ -100,6 +100,8 @@ async def test_flash_llama_completion_many_prompts_stream(
|
|||||||
chunk = [c.replace("data:", "") for c in chunk]
|
chunk = [c.replace("data:", "") for c in chunk]
|
||||||
# remove empty strings
|
# remove empty strings
|
||||||
chunk = [c for c in chunk if c]
|
chunk = [c for c in chunk if c]
|
||||||
|
# remove completion marking chunk
|
||||||
|
chunk = [c for c in chunk if c != " [DONE]"]
|
||||||
# parse json
|
# parse json
|
||||||
chunk = [json.loads(c) for c in chunk]
|
chunk = [json.loads(c) for c in chunk]
|
||||||
|
|
||||||
|
@ -18,6 +18,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -25,7 +26,6 @@ import torch.distributed
|
|||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from typing import Optional, List, Tuple
|
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
@ -42,10 +42,16 @@ from text_generation_server.layers import (
|
|||||||
TensorParallelMultiAdapterLinear,
|
TensorParallelMultiAdapterLinear,
|
||||||
TensorParallelAdapterRowLinear,
|
TensorParallelAdapterRowLinear,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.fp8 import Fp8Weight
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.weights import (
|
||||||
|
DefaultWeightsLoader,
|
||||||
|
UnquantizedWeight,
|
||||||
|
Weights,
|
||||||
|
)
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "rocm":
|
||||||
try:
|
try:
|
||||||
@ -105,6 +111,19 @@ def load_attention(config, prefix: str, weights, layer_id):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def no_fp8(weights: Weights):
|
||||||
|
weights_loader = weights.weights_loader
|
||||||
|
if (
|
||||||
|
isinstance(weights_loader, DefaultWeightsLoader)
|
||||||
|
and weights_loader.weight_class is Fp8Weight
|
||||||
|
):
|
||||||
|
weights_loader = DefaultWeightsLoader(UnquantizedWeight)
|
||||||
|
|
||||||
|
with weights.use_loader(weights_loader):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
class FlashLlamaAttention(torch.nn.Module):
|
class FlashLlamaAttention(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -330,12 +349,15 @@ class LlamaMLP(nn.Module):
|
|||||||
class FlashLlamaLayer(nn.Module):
|
class FlashLlamaLayer(nn.Module):
|
||||||
def __init__(self, index, prefix, config, weights):
|
def __init__(self, index, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.self_attn = FlashLlamaAttention(
|
|
||||||
index=index,
|
with no_fp8(weights):
|
||||||
prefix=f"{prefix}.self_attn",
|
self.self_attn = FlashLlamaAttention(
|
||||||
config=config,
|
index=index,
|
||||||
weights=weights,
|
prefix=f"{prefix}.self_attn",
|
||||||
)
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
self.mlp = LlamaMLP(
|
self.mlp = LlamaMLP(
|
||||||
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
|
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
|
||||||
)
|
)
|
||||||
@ -470,23 +492,27 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
def __init__(self, prefix: str, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
with no_fp8(weights):
|
||||||
prefix=(
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
|
prefix=(
|
||||||
),
|
"model.embed_tokens"
|
||||||
weights=weights,
|
if not prefix
|
||||||
)
|
else f"{prefix}.model.embed_tokens"
|
||||||
|
),
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
self.model = FlashLlamaModel(prefix, config, weights)
|
self.model = FlashLlamaModel(prefix, config, weights)
|
||||||
if config.tie_word_embeddings:
|
if config.tie_word_embeddings:
|
||||||
suffix = "model.embed_tokens"
|
suffix = "model.embed_tokens"
|
||||||
else:
|
else:
|
||||||
suffix = "lm_head"
|
suffix = "lm_head"
|
||||||
|
|
||||||
self.lm_head = SpeculativeHead.load(
|
with no_fp8(weights):
|
||||||
config,
|
self.lm_head = SpeculativeHead.load(
|
||||||
prefix=suffix if not prefix else f"{prefix}.{suffix}",
|
config,
|
||||||
weights=weights,
|
prefix=suffix if not prefix else f"{prefix}.{suffix}",
|
||||||
)
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user