mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
[REWRITTEN] added a bunch of cleanup based on comments in PR; removed conditionals from LayerNormParameterized and renamed to MLPSpeculatorLayerNorm; now using modules for tensor-parallel (this is work in progress - looking into if this is right approach); fixed issue with getting medusa model; fixed for more efficient loading
This commit is contained in:
parent
38d6045443
commit
9291d42865
@ -419,5 +419,6 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
# input_ids = input_ids[lm_head_indices]
|
||||||
logits, speculative_logits = self.lm_head(hidden_states, input_ids)
|
logits, speculative_logits = self.lm_head(hidden_states, input_ids)
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
@ -480,5 +480,5 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states, input_ids)
|
||||||
return logits
|
return logits
|
||||||
|
@ -1101,6 +1101,8 @@ class FlashCausalLM(Model):
|
|||||||
next_token_texts = []
|
next_token_texts = []
|
||||||
left = 0
|
left = 0
|
||||||
|
|
||||||
|
logger.info(f"Accepted ids {n_accepted_ids}")
|
||||||
|
|
||||||
current_stopped = False
|
current_stopped = False
|
||||||
for j in range(index, index + n_accepted_ids):
|
for j in range(index, index + n_accepted_ids):
|
||||||
# Generated token
|
# Generated token
|
||||||
|
@ -313,7 +313,7 @@ class BaseFlashMistral(FlashCausalLM):
|
|||||||
config_cls=AutoConfig,
|
config_cls=AutoConfig,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
tokenizer_class=AutoTokenizer,
|
tokenizer_class=AutoTokenizer,
|
||||||
@ -340,7 +340,7 @@ class BaseFlashMistral(FlashCausalLM):
|
|||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
|
|
||||||
# Set context windows
|
# Set context windows
|
||||||
if getattr(config, "sliding_window", None) is not None:
|
if getattr(config, "sliding_window", None) is not None:
|
||||||
@ -567,7 +567,7 @@ class FlashMistral(BaseFlashMistral):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -577,7 +577,7 @@ class FlashMistral(BaseFlashMistral):
|
|||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import math
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -494,7 +495,7 @@ class MLPSpeculatorModel(torch.nn.Module):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.proj = [
|
self.proj = [
|
||||||
TensorParallelColumnLinear.load(
|
FastLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.proj.{i}",
|
prefix=f"{prefix}.proj.{i}",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
@ -504,9 +505,7 @@ class MLPSpeculatorModel(torch.nn.Module):
|
|||||||
]
|
]
|
||||||
self.head = nn.ModuleList(
|
self.head = nn.ModuleList(
|
||||||
[
|
[
|
||||||
TensorParallelRowLinear.load(
|
FastLinear.load(config, f"{prefix}.head.{i}", weights, bias=False)
|
||||||
config, f"{prefix}.head.{i}", weights, bias=False
|
|
||||||
)
|
|
||||||
for i in range(self.n_predict)
|
for i in range(self.n_predict)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -528,32 +527,36 @@ class MLPSpeculatorModel(torch.nn.Module):
|
|||||||
# TODO
|
# TODO
|
||||||
self.vsize = 128256
|
self.vsize = 128256
|
||||||
self.inner_dim = 3072
|
self.inner_dim = 3072
|
||||||
|
self.top_k_tokens_per_head = [1] * self.n_predict
|
||||||
|
self.candidates = 1
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
state: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
top_k_tokens_per_head: Optional[List[int]] = None,
|
|
||||||
num_candidates: int = 1,
|
|
||||||
):
|
):
|
||||||
# TODO
|
top_k_tokens_per_head = self.top_k_tokens_per_head
|
||||||
top_k_tokens_per_head = [1, 1, 1, 1]
|
num_candidates = self.candidates
|
||||||
if top_k_tokens_per_head is None:
|
|
||||||
top_k_tokens_per_head = self.config.top_k_tokens_per_head
|
|
||||||
|
|
||||||
ind = input_ids
|
# if state.shape[0] > 1:
|
||||||
|
# state = state[:1]
|
||||||
|
|
||||||
# k indicates # of candidates
|
# k indicates # of candidates
|
||||||
# h indicates # of generated tokens
|
# h indicates # of generated tokens
|
||||||
|
state = hidden_states
|
||||||
b = state.size(0)
|
b = state.size(0)
|
||||||
out = torch.empty(b, 1, 0, device=state.device).int() # b k h
|
ind = input_ids[-b:].unsqueeze(0)
|
||||||
log_probs = torch.zeros(b, 1, device=state.device) # b k
|
out = torch.empty(1, b, self.n_predict, device=state.device).int() # b k h
|
||||||
all_probs = torch.empty(b, 1, 0, self.vsize, device=state.device) # b k h v
|
log_probs = torch.zeros(1, b, device=state.device) # b k
|
||||||
|
all_probs = torch.empty(
|
||||||
|
1, b, self.n_predict, self.vsize, device=state.device
|
||||||
|
) # b k h v
|
||||||
assert (
|
assert (
|
||||||
len(top_k_tokens_per_head) == self.n_predict
|
len(top_k_tokens_per_head) == self.n_predict
|
||||||
), f"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)"
|
), f"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)"
|
||||||
for i in range(self.n_predict):
|
for i in range(self.n_predict):
|
||||||
# Project and predict
|
# Project and predict
|
||||||
|
# print(ind)
|
||||||
z = self.emb[i](ind)
|
z = self.emb[i](ind)
|
||||||
z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d
|
z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d
|
||||||
state = self.proj[i](state) * self.state_weight + z
|
state = self.proj[i](state) * self.state_weight + z
|
||||||
@ -562,43 +565,32 @@ class MLPSpeculatorModel(torch.nn.Module):
|
|||||||
probs, preds = _probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k'
|
probs, preds = _probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k'
|
||||||
|
|
||||||
# Update candidate set with new predictions
|
# Update candidate set with new predictions
|
||||||
out = out.unsqueeze(2).expand(
|
out[:, :, i : i + 1] = preds
|
||||||
-1, -1, top_k_tokens_per_head[i], -1
|
|
||||||
) # b k k' h
|
|
||||||
try:
|
|
||||||
out = torch.cat(
|
|
||||||
[out, preds.unsqueeze(2).unsqueeze(3)], dim=-1
|
|
||||||
) # b k k' h+1
|
|
||||||
except Exception:
|
|
||||||
import ipdb
|
|
||||||
|
|
||||||
ipdb.set_trace()
|
|
||||||
out = out.view(b, -1, i + 1) # b kk' h+1
|
|
||||||
|
|
||||||
# Update distribution set with new logits
|
# Update distribution set with new logits
|
||||||
all_probs = torch.cat(
|
all_probs[:, :, i] = _probs.exp()
|
||||||
[all_probs, _probs.exp().unsqueeze(2)], dim=-1
|
|
||||||
) # b k h+1 v
|
|
||||||
all_probs = all_probs.repeat(
|
|
||||||
1, top_k_tokens_per_head[i], 1, 1
|
|
||||||
) # b kk' h+1 v
|
|
||||||
|
|
||||||
# Update state, log_probs and ind for new predictions
|
# Update state, log_probs and ind for new predictions
|
||||||
state = state.unsqueeze(2).expand(
|
state = state.unsqueeze(2).expand(
|
||||||
-1, -1, top_k_tokens_per_head[i], -1
|
-1, -1, top_k_tokens_per_head[i], -1
|
||||||
) # b k k' d
|
) # b k k' d
|
||||||
state = state.reshape(b, -1, state.size(3)) # b kk' d
|
state = state.reshape(-1, b, state.size(3)) # b kk' d
|
||||||
ind = preds.view(b, -1) # b kk'
|
ind = preds.view(-1, b) # b kk'
|
||||||
log_probs = log_probs.unsqueeze(2).expand(
|
log_probs = log_probs.unsqueeze(2).expand(
|
||||||
b, -1, top_k_tokens_per_head[i]
|
-1, b, top_k_tokens_per_head[i]
|
||||||
) # b k k'
|
) # b k k'
|
||||||
log_probs = log_probs.add(probs).reshape(b, -1) # b kk'
|
log_probs = log_probs.add(probs).reshape(-1, b) # b kk'
|
||||||
|
|
||||||
|
# print("done")
|
||||||
# Take only top n best guesses
|
# Take only top n best guesses
|
||||||
best_guesses = log_probs.topk(num_candidates, dim=1)[1] # b k
|
best_guesses = log_probs.topk(num_candidates, dim=1)[1] # b k
|
||||||
return all_probs.gather(
|
# speculative_logits = all_probs.gather(
|
||||||
1, best_guesses[:, :, None, None].expand(-1, -1, self.n_predict, self.vsize)
|
# 1, best_guesses[:, :, None, None].expand(-1, -1, self.n_predict, self.vsize)
|
||||||
) # b n h v
|
# ).squeeze(0)
|
||||||
|
speculative_logits = all_probs[0]
|
||||||
|
# assert list(speculative_logits.shape) == [hidden_states.shape[0], self.n_predict, self.vsize], f"{speculative_logits.shape}, {hidden_states.shape[0]} {self.n_predict} {self.vsize}"
|
||||||
|
# TODO Why is this shift existing, are speculative logits also including the natural next token ?
|
||||||
|
return speculative_logits[:, 1:]
|
||||||
|
|
||||||
|
|
||||||
class MLPSpeculatorHead(nn.Module):
|
class MLPSpeculatorHead(nn.Module):
|
||||||
@ -692,10 +684,10 @@ class MedusaHeadV1(nn.Module):
|
|||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
import json
|
import json
|
||||||
|
|
||||||
use_medusa = config.use_medusa
|
speculator = config.speculator
|
||||||
|
|
||||||
medusa_config = str(Path(use_medusa) / "config.json")
|
medusa_config = str(Path(speculator) / "config.json")
|
||||||
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
|
filename = str(Path(speculator) / "medusa_lm_head.safetensors")
|
||||||
|
|
||||||
with open(medusa_config, "r") as f:
|
with open(medusa_config, "r") as f:
|
||||||
medusa_config = json.load(f)
|
medusa_config = json.load(f)
|
||||||
@ -713,7 +705,7 @@ class MedusaHeadV1(nn.Module):
|
|||||||
return MedusaHeadV1(lm_head, medusa)
|
return MedusaHeadV1(lm_head, medusa)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input: torch.Tensor
|
self, input: torch.Tensor, _input_ids: torch.Tensor
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
logits = self.lm_head(input)
|
logits = self.lm_head(input)
|
||||||
# If we have too many tokens, we skip speculative logits
|
# If we have too many tokens, we skip speculative logits
|
||||||
@ -731,10 +723,10 @@ class MedusaHeadV2(nn.Module):
|
|||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
import json
|
import json
|
||||||
|
|
||||||
use_medusa = config.use_medusa
|
speculator = config.speculator
|
||||||
|
|
||||||
medusa_config = str(Path(use_medusa) / "config.json")
|
medusa_config = str(Path(speculator) / "config.json")
|
||||||
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
|
filename = str(Path(speculator) / "medusa_lm_head.safetensors")
|
||||||
|
|
||||||
with open(medusa_config, "r") as f:
|
with open(medusa_config, "r") as f:
|
||||||
medusa_config = json.load(f)
|
medusa_config = json.load(f)
|
||||||
@ -765,7 +757,7 @@ class MedusaHeadV2(nn.Module):
|
|||||||
|
|
||||||
self.lm_head = TensorParallelHead.load(config, prefix, weights)
|
self.lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, _input_ids):
|
||||||
# If we have too many tokens, we skip speculative logits
|
# If we have too many tokens, we skip speculative logits
|
||||||
if x.shape[0] > 128:
|
if x.shape[0] > 128:
|
||||||
logits = self.lm_head(x)
|
logits = self.lm_head(x)
|
||||||
|
Loading…
Reference in New Issue
Block a user