Flash Transformers modeling backend support (#2913)

* add transformers_flash

* inits

* switch version to make it work

* Update Makefile-flash-att-v2

* Update Makefile-flash-att-v2

* Update Makefile-flash-att-v2

* Update Makefile-flash-att-v2

* Update Makefile-flash-att-v2

* Update Makefile-flash-att-v2

* runnable version

* working

* push change

* fix high dim

* init

* default

* latest transformers changes

* revert

* simplify check

* remove flag

* improve type hints + required args

* Update based on transformers PR

* small fix

* Remove Warpers for Processor

* fix compatibility version issue

* raise error if needed

* Simplify with monkey patch

* revert + style + minor improvements

* update comment

* device check

* move the import to avoid device issue

* Update __init__.py

* check for non-native models

* oupsi

---------

Co-authored-by: System administrator <root@ip-10-90-0-159.ec2.internal>
This commit is contained in:
Cyril Vallez 2025-01-21 10:01:51 +01:00 committed by GitHub
parent 447a5b2f87
commit b980848abf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 330 additions and 31 deletions

View File

@ -956,15 +956,24 @@ def quantize(
pack(model, quantizers, bits, groupsize) pack(model, quantizers, bits, groupsize)
from safetensors.torch import save_file from safetensors.torch import save_file
from transformers.modeling_utils import shard_checkpoint from huggingface_hub import split_torch_state_dict_into_shards
state_dict = model.state_dict() state_dict = model.state_dict()
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()} state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
max_shard_size = "10GB" max_shard_size = "10GB"
shards, index = shard_checkpoint( state_dict_split = split_torch_state_dict_into_shards(
state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors" state_dict,
filename_pattern="model.safetensors",
max_shard_size=max_shard_size,
) )
index = None
if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
shards = state_dict_split.filename_to_tensors
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
for shard_file, shard in shards.items(): for shard_file, shard in shards.items():
save_file( save_file(

View File

@ -16,10 +16,12 @@ from transformers.models.auto import modeling_auto
from huggingface_hub import hf_hub_download, HfApi from huggingface_hub import hf_hub_download, HfApi
from typing import Optional, List, Dict from typing import Optional, List, Dict
from pathlib import Path from pathlib import Path
import transformers
from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from text_generation_server.models.custom_modeling.mpt_modeling import ( from text_generation_server.models.custom_modeling.mpt_modeling import (
MPTForCausalLM, MPTForCausalLM,
@ -178,6 +180,14 @@ except ImportError as e:
if MAMBA_AVAILABLE: if MAMBA_AVAILABLE:
__all__.append(Mamba) __all__.append(Mamba)
FLASH_TRANSFORMERS_BACKEND = True
try:
from text_generation_server.models.transformers_flash_causal_lm import (
TransformersFlashCausalLM,
)
except ImportError:
FLASH_TRANSFORMERS_BACKEND = False
class ModelType(enum.Enum): class ModelType(enum.Enum):
DEEPSEEK_V2 = { DEEPSEEK_V2 = {
@ -381,6 +391,21 @@ def get_model(
) )
model_type = config_dict.get("model_type", None) model_type = config_dict.get("model_type", None)
transformers_causal_lm_class = CausalLM
# Fast transformers path
transformers_model_class = getattr(
transformers,
modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.get(model_type, ""),
None,
)
if (
FLASH_TRANSFORMERS_BACKEND
and transformers_model_class is not None
and transformers_model_class._supports_flex_attn
):
transformers_causal_lm_class = TransformersFlashCausalLM
quantization_config = config_dict.get("quantization_config", None) quantization_config = config_dict.get("quantization_config", None)
if quantization_config is None: if quantization_config is None:
quantization_config = config_dict.get("compression_config", None) quantization_config = config_dict.get("compression_config", None)
@ -624,7 +649,7 @@ def get_model(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2") FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2")
) )
else: else:
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -683,7 +708,7 @@ def get_model(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder") FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
) )
else: else:
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id=model_id, model_id=model_id,
revision=revision, revision=revision,
quantize=quantize, quantize=quantize,
@ -731,7 +756,7 @@ def get_model(
except RuntimeError as e: except RuntimeError as e:
# Lots of legacy models with various weight names. # Lots of legacy models with various weight names.
log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}") log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}")
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -742,7 +767,7 @@ def get_model(
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
else: else:
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -767,7 +792,7 @@ def get_model(
except RuntimeError as e: except RuntimeError as e:
# Lots of legacy models with various weight names. # Lots of legacy models with various weight names.
log_master(logger.warning, f"Couldn't load flash gptj variant: {e}") log_master(logger.warning, f"Couldn't load flash gptj variant: {e}")
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -778,7 +803,7 @@ def get_model(
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-J")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-J"))
else: else:
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -815,7 +840,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
else: else:
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -838,7 +863,7 @@ def get_model(
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
else: else:
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -862,7 +887,7 @@ def get_model(
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
else: else:
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -911,7 +936,7 @@ def get_model(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}") FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}")
) )
else: else:
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -937,7 +962,7 @@ def get_model(
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
else: else:
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -963,7 +988,7 @@ def get_model(
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
else: else:
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -988,7 +1013,7 @@ def get_model(
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
else: else:
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -1016,7 +1041,7 @@ def get_model(
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
else: else:
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -1066,7 +1091,7 @@ def get_model(
config_class=RWConfig, config_class=RWConfig,
) )
else: else:
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -1091,7 +1116,7 @@ def get_model(
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
else: else:
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -1116,7 +1141,7 @@ def get_model(
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
else: else:
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -1143,7 +1168,7 @@ def get_model(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2") FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
) )
else: else:
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -1168,7 +1193,7 @@ def get_model(
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
else: else:
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -1329,7 +1354,7 @@ def get_model(
elif quantize == "exl2": elif quantize == "exl2":
raise NotImplementedError("exl2 quantization is not supported for AutoModel") raise NotImplementedError("exl2 quantization is not supported for AutoModel")
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -1350,7 +1375,7 @@ def get_model(
auto_map = config_dict.get("auto_map", None) auto_map = config_dict.get("auto_map", None)
if trust_remote_code and auto_map is not None: if trust_remote_code and auto_map is not None:
if "AutoModelForCausalLM" in auto_map.keys(): if "AutoModelForCausalLM" in auto_map.keys():
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,

View File

@ -0,0 +1,266 @@
import math
from typing import List, Optional
import torch
from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import transformers.modeling_utils
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.utils import initialize_torch_distributed
from text_generation_server.layers.attention import paged_attention, attention, Seqlen
from text_generation_server.layers.attention.kv_cache import KVScales, KVCache
from text_generation_server.models.globals import ATTENTION
tracer = trace.get_tracer(__name__)
def tgi_flash_attention_forward(
module,
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask: Optional[torch.Tensor], # This is a positional arg in Transformers
kv_cache: List[KVCache],
kv_head_mapping: torch.Tensor,
slots: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
seqlen: Seqlen,
block_tables: torch.Tensor,
max_s: int,
kv_scales: KVScales,
softmax_scale: Optional[float] = None,
sliding_window: Optional[int] = None,
softcap: Optional[float] = None,
**kwargs, # This is needed to "absorb" other args passed by Transformers modeling
):
kv_cache = kv_cache[module.layer_idx]
query_states = query_states.transpose(1, 2).squeeze(dim=0)
key_states = key_states.transpose(1, 2).squeeze(dim=0)
value_states = value_states.transpose(1, 2).squeeze(dim=0)
# Take care of updating the cache in-place
kv_cache.store(key=key_states, value=value_states, slots=slots, kv_scales=kv_scales)
_, num_heads, head_dim = query_states.shape
softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
sliding_window = -1 if sliding_window is None else sliding_window
if cu_seqlen_prefill is not None:
attn_output = attention(
query=query_states,
key=key_states,
value=value_states,
kv_cache=kv_cache,
kv_scales=kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=softmax_scale,
window_size_left=sliding_window,
softcap=softcap,
)
else:
attn_output = paged_attention(
query_states,
kv_cache,
kv_head_mapping,
softmax_scale,
block_tables,
seqlen,
max_s,
kv_scales=kv_scales,
softcap=softcap,
)
attn_output = attn_output.view(-1, num_heads * head_dim)
return attn_output, None
transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward
class TransformersFlashCausalLM(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
default_dtype=torch.float16,
trust_remote_code: bool = False,
tokenizer_class=AutoTokenizer,
config_class=AutoConfig,
kv_cache_dtype: Optional[torch.dtype] = None,
):
self.quantize = quantize
self.process_group, rank, world_size = initialize_torch_distributed()
if speculator:
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
if torch.cuda.is_available():
device = torch.device("cuda:0")
dtype = torch.float16 if dtype is None else dtype
elif hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device("xpu")
dtype = torch.float16 if dtype is None else dtype
else:
raise ValueError(
"Flash `Transformers` modeling backend is not available on cpu."
)
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
device_map="auto",
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
attn_implementation="tgi",
tp_plan="auto" if world_size > 1 else None,
)
if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None:
tokenizer.pad_token_id = model.config.pad_token_id
elif model.config.eos_token_id is not None and isinstance(
model.config.eos_token_id, int
):
tokenizer.pad_token_id = model.config.eos_token_id
elif tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
self.num_layers = model.config.num_hidden_layers
self.num_heads = model.config.num_attention_heads // self.process_group.size()
self.num_kv_heads = model.config.num_key_value_heads
self.num_kv_heads = (
self.num_kv_heads // self.process_group.size()
if self.num_kv_heads > 1
else self.num_kv_heads
)
self.head_size = model.config.hidden_size // model.config.num_attention_heads
self.cuda_graphs = {}
self.kv_cache = []
self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flashinfer import (
create_prefill_state,
create_decode_state,
create_prefill_with_paged_kv_state,
)
self.prefill_state = create_prefill_state(device=device)
self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state(
device=device
)
self.decode_state = create_decode_state(
device=device,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
)
self.num_groups = self.num_heads // self.num_kv_heads
# Those will never change and will be used in the forwards
self.kv_head_mapping = torch.arange(
0, self.num_kv_heads, dtype=torch.int32, device=device
).repeat_interleave(self.num_groups)
# This means no scale
self.kv_scales = KVScales(
torch.tensor(1.0, device=device),
torch.tensor(1.0, device=device),
)
torch.distributed.barrier(group=self.process_group)
# Skip FlashCausalLM init.
super(FlashCausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
# Monkey patch of `self.model.forward` to match `FlashCausalLM`. It avoids duplicating a lot of code
# We first copy the original model.forward because we still need it in the monkey patch
self.model.original_forward = self.model.forward
self.model.forward = self._model_forward
@classmethod
def fallback(
cls,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
return cls(
model_id=model_id,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
def _model_forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[KVCache],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
lm_head_indices: Optional[torch.Tensor],
prefill_cache_indices=None, # not used, but passed to match original signature
adapter_data=None, # not supported, but passed to match original signature
):
hidden_states = self.model.model.forward(
input_ids=input_ids.unsqueeze(0), # expand dim to fit Transformers
position_ids=position_ids.unsqueeze(0), # expand dim to fit Transformers
past_key_values=None, # we use self.kv_cache instead of transformers cache object
use_cache=False, # we use self.kv_cache instead of transformers cache object
return_dict=True,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
kv_head_mapping=self.kv_head_mapping,
kv_scales=self.kv_scales,
)[0].squeeze(dim=0)
# And compute logits from the lm_head, slicing correctly the indices
# NOTE: some logits post-processing (e.g. in gemma2) may be absent here with the split of the modules
# To update with full Transformers support asap
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.model.lm_head.forward(hidden_states)
return logits, None

View File

@ -5,13 +5,12 @@ import torch
from typing import List, Optional, DefaultDict from typing import List, Optional, DefaultDict
from loguru import logger from loguru import logger
from typing import Dict, Union from typing import Dict
from text_generation_server.pb.generate_pb2 import GrammarType from text_generation_server.pb.generate_pb2 import GrammarType
from outlines.fsm.guide import RegexGuide from outlines.fsm.guide import RegexGuide
from transformers import ( from transformers import (
LogitsWarper,
LogitsProcessor, LogitsProcessor,
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
TemperatureLogitsWarper, TemperatureLogitsWarper,
@ -219,7 +218,7 @@ class HeterogeneousTemperatureLogitsWarper:
return None return None
class HeterogeneousTopPLogitsWarper(LogitsWarper): class HeterogeneousTopPLogitsWarper(LogitsProcessor):
""" """
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
This version allows for a separate value for each sample and runs inplace when possible. This version allows for a separate value for each sample and runs inplace when possible.
@ -278,7 +277,7 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
return None return None
class HeterogeneousTopKLogitsWarper(LogitsWarper): class HeterogeneousTopKLogitsWarper(LogitsProcessor):
r""" r"""
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
This version allows for a separate value for each sample and runs inplace when possible. This version allows for a separate value for each sample and runs inplace when possible.
@ -359,7 +358,7 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper):
return None return None
class HeterogeneousTypicalLogitsWarper(LogitsWarper): class HeterogeneousTypicalLogitsWarper(LogitsProcessor):
r""" r"""
[`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language [`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language
Generation](https://arxiv.org/abs/2202.00666) for more information. Generation](https://arxiv.org/abs/2202.00666) for more information.
@ -453,13 +452,13 @@ class HeterogeneousProcessorWrapper(LogitsProcessor):
r""" r"""
A wrapper for logit warpers or processors without heterogeneous parameter support. A wrapper for logit warpers or processors without heterogeneous parameter support.
Args: Args:
processors (`Dict[int, Union[LogitsProcessor, LogitsWarper]]`): processors (`Dict[int, LogitsProcessor]`):
A mapping of sample indices to logit warpers or processors, to be run sequentially. A mapping of sample indices to logit warpers or processors, to be run sequentially.
""" """
def __init__( def __init__(
self, self,
processors: Dict[int, Union[LogitsProcessor, LogitsWarper]], processors: Dict[int, LogitsProcessor],
): ):
self.processors = processors self.processors = processors