mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-25 12:02:08 +00:00
Dummy fix for medusa.
This commit is contained in:
parent
b24bdb9f8c
commit
10dd0150c0
@ -874,21 +874,15 @@ class FlashCausalLM(Model):
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
bs = input_ids.shape[0]
|
||||
padded_bs = bs
|
||||
if bs == 3:
|
||||
padded_bs = 4
|
||||
elif 3 < bs <= 8:
|
||||
padded_bs = 8
|
||||
elif bs > 8:
|
||||
padded_bs = (bs + 7) // 8 * 8
|
||||
|
||||
# Try to find an associated cuda graph
|
||||
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
||||
graphs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
||||
if graphs:
|
||||
cuda_graph = graphs[0]
|
||||
else:
|
||||
cuda_graph = None
|
||||
|
||||
if (
|
||||
cu_seqlen_prefill is not None
|
||||
or cuda_graph is None
|
||||
or batch.speculative_ids is not None
|
||||
):
|
||||
return self.model.forward(
|
||||
input_ids=input_ids,
|
||||
@ -980,6 +974,9 @@ class FlashCausalLM(Model):
|
||||
speculative_logits,
|
||||
)
|
||||
|
||||
if os.getenv("RANK") == "0":
|
||||
logger.info(f"Accepted ids {accepted_ids}")
|
||||
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
|
||||
)
|
||||
|
@ -412,48 +412,61 @@ class SuperLayer(nn.Module):
|
||||
class ResBlock(torch.nn.Module):
|
||||
def __init__(self, config, prefix, weights):
|
||||
super().__init__()
|
||||
self.linear = FastLinear.load(
|
||||
self.linear = TensorParallelColumnLinear.load(
|
||||
config, prefix=f"{prefix}.linear", weights=weights, bias=True
|
||||
)
|
||||
self.process_group = weights.process_group
|
||||
self.world_size = self.process_group.size()
|
||||
self.rank = self.process_group.rank()
|
||||
|
||||
self.act = torch.nn.SiLU()
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.act(self.linear(x))
|
||||
size = x.shape[-1]
|
||||
block_size = (size + self.world_size - 1) // self.world_size
|
||||
start = self.rank * block_size
|
||||
stop = (self.rank + 1) * block_size
|
||||
|
||||
output = x[:, start: stop] + self.act(self.linear(x))
|
||||
world_output = [
|
||||
torch.empty_like(output) for _ in range(self.process_group.size())
|
||||
]
|
||||
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
||||
world_output = torch.cat(world_output, dim=-1)
|
||||
return world_output
|
||||
|
||||
|
||||
class MedusaModel(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, config, medusa_config, weights, lm_head):
|
||||
super().__init__()
|
||||
self.heads = torch.nn.ModuleList(
|
||||
[
|
||||
MedusaHead(config, prefix=f"{i}", weights=weights)
|
||||
for i in range(config["medusa_num_heads"])
|
||||
MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights)
|
||||
for i in range(medusa_config["medusa_num_heads"])
|
||||
]
|
||||
)
|
||||
self.lm_head = lm_head
|
||||
|
||||
def forward(self, x):
|
||||
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
|
||||
speculative_hidden_states = torch.stack([head(x) for head in self.heads], dim=1)
|
||||
speculative_logits = self.lm_head(speculative_hidden_states)
|
||||
return speculative_logits
|
||||
|
||||
|
||||
class MedusaHead(torch.nn.Module):
|
||||
def __init__(self, config, prefix, weights):
|
||||
def __init__(self, config, medusa_config, prefix, weights):
|
||||
super().__init__()
|
||||
self.blocks = torch.nn.ModuleList(
|
||||
[
|
||||
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
|
||||
for i in range(config["medusa_num_layers"])
|
||||
for i in range(medusa_config["medusa_num_layers"])
|
||||
]
|
||||
)
|
||||
n = len(self.blocks)
|
||||
self.out = FastLinear.load(
|
||||
config, prefix=f"{prefix}.{n}", weights=weights, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
x = self.out(x)
|
||||
return x
|
||||
|
||||
|
||||
@ -476,7 +489,7 @@ class SpeculativeHead(nn.Module):
|
||||
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
|
||||
|
||||
with open(medusa_config, "r") as f:
|
||||
config = json.load(f)
|
||||
medusa_config = json.load(f)
|
||||
routing = weights.routing
|
||||
with safe_open(filename, framework="pytorch") as f:
|
||||
for k in f.keys():
|
||||
@ -486,7 +499,7 @@ class SpeculativeHead(nn.Module):
|
||||
)
|
||||
weights.routing[k] = filename
|
||||
|
||||
medusa = MedusaModel(config, weights)
|
||||
medusa = MedusaModel(config, medusa_config, weights, lm_head)
|
||||
else:
|
||||
medusa = None
|
||||
return SpeculativeHead(lm_head, medusa)
|
||||
|
Loading…
Reference in New Issue
Block a user