diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 2a9d3914..538dbb62 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -874,21 +874,15 @@ class FlashCausalLM(Model): lm_head_indices = batch.prefill_head_indices bs = input_ids.shape[0] - padded_bs = bs - if bs == 3: - padded_bs = 4 - elif 3 < bs <= 8: - padded_bs = 8 - elif bs > 8: - padded_bs = (bs + 7) // 8 * 8 - - # Try to find an associated cuda graph - cuda_graph = self.cuda_graphs.get(padded_bs, None) + graphs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) + if graphs: + cuda_graph = graphs[0] + else: + cuda_graph = None if ( cu_seqlen_prefill is not None or cuda_graph is None - or batch.speculative_ids is not None ): return self.model.forward( input_ids=input_ids, @@ -980,6 +974,9 @@ class FlashCausalLM(Model): speculative_logits, ) + if os.getenv("RANK") == "0": + logger.info(f"Accepted ids {accepted_ids}") + batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids ) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 312a4482..c4aeb602 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -412,48 +412,61 @@ class SuperLayer(nn.Module): class ResBlock(torch.nn.Module): def __init__(self, config, prefix, weights): super().__init__() - self.linear = FastLinear.load( + self.linear = TensorParallelColumnLinear.load( config, prefix=f"{prefix}.linear", weights=weights, bias=True ) + self.process_group = weights.process_group + self.world_size = self.process_group.size() + self.rank = self.process_group.rank() + self.act = torch.nn.SiLU() def forward(self, x): - return x + self.act(self.linear(x)) + size = x.shape[-1] + block_size = (size + self.world_size - 1) // self.world_size + start = self.rank * block_size + stop = (self.rank + 1) * block_size + + output = x[:, start: stop] + self.act(self.linear(x)) + world_output = [ + torch.empty_like(output) for _ in range(self.process_group.size()) + ] + torch.distributed.all_gather(world_output, output, group=self.process_group) + world_output = torch.cat(world_output, dim=-1) + return world_output class MedusaModel(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, config, medusa_config, weights, lm_head): super().__init__() self.heads = torch.nn.ModuleList( [ - MedusaHead(config, prefix=f"{i}", weights=weights) - for i in range(config["medusa_num_heads"]) + MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights) + for i in range(medusa_config["medusa_num_heads"]) ] ) + self.lm_head = lm_head def forward(self, x): - speculative_logits = torch.stack([head(x) for head in self.heads], dim=1) + speculative_hidden_states = torch.stack([head(x) for head in self.heads], dim=1) + speculative_logits = self.lm_head(speculative_hidden_states) return speculative_logits class MedusaHead(torch.nn.Module): - def __init__(self, config, prefix, weights): + def __init__(self, config, medusa_config, prefix, weights): super().__init__() self.blocks = torch.nn.ModuleList( [ ResBlock(config, prefix=f"{prefix}.{i}", weights=weights) - for i in range(config["medusa_num_layers"]) + for i in range(medusa_config["medusa_num_layers"]) ] ) n = len(self.blocks) - self.out = FastLinear.load( - config, prefix=f"{prefix}.{n}", weights=weights, bias=False - ) def forward(self, x): for block in self.blocks: x = block(x) - x = self.out(x) return x @@ -476,7 +489,7 @@ class SpeculativeHead(nn.Module): filename = str(Path(use_medusa) / "medusa_lm_head.safetensors") with open(medusa_config, "r") as f: - config = json.load(f) + medusa_config = json.load(f) routing = weights.routing with safe_open(filename, framework="pytorch") as f: for k in f.keys(): @@ -486,7 +499,7 @@ class SpeculativeHead(nn.Module): ) weights.routing[k] = filename - medusa = MedusaModel(config, weights) + medusa = MedusaModel(config, medusa_config, weights, lm_head) else: medusa = None return SpeculativeHead(lm_head, medusa)