mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-08 19:04:52 +00:00
[gaudi] Deepseek v2 mla and add ep to unquantized moe (#3287)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
778b61c0da
commit
ebb26f0ccd
@ -118,9 +118,9 @@ ENTRYPOINT ["./entrypoint.sh"]
|
|||||||
# Final image
|
# Final image
|
||||||
FROM base
|
FROM base
|
||||||
|
|
||||||
ENV HF_HUB_ENABLE_HF_TRANSFER 1
|
ENV HF_HUB_ENABLE_HF_TRANSFER=1
|
||||||
ENV HABANA_VISIBLE_DEVICES all
|
ENV HABANA_VISIBLE_DEVICES=all
|
||||||
ENV OMPI_MCA_btl_vader_single_copy_mechanism NONE
|
ENV OMPI_MCA_btl_vader_single_copy_mechanism=NONE
|
||||||
|
|
||||||
COPY backends/gaudi/tgi-entrypoint.sh /tgi-entrypoint.sh
|
COPY backends/gaudi/tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||||
RUN chmod +x /tgi-entrypoint.sh
|
RUN chmod +x /tgi-entrypoint.sh
|
||||||
|
@ -51,10 +51,12 @@ class FP8SparseMoELayer(nn.Module):
|
|||||||
self.rank = weights.process_group.rank()
|
self.rank = weights.process_group.rank()
|
||||||
self.ep_rank = self.rank
|
self.ep_rank = self.rank
|
||||||
self.use_ep = os.getenv("USE_EXPERT_PARALLEL", "true").lower() == "true"
|
self.use_ep = os.getenv("USE_EXPERT_PARALLEL", "true").lower() == "true"
|
||||||
|
if (n_experts + self.world_size - 1) // self.world_size < 4:
|
||||||
|
self.use_ep = False
|
||||||
if self.use_ep:
|
if self.use_ep:
|
||||||
n_experts = (n_experts + self.world_size - 1) // self.world_size
|
n_experts_per_rank = (n_experts + self.world_size - 1) // self.world_size
|
||||||
self.ep_offset = self.ep_rank * n_experts
|
self.ep_offset = self.ep_rank * n_experts_per_rank
|
||||||
|
n_experts = min(n_experts_per_rank, n_experts - self.ep_offset)
|
||||||
else:
|
else:
|
||||||
self.ep_offset = 0
|
self.ep_offset = 0
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ from text_generation_server.utils.weights import UnquantizedWeight, Weights
|
|||||||
from vllm_hpu_extension.ops import VllmMixtureOfExpertsOp
|
from vllm_hpu_extension.ops import VllmMixtureOfExpertsOp
|
||||||
import habana_frameworks.torch as htorch
|
import habana_frameworks.torch as htorch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
class UnquantizedSparseMoELayer(nn.Module):
|
class UnquantizedSparseMoELayer(nn.Module):
|
||||||
@ -39,6 +40,21 @@ class UnquantizedSparseMoELayer(nn.Module):
|
|||||||
self.weight_block_size = weights.weights_loader.weight_block_size
|
self.weight_block_size = weights.weights_loader.weight_block_size
|
||||||
self.scoring_func = scoring_func
|
self.scoring_func = scoring_func
|
||||||
self.e_score_correction_bias = e_score_correction_bias
|
self.e_score_correction_bias = e_score_correction_bias
|
||||||
|
self.rank = weights.process_group.rank()
|
||||||
|
self.world_size = weights.process_group.size()
|
||||||
|
self.use_ep = os.getenv("USE_EXPERT_PARALLEL", "true").lower() == "true"
|
||||||
|
if (n_experts + self.world_size - 1) // self.world_size < 4:
|
||||||
|
self.use_ep = False
|
||||||
|
if self.use_ep:
|
||||||
|
n_experts_per_rank = (n_experts + self.world_size - 1) // self.world_size
|
||||||
|
self.ep_offset = self.rank * n_experts_per_rank
|
||||||
|
n_experts = min(n_experts_per_rank, n_experts - self.ep_offset)
|
||||||
|
experts_min = self.ep_offset
|
||||||
|
experts_max = self.ep_offset + n_experts - 1
|
||||||
|
else:
|
||||||
|
self.ep_offset = 0
|
||||||
|
experts_min = 0
|
||||||
|
experts_max = n_experts - 1
|
||||||
|
|
||||||
self.gate_up_proj = _load_expert_multi_weights_col(
|
self.gate_up_proj = _load_expert_multi_weights_col(
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
@ -46,6 +62,8 @@ class UnquantizedSparseMoELayer(nn.Module):
|
|||||||
gate_proj_name=gate_proj_name,
|
gate_proj_name=gate_proj_name,
|
||||||
up_proj_name=up_proj_name,
|
up_proj_name=up_proj_name,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
use_ep=self.use_ep,
|
||||||
|
ep_offset=self.ep_offset,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.down_proj = _load_expert_weights_row(
|
self.down_proj = _load_expert_weights_row(
|
||||||
@ -53,9 +71,11 @@ class UnquantizedSparseMoELayer(nn.Module):
|
|||||||
n_experts=n_experts,
|
n_experts=n_experts,
|
||||||
name=down_proj_name,
|
name=down_proj_name,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
use_ep=self.use_ep,
|
||||||
|
ep_offset=self.ep_offset,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.MoeOp = VllmMixtureOfExpertsOp(n_experts, 0, n_experts - 1)
|
self.MoeOp = VllmMixtureOfExpertsOp(n_experts, experts_min, experts_max)
|
||||||
for i in range(n_experts):
|
for i in range(n_experts):
|
||||||
self.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i])
|
self.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i])
|
||||||
self.MoeOp.w2_list[i].set_weight(self.down_proj[i])
|
self.MoeOp.w2_list[i].set_weight(self.down_proj[i])
|
||||||
@ -87,12 +107,23 @@ def _load_expert_multi_weights_col(
|
|||||||
gate_proj_name: str,
|
gate_proj_name: str,
|
||||||
up_proj_name: str,
|
up_proj_name: str,
|
||||||
weights: Weights,
|
weights: Weights,
|
||||||
|
use_ep: bool = False,
|
||||||
|
ep_offset: int = 0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
all_weight = None
|
all_weight = None
|
||||||
for i in range(n_experts):
|
for i in range(n_experts):
|
||||||
weight = weights.get_multi_weights_col(
|
if not use_ep:
|
||||||
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
|
weight = weights.get_multi_weights_col(
|
||||||
)
|
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
weight = weights.get_multi_weights(
|
||||||
|
[
|
||||||
|
f"{prefix}.{i+ep_offset}.{gate_proj_name}",
|
||||||
|
f"{prefix}.{i+ep_offset}.{up_proj_name}",
|
||||||
|
],
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
assert isinstance(weight, UnquantizedWeight)
|
assert isinstance(weight, UnquantizedWeight)
|
||||||
|
|
||||||
@ -116,12 +147,19 @@ def _load_expert_weights_row(
|
|||||||
n_experts: int,
|
n_experts: int,
|
||||||
name: str,
|
name: str,
|
||||||
weights: Weights,
|
weights: Weights,
|
||||||
|
use_ep: bool = False,
|
||||||
|
ep_offset: int = 0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
all_weight = None
|
all_weight = None
|
||||||
for i in range(n_experts):
|
for i in range(n_experts):
|
||||||
weight = weights.get_weights_row(
|
if not use_ep:
|
||||||
f"{prefix}.{i}.{name}",
|
weight = weights.get_weights_row(
|
||||||
)
|
f"{prefix}.{i}.{name}",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
weight = weights.get_weights(
|
||||||
|
f"{prefix}.{i+ep_offset}.{name}",
|
||||||
|
)
|
||||||
|
|
||||||
assert isinstance(weight, UnquantizedWeight)
|
assert isinstance(weight, UnquantizedWeight)
|
||||||
|
|
||||||
|
@ -28,11 +28,12 @@ from text_generation_server.layers import (
|
|||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
get_linear,
|
get_linear,
|
||||||
|
Fp8Linear,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
Seqlen,
|
Seqlen,
|
||||||
attention,
|
attention,
|
||||||
paged_attention,
|
paged_attention_mla,
|
||||||
set_block_mapping,
|
set_block_mapping,
|
||||||
HPUPagedAttentionMetadata,
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
@ -44,6 +45,18 @@ from text_generation_server.utils.weights import Weights
|
|||||||
import habana_frameworks.torch as htorch
|
import habana_frameworks.torch as htorch
|
||||||
|
|
||||||
|
|
||||||
|
def get_and_maybe_dequant_weights(layer: torch.nn.Module) -> torch.Tensor:
|
||||||
|
if isinstance(layer, Fp8Linear):
|
||||||
|
eye = torch.eye(
|
||||||
|
layer.qweight.shape[-1], dtype=torch.bfloat16, device=layer.qweight.device
|
||||||
|
)
|
||||||
|
dequant_weights = layer(eye)
|
||||||
|
del eye
|
||||||
|
# standardize to (output, input)
|
||||||
|
return dequant_weights.T
|
||||||
|
return layer.weight
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV2Config(PretrainedConfig):
|
class DeepseekV2Config(PretrainedConfig):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -246,6 +259,45 @@ class DeepseekV2Attention(torch.nn.Module):
|
|||||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
).repeat_interleave(self.num_groups)
|
).repeat_interleave(self.num_groups)
|
||||||
|
|
||||||
|
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj.linear).T
|
||||||
|
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||||
|
self.kv_lora_rank,
|
||||||
|
self.num_heads,
|
||||||
|
self.qk_nope_head_dim + self.value_head_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
W_UK, W_UV = kv_b_proj_weight.split(
|
||||||
|
[self.qk_nope_head_dim, self.value_head_size], dim=-1
|
||||||
|
)
|
||||||
|
# Convert from (L, N, V) to (N, L, V)
|
||||||
|
self.W_UV = W_UV.transpose(0, 1)
|
||||||
|
# Convert from (L, N, P) to (N, P, L)
|
||||||
|
self.W_UK_T = W_UK.permute(1, 2, 0)
|
||||||
|
|
||||||
|
def _q_proj_and_k_up_proj(self, x):
|
||||||
|
q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj
|
||||||
|
q_nope, q_pe = (
|
||||||
|
q_proj(x)
|
||||||
|
.view(-1, self.num_heads, self.head_size)
|
||||||
|
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert from (B, N, P) to (N, B, P)
|
||||||
|
q_nope = q_nope.transpose(0, 1)
|
||||||
|
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
||||||
|
ql_nope = torch.bmm(q_nope, self.W_UK_T)
|
||||||
|
# Convert from (N, B, L) to (B, N, L)
|
||||||
|
return ql_nope.transpose(0, 1), q_pe
|
||||||
|
|
||||||
|
def _v_up_proj_and_o_proj(self, x):
|
||||||
|
# Convert from (B, N, L) to (N, B, L)
|
||||||
|
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||||
|
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
||||||
|
x = torch.bmm(x, self.W_UV)
|
||||||
|
# Convert from (N, B, V) to (B, N * V)
|
||||||
|
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.value_head_size)
|
||||||
|
return self.o_proj(x)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@ -258,14 +310,9 @@ class DeepseekV2Attention(torch.nn.Module):
|
|||||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
):
|
):
|
||||||
if self.q_lora_rank is None:
|
if self.q_lora_rank is None:
|
||||||
query = self.q_proj(hidden_states)
|
hidden_states_or_q_c = hidden_states
|
||||||
else:
|
else:
|
||||||
query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0])
|
hidden_states_or_q_c = self.q_a_layernorm(self.q_a_proj(hidden_states))[0]
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
|
||||||
|
|
||||||
_, query_pe = torch.split(
|
|
||||||
query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
|
||||||
)
|
|
||||||
|
|
||||||
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
||||||
compressed_kv, key_pe = torch.split(
|
compressed_kv, key_pe = torch.split(
|
||||||
@ -273,13 +320,18 @@ class DeepseekV2Attention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim)
|
key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim)
|
||||||
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view(
|
kv_c_normed = self.kv_a_layernorm(compressed_kv.contiguous())[0]
|
||||||
-1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size
|
|
||||||
)
|
|
||||||
|
|
||||||
key_nope, value = torch.split(
|
# Prefill
|
||||||
kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
|
if cu_seqlen_prefill is not None:
|
||||||
)
|
q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj
|
||||||
|
query = q_proj(hidden_states_or_q_c)
|
||||||
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
query_nope, query_pe = torch.split(
|
||||||
|
query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
query_nope, query_pe = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
|
||||||
|
|
||||||
batch_size, heads, head_dim = query_pe.shape
|
batch_size, heads, head_dim = query_pe.shape
|
||||||
query_pe = (
|
query_pe = (
|
||||||
@ -294,33 +346,47 @@ class DeepseekV2Attention(torch.nn.Module):
|
|||||||
.reshape(batch_size, heads, head_dim)
|
.reshape(batch_size, heads, head_dim)
|
||||||
)
|
)
|
||||||
self.rotary_emb(query_pe, key_pe, cos, sin)
|
self.rotary_emb(query_pe, key_pe, cos, sin)
|
||||||
|
latent_vec_k = torch.concat(
|
||||||
|
(kv_c_normed, key_pe.view(-1, self.qk_rope_head_dim)), dim=-1
|
||||||
|
)
|
||||||
|
latent_vec_k = latent_vec_k.view(-1, self.qk_rope_head_dim + self.kv_lora_rank)
|
||||||
|
|
||||||
query[..., self.qk_nope_head_dim :] = query_pe
|
latent_vec_k = latent_vec_k.unflatten(0, (slots.size(0), -1))
|
||||||
key = torch.empty_like(query)
|
|
||||||
key[..., : self.qk_nope_head_dim] = key_nope
|
|
||||||
key[..., self.qk_nope_head_dim :] = key_pe
|
|
||||||
|
|
||||||
# We need to pad the heads because Flash Attention does not support
|
|
||||||
# qk and v with different head sizes.
|
|
||||||
query = torch.nn.functional.pad(
|
|
||||||
query, (0, self.head_pad_size - self.head_size), value=0
|
|
||||||
)
|
|
||||||
key = torch.nn.functional.pad(
|
|
||||||
key, (0, self.head_pad_size - self.head_size), value=0
|
|
||||||
)
|
|
||||||
value = torch.nn.functional.pad(
|
|
||||||
value, (0, self.head_pad_size - self.value_head_size), value=0
|
|
||||||
)
|
|
||||||
|
|
||||||
kv_cache.store(
|
kv_cache.store(
|
||||||
key=key,
|
key=latent_vec_k,
|
||||||
value=value,
|
value=None,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
kv_scales=self.kv_scales,
|
kv_scales=self.kv_scales,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prefill
|
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
|
kv = self.kv_b_proj(kv_c_normed).view(
|
||||||
|
-1,
|
||||||
|
self.num_key_value_heads,
|
||||||
|
self.qk_nope_head_dim + self.value_head_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
key_nope, value = torch.split(
|
||||||
|
kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
|
||||||
|
)
|
||||||
|
query[..., self.qk_nope_head_dim :] = query_pe
|
||||||
|
key = torch.empty_like(query)
|
||||||
|
key[..., : self.qk_nope_head_dim] = key_nope
|
||||||
|
key[..., self.qk_nope_head_dim :] = key_pe
|
||||||
|
|
||||||
|
# We need to pad the heads because Flash Attention does not support
|
||||||
|
# qk and v with different head sizes.
|
||||||
|
query = torch.nn.functional.pad(
|
||||||
|
query, (0, self.head_pad_size - self.head_size), value=0
|
||||||
|
)
|
||||||
|
key = torch.nn.functional.pad(
|
||||||
|
key, (0, self.head_pad_size - self.head_size), value=0
|
||||||
|
)
|
||||||
|
value = torch.nn.functional.pad(
|
||||||
|
value, (0, self.head_pad_size - self.value_head_size), value=0
|
||||||
|
)
|
||||||
|
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query=query,
|
query=query,
|
||||||
@ -331,9 +397,15 @@ class DeepseekV2Attention(torch.nn.Module):
|
|||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
softmax_scale=self.softmax_scale,
|
softmax_scale=self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
attn_output = attn_output[..., : self.value_head_size]
|
||||||
|
|
||||||
|
return self.o_proj(
|
||||||
|
attn_output.reshape(-1, self.num_heads * self.value_head_size)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
# Decode
|
||||||
|
query = torch.cat([query_nope, query_pe], dim=-1)
|
||||||
|
attn_output = paged_attention_mla(
|
||||||
query,
|
query,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
@ -341,14 +413,10 @@ class DeepseekV2Attention(torch.nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
kv_scales=self.kv_scales,
|
kv_scales=self.kv_scales,
|
||||||
hpu_attention_meta=hpu_attention_meta,
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
|
kv_lora_rank=self.kv_lora_rank,
|
||||||
)
|
)
|
||||||
|
attn_output = self._v_up_proj_and_o_proj(attn_output)
|
||||||
# Remove padding.
|
return attn_output
|
||||||
attn_output = attn_output[..., : self.value_head_size]
|
|
||||||
|
|
||||||
return self.o_proj(
|
|
||||||
attn_output.reshape(-1, self.num_heads * self.value_head_size)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV2MLP(nn.Module):
|
class DeepseekV2MLP(nn.Module):
|
||||||
|
@ -21,6 +21,7 @@ import torch.nn.functional as F
|
|||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
attention,
|
attention,
|
||||||
paged_attention,
|
paged_attention,
|
||||||
|
set_block_mapping,
|
||||||
Seqlen,
|
Seqlen,
|
||||||
HPUPagedAttentionMetadata,
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
@ -466,6 +467,10 @@ class Qwen3MoeModel(nn.Module):
|
|||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if hpu_attention_meta is not None:
|
||||||
|
hpu_attention_meta = set_block_mapping(
|
||||||
|
hpu_attention_meta, inputs_embeds.shape[0]
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
@ -1606,7 +1606,7 @@ class FlashCausalLM(Model):
|
|||||||
):
|
):
|
||||||
self.kv_cache = []
|
self.kv_cache = []
|
||||||
empty_cache()
|
empty_cache()
|
||||||
if self.config.model_type == "deepseek_v3":
|
if self.config.model_type in ["deepseek_v3", "deepseek_v2"]:
|
||||||
self.kv_cache = [
|
self.kv_cache = [
|
||||||
KVCompressCache(
|
KVCompressCache(
|
||||||
num_blocks=num_blocks,
|
num_blocks=num_blocks,
|
||||||
@ -1646,7 +1646,7 @@ class FlashCausalLM(Model):
|
|||||||
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
||||||
# Calculate the number of blocks that can be allocated with the free memory
|
# Calculate the number of blocks that can be allocated with the free memory
|
||||||
dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
|
dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
|
||||||
if self.config.model_type == "deepseek_v3":
|
if self.config.model_type in ["deepseek_v3", "deepseek_v2"]:
|
||||||
cache_block_size = BLOCK_SIZE * (
|
cache_block_size = BLOCK_SIZE * (
|
||||||
self.config.kv_lora_rank + self.config.qk_rope_head_dim
|
self.config.kv_lora_rank + self.config.qk_rope_head_dim
|
||||||
)
|
)
|
||||||
|
@ -1,50 +0,0 @@
|
|||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
from loguru import logger
|
|
||||||
from text_generation_server import server
|
|
||||||
import argparse
|
|
||||||
from text_generation_server.utils.adapter import parse_lora_adapters
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
|
||||||
logger.info("TGIService: starting tgi service .... ")
|
|
||||||
logger.info(
|
|
||||||
"TGIService: --model_id {}, --revision {}, --sharded {}, --speculate {}, --dtype {}, --trust_remote_code {}, --uds_path {} ".format(
|
|
||||||
args.model_id,
|
|
||||||
args.revision,
|
|
||||||
args.sharded,
|
|
||||||
args.speculate,
|
|
||||||
args.dtype,
|
|
||||||
args.trust_remote_code,
|
|
||||||
args.uds_path,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
lora_adapters = parse_lora_adapters(os.getenv("LORA_ADAPTERS"))
|
|
||||||
server.serve(
|
|
||||||
model_id=args.model_id,
|
|
||||||
lora_adapters=lora_adapters,
|
|
||||||
revision=args.revision,
|
|
||||||
sharded=args.sharded,
|
|
||||||
quantize=args.quantize,
|
|
||||||
speculate=args.speculate,
|
|
||||||
dtype=args.dtype,
|
|
||||||
trust_remote_code=args.trust_remote_code,
|
|
||||||
uds_path=args.uds_path,
|
|
||||||
max_input_tokens=args.max_input_tokens,
|
|
||||||
kv_cache_dtype="auto",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--model_id", type=str)
|
|
||||||
parser.add_argument("--revision", type=str)
|
|
||||||
parser.add_argument("--sharded", type=bool)
|
|
||||||
parser.add_argument("--speculate", type=int, default=None)
|
|
||||||
parser.add_argument("--dtype", type=str)
|
|
||||||
parser.add_argument("--trust_remote_code", type=bool)
|
|
||||||
parser.add_argument("--uds_path", type=Path)
|
|
||||||
parser.add_argument("--quantize", type=str)
|
|
||||||
parser.add_argument("--max_input_tokens", type=int)
|
|
||||||
args = parser.parse_args()
|
|
||||||
main(args)
|
|
@ -184,6 +184,7 @@ Text Generation Inference enables serving optimized models on Gaudi hardware. Th
|
|||||||
|
|
||||||
**Large Language Models (LLMs)**
|
**Large Language Models (LLMs)**
|
||||||
- [deepseek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1)
|
- [deepseek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1)
|
||||||
|
- [deepseek-v2](https://huggingface.co/deepseek-ai/DeepSeek-V2)
|
||||||
- [Llama2](https://huggingface.co/collections/meta-llama/llama-2-family-661da1f90a9d678b6f55773b)
|
- [Llama2](https://huggingface.co/collections/meta-llama/llama-2-family-661da1f90a9d678b6f55773b)
|
||||||
- [Llama3](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f)
|
- [Llama3](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f)
|
||||||
- [CodeLlama](https://huggingface.co/codellama/CodeLlama-13b-hf)
|
- [CodeLlama](https://huggingface.co/codellama/CodeLlama-13b-hf)
|
||||||
|
Loading…
Reference in New Issue
Block a user