mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-26 12:32:10 +00:00
Dummy fix for medusa.
This commit is contained in:
parent
b24bdb9f8c
commit
10dd0150c0
@ -874,21 +874,15 @@ class FlashCausalLM(Model):
|
|||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
bs = input_ids.shape[0]
|
bs = input_ids.shape[0]
|
||||||
padded_bs = bs
|
graphs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
||||||
if bs == 3:
|
if graphs:
|
||||||
padded_bs = 4
|
cuda_graph = graphs[0]
|
||||||
elif 3 < bs <= 8:
|
else:
|
||||||
padded_bs = 8
|
cuda_graph = None
|
||||||
elif bs > 8:
|
|
||||||
padded_bs = (bs + 7) // 8 * 8
|
|
||||||
|
|
||||||
# Try to find an associated cuda graph
|
|
||||||
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
cu_seqlen_prefill is not None
|
cu_seqlen_prefill is not None
|
||||||
or cuda_graph is None
|
or cuda_graph is None
|
||||||
or batch.speculative_ids is not None
|
|
||||||
):
|
):
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@ -980,6 +974,9 @@ class FlashCausalLM(Model):
|
|||||||
speculative_logits,
|
speculative_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if os.getenv("RANK") == "0":
|
||||||
|
logger.info(f"Accepted ids {accepted_ids}")
|
||||||
|
|
||||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||||
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
|
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
|
||||||
)
|
)
|
||||||
|
@ -412,48 +412,61 @@ class SuperLayer(nn.Module):
|
|||||||
class ResBlock(torch.nn.Module):
|
class ResBlock(torch.nn.Module):
|
||||||
def __init__(self, config, prefix, weights):
|
def __init__(self, config, prefix, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.linear = FastLinear.load(
|
self.linear = TensorParallelColumnLinear.load(
|
||||||
config, prefix=f"{prefix}.linear", weights=weights, bias=True
|
config, prefix=f"{prefix}.linear", weights=weights, bias=True
|
||||||
)
|
)
|
||||||
|
self.process_group = weights.process_group
|
||||||
|
self.world_size = self.process_group.size()
|
||||||
|
self.rank = self.process_group.rank()
|
||||||
|
|
||||||
self.act = torch.nn.SiLU()
|
self.act = torch.nn.SiLU()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x + self.act(self.linear(x))
|
size = x.shape[-1]
|
||||||
|
block_size = (size + self.world_size - 1) // self.world_size
|
||||||
|
start = self.rank * block_size
|
||||||
|
stop = (self.rank + 1) * block_size
|
||||||
|
|
||||||
|
output = x[:, start: stop] + self.act(self.linear(x))
|
||||||
|
world_output = [
|
||||||
|
torch.empty_like(output) for _ in range(self.process_group.size())
|
||||||
|
]
|
||||||
|
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
||||||
|
world_output = torch.cat(world_output, dim=-1)
|
||||||
|
return world_output
|
||||||
|
|
||||||
|
|
||||||
class MedusaModel(torch.nn.Module):
|
class MedusaModel(torch.nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, config, medusa_config, weights, lm_head):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.heads = torch.nn.ModuleList(
|
self.heads = torch.nn.ModuleList(
|
||||||
[
|
[
|
||||||
MedusaHead(config, prefix=f"{i}", weights=weights)
|
MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights)
|
||||||
for i in range(config["medusa_num_heads"])
|
for i in range(medusa_config["medusa_num_heads"])
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
self.lm_head = lm_head
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
|
speculative_hidden_states = torch.stack([head(x) for head in self.heads], dim=1)
|
||||||
|
speculative_logits = self.lm_head(speculative_hidden_states)
|
||||||
return speculative_logits
|
return speculative_logits
|
||||||
|
|
||||||
|
|
||||||
class MedusaHead(torch.nn.Module):
|
class MedusaHead(torch.nn.Module):
|
||||||
def __init__(self, config, prefix, weights):
|
def __init__(self, config, medusa_config, prefix, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.blocks = torch.nn.ModuleList(
|
self.blocks = torch.nn.ModuleList(
|
||||||
[
|
[
|
||||||
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
|
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
|
||||||
for i in range(config["medusa_num_layers"])
|
for i in range(medusa_config["medusa_num_layers"])
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
n = len(self.blocks)
|
n = len(self.blocks)
|
||||||
self.out = FastLinear.load(
|
|
||||||
config, prefix=f"{prefix}.{n}", weights=weights, bias=False
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(x)
|
x = block(x)
|
||||||
x = self.out(x)
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -476,7 +489,7 @@ class SpeculativeHead(nn.Module):
|
|||||||
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
|
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
|
||||||
|
|
||||||
with open(medusa_config, "r") as f:
|
with open(medusa_config, "r") as f:
|
||||||
config = json.load(f)
|
medusa_config = json.load(f)
|
||||||
routing = weights.routing
|
routing = weights.routing
|
||||||
with safe_open(filename, framework="pytorch") as f:
|
with safe_open(filename, framework="pytorch") as f:
|
||||||
for k in f.keys():
|
for k in f.keys():
|
||||||
@ -486,7 +499,7 @@ class SpeculativeHead(nn.Module):
|
|||||||
)
|
)
|
||||||
weights.routing[k] = filename
|
weights.routing[k] = filename
|
||||||
|
|
||||||
medusa = MedusaModel(config, weights)
|
medusa = MedusaModel(config, medusa_config, weights, lm_head)
|
||||||
else:
|
else:
|
||||||
medusa = None
|
medusa = None
|
||||||
return SpeculativeHead(lm_head, medusa)
|
return SpeculativeHead(lm_head, medusa)
|
||||||
|
Loading…
Reference in New Issue
Block a user