mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
Flash Transformers modeling backend support (#2913)
* add transformers_flash * inits * switch version to make it work * Update Makefile-flash-att-v2 * Update Makefile-flash-att-v2 * Update Makefile-flash-att-v2 * Update Makefile-flash-att-v2 * Update Makefile-flash-att-v2 * Update Makefile-flash-att-v2 * runnable version * working * push change * fix high dim * init * default * latest transformers changes * revert * simplify check * remove flag * improve type hints + required args * Update based on transformers PR * small fix * Remove Warpers for Processor * fix compatibility version issue * raise error if needed * Simplify with monkey patch * revert + style + minor improvements * update comment * device check * move the import to avoid device issue * Update __init__.py * check for non-native models * oupsi --------- Co-authored-by: System administrator <root@ip-10-90-0-159.ec2.internal>
This commit is contained in:
parent
447a5b2f87
commit
b980848abf
@ -956,15 +956,24 @@ def quantize(
|
||||
|
||||
pack(model, quantizers, bits, groupsize)
|
||||
from safetensors.torch import save_file
|
||||
from transformers.modeling_utils import shard_checkpoint
|
||||
from huggingface_hub import split_torch_state_dict_into_shards
|
||||
|
||||
state_dict = model.state_dict()
|
||||
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
|
||||
|
||||
max_shard_size = "10GB"
|
||||
shards, index = shard_checkpoint(
|
||||
state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors"
|
||||
state_dict_split = split_torch_state_dict_into_shards(
|
||||
state_dict,
|
||||
filename_pattern="model.safetensors",
|
||||
max_shard_size=max_shard_size,
|
||||
)
|
||||
index = None
|
||||
if state_dict_split.is_sharded:
|
||||
index = {
|
||||
"metadata": state_dict_split.metadata,
|
||||
"weight_map": state_dict_split.tensor_to_filename,
|
||||
}
|
||||
shards = state_dict_split.filename_to_tensors
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
for shard_file, shard in shards.items():
|
||||
save_file(
|
||||
|
@ -16,10 +16,12 @@ from transformers.models.auto import modeling_auto
|
||||
from huggingface_hub import hf_hub_download, HfApi
|
||||
from typing import Optional, List, Dict
|
||||
from pathlib import Path
|
||||
import transformers
|
||||
|
||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||
from text_generation_server.models.model import Model
|
||||
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
|
||||
|
||||
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
|
||||
from text_generation_server.models.custom_modeling.mpt_modeling import (
|
||||
MPTForCausalLM,
|
||||
@ -178,6 +180,14 @@ except ImportError as e:
|
||||
if MAMBA_AVAILABLE:
|
||||
__all__.append(Mamba)
|
||||
|
||||
FLASH_TRANSFORMERS_BACKEND = True
|
||||
try:
|
||||
from text_generation_server.models.transformers_flash_causal_lm import (
|
||||
TransformersFlashCausalLM,
|
||||
)
|
||||
except ImportError:
|
||||
FLASH_TRANSFORMERS_BACKEND = False
|
||||
|
||||
|
||||
class ModelType(enum.Enum):
|
||||
DEEPSEEK_V2 = {
|
||||
@ -381,6 +391,21 @@ def get_model(
|
||||
)
|
||||
model_type = config_dict.get("model_type", None)
|
||||
|
||||
transformers_causal_lm_class = CausalLM
|
||||
|
||||
# Fast transformers path
|
||||
transformers_model_class = getattr(
|
||||
transformers,
|
||||
modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.get(model_type, ""),
|
||||
None,
|
||||
)
|
||||
if (
|
||||
FLASH_TRANSFORMERS_BACKEND
|
||||
and transformers_model_class is not None
|
||||
and transformers_model_class._supports_flex_attn
|
||||
):
|
||||
transformers_causal_lm_class = TransformersFlashCausalLM
|
||||
|
||||
quantization_config = config_dict.get("quantization_config", None)
|
||||
if quantization_config is None:
|
||||
quantization_config = config_dict.get("compression_config", None)
|
||||
@ -624,7 +649,7 @@ def get_model(
|
||||
FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2")
|
||||
)
|
||||
else:
|
||||
return CausalLM.fallback(
|
||||
return transformers_causal_lm_class.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -683,7 +708,7 @@ def get_model(
|
||||
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
|
||||
)
|
||||
else:
|
||||
return CausalLM.fallback(
|
||||
return transformers_causal_lm_class.fallback(
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
@ -731,7 +756,7 @@ def get_model(
|
||||
except RuntimeError as e:
|
||||
# Lots of legacy models with various weight names.
|
||||
log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}")
|
||||
return CausalLM.fallback(
|
||||
return transformers_causal_lm_class.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -742,7 +767,7 @@ def get_model(
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
|
||||
else:
|
||||
return CausalLM.fallback(
|
||||
return transformers_causal_lm_class.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -767,7 +792,7 @@ def get_model(
|
||||
except RuntimeError as e:
|
||||
# Lots of legacy models with various weight names.
|
||||
log_master(logger.warning, f"Couldn't load flash gptj variant: {e}")
|
||||
return CausalLM.fallback(
|
||||
return transformers_causal_lm_class.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -778,7 +803,7 @@ def get_model(
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-J"))
|
||||
else:
|
||||
return CausalLM.fallback(
|
||||
return transformers_causal_lm_class.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -815,7 +840,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
else:
|
||||
return CausalLM.fallback(
|
||||
return transformers_causal_lm_class.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -838,7 +863,7 @@ def get_model(
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
else:
|
||||
return CausalLM.fallback(
|
||||
return transformers_causal_lm_class.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -862,7 +887,7 @@ def get_model(
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
else:
|
||||
return CausalLM.fallback(
|
||||
return transformers_causal_lm_class.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -911,7 +936,7 @@ def get_model(
|
||||
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}")
|
||||
)
|
||||
else:
|
||||
return CausalLM.fallback(
|
||||
return transformers_causal_lm_class.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -937,7 +962,7 @@ def get_model(
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
|
||||
else:
|
||||
return CausalLM.fallback(
|
||||
return transformers_causal_lm_class.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -963,7 +988,7 @@ def get_model(
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
|
||||
else:
|
||||
return CausalLM.fallback(
|
||||
return transformers_causal_lm_class.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -988,7 +1013,7 @@ def get_model(
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
|
||||
else:
|
||||
return CausalLM.fallback(
|
||||
return transformers_causal_lm_class.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -1016,7 +1041,7 @@ def get_model(
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
|
||||
else:
|
||||
return CausalLM.fallback(
|
||||
return transformers_causal_lm_class.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -1066,7 +1091,7 @@ def get_model(
|
||||
config_class=RWConfig,
|
||||
)
|
||||
else:
|
||||
return CausalLM.fallback(
|
||||
return transformers_causal_lm_class.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -1091,7 +1116,7 @@ def get_model(
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
|
||||
else:
|
||||
return CausalLM.fallback(
|
||||
return transformers_causal_lm_class.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -1116,7 +1141,7 @@ def get_model(
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
|
||||
else:
|
||||
return CausalLM.fallback(
|
||||
return transformers_causal_lm_class.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -1143,7 +1168,7 @@ def get_model(
|
||||
FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
|
||||
)
|
||||
else:
|
||||
return CausalLM.fallback(
|
||||
return transformers_causal_lm_class.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -1168,7 +1193,7 @@ def get_model(
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
|
||||
else:
|
||||
return CausalLM.fallback(
|
||||
return transformers_causal_lm_class.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -1329,7 +1354,7 @@ def get_model(
|
||||
elif quantize == "exl2":
|
||||
raise NotImplementedError("exl2 quantization is not supported for AutoModel")
|
||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||
return CausalLM.fallback(
|
||||
return transformers_causal_lm_class.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -1350,7 +1375,7 @@ def get_model(
|
||||
auto_map = config_dict.get("auto_map", None)
|
||||
if trust_remote_code and auto_map is not None:
|
||||
if "AutoModelForCausalLM" in auto_map.keys():
|
||||
return CausalLM.fallback(
|
||||
return transformers_causal_lm_class.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
|
@ -0,0 +1,266 @@
|
||||
import math
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
||||
import transformers.modeling_utils
|
||||
|
||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||
from text_generation_server.utils import initialize_torch_distributed
|
||||
|
||||
from text_generation_server.layers.attention import paged_attention, attention, Seqlen
|
||||
from text_generation_server.layers.attention.kv_cache import KVScales, KVCache
|
||||
from text_generation_server.models.globals import ATTENTION
|
||||
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
def tgi_flash_attention_forward(
|
||||
module,
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor], # This is a positional arg in Transformers
|
||||
kv_cache: List[KVCache],
|
||||
kv_head_mapping: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
seqlen: Seqlen,
|
||||
block_tables: torch.Tensor,
|
||||
max_s: int,
|
||||
kv_scales: KVScales,
|
||||
softmax_scale: Optional[float] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
softcap: Optional[float] = None,
|
||||
**kwargs, # This is needed to "absorb" other args passed by Transformers modeling
|
||||
):
|
||||
|
||||
kv_cache = kv_cache[module.layer_idx]
|
||||
|
||||
query_states = query_states.transpose(1, 2).squeeze(dim=0)
|
||||
key_states = key_states.transpose(1, 2).squeeze(dim=0)
|
||||
value_states = value_states.transpose(1, 2).squeeze(dim=0)
|
||||
|
||||
# Take care of updating the cache in-place
|
||||
kv_cache.store(key=key_states, value=value_states, slots=slots, kv_scales=kv_scales)
|
||||
|
||||
_, num_heads, head_dim = query_states.shape
|
||||
softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
|
||||
sliding_window = -1 if sliding_window is None else sliding_window
|
||||
|
||||
if cu_seqlen_prefill is not None:
|
||||
attn_output = attention(
|
||||
query=query_states,
|
||||
key=key_states,
|
||||
value=value_states,
|
||||
kv_cache=kv_cache,
|
||||
kv_scales=kv_scales,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=softmax_scale,
|
||||
window_size_left=sliding_window,
|
||||
softcap=softcap,
|
||||
)
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query_states,
|
||||
kv_cache,
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
seqlen,
|
||||
max_s,
|
||||
kv_scales=kv_scales,
|
||||
softcap=softcap,
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(-1, num_heads * head_dim)
|
||||
|
||||
return attn_output, None
|
||||
|
||||
|
||||
transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward
|
||||
|
||||
|
||||
class TransformersFlashCausalLM(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
default_dtype=torch.float16,
|
||||
trust_remote_code: bool = False,
|
||||
tokenizer_class=AutoTokenizer,
|
||||
config_class=AutoConfig,
|
||||
kv_cache_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
self.quantize = quantize
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
|
||||
if speculator:
|
||||
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda:0")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device("xpu")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
raise ValueError(
|
||||
"Flash `Transformers` modeling backend is not available on cpu."
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
device_map="auto",
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
trust_remote_code=trust_remote_code,
|
||||
attn_implementation="tgi",
|
||||
tp_plan="auto" if world_size > 1 else None,
|
||||
)
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
if model.config.pad_token_id is not None:
|
||||
tokenizer.pad_token_id = model.config.pad_token_id
|
||||
elif model.config.eos_token_id is not None and isinstance(
|
||||
model.config.eos_token_id, int
|
||||
):
|
||||
tokenizer.pad_token_id = model.config.eos_token_id
|
||||
elif tokenizer.eos_token_id is not None:
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
else:
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
|
||||
self.num_layers = model.config.num_hidden_layers
|
||||
self.num_heads = model.config.num_attention_heads // self.process_group.size()
|
||||
self.num_kv_heads = model.config.num_key_value_heads
|
||||
self.num_kv_heads = (
|
||||
self.num_kv_heads // self.process_group.size()
|
||||
if self.num_kv_heads > 1
|
||||
else self.num_kv_heads
|
||||
)
|
||||
self.head_size = model.config.hidden_size // model.config.num_attention_heads
|
||||
|
||||
self.cuda_graphs = {}
|
||||
self.kv_cache = []
|
||||
self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
|
||||
|
||||
if ATTENTION == "flashinfer":
|
||||
from text_generation_server.layers.attention.flashinfer import (
|
||||
create_prefill_state,
|
||||
create_decode_state,
|
||||
create_prefill_with_paged_kv_state,
|
||||
)
|
||||
|
||||
self.prefill_state = create_prefill_state(device=device)
|
||||
self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state(
|
||||
device=device
|
||||
)
|
||||
|
||||
self.decode_state = create_decode_state(
|
||||
device=device,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
)
|
||||
|
||||
self.num_groups = self.num_heads // self.num_kv_heads
|
||||
|
||||
# Those will never change and will be used in the forwards
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_kv_heads, dtype=torch.int32, device=device
|
||||
).repeat_interleave(self.num_groups)
|
||||
# This means no scale
|
||||
self.kv_scales = KVScales(
|
||||
torch.tensor(1.0, device=device),
|
||||
torch.tensor(1.0, device=device),
|
||||
)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
# Skip FlashCausalLM init.
|
||||
super(FlashCausalLM, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
# Monkey patch of `self.model.forward` to match `FlashCausalLM`. It avoids duplicating a lot of code
|
||||
# We first copy the original model.forward because we still need it in the monkey patch
|
||||
self.model.original_forward = self.model.forward
|
||||
self.model.forward = self._model_forward
|
||||
|
||||
@classmethod
|
||||
def fallback(
|
||||
cls,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
return cls(
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
def _model_forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[KVCache],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
lm_head_indices: Optional[torch.Tensor],
|
||||
prefill_cache_indices=None, # not used, but passed to match original signature
|
||||
adapter_data=None, # not supported, but passed to match original signature
|
||||
):
|
||||
hidden_states = self.model.model.forward(
|
||||
input_ids=input_ids.unsqueeze(0), # expand dim to fit Transformers
|
||||
position_ids=position_ids.unsqueeze(0), # expand dim to fit Transformers
|
||||
past_key_values=None, # we use self.kv_cache instead of transformers cache object
|
||||
use_cache=False, # we use self.kv_cache instead of transformers cache object
|
||||
return_dict=True,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
max_s=max_s,
|
||||
kv_head_mapping=self.kv_head_mapping,
|
||||
kv_scales=self.kv_scales,
|
||||
)[0].squeeze(dim=0)
|
||||
|
||||
# And compute logits from the lm_head, slicing correctly the indices
|
||||
# NOTE: some logits post-processing (e.g. in gemma2) may be absent here with the split of the modules
|
||||
# To update with full Transformers support asap
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits = self.model.lm_head.forward(hidden_states)
|
||||
|
||||
return logits, None
|
@ -5,13 +5,12 @@ import torch
|
||||
from typing import List, Optional, DefaultDict
|
||||
|
||||
from loguru import logger
|
||||
from typing import Dict, Union
|
||||
from typing import Dict
|
||||
from text_generation_server.pb.generate_pb2 import GrammarType
|
||||
|
||||
from outlines.fsm.guide import RegexGuide
|
||||
|
||||
from transformers import (
|
||||
LogitsWarper,
|
||||
LogitsProcessor,
|
||||
PreTrainedTokenizerBase,
|
||||
TemperatureLogitsWarper,
|
||||
@ -219,7 +218,7 @@ class HeterogeneousTemperatureLogitsWarper:
|
||||
return None
|
||||
|
||||
|
||||
class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
||||
class HeterogeneousTopPLogitsWarper(LogitsProcessor):
|
||||
"""
|
||||
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
|
||||
This version allows for a separate value for each sample and runs inplace when possible.
|
||||
@ -278,7 +277,7 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
||||
return None
|
||||
|
||||
|
||||
class HeterogeneousTopKLogitsWarper(LogitsWarper):
|
||||
class HeterogeneousTopKLogitsWarper(LogitsProcessor):
|
||||
r"""
|
||||
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
|
||||
This version allows for a separate value for each sample and runs inplace when possible.
|
||||
@ -359,7 +358,7 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper):
|
||||
return None
|
||||
|
||||
|
||||
class HeterogeneousTypicalLogitsWarper(LogitsWarper):
|
||||
class HeterogeneousTypicalLogitsWarper(LogitsProcessor):
|
||||
r"""
|
||||
[`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language
|
||||
Generation](https://arxiv.org/abs/2202.00666) for more information.
|
||||
@ -453,13 +452,13 @@ class HeterogeneousProcessorWrapper(LogitsProcessor):
|
||||
r"""
|
||||
A wrapper for logit warpers or processors without heterogeneous parameter support.
|
||||
Args:
|
||||
processors (`Dict[int, Union[LogitsProcessor, LogitsWarper]]`):
|
||||
processors (`Dict[int, LogitsProcessor]`):
|
||||
A mapping of sample indices to logit warpers or processors, to be run sequentially.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
processors: Dict[int, Union[LogitsProcessor, LogitsWarper]],
|
||||
processors: Dict[int, LogitsProcessor],
|
||||
):
|
||||
self.processors = processors
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user