mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
remove a tad of cpu bottleneck
This commit is contained in:
parent
af1989459c
commit
e69eed8ea3
@ -188,10 +188,7 @@ def download_weights(
|
|||||||
# Try to see if there are local pytorch weights
|
# Try to see if there are local pytorch weights
|
||||||
try:
|
try:
|
||||||
# Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
|
# Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
|
||||||
try:
|
local_pt_files = utils.weight_files(model_id, revision, ".bin")
|
||||||
local_pt_files = utils.weight_files(model_id, revision, extension=".bin")
|
|
||||||
except (FileNotFoundError, utils.EntryNotFoundError):
|
|
||||||
local_pt_files = utils.weight_files(model_id, revision, extension=".pt")
|
|
||||||
|
|
||||||
# No local pytorch weights
|
# No local pytorch weights
|
||||||
except utils.LocalEntryNotFoundError:
|
except utils.LocalEntryNotFoundError:
|
||||||
@ -202,10 +199,7 @@ def download_weights(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Try to see if there are pytorch weights on the hub
|
# Try to see if there are pytorch weights on the hub
|
||||||
try:
|
pt_filenames = utils.weight_hub_files(model_id, revision, ".bin")
|
||||||
pt_filenames = utils.weight_hub_files(model_id, revision, extension=".bin")
|
|
||||||
except utils.EntryNotFoundError:
|
|
||||||
pt_filenames = utils.weight_hub_files(model_id, revision, extension=".pt")
|
|
||||||
# Download pytorch weights
|
# Download pytorch weights
|
||||||
local_pt_files = utils.download_weights(pt_filenames, model_id, revision)
|
local_pt_files = utils.download_weights(pt_filenames, model_id, revision)
|
||||||
|
|
||||||
|
@ -282,7 +282,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type == "mistral":
|
if model_type in ["mistral", "mixtral"]:
|
||||||
if MISTRAL:
|
if MISTRAL:
|
||||||
return FlashMistral(
|
return FlashMistral(
|
||||||
model_id,
|
model_id,
|
||||||
@ -292,7 +292,7 @@ def get_model(
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
raise NotImplementedError("Mistral model requires flash attention v2")
|
raise NotImplementedError("Mistral models requires flash attention v2")
|
||||||
|
|
||||||
if model_type == "opt":
|
if model_type == "opt":
|
||||||
return OPTSharded(
|
return OPTSharded(
|
||||||
|
@ -86,7 +86,7 @@ class MixtralConfig(PretrainedConfig):
|
|||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.intermediate_size = intermediate_size
|
self.intermediate_size = intermediate_size
|
||||||
self.num_hidden_layers = 4
|
self.num_hidden_layers = num_hidden_layers
|
||||||
self.num_attention_heads = num_attention_heads
|
self.num_attention_heads = num_attention_heads
|
||||||
self.sliding_window = sliding_window
|
self.sliding_window = sliding_window
|
||||||
|
|
||||||
@ -123,7 +123,7 @@ def load_attention(config, prefix, weights):
|
|||||||
else:
|
else:
|
||||||
return TensorParallelColumnLinear.load_multi(
|
return TensorParallelColumnLinear.load_multi(
|
||||||
config,
|
config,
|
||||||
prefixes=[f"{prefix}.wq", f"{prefix}.wk", f"{prefix}.wv"],
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
dim=0,
|
dim=0,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=False,
|
||||||
@ -135,7 +135,7 @@ def _load_gqa(config, prefix: str, weights):
|
|||||||
assert config.num_attention_heads % weights.process_group.size() == 0
|
assert config.num_attention_heads % weights.process_group.size() == 0
|
||||||
|
|
||||||
weight = weights.get_multi_weights_col(
|
weight = weights.get_multi_weights_col(
|
||||||
prefixes=[f"{prefix}.wq", f"{prefix}.wk", f"{prefix}.wv"],
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
quantize=config.quantize,
|
quantize=config.quantize,
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
@ -156,34 +156,35 @@ def _load_gqa(config, prefix: str, weights):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _load_experts(config, prefix, weights):
|
def _load_experts(config, prefix, mat, weights):
|
||||||
if config.quantize is not None:
|
if config.quantize is not None:
|
||||||
raise NotImplementedError("Mixtral does not support weight quantization yet.")
|
raise NotImplementedError("Mixtral does not support weight quantization yet.")
|
||||||
|
|
||||||
slice_ = weights._get_slice(prefix)
|
assert mat in ["w1", "w2", "w3"]
|
||||||
|
|
||||||
world_size = weights.process_group.size()
|
world_size = weights.process_group.size()
|
||||||
rank = weights.process_group.rank()
|
rank = weights.process_group.rank()
|
||||||
|
|
||||||
if world_size == 1:
|
|
||||||
tensor = slice_[:].to(dtype=weights.dtype).to(device=weights.device)
|
|
||||||
else:
|
|
||||||
assert (
|
assert (
|
||||||
config.intermediate_size % world_size == 0
|
config.intermediate_size % world_size == 0
|
||||||
), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards"
|
), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards"
|
||||||
assert slice_.get_shape()[0] == config.num_local_experts * config.intermediate_size
|
|
||||||
|
|
||||||
block_size = config.intermediate_size // world_size
|
block_size = config.intermediate_size // world_size
|
||||||
start = rank * block_size
|
start = rank * block_size
|
||||||
stop = (rank + 1) * block_size
|
stop = (rank + 1) * block_size
|
||||||
|
|
||||||
expert_slices = []
|
tensor = torch.empty((config.num_local_experts * block_size, config.hidden_size),
|
||||||
|
dtype=weights.dtype,
|
||||||
|
device=weights.device)
|
||||||
|
|
||||||
for i in range(config.num_local_experts):
|
for i in range(config.num_local_experts):
|
||||||
expert_start = i * config.intermediate_size
|
slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight")
|
||||||
|
|
||||||
expert_slices.append(slice_[start + expert_start:stop + expert_start])
|
|
||||||
|
|
||||||
tensor = torch.cat(expert_slices, dim=0).to(dtype=weights.dtype).to(device=weights.device)
|
|
||||||
|
|
||||||
|
if mat == "w2":
|
||||||
|
expert_slice = slice_[:, start:stop].t().contiguous()
|
||||||
|
else:
|
||||||
|
expert_slice = slice_[start:stop]
|
||||||
|
tensor[i * block_size:(i + 1) * block_size] = expert_slice.to(dtype=weights.dtype).to(device=weights.device)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
@ -223,9 +224,9 @@ class MixtralAttention(torch.nn.Module):
|
|||||||
|
|
||||||
self.query_key_value = load_attention(config, prefix, weights)
|
self.query_key_value = load_attention(config, prefix, weights)
|
||||||
|
|
||||||
self.wo = TensorParallelRowLinear.load(
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.wo",
|
prefix=f"{prefix}.o_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
@ -299,7 +300,25 @@ class MixtralAttention(torch.nn.Module):
|
|||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.wo(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def select_experts(gate_logits: torch.Tensor, top_k: int):
|
||||||
|
# all_probs: (sequence_length, n_experts) and upcast for softmax
|
||||||
|
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
|
||||||
|
# weights, selected_experts: (sequence_length, top-k)
|
||||||
|
weights, selected_experts = torch.topk(all_probs, top_k, dim=-1)
|
||||||
|
weights /= weights.sum(dim=-1, keepdim=True)
|
||||||
|
weights = weights.view(-1)
|
||||||
|
selected_experts = selected_experts.view(-1)
|
||||||
|
|
||||||
|
return selected_experts, weights
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def round_up(x: torch.Tensor, value: int):
|
||||||
|
return torch.div(x + (value - 1), value, rounding_mode="trunc") * value
|
||||||
|
|
||||||
|
|
||||||
class BlockSparseMoE(nn.Module):
|
class BlockSparseMoE(nn.Module):
|
||||||
@ -339,9 +358,12 @@ class BlockSparseMoE(nn.Module):
|
|||||||
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||||
|
|
||||||
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
|
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
|
||||||
self.w1 = _load_experts(config, f"{prefix}.w1", weights)
|
self.w1 = _load_experts(config, f"{prefix}.experts", "w1", weights).t()
|
||||||
self.w2 = _load_experts(config, f"{prefix}.w2", weights)
|
self.w2 = _load_experts(config, f"{prefix}.experts", "w2", weights)
|
||||||
self.w3 = _load_experts(config, f"{prefix}.w3", weights)
|
self.w3 = _load_experts(config, f"{prefix}.experts", "w3", weights).t()
|
||||||
|
|
||||||
|
self.offsets = None
|
||||||
|
self.offsets_block_rows = 0
|
||||||
|
|
||||||
self.process_group = weights.process_group
|
self.process_group = weights.process_group
|
||||||
|
|
||||||
@ -361,13 +383,18 @@ class BlockSparseMoE(nn.Module):
|
|||||||
# dimensionality of a single expert.
|
# dimensionality of a single expert.
|
||||||
block_rows = padded_tokens // self.blocking
|
block_rows = padded_tokens // self.blocking
|
||||||
blocks_per_row = self.ffn_dim // self.blocking
|
blocks_per_row = self.ffn_dim // self.blocking
|
||||||
offsets = torch.arange(
|
if self.offsets is None or block_rows > self.offsets_block_rows:
|
||||||
|
self.offsets = torch.arange(
|
||||||
0,
|
0,
|
||||||
block_rows * blocks_per_row + 1,
|
block_rows * blocks_per_row + 1,
|
||||||
blocks_per_row,
|
blocks_per_row,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=x.device,
|
device=x.device,
|
||||||
)
|
)
|
||||||
|
self.offsets_block_rows = block_rows
|
||||||
|
offsets = self.offsets
|
||||||
|
else:
|
||||||
|
offsets = self.offsets[:block_rows]
|
||||||
|
|
||||||
# Indices for the sparse matrix. The indices for
|
# Indices for the sparse matrix. The indices for
|
||||||
# the intermediate matrix are dynamic depending
|
# the intermediate matrix are dynamic depending
|
||||||
@ -375,7 +402,6 @@ class BlockSparseMoE(nn.Module):
|
|||||||
column_indices = ops.topology(padded_bins, self.blocking, block_rows,
|
column_indices = ops.topology(padded_bins, self.blocking, block_rows,
|
||||||
blocks_per_row)
|
blocks_per_row)
|
||||||
|
|
||||||
# TODO(tgale): This is unused. Remove the need for this in stk.
|
|
||||||
# For now, use meta init to save the device memory.
|
# For now, use meta init to save the device memory.
|
||||||
data = torch.empty(
|
data = torch.empty(
|
||||||
column_indices.numel(),
|
column_indices.numel(),
|
||||||
@ -400,7 +426,7 @@ class BlockSparseMoE(nn.Module):
|
|||||||
def indices_and_padded_bins(self, selected_experts: torch.Tensor):
|
def indices_and_padded_bins(self, selected_experts: torch.Tensor):
|
||||||
# Sort the expert ids to produce the scatter/gather
|
# Sort the expert ids to produce the scatter/gather
|
||||||
# indices for the permutation.
|
# indices for the permutation.
|
||||||
selected_experts = selected_experts.int()
|
# selected_experts = selected_experts.int()
|
||||||
|
|
||||||
# returns bin_ids == num of experts for this sequence ? == unique selected experts?
|
# returns bin_ids == num of experts for this sequence ? == unique selected experts?
|
||||||
# and indices == how to sort tokens?
|
# and indices == how to sort tokens?
|
||||||
@ -418,7 +444,7 @@ class BlockSparseMoE(nn.Module):
|
|||||||
# position of each bin.
|
# position of each bin.
|
||||||
|
|
||||||
# List of size num_experts
|
# List of size num_experts
|
||||||
padded_tokens_per_expert = ops.round_up(tokens_per_expert,
|
padded_tokens_per_expert = round_up(tokens_per_expert,
|
||||||
self.blocking)
|
self.blocking)
|
||||||
# padded_tokens_per_expert => [128, O, 128, ...]
|
# padded_tokens_per_expert => [128, O, 128, ...]
|
||||||
|
|
||||||
@ -446,13 +472,7 @@ class BlockSparseMoE(nn.Module):
|
|||||||
|
|
||||||
# gate_logits: (sequence_length, n_experts)
|
# gate_logits: (sequence_length, n_experts)
|
||||||
gate_logits = self.gate(x)
|
gate_logits = self.gate(x)
|
||||||
# all_probs: (sequence_length, n_experts) and upcast for softmax
|
selected_experts, weights = select_experts(gate_logits, self.top_k)
|
||||||
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
|
|
||||||
# weights, selected_experts: (sequence_length, top-k)
|
|
||||||
weights, selected_experts = torch.topk(all_probs, self.top_k, dim=-1)
|
|
||||||
weights /= weights.sum(dim=-1, keepdim=True)
|
|
||||||
weights = weights.flatten().to(x.dtype)
|
|
||||||
selected_experts = selected_experts.flatten()
|
|
||||||
|
|
||||||
(
|
(
|
||||||
indices,
|
indices,
|
||||||
@ -465,7 +485,7 @@ class BlockSparseMoE(nn.Module):
|
|||||||
# Permute tokens and pad to prepare expert computation
|
# Permute tokens and pad to prepare expert computation
|
||||||
# (top_k * sequence_length + padding, model_dim)
|
# (top_k * sequence_length + padding, model_dim)
|
||||||
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins,
|
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins,
|
||||||
self.top_k, x.shape[0] + self.num_experts * self.blocking)
|
self.top_k)
|
||||||
|
|
||||||
# Create the sparse matrix topology
|
# Create the sparse matrix topology
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -476,8 +496,8 @@ class BlockSparseMoE(nn.Module):
|
|||||||
# (top_k * sequence_length + padding, ffn_dim * n_experts)
|
# (top_k * sequence_length + padding, ffn_dim * n_experts)
|
||||||
x = stk.Matrix(
|
x = stk.Matrix(
|
||||||
topo.size(),
|
topo.size(),
|
||||||
self.act(stk.ops.sdd(x, self.w1.t(), topo).data) *
|
self.act(stk.ops.sdd(x, self.w1, topo).data) *
|
||||||
stk.ops.sdd(x, self.w3.t(), topo).data,
|
stk.ops.sdd(x, self.w3, topo).data,
|
||||||
topo.row_indices,
|
topo.row_indices,
|
||||||
topo.column_indices,
|
topo.column_indices,
|
||||||
topo.offsets,
|
topo.offsets,
|
||||||
@ -512,18 +532,18 @@ class BlockSparseMoE(nn.Module):
|
|||||||
class MixtralLayer(nn.Module):
|
class MixtralLayer(nn.Module):
|
||||||
def __init__(self, layer_id, config, weights):
|
def __init__(self, layer_id, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
prefix = f"layers.{layer_id}"
|
prefix = f"model.layers.{layer_id}"
|
||||||
|
|
||||||
self.attention = MixtralAttention(
|
self.self_attn = MixtralAttention(
|
||||||
prefix=f"{prefix}.attention", config=config, weights=weights
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
)
|
)
|
||||||
self.block_sparse_moe = BlockSparseMoE(f"{prefix}.block_sparse_moe", config, weights)
|
self.block_sparse_moe = BlockSparseMoE(f"{prefix}.block_sparse_moe", config, weights)
|
||||||
|
|
||||||
self.attention_norm = FastRMSNorm.load(
|
self.input_layernorm = FastRMSNorm.load(
|
||||||
prefix=f"{prefix}.attention_norm", weights=weights, eps=config.rms_norm_eps
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
)
|
)
|
||||||
self.ffn_norm = FastRMSNorm.load(
|
self.post_attention_layernorm = FastRMSNorm.load(
|
||||||
prefix=f"{prefix}.ffn_norm",
|
prefix=f"{prefix}.post_attention_layernorm",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.rms_norm_eps,
|
eps=config.rms_norm_eps,
|
||||||
)
|
)
|
||||||
@ -542,10 +562,10 @@ class MixtralLayer(nn.Module):
|
|||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.attention_norm(hidden_states, residual)
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
attn_output = self.attention(
|
attn_output = self.self_attn(
|
||||||
normed_hidden_states,
|
normed_hidden_states,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
@ -559,21 +579,21 @@ class MixtralLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# faster post attention rms norm
|
# faster post attention rms norm
|
||||||
normed_attn_res_output, attn_res = self.ffn_norm(
|
normed_attn_res_output, attn_res = self.post_attention_layernorm(
|
||||||
attn_output, res
|
attn_output, res
|
||||||
)
|
)
|
||||||
|
|
||||||
mlp_output = self.block_sparse_moe(normed_attn_res_output)
|
block_sparse_moe_output = self.block_sparse_moe(normed_attn_res_output)
|
||||||
|
|
||||||
return mlp_output, attn_res
|
return block_sparse_moe_output, attn_res
|
||||||
|
|
||||||
|
|
||||||
class MixtralModel(torch.nn.Module):
|
class MixtralModel(torch.nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.tok_embeddings = TensorParallelEmbedding(
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
prefix="tok_embeddings", weights=weights
|
prefix="model.embed_tokens", weights=weights
|
||||||
)
|
)
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
@ -587,12 +607,12 @@ class MixtralModel(torch.nn.Module):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.norm = FastRMSNorm.load(
|
self.norm = FastRMSNorm.load(
|
||||||
prefix="norm", weights=weights, eps=config.rms_norm_eps
|
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
|
||||||
)
|
)
|
||||||
|
|
||||||
self.head_size = self.layers[0].attention.head_size
|
self.head_size = self.layers[0].self_attn.head_size
|
||||||
self.num_heads = self.layers[0].attention.num_heads
|
self.num_heads = self.layers[0].self_attn.num_heads
|
||||||
self.num_key_value_heads = self.layers[0].attention.num_key_value_heads
|
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -606,11 +626,11 @@ class MixtralModel(torch.nn.Module):
|
|||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.tok_embeddings(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||||
position_ids, max_s, hidden_states.dtype
|
position_ids, max_s, hidden_states.dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -642,7 +662,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
|||||||
self.model = MixtralModel(config, weights)
|
self.model = MixtralModel(config, weights)
|
||||||
self.lm_head = TensorParallelHead.load(
|
self.lm_head = TensorParallelHead.load(
|
||||||
config,
|
config,
|
||||||
prefix="output",
|
prefix="lm_head",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
self.max_past = config.sliding_window
|
self.max_past = config.sliding_window
|
||||||
|
@ -91,8 +91,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# from torch.profiler import profile, ProfilerActivity
|
||||||
|
# with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prefill_prof:
|
||||||
generations, next_batch = self.model.generate_token(batch)
|
generations, next_batch = self.model.generate_token(batch)
|
||||||
self.cache.set(next_batch)
|
self.cache.set(next_batch)
|
||||||
|
# if self.model.rank == 0:
|
||||||
|
# prefill_prof.export_chrome_trace("prefill.json")
|
||||||
|
|
||||||
return generate_pb2.PrefillResponse(
|
return generate_pb2.PrefillResponse(
|
||||||
generations=[generation.to_pb() for generation in generations],
|
generations=[generation.to_pb() for generation in generations],
|
||||||
@ -118,8 +122,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
else:
|
else:
|
||||||
batch = batches[0]
|
batch = batches[0]
|
||||||
|
|
||||||
|
# from torch.profiler import profile, ProfilerActivity
|
||||||
|
# with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prefill_prof:
|
||||||
generations, next_batch = self.model.generate_token(batch)
|
generations, next_batch = self.model.generate_token(batch)
|
||||||
self.cache.set(next_batch)
|
self.cache.set(next_batch)
|
||||||
|
# if self.model.rank == 0:
|
||||||
|
# prefill_prof.export_chrome_trace("decode.json")
|
||||||
|
|
||||||
return generate_pb2.DecodeResponse(
|
return generate_pb2.DecodeResponse(
|
||||||
generations=[generation.to_pb() for generation in generations],
|
generations=[generation.to_pb() for generation in generations],
|
||||||
|
@ -98,10 +98,7 @@ def weight_files(
|
|||||||
if extension != ".safetensors":
|
if extension != ".safetensors":
|
||||||
raise e
|
raise e
|
||||||
# Try to see if there are pytorch weights
|
# Try to see if there are pytorch weights
|
||||||
try:
|
|
||||||
pt_filenames = weight_hub_files(model_id, revision, extension=".bin")
|
pt_filenames = weight_hub_files(model_id, revision, extension=".bin")
|
||||||
except EntryNotFoundError:
|
|
||||||
pt_filenames = weight_hub_files(model_id, revision, extension=".pt")
|
|
||||||
# Change pytorch extension to safetensors extension
|
# Change pytorch extension to safetensors extension
|
||||||
# It is possible that we have safetensors weights locally even though they are not on the
|
# It is possible that we have safetensors weights locally even though they are not on the
|
||||||
# hub if we converted weights locally without pushing them
|
# hub if we converted weights locally without pushing them
|
||||||
|
@ -23,7 +23,7 @@ class Weights:
|
|||||||
with safe_open(filename, framework="pytorch") as f:
|
with safe_open(filename, framework="pytorch") as f:
|
||||||
for k in f.keys():
|
for k in f.keys():
|
||||||
if k in routing:
|
if k in routing:
|
||||||
logger.warning(
|
raise RuntimeError(
|
||||||
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
||||||
)
|
)
|
||||||
routing[k] = filename
|
routing[k] = filename
|
||||||
@ -116,7 +116,7 @@ class Weights:
|
|||||||
size = slice_.get_shape()[dim]
|
size = slice_.get_shape()[dim]
|
||||||
assert (
|
assert (
|
||||||
size % world_size == 0
|
size % world_size == 0
|
||||||
), f"The chosen size {size} is not compatible with sharding on {world_size} shards"
|
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
||||||
return self.get_partial_sharded(tensor_name, dim)
|
return self.get_partial_sharded(tensor_name, dim)
|
||||||
|
|
||||||
def _get_qweight(self, name: str):
|
def _get_qweight(self, name: str):
|
||||||
|
Loading…
Reference in New Issue
Block a user