mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Update vLLM dependency to 0.5.3.post1
This commit is contained in:
parent
5d85a958c9
commit
b890c8c47d
@ -1,9 +1,9 @@
|
|||||||
commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b
|
commit_cuda := 1ce677148e4d0252c7f668b87905bcec210207a6
|
||||||
commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921
|
commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921
|
||||||
build-vllm-cuda:
|
build-vllm-cuda:
|
||||||
if [ ! -d 'vllm' ]; then \
|
if [ ! -d 'vllm' ]; then \
|
||||||
pip install -U ninja packaging --no-cache-dir && \
|
pip install -U ninja packaging --no-cache-dir && \
|
||||||
git clone https://github.com/Narsil/vllm.git vllm; \
|
git clone https://github.com/danieldk/vllm.git vllm; \
|
||||||
fi
|
fi
|
||||||
cd vllm && git fetch origin && git checkout $(commit_cuda) && python setup.py build
|
cd vllm && git fetch origin && git checkout $(commit_cuda) && python setup.py build
|
||||||
|
|
||||||
|
@ -9,8 +9,7 @@ is_sm75 = major == 7 and minor == 5
|
|||||||
_PARTITION_SIZE = 512
|
_PARTITION_SIZE = 512
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from vllm._C import cache_ops
|
import vllm._custom_ops as ops
|
||||||
from vllm._C import ops
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
||||||
@ -29,8 +28,8 @@ def reshape_and_cache(
|
|||||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
||||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
||||||
else:
|
else:
|
||||||
cache_ops.reshape_and_cache(
|
ops.reshape_and_cache(
|
||||||
key, value, key_cache, value_cache, slots, "auto", 1.0
|
key, value, key_cache, value_cache, slots, "auto", 1.0, 1.0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -114,7 +113,7 @@ def paged_attention(
|
|||||||
if softcap is not None:
|
if softcap is not None:
|
||||||
raise RuntimeError("Paged attention doesn't support softcapping")
|
raise RuntimeError("Paged attention doesn't support softcapping")
|
||||||
input_lengths = seqlen.input_lengths
|
input_lengths = seqlen.input_lengths
|
||||||
from vllm._C import ops
|
import vllm._custom_ops as ops
|
||||||
|
|
||||||
use_v1 = max_s <= 8192 and (
|
use_v1 = max_s <= 8192 and (
|
||||||
max_num_partitions == 1 or num_seqs * num_heads > 512
|
max_num_partitions == 1 or num_seqs * num_heads > 512
|
||||||
@ -134,6 +133,7 @@ def paged_attention(
|
|||||||
None,
|
None,
|
||||||
"auto",
|
"auto",
|
||||||
1.0,
|
1.0,
|
||||||
|
1.0,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Run PagedAttention V2.
|
# Run PagedAttention V2.
|
||||||
@ -167,6 +167,7 @@ def paged_attention(
|
|||||||
None,
|
None,
|
||||||
"auto",
|
"auto",
|
||||||
1.0,
|
1.0,
|
||||||
|
1.0,
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
@ -38,6 +38,7 @@ from text_generation_server.utils.weights import Weights
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
from vllm.model_executor.layers.fused_moe import fused_experts, grouped_topk
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV2Config(PretrainedConfig):
|
class DeepseekV2Config(PretrainedConfig):
|
||||||
@ -798,183 +799,3 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module):
|
|||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits, speculative_logits = self.lm_head(hidden_states)
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
|
||||||
|
|
||||||
# Functions below are from vLLM:
|
|
||||||
#
|
|
||||||
# https://github.com/vllm-project/vllm/blob/f7160d946a0a07703e72d81ba9ecf3913f192605/vllm/model_executor/layers/fused_moe/fused_moe.py#L397
|
|
||||||
#
|
|
||||||
# Remove after we have synced our version with upstream.
|
|
||||||
|
|
||||||
|
|
||||||
def grouped_topk(
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
gating_output: torch.Tensor,
|
|
||||||
topk: int,
|
|
||||||
renormalize: bool,
|
|
||||||
num_expert_group: int = 0,
|
|
||||||
topk_group: int = 0,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
scores = torch.softmax(gating_output, dim=-1)
|
|
||||||
num_token = scores.shape[0]
|
|
||||||
group_scores = (
|
|
||||||
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
|
||||||
) # [n, n_group]
|
|
||||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
|
||||||
1
|
|
||||||
] # [n, top_k_group]
|
|
||||||
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
|
||||||
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
|
||||||
score_mask = (
|
|
||||||
group_mask.unsqueeze(-1)
|
|
||||||
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
|
||||||
.reshape(num_token, -1)
|
|
||||||
) # [n, e]
|
|
||||||
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
|
||||||
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
|
||||||
|
|
||||||
if renormalize:
|
|
||||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
return topk_weights, topk_ids
|
|
||||||
|
|
||||||
|
|
||||||
def get_default_config(
|
|
||||||
M: int,
|
|
||||||
E: int,
|
|
||||||
N: int,
|
|
||||||
K: int,
|
|
||||||
topk: int,
|
|
||||||
dtype: Optional[str],
|
|
||||||
) -> Dict[str, int]:
|
|
||||||
config = {
|
|
||||||
"BLOCK_SIZE_M": 64,
|
|
||||||
"BLOCK_SIZE_N": 64,
|
|
||||||
"BLOCK_SIZE_K": 32,
|
|
||||||
"GROUP_SIZE_M": 8,
|
|
||||||
}
|
|
||||||
if M <= E:
|
|
||||||
config = {
|
|
||||||
"BLOCK_SIZE_M": 16,
|
|
||||||
"BLOCK_SIZE_N": 32,
|
|
||||||
"BLOCK_SIZE_K": 64,
|
|
||||||
"GROUP_SIZE_M": 1,
|
|
||||||
}
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def fused_experts(
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
w1: torch.Tensor,
|
|
||||||
w2: torch.Tensor,
|
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
inplace: bool = False,
|
|
||||||
override_config: Optional[Dict[str, Any]] = None,
|
|
||||||
use_fp8: bool = False,
|
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
|
||||||
):
|
|
||||||
# Check constraints.
|
|
||||||
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
|
||||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
|
||||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
|
||||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
|
||||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
|
||||||
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
|
|
||||||
|
|
||||||
import triton.language as tl
|
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
|
||||||
get_moe_configs,
|
|
||||||
invoke_fused_moe_kernel,
|
|
||||||
moe_align_block_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
M, _ = hidden_states.shape
|
|
||||||
E, N, _ = w1.shape
|
|
||||||
|
|
||||||
if override_config:
|
|
||||||
config = override_config
|
|
||||||
else:
|
|
||||||
# First try to load optimal config from the file
|
|
||||||
configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None)
|
|
||||||
|
|
||||||
if configs:
|
|
||||||
# If an optimal configuration map has been found, look up the
|
|
||||||
# optimal config
|
|
||||||
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
|
||||||
else:
|
|
||||||
# Else use the default config
|
|
||||||
config = get_default_config(
|
|
||||||
M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None
|
|
||||||
)
|
|
||||||
|
|
||||||
intermediate_cache1 = torch.empty(
|
|
||||||
(M, topk_ids.shape[1], N),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
|
||||||
intermediate_cache2 = torch.empty(
|
|
||||||
(M * topk_ids.shape[1], N // 2),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
|
||||||
intermediate_cache3 = torch.empty(
|
|
||||||
(M, topk_ids.shape[1], w2.shape[1]),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
|
||||||
topk_ids, config["BLOCK_SIZE_M"], E
|
|
||||||
)
|
|
||||||
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
|
|
||||||
|
|
||||||
invoke_fused_moe_kernel(
|
|
||||||
hidden_states,
|
|
||||||
w1,
|
|
||||||
intermediate_cache1,
|
|
||||||
a1_scale,
|
|
||||||
w1_scale,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
sorted_token_ids,
|
|
||||||
expert_ids,
|
|
||||||
num_tokens_post_padded,
|
|
||||||
False,
|
|
||||||
topk_ids.shape[1],
|
|
||||||
config,
|
|
||||||
compute_type=compute_type,
|
|
||||||
use_fp8=use_fp8,
|
|
||||||
)
|
|
||||||
|
|
||||||
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
|
||||||
|
|
||||||
invoke_fused_moe_kernel(
|
|
||||||
intermediate_cache2,
|
|
||||||
w2,
|
|
||||||
intermediate_cache3,
|
|
||||||
a2_scale,
|
|
||||||
w2_scale,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
sorted_token_ids,
|
|
||||||
expert_ids,
|
|
||||||
num_tokens_post_padded,
|
|
||||||
True,
|
|
||||||
1,
|
|
||||||
config,
|
|
||||||
compute_type=compute_type,
|
|
||||||
use_fp8=use_fp8,
|
|
||||||
)
|
|
||||||
|
|
||||||
if inplace:
|
|
||||||
return torch.sum(
|
|
||||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
|
||||||
dim=1,
|
|
||||||
out=hidden_states,
|
|
||||||
)
|
|
||||||
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user