mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
More cleanup.
This commit is contained in:
parent
7d96b1a103
commit
ce913b874b
@ -12,16 +12,26 @@ from pathlib import Path
|
||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||
from text_generation_server.models.model import Model
|
||||
from text_generation_server.models.causal_lm import CausalLM
|
||||
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
|
||||
from text_generation_server.models.custom_modeling.mpt_modeling import (
|
||||
MPTForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.bloom import BLOOMSharded
|
||||
from text_generation_server.models.mpt import MPTSharded
|
||||
from text_generation_server.models.custom_modeling.bloom_modeling import (
|
||||
BloomForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
||||
from text_generation_server.models.rw import RW
|
||||
from text_generation_server.models.opt import OPTSharded
|
||||
from text_generation_server.models.galactica import GalacticaSharded
|
||||
from text_generation_server.models.santacoder import SantaCoder
|
||||
from text_generation_server.models.t5 import T5Sharded
|
||||
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
||||
from text_generation_server.models.phi import Phi
|
||||
from text_generation_server.models.custom_modeling.neox_modeling import (
|
||||
GPTNeoxForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.phi_modeling import (
|
||||
PhiConfig,
|
||||
PhiForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.t5_modeling import (
|
||||
T5ForConditionalGeneration,
|
||||
)
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
@ -41,9 +51,6 @@ __all__ = [
|
||||
"CausalLM",
|
||||
"GalacticaSharded",
|
||||
"Seq2SeqLM",
|
||||
"SantaCoder",
|
||||
"OPTSharded",
|
||||
"T5Sharded",
|
||||
"get_model",
|
||||
]
|
||||
|
||||
@ -457,8 +464,9 @@ def get_model(
|
||||
|
||||
if model_id.startswith("facebook/galactica"):
|
||||
return GalacticaSharded(
|
||||
model_id,
|
||||
revision,
|
||||
model_id=model_id,
|
||||
model_class=OPTForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
@ -487,9 +495,9 @@ def get_model(
|
||||
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
|
||||
)
|
||||
else:
|
||||
return SantaCoder(
|
||||
model_id,
|
||||
revision,
|
||||
return CausalLM.fallback(
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
@ -498,17 +506,19 @@ def get_model(
|
||||
|
||||
if model_type == BLOOM:
|
||||
return BLOOMSharded(
|
||||
model_id,
|
||||
revision,
|
||||
model_id=model_id,
|
||||
model_class=BloomForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif model_type == MPT:
|
||||
return MPTSharded(
|
||||
model_id,
|
||||
revision,
|
||||
return CausalLM(
|
||||
model_id=model_id,
|
||||
model_class=MPTForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
@ -530,7 +540,7 @@ def get_model(
|
||||
except RuntimeError as e:
|
||||
# Lots of legacy models with various weight names.
|
||||
logger.warning(f"Couldn't load flash gpt2 variant: {e}")
|
||||
return CausalLM(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -541,7 +551,7 @@ def get_model(
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
|
||||
else:
|
||||
return CausalLM(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -562,16 +572,17 @@ def get_model(
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif sharded:
|
||||
return GPTNeoxSharded(
|
||||
model_id,
|
||||
revision,
|
||||
return CausalLM(
|
||||
model_id=model_id,
|
||||
model_class=GPTNeoxForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
else:
|
||||
return CausalLM(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -593,7 +604,7 @@ def get_model(
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
else:
|
||||
return CausalLM(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -608,9 +619,11 @@ def get_model(
|
||||
"Legacy phi-msft is not supported with Flash Attention"
|
||||
)
|
||||
else:
|
||||
return Phi(
|
||||
model_id,
|
||||
revision,
|
||||
return CausalLM(
|
||||
model_id=model_id,
|
||||
model_class=PhiForCausalLM,
|
||||
config_class=PhiConfig,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
@ -632,7 +645,7 @@ def get_model(
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
|
||||
else:
|
||||
return CausalLM(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -655,7 +668,7 @@ def get_model(
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
|
||||
else:
|
||||
return CausalLM(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -678,7 +691,7 @@ def get_model(
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
|
||||
else:
|
||||
return CausalLM(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -702,7 +715,7 @@ def get_model(
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
|
||||
else:
|
||||
return CausalLM(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -727,7 +740,7 @@ def get_model(
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
|
||||
else:
|
||||
return CausalLM(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -767,7 +780,7 @@ def get_model(
|
||||
config_class=RWConfig,
|
||||
)
|
||||
else:
|
||||
return RW(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -791,7 +804,7 @@ def get_model(
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
|
||||
else:
|
||||
return CausalLM(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -815,7 +828,7 @@ def get_model(
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
|
||||
else:
|
||||
return CausalLM(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -841,7 +854,7 @@ def get_model(
|
||||
FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
|
||||
)
|
||||
else:
|
||||
return CausalLM(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -865,7 +878,7 @@ def get_model(
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
|
||||
else:
|
||||
return CausalLM(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -875,9 +888,10 @@ def get_model(
|
||||
)
|
||||
|
||||
if model_type == OPT:
|
||||
return OPTSharded(
|
||||
model_id,
|
||||
revision,
|
||||
return CausalLM(
|
||||
model_id=model_id,
|
||||
model_class=OPTForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
@ -885,13 +899,20 @@ def get_model(
|
||||
)
|
||||
|
||||
if model_type == T5:
|
||||
return T5Sharded(
|
||||
model_id,
|
||||
revision,
|
||||
return Seq2SeqLM(
|
||||
model_id=model_id,
|
||||
model_class=T5ForConditionalGeneration,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
aliases={
|
||||
"shared.weight": [
|
||||
"encoder.embed_tokens.weight",
|
||||
"decoder.embed_tokens.weight",
|
||||
]
|
||||
},
|
||||
)
|
||||
if model_type == IDEFICS:
|
||||
if FLASH_ATTENTION:
|
||||
@ -967,7 +988,7 @@ def get_model(
|
||||
elif quantize == "exl2":
|
||||
raise NotImplementedError("exl2 quantization is not supported for AutoModel")
|
||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||
return CausalLM(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -976,7 +997,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
||||
return Seq2SeqLM(
|
||||
return Seq2SeqLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -988,7 +1009,7 @@ def get_model(
|
||||
auto_map = config_dict.get("auto_map", None)
|
||||
if trust_remote_code and auto_map is not None:
|
||||
if "AutoModelForCausalLM" in auto_map.keys():
|
||||
return CausalLM(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -997,7 +1018,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if "AutoModelForSeq2SeqLM" in auto_map.keys():
|
||||
return Seq2SeqLM(
|
||||
return Seq2SeqLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
|
@ -4,22 +4,12 @@ import torch.distributed
|
||||
from typing import Optional, Type
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoConfig,
|
||||
PreTrainedTokenizerBase,
|
||||
)
|
||||
|
||||
from text_generation_server.models.custom_modeling.bloom_modeling import (
|
||||
BloomForCausalLM,
|
||||
)
|
||||
from text_generation_server.models import CausalLM
|
||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
|
||||
|
||||
class BloomCausalLMBatch(CausalLMBatch):
|
||||
@ -37,69 +27,6 @@ class BloomCausalLMBatch(CausalLMBatch):
|
||||
|
||||
|
||||
class BLOOMSharded(CausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
slow_but_exact=False,
|
||||
tp_parallel=True,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config.pad_token_id = 3
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(
|
||||
filenames,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
process_group=self.process_group,
|
||||
prefix="transformer",
|
||||
)
|
||||
if config.quantize in ["gptq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = BloomForCausalLM(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[CausalLMBatch]:
|
||||
return BloomCausalLMBatch
|
||||
|
@ -1,11 +1,22 @@
|
||||
import torch
|
||||
import time
|
||||
import torch.distributed
|
||||
|
||||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
AutoModelForCausalLM,
|
||||
PreTrainedTokenizerBase,
|
||||
)
|
||||
from typing import Optional, Tuple, List, Type, Dict
|
||||
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.models import Model
|
||||
from text_generation_server.utils.chunks import concat_text_chunks
|
||||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
@ -482,6 +493,67 @@ class CausalLM(Model):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
model_class,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
tokenizer_class=AutoTokenizer,
|
||||
config_class=AutoConfig,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = config_class.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
if tokenizer.pad_token_id is None:
|
||||
tokenizer.pad_token_id = config.pad_token_id
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(
|
||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
||||
)
|
||||
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = model_class(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super().__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def fallback(
|
||||
cls,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
@ -537,7 +609,7 @@ class CausalLM(Model):
|
||||
else:
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
|
||||
super(CausalLM, self).__init__(
|
||||
super(CausalLM, cls).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
|
@ -1,89 +0,0 @@
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoConfig,
|
||||
)
|
||||
from text_generation_server.models import CausalLM
|
||||
from text_generation_server.models.custom_modeling.neox_modeling import (
|
||||
GPTNeoxForCausalLM,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
|
||||
|
||||
class GPTNeoxSharded(CausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(
|
||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
||||
)
|
||||
if config.quantize in ["gptq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = GPTNeoxForCausalLM(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
):
|
||||
outputs, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
return outputs.logits, speculative_logits, outputs.past_key_values
|
@ -1,105 +0,0 @@
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase
|
||||
from huggingface_hub import hf_hub_download
|
||||
import json
|
||||
|
||||
from text_generation_server.models import CausalLM
|
||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.custom_modeling.mpt_modeling import (
|
||||
MPTForCausalLM,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class MPTCausalLMBatch(CausalLMBatch):
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "CausalLMBatch":
|
||||
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
|
||||
batch.keys_head_dim_last = False
|
||||
return batch
|
||||
|
||||
|
||||
class MPTSharded(CausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# If model_id is a local path, load the file directly
|
||||
local_path = Path(model_id, "config.json")
|
||||
if local_path.exists():
|
||||
filename = str(local_path.resolve())
|
||||
else:
|
||||
filename = hf_hub_download(
|
||||
model_id, revision=revision, filename="config.json"
|
||||
)
|
||||
with open(filename, "r") as f:
|
||||
config = json.load(f)
|
||||
config = PretrainedConfig(**config)
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
if config.quantize in ["gptq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
config.quantize = quantize
|
||||
model = MPTForCausalLM(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[CausalLMBatch]:
|
||||
return MPTCausalLMBatch
|
@ -1,69 +0,0 @@
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.models import CausalLM
|
||||
from text_generation_server.models.custom_modeling.phi_modeling import (
|
||||
PhiConfig,
|
||||
PhiForCausalLM,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
|
||||
|
||||
class Phi(CausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, _rank, _world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config = PhiConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
tokenizer.bos_token_id = config.bos_token_id
|
||||
tokenizer.eos_token_id = config.eos_token_id
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
model = PhiForCausalLM(config, weights)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
@ -1,84 +0,0 @@
|
||||
import torch
|
||||
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from text_generation_server.models import CausalLM
|
||||
|
||||
|
||||
class RW(CausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
if speculator:
|
||||
raise RuntimeError("Medusa decoding is not enabled for AutoModel")
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
device_map=(
|
||||
"auto"
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||
else None
|
||||
),
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
||||
model = model.cuda()
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
if model.config.pad_token_id is not None:
|
||||
tokenizer.pad_token_id = model.config.pad_token_id
|
||||
elif model.config.eos_token_id is not None:
|
||||
tokenizer.pad_token_id = model.config.eos_token_id
|
||||
elif tokenizer.eos_token_id is not None:
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
else:
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
|
||||
super(CausalLM, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
):
|
||||
# Model Forward
|
||||
outputs, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
return outputs.logits, speculative_logits, outputs.past_key_values
|
@ -1,77 +0,0 @@
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from typing import Optional, List
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
from text_generation_server.models import CausalLM
|
||||
|
||||
FIM_PREFIX = "<fim-prefix>"
|
||||
FIM_MIDDLE = "<fim-middle>"
|
||||
FIM_SUFFIX = "<fim-suffix>"
|
||||
FIM_PAD = "<fim-pad>"
|
||||
EOD = "<|endoftext|>"
|
||||
|
||||
|
||||
class SantaCoder(CausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
tokenizer.add_special_tokens(
|
||||
{
|
||||
"additional_special_tokens": [
|
||||
EOD,
|
||||
FIM_PREFIX,
|
||||
FIM_MIDDLE,
|
||||
FIM_SUFFIX,
|
||||
FIM_PAD,
|
||||
],
|
||||
"pad_token": EOD,
|
||||
}
|
||||
)
|
||||
with device:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
super(CausalLM, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def decode(self, generated_ids: List[int]) -> str:
|
||||
# Do not skip special tokens as they are used for custom parsing rules of the generated text
|
||||
return self.tokenizer.decode(
|
||||
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
|
||||
)
|
@ -1,11 +1,22 @@
|
||||
import torch
|
||||
import torch.distributed
|
||||
import time
|
||||
|
||||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForSeq2SeqLM,
|
||||
PreTrainedTokenizerBase,
|
||||
AutoConfig,
|
||||
)
|
||||
from typing import Optional, Tuple, List, Type, Dict
|
||||
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.chunks import concat_text_chunks
|
||||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
from text_generation_server.models import Model
|
||||
@ -531,6 +542,71 @@ class Seq2SeqLM(Model):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
model_class,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
config_class=AutoConfig,
|
||||
tokenizer_class=AutoTokenizer,
|
||||
aliases=None,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
config = config_class.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
tokenizer.bos_token_id = config.decoder_start_token_id
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(
|
||||
filenames,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
process_group=self.process_group,
|
||||
aliases=aliases,
|
||||
)
|
||||
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = model_class(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super().__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def fallback(
|
||||
cls,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
@ -574,7 +650,7 @@ class Seq2SeqLM(Model):
|
||||
)
|
||||
tokenizer.bos_token_id = model.config.decoder_start_token_id
|
||||
|
||||
super(Seq2SeqLM, self).__init__(
|
||||
super(Seq2SeqLM, cls).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
|
@ -1,14 +1,17 @@
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoConfig,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
|
||||
from text_generation_server.models import CausalLM
|
||||
|
||||
from text_generation_server.models import Seq2SeqLM
|
||||
from text_generation_server.models.custom_modeling.t5_modeling import (
|
||||
T5ForConditionalGeneration,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
@ -16,15 +19,19 @@ from text_generation_server.utils import (
|
||||
)
|
||||
|
||||
|
||||
class OPTSharded(CausalLM):
|
||||
class ShardedSeq2SeqLM(Seq2SeqLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
model_class,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
config_class=AutoConfig,
|
||||
tokenizer_class=AutoTokenizer,
|
||||
aliases=None,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
@ -34,35 +41,37 @@ class OPTSharded(CausalLM):
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
config = config_class.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
tokenizer.pad_token_id = config.pad_token_id
|
||||
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
tokenizer.bos_token_id = config.decoder_start_token_id
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(
|
||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
||||
filenames,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
process_group=self.process_group,
|
||||
aliases=aliases,
|
||||
)
|
||||
if config.quantize in ["gptq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = OPTForCausalLM(config, weights)
|
||||
model = model_class(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
super(Seq2SeqLM, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
@ -72,15 +81,3 @@ class OPTSharded(CausalLM):
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
):
|
||||
outputs, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
return outputs.logits, speculative_logits, outputs.past_key_values
|
@ -1,115 +0,0 @@
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoConfig,
|
||||
)
|
||||
|
||||
from text_generation_server.models import Seq2SeqLM
|
||||
from text_generation_server.models.custom_modeling.t5_modeling import (
|
||||
T5ForConditionalGeneration,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
|
||||
|
||||
class T5Sharded(Seq2SeqLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
tokenizer.bos_token_id = config.decoder_start_token_id
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(
|
||||
filenames,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
process_group=self.process_group,
|
||||
aliases={
|
||||
"shared.weight": [
|
||||
"encoder.embed_tokens.weight",
|
||||
"decoder.embed_tokens.weight",
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
model = T5ForConditionalGeneration(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(Seq2SeqLM, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask: Optional,
|
||||
encoder_last_hidden_state: Optional,
|
||||
past_key_values: Optional = None,
|
||||
) -> Tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
|
||||
]:
|
||||
# Model Forward
|
||||
outputs, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
encoder_outputs=encoder_last_hidden_state,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
return (
|
||||
outputs.logits,
|
||||
speculative_logits,
|
||||
outputs.encoder_last_hidden_state,
|
||||
outputs.past_key_values,
|
||||
)
|
Loading…
Reference in New Issue
Block a user