feat: add safetensors conversion

This commit is contained in:
OlivierDehaene 2023-02-13 16:13:04 +01:00
parent 9af454142a
commit 397a28080c
13 changed files with 466 additions and 329 deletions

View File

@ -60,8 +60,24 @@ def download_weights(
model_id: str,
revision: Optional[str] = None,
extension: str = ".safetensors",
convert: bool = False,
):
utils.download_weights(model_id, revision, extension)
try:
filenames = utils.weight_hub_files(model_id, revision, extension)
utils.download_weights(model_id, revision, filenames)
except utils.EntryNotFoundError as e:
if not convert or not extension == ".safetensors":
raise e
# Try to see if there are pytorch weights
pt_filenames = utils.weight_hub_files(model_id, revision, ".bin")
# Download pytorch weights
local_pt_files = utils.download_weights(model_id, revision, pt_filenames)
local_st_files = [
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
for p in local_pt_files
]
# Convert pytorch weights to safetensors
utils.convert_files(local_pt_files, local_st_files)
if __name__ == "__main__":

View File

@ -41,6 +41,15 @@ torch.set_grad_enabled(False)
def get_model(
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
) -> Model:
if model_id.startswith("facebook/galactica"):
if sharded:
return GalacticaSharded(model_id, revision, quantize=quantize)
else:
return Galactica(model_id, revision, quantize=quantize)
if "santacoder" in model_id:
return SantaCoder(model_id, revision, quantize)
config = AutoConfig.from_pretrained(model_id, revision=revision)
if config.model_type == "bloom":
@ -48,27 +57,22 @@ def get_model(
return BLOOMSharded(model_id, revision, quantize=quantize)
else:
return BLOOM(model_id, revision, quantize=quantize)
elif config.model_type == "gpt_neox":
if config.model_type == "gpt_neox":
if sharded:
return GPTNeoxSharded(model_id, revision, quantize=quantize)
else:
return GPTNeox(model_id, revision, quantize=quantize)
elif config.model_type == "t5":
if config.model_type == "t5":
if sharded:
return T5Sharded(model_id, revision, quantize=quantize)
else:
return Seq2SeqLM(model_id, revision, quantize=quantize)
elif model_id.startswith("facebook/galactica"):
if sharded:
return GalacticaSharded(model_id, revision, quantize=quantize)
else:
return Galactica(model_id, revision, quantize=quantize)
elif "santacoder" in model_id:
return SantaCoder(model_id, revision, quantize)
else:
if sharded:
raise ValueError("sharded is not supported for AutoModel")
try:
return CausalLM(model_id, revision, quantize=quantize)
except Exception:
return Seq2SeqLM(model_id, revision, quantize=quantize)
if sharded:
raise ValueError("sharded is not supported for AutoModel")
try:
return CausalLM(model_id, revision, quantize=quantize)
except Exception:
return Seq2SeqLM(model_id, revision, quantize=quantize)

View File

@ -23,7 +23,6 @@ from text_generation.pb import generate_pb2
from text_generation.utils import (
initialize_torch_distributed,
weight_files,
download_weights,
)
HAS_BITS_AND_BYTES = True
@ -80,14 +79,8 @@ class BLOOMSharded(BLOOM):
)
config.pad_token_id = 3
# Only download weights for small models
if self.master and model_id == "bigscience/bloom-560m":
download_weights(model_id, revision=revision, extension=".safetensors")
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
if not filenames:
raise ValueError("No safetensors weights found")
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)

View File

@ -26,7 +26,6 @@ from text_generation.utils import (
StoppingCriteria,
initialize_torch_distributed,
weight_files,
download_weights,
)
HAS_BITS_AND_BYTES = True
@ -172,14 +171,8 @@ class GalacticaSharded(Galactica):
)
tokenizer.pad_token_id = config.pad_token_id
# Only download weights for small models
if self.master and model_id == "facebook/galactica-125m":
download_weights(model_id, revision=revision, extension=".safetensors")
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
if not filenames:
raise ValueError("No safetensors weights found")
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)

View File

@ -20,7 +20,6 @@ from text_generation.models import CausalLM
from text_generation.utils import (
initialize_torch_distributed,
weight_files,
download_weights,
)
HAS_BITS_AND_BYTES = True
@ -69,14 +68,8 @@ class GPTNeoxSharded(GPTNeox):
model_id, revision=revision, tp_parallel=True
)
# Only master download weights
if self.master:
download_weights(model_id, revision=revision, extension=".safetensors")
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
if not filenames:
raise ValueError("No safetensors weights found")
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)

View File

@ -1,7 +1,7 @@
import torch
import torch.distributed
from typing import Optional, List, Tuple
from typing import Optional, List
from transformers import AutoTokenizer, AutoModelForCausalLM
from text_generation.models import CausalLM

View File

@ -20,7 +20,6 @@ from text_generation.models import Seq2SeqLM
from text_generation.utils import (
initialize_torch_distributed,
weight_files,
download_weights,
)
HAS_BITS_AND_BYTES = True
@ -53,14 +52,8 @@ class T5Sharded(Seq2SeqLM):
)
tokenizer.bos_token_id = config.decoder_start_token_id
# Only master download weights
if self.master:
download_weights(model_id, revision=revision, extension=".safetensors")
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
if not filenames:
raise ValueError("No safetensors weights found")
with init_empty_weights():
model = AutoModelForSeq2SeqLM.from_config(config)

View File

@ -1,283 +0,0 @@
import concurrent
import os
import re
import torch
import torch.distributed
from datetime import timedelta
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from pathlib import Path
from huggingface_hub import HfApi, hf_hub_download, _CACHED_NO_EXIST
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from huggingface_hub.utils import LocalEntryNotFoundError
from tqdm import tqdm
from typing import List, Optional, Tuple
from transformers import PreTrainedTokenizerBase
from transformers.generation.logits_process import (
LogitsProcessorList,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopPLogitsWarper,
TopKLogitsWarper,
)
from text_generation.pb import generate_pb2
from text_generation.pb.generate_pb2 import FinishReason
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
class Sampling:
def __init__(self, seed: int, device: str = "cpu"):
self.generator = torch.Generator(device)
self.generator.manual_seed(seed)
self.seed = seed
def __call__(self, logits):
probs = torch.nn.functional.softmax(logits)
next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator)
return next_tokens
class Greedy:
def __call__(self, logits):
return logits.argmax()
class NextTokenChooser:
def __init__(
self,
temperature=1.0,
repetition_penalty=1.0,
top_k=None,
top_p=None,
do_sample=False,
seed=0,
device="cpu",
):
warpers = LogitsProcessorList()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
sampling = do_sample
if temperature is not None and temperature != 1.0:
temperature = float(temperature)
warpers.append(TemperatureLogitsWarper(temperature))
sampling = True
if top_k is not None and top_k != 0:
warpers.append(TopKLogitsWarper(top_k=top_k))
sampling = True
if top_p is not None and top_p < 1.0:
warpers.append(TopPLogitsWarper(top_p=top_p))
sampling = True
if repetition_penalty is not None and repetition_penalty != 1.0:
warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
self.warpers = warpers
self.choice = Sampling(seed, device) if sampling else Greedy()
def __call__(self, input_ids, scores):
# Warp logits
scores = self.warpers(input_ids, scores)
# Compute logprobs
logprobs = torch.log_softmax(scores, -1)
# Choose tokens
next_id = self.choice(scores[-1])
return next_id.view(1, 1), logprobs
@classmethod
def from_pb(
cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device
) -> "NextTokenChooser":
return NextTokenChooser(
temperature=pb.temperature,
repetition_penalty=pb.repetition_penalty,
top_k=pb.top_k,
top_p=pb.top_p,
do_sample=pb.do_sample,
seed=pb.seed,
device=device,
)
class StopSequenceCriteria:
def __init__(self, stop_sequence: str):
self.regex = re.compile(f".*{stop_sequence}$")
def __call__(self, output: str) -> bool:
if self.regex.findall(output):
return True
return False
class StoppingCriteria:
def __init__(
self,
eos_token_id: int,
stop_sequence_criterias: List[StopSequenceCriteria],
max_new_tokens=20,
):
self.eos_token_id = eos_token_id
self.stop_sequence_criterias = stop_sequence_criterias
self.max_new_tokens = max_new_tokens
self.current_tokens = 0
self.current_output = ""
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
self.current_tokens += 1
if self.current_tokens >= self.max_new_tokens:
return True, FinishReason.FINISH_REASON_LENGTH
if last_token == self.eos_token_id:
return True, FinishReason.FINISH_REASON_EOS_TOKEN
self.current_output += last_output
for stop_sequence_criteria in self.stop_sequence_criterias:
if stop_sequence_criteria(self.current_output):
return True, FinishReason.FINISH_REASON_STOP_SEQUENCE
return False, None
@classmethod
def from_pb(
cls,
pb: generate_pb2.StoppingCriteriaParameters,
tokenizer: PreTrainedTokenizerBase,
) -> "StoppingCriteria":
stop_sequence_criterias = [
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
]
return StoppingCriteria(
tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens
)
def initialize_torch_distributed():
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
if torch.cuda.is_available():
from torch.distributed import ProcessGroupNCCL
# Set the device id.
assert world_size <= torch.cuda.device_count(), "Each process is one gpu"
device = rank % torch.cuda.device_count()
torch.cuda.set_device(device)
backend = "nccl"
options = ProcessGroupNCCL.Options()
options.is_high_priority_stream = True
options._timeout = timedelta(seconds=60)
else:
backend = "gloo"
options = None
# Call the init process.
torch.distributed.init_process_group(
backend=backend,
world_size=world_size,
rank=rank,
timeout=timedelta(seconds=60),
pg_options=options,
)
return torch.distributed.group.WORLD, rank, world_size
def weight_hub_files(model_id, revision=None, extension=".safetensors"):
"""Get the safetensors filenames on the hub"""
api = HfApi()
info = api.model_info(model_id, revision=revision)
filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)]
return filenames
def try_to_load_from_cache(model_id, revision, filename):
"""Try to load a file from the Hugging Face cache"""
if revision is None:
revision = "main"
object_id = model_id.replace("/", "--")
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}"
if not repo_cache.is_dir():
# No cache for this model
return None
refs_dir = repo_cache / "refs"
snapshots_dir = repo_cache / "snapshots"
no_exist_dir = repo_cache / ".no_exist"
# Resolve refs (for instance to convert main to the associated commit sha)
if refs_dir.is_dir():
revision_file = refs_dir / revision
if revision_file.exists():
with revision_file.open() as f:
revision = f.read()
# Check if file is cached as "no_exist"
if (no_exist_dir / revision / filename).is_file():
return _CACHED_NO_EXIST
# Check if revision folder exists
if not snapshots_dir.exists():
return None
cached_shas = os.listdir(snapshots_dir)
if revision not in cached_shas:
# No cache for this revision and we won't try to return a random revision
return None
# Check if file exists in cache
cached_file = snapshots_dir / revision / filename
return str(cached_file) if cached_file.is_file() else None
def weight_files(model_id, revision=None, extension=".safetensors"):
"""Get the local safetensors filenames"""
if WEIGHTS_CACHE_OVERRIDE is not None:
return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))
filenames = weight_hub_files(model_id, revision, extension)
files = []
for filename in filenames:
cache_file = try_to_load_from_cache(
model_id, revision=revision, filename=filename
)
if cache_file is None:
raise LocalEntryNotFoundError(
f"File {filename} of model {model_id} not found in "
f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
f"Please run `text-generation-server download-weights {model_id}` first."
)
files.append(cache_file)
return files
def download_weights(model_id, revision=None, extension=".safetensors"):
"""Download the safetensors files from the hub"""
if WEIGHTS_CACHE_OVERRIDE is not None:
return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))
filenames = weight_hub_files(model_id, revision, extension)
download_function = partial(
hf_hub_download,
repo_id=model_id,
local_files_only=False,
)
executor = ThreadPoolExecutor(max_workers=5)
futures = [
executor.submit(download_function, filename=filename, revision=revision)
for filename in filenames
]
files = [
future.result()
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures))
]
return files

View File

@ -0,0 +1,30 @@
from text_generation.utils.convert import convert_file, convert_files
from text_generation.utils.dist import initialize_torch_distributed
from text_generation.utils.hub import (
weight_files,
weight_hub_files,
download_weights,
EntryNotFoundError,
LocalEntryNotFoundError,
)
from text_generation.utils.tokens import (
Greedy,
NextTokenChooser,
Sampling,
StoppingCriteria,
)
__all__ = [
"convert_file",
"convert_files",
"initialize_torch_distributed",
"weight_files",
"weight_hub_files",
"download_weights",
"EntryNotFoundError",
"LocalEntryNotFoundError",
"Greedy",
"NextTokenChooser",
"Sampling",
"StoppingCriteria",
]

View File

@ -0,0 +1,85 @@
import concurrent
import torch
from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict
from pathlib import Path
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from typing import Dict, List
def check_file_size(source_file: Path, target_file: Path):
"""
Check that two files are close in size
"""
source_file_size = source_file.stat().st_size
target_file_size = target_file.stat().st_size
if (source_file_size - target_file_size) / source_file_size > 0.01:
raise RuntimeError(
f"""The file size different is more than 1%:
- {source_file}: {source_file_size}
- {target_file}: {target_file_size}
"""
)
def remove_shared_pointers(tensors: Dict[str, torch.Tensor]):
"""
For a Dict of tensors, check if two or more tensors point to the same underlying memory and
remove them
"""
ptrs = defaultdict(list)
for k, v in tensors.items():
ptrs[v.data_ptr()].append(k)
# Iterate over all found memory addresses
for ptr, names in ptrs.items():
if len(names) > 1:
# Multiple tensors are point to the same memory
# Only keep the first tensor
for name in names[1:]:
tensors.pop(name)
def convert_file(pt_file: Path, st_file: Path):
"""
Convert a pytorch file to a safetensors file
"""
pt_state = torch.load(pt_file, map_location="cpu")
if "state_dict" in pt_state:
pt_state = pt_state["state_dict"]
remove_shared_pointers(pt_state)
# Tensors need to be contiguous
pt_state = {k: v.contiguous() for k, v in pt_state.items()}
st_file.parent.mkdir(parents=True, exist_ok=True)
save_file(pt_state, str(st_file), metadata={"format": "pt"})
# Check that both files are close in size
check_file_size(pt_file, st_file)
# Load safetensors state
st_state = load_file(str(st_file))
for k in st_state:
pt_tensor = pt_state[k]
st_tensor = st_state[k]
if not torch.equal(pt_tensor, st_tensor):
raise RuntimeError(f"The output tensors do not match for key {k}")
def convert_files(pt_files: List[Path], st_files: List[Path]):
assert len(pt_files) == len(st_files)
executor = ThreadPoolExecutor(max_workers=5)
futures = [
executor.submit(convert_file, pt_file=pt_file, st_file=st_file)
for pt_file, st_file in zip(pt_files, st_files)
]
[
future.result()
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures))
]

View File

@ -0,0 +1,35 @@
import os
import torch
from datetime import timedelta
def initialize_torch_distributed():
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
if torch.cuda.is_available():
from torch.distributed import ProcessGroupNCCL
# Set the device id.
assert world_size <= torch.cuda.device_count(), "Each process is one gpu"
device = rank % torch.cuda.device_count()
torch.cuda.set_device(device)
backend = "nccl"
options = ProcessGroupNCCL.Options()
options.is_high_priority_stream = True
options._timeout = timedelta(seconds=60)
else:
backend = "gloo"
options = None
# Call the init process.
torch.distributed.init_process_group(
backend=backend,
world_size=world_size,
rank=rank,
timeout=timedelta(seconds=60),
pg_options=options,
)
return torch.distributed.group.WORLD, rank, world_size

View File

@ -0,0 +1,137 @@
import concurrent
import os
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from pathlib import Path
from tqdm import tqdm
from typing import Optional, List
from huggingface_hub import HfApi, _CACHED_NO_EXIST, hf_hub_download
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from huggingface_hub.utils import LocalEntryNotFoundError, EntryNotFoundError
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
def weight_hub_files(
model_id: str, revision: str = None, extension: str = ".safetensors"
) -> List[str]:
"""Get the weights filenames on the hub"""
api = HfApi()
info = api.model_info(model_id, revision=revision)
filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)]
if not filenames:
raise EntryNotFoundError(
f"No {extension} weights found for model {model_id} and revision {revision}.",
None,
)
return filenames
def try_to_load_from_cache(
model_id: str, revision: str, filename: str
) -> Optional[Path]:
"""Try to load a file from the Hugging Face cache"""
if revision is None:
revision = "main"
object_id = model_id.replace("/", "--")
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}"
if not repo_cache.is_dir():
# No cache for this model
return None
refs_dir = repo_cache / "refs"
snapshots_dir = repo_cache / "snapshots"
no_exist_dir = repo_cache / ".no_exist"
# Resolve refs (for instance to convert main to the associated commit sha)
if refs_dir.is_dir():
revision_file = refs_dir / revision
if revision_file.exists():
with revision_file.open() as f:
revision = f.read()
# Check if file is cached as "no_exist"
if (no_exist_dir / revision / filename).is_file():
return _CACHED_NO_EXIST
# Check if revision folder exists
if not snapshots_dir.exists():
return None
cached_shas = os.listdir(snapshots_dir)
if revision not in cached_shas:
# No cache for this revision and we won't try to return a random revision
return None
# Check if file exists in cache
cached_file = snapshots_dir / revision / filename
return cached_file if cached_file.is_file() else None
def weight_files(model_id: str, revision: str, extension: str) -> List[Path]:
"""Get the local files"""
try:
filenames = weight_hub_files(model_id, revision, extension)
except EntryNotFoundError as e:
if extension != ".safetensors":
raise e
# Try to see if there are pytorch weights
pt_filenames = weight_hub_files(model_id, revision, extension=".bin")
# Change pytorch extension to safetensors extension
# It is possible that we have safetensors weights locally even though they are not on the
# hub if we converted weights locally without pushing them
filenames = [
f"{Path(f).stem.lstrip('pytorch_')}.safetensors" for f in pt_filenames
]
if WEIGHTS_CACHE_OVERRIDE is not None:
files = []
for filename in filenames:
p = Path(WEIGHTS_CACHE_OVERRIDE) / filename
if not p.exists():
raise LocalEntryNotFoundError(
f"File {p} not found in {WEIGHTS_CACHE_OVERRIDE}."
)
files.append(p)
return files
files = []
for filename in filenames:
cache_file = try_to_load_from_cache(
model_id, revision=revision, filename=filename
)
if cache_file is None:
raise LocalEntryNotFoundError(
f"File {filename} of model {model_id} not found in "
f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
f"Please run `text-generation-server download-weights {model_id}` first."
)
files.append(cache_file)
return files
def download_weights(model_id: str, revision: str, filenames: List[str]) -> List[Path]:
"""Download the safetensors files from the hub"""
download_function = partial(
hf_hub_download,
repo_id=model_id,
local_files_only=False,
)
executor = ThreadPoolExecutor(max_workers=5)
futures = [
executor.submit(download_function, filename=filename, revision=revision)
for filename in filenames
]
files = [
future.result()
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures))
]
return [Path(p) for p in files]

View File

@ -0,0 +1,141 @@
import re
import torch
from transformers import (
LogitsProcessorList,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
RepetitionPenaltyLogitsProcessor,
PreTrainedTokenizerBase,
)
from typing import List, Tuple, Optional
from text_generation.pb import generate_pb2
class Sampling:
def __init__(self, seed: int, device: str = "cpu"):
self.generator = torch.Generator(device)
self.generator.manual_seed(seed)
self.seed = seed
def __call__(self, logits):
probs = torch.nn.functional.softmax(logits)
next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator)
return next_tokens
class Greedy:
def __call__(self, logits):
return logits.argmax()
class NextTokenChooser:
def __init__(
self,
temperature=1.0,
repetition_penalty=1.0,
top_k=None,
top_p=None,
do_sample=False,
seed=0,
device="cpu",
):
warpers = LogitsProcessorList()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
sampling = do_sample
if temperature is not None and temperature != 1.0:
temperature = float(temperature)
warpers.append(TemperatureLogitsWarper(temperature))
sampling = True
if top_k is not None and top_k != 0:
warpers.append(TopKLogitsWarper(top_k=top_k))
sampling = True
if top_p is not None and top_p < 1.0:
warpers.append(TopPLogitsWarper(top_p=top_p))
sampling = True
if repetition_penalty is not None and repetition_penalty != 1.0:
warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
self.warpers = warpers
self.choice = Sampling(seed, device) if sampling else Greedy()
def __call__(self, input_ids, scores):
# Warp logits
scores = self.warpers(input_ids, scores)
# Compute logprobs
logprobs = torch.log_softmax(scores, -1)
# Choose tokens
next_id = self.choice(scores[-1])
return next_id.view(1, 1), logprobs
@classmethod
def from_pb(
cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device
) -> "NextTokenChooser":
return NextTokenChooser(
temperature=pb.temperature,
repetition_penalty=pb.repetition_penalty,
top_k=pb.top_k,
top_p=pb.top_p,
do_sample=pb.do_sample,
seed=pb.seed,
device=device,
)
class StopSequenceCriteria:
def __init__(self, stop_sequence: str):
self.regex = re.compile(f".*{stop_sequence}$")
def __call__(self, output: str) -> bool:
if self.regex.findall(output):
return True
return False
class StoppingCriteria:
def __init__(
self,
eos_token_id: int,
stop_sequence_criterias: List[StopSequenceCriteria],
max_new_tokens=20,
):
self.eos_token_id = eos_token_id
self.stop_sequence_criterias = stop_sequence_criterias
self.max_new_tokens = max_new_tokens
self.current_tokens = 0
self.current_output = ""
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
self.current_tokens += 1
if self.current_tokens >= self.max_new_tokens:
return True, generate_pb2.FinishReason.FINISH_REASON_LENGTH
if last_token == self.eos_token_id:
return True, generate_pb2.FinishReason.FINISH_REASON_EOS_TOKEN
self.current_output += last_output
for stop_sequence_criteria in self.stop_sequence_criterias:
if stop_sequence_criteria(self.current_output):
return True, generate_pb2.FinishReason.FINISH_REASON_STOP_SEQUENCE
return False, None
@classmethod
def from_pb(
cls,
pb: generate_pb2.StoppingCriteriaParameters,
tokenizer: PreTrainedTokenizerBase,
) -> "StoppingCriteria":
stop_sequence_criterias = [
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
]
return StoppingCriteria(
tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens
)