diff --git a/docs/source/supported_models.md b/docs/source/supported_models.md index 5ac90351..d6a8b7ae 100644 --- a/docs/source/supported_models.md +++ b/docs/source/supported_models.md @@ -4,6 +4,7 @@ Text Generation Inference enables serving optimized models. The following sections list which models (VLMs & LLMs) are supported. - [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2) +- [Deepseek V3](https://huggingface.co/deepseek-ai/DeepSeek-V3) - [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal) - [Idefics 3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3) (Multimodal) - [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal) diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index 9bfdae5e..ca3e8392 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -5,7 +5,7 @@ from typing import Optional, Tuple, Type, Union, List import torch from loguru import logger -from moe_kernels import w8a8_block_fp8_matmul, per_token_group_quant_fp8 +from moe_kernels.fp8_utils import w8a8_block_fp8_matmul, per_token_group_quant_fp8 from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.weights import ( Weight, @@ -287,7 +287,9 @@ class HybridFP8UnquantLoader(WeightsLoader): weights.get_sharded(f"{p}.weight_scale_inv", dim=0, to_device=False) for p in prefixes ] - scale = torch.cat(scale, dim=dim).to(weights.device) + scale = torch.cat(scale, dim=dim) + if scale.device == torch.device("cpu"): + scale = scale.to(weights.device) return Fp8Weight( weight=w, weight_scale=scale, @@ -489,7 +491,6 @@ class Fp8Linear(torch.nn.Module): def forward(self, input: torch.Tensor) -> torch.Tensor: if self.weight_block_size is not None: qinput, scale = per_token_group_quant_fp8(input, self.weight_block_size[1]) - # logger.info(f"qinput: {qinput.shape} {scale.shape} {self.qweight.shape} {self.scale.shape} {self.weight_block_size}") output = w8a8_block_fp8_matmul( qinput, self.qweight,