diff --git a/server/text_generation/cli.py b/server/text_generation/cli.py index e9c8ea92..17f99d68 100644 --- a/server/text_generation/cli.py +++ b/server/text_generation/cli.py @@ -60,8 +60,24 @@ def download_weights( model_id: str, revision: Optional[str] = None, extension: str = ".safetensors", + convert: bool = False, ): - utils.download_weights(model_id, revision, extension) + try: + filenames = utils.weight_hub_files(model_id, revision, extension) + utils.download_weights(model_id, revision, filenames) + except utils.EntryNotFoundError as e: + if not convert or not extension == ".safetensors": + raise e + # Try to see if there are pytorch weights + pt_filenames = utils.weight_hub_files(model_id, revision, ".bin") + # Download pytorch weights + local_pt_files = utils.download_weights(model_id, revision, pt_filenames) + local_st_files = [ + p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" + for p in local_pt_files + ] + # Convert pytorch weights to safetensors + utils.convert_files(local_pt_files, local_st_files) if __name__ == "__main__": diff --git a/server/text_generation/models/__init__.py b/server/text_generation/models/__init__.py index 7445b427..908b144c 100644 --- a/server/text_generation/models/__init__.py +++ b/server/text_generation/models/__init__.py @@ -41,6 +41,15 @@ torch.set_grad_enabled(False) def get_model( model_id: str, revision: Optional[str], sharded: bool, quantize: bool ) -> Model: + if model_id.startswith("facebook/galactica"): + if sharded: + return GalacticaSharded(model_id, revision, quantize=quantize) + else: + return Galactica(model_id, revision, quantize=quantize) + + if "santacoder" in model_id: + return SantaCoder(model_id, revision, quantize) + config = AutoConfig.from_pretrained(model_id, revision=revision) if config.model_type == "bloom": @@ -48,27 +57,22 @@ def get_model( return BLOOMSharded(model_id, revision, quantize=quantize) else: return BLOOM(model_id, revision, quantize=quantize) - elif config.model_type == "gpt_neox": + + if config.model_type == "gpt_neox": if sharded: return GPTNeoxSharded(model_id, revision, quantize=quantize) else: return GPTNeox(model_id, revision, quantize=quantize) - elif config.model_type == "t5": + + if config.model_type == "t5": if sharded: return T5Sharded(model_id, revision, quantize=quantize) else: return Seq2SeqLM(model_id, revision, quantize=quantize) - elif model_id.startswith("facebook/galactica"): - if sharded: - return GalacticaSharded(model_id, revision, quantize=quantize) - else: - return Galactica(model_id, revision, quantize=quantize) - elif "santacoder" in model_id: - return SantaCoder(model_id, revision, quantize) - else: - if sharded: - raise ValueError("sharded is not supported for AutoModel") - try: - return CausalLM(model_id, revision, quantize=quantize) - except Exception: - return Seq2SeqLM(model_id, revision, quantize=quantize) + + if sharded: + raise ValueError("sharded is not supported for AutoModel") + try: + return CausalLM(model_id, revision, quantize=quantize) + except Exception: + return Seq2SeqLM(model_id, revision, quantize=quantize) diff --git a/server/text_generation/models/bloom.py b/server/text_generation/models/bloom.py index 992d7b5b..08c3ac94 100644 --- a/server/text_generation/models/bloom.py +++ b/server/text_generation/models/bloom.py @@ -23,7 +23,6 @@ from text_generation.pb import generate_pb2 from text_generation.utils import ( initialize_torch_distributed, weight_files, - download_weights, ) HAS_BITS_AND_BYTES = True @@ -80,14 +79,8 @@ class BLOOMSharded(BLOOM): ) config.pad_token_id = 3 - # Only download weights for small models - if self.master and model_id == "bigscience/bloom-560m": - download_weights(model_id, revision=revision, extension=".safetensors") - torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - if not filenames: - raise ValueError("No safetensors weights found") with init_empty_weights(): model = AutoModelForCausalLM.from_config(config) diff --git a/server/text_generation/models/galactica.py b/server/text_generation/models/galactica.py index f1dc8a30..780a94f1 100644 --- a/server/text_generation/models/galactica.py +++ b/server/text_generation/models/galactica.py @@ -26,7 +26,6 @@ from text_generation.utils import ( StoppingCriteria, initialize_torch_distributed, weight_files, - download_weights, ) HAS_BITS_AND_BYTES = True @@ -172,14 +171,8 @@ class GalacticaSharded(Galactica): ) tokenizer.pad_token_id = config.pad_token_id - # Only download weights for small models - if self.master and model_id == "facebook/galactica-125m": - download_weights(model_id, revision=revision, extension=".safetensors") - torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - if not filenames: - raise ValueError("No safetensors weights found") with init_empty_weights(): model = AutoModelForCausalLM.from_config(config) diff --git a/server/text_generation/models/gpt_neox.py b/server/text_generation/models/gpt_neox.py index 2d467f4c..0197f976 100644 --- a/server/text_generation/models/gpt_neox.py +++ b/server/text_generation/models/gpt_neox.py @@ -20,7 +20,6 @@ from text_generation.models import CausalLM from text_generation.utils import ( initialize_torch_distributed, weight_files, - download_weights, ) HAS_BITS_AND_BYTES = True @@ -69,14 +68,8 @@ class GPTNeoxSharded(GPTNeox): model_id, revision=revision, tp_parallel=True ) - # Only master download weights - if self.master: - download_weights(model_id, revision=revision, extension=".safetensors") - torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - if not filenames: - raise ValueError("No safetensors weights found") with init_empty_weights(): model = AutoModelForCausalLM.from_config(config) diff --git a/server/text_generation/models/santacoder.py b/server/text_generation/models/santacoder.py index fb496197..5d271c85 100644 --- a/server/text_generation/models/santacoder.py +++ b/server/text_generation/models/santacoder.py @@ -1,7 +1,7 @@ import torch import torch.distributed -from typing import Optional, List, Tuple +from typing import Optional, List from transformers import AutoTokenizer, AutoModelForCausalLM from text_generation.models import CausalLM diff --git a/server/text_generation/models/t5.py b/server/text_generation/models/t5.py index d7241c81..536ebda3 100644 --- a/server/text_generation/models/t5.py +++ b/server/text_generation/models/t5.py @@ -20,7 +20,6 @@ from text_generation.models import Seq2SeqLM from text_generation.utils import ( initialize_torch_distributed, weight_files, - download_weights, ) HAS_BITS_AND_BYTES = True @@ -53,14 +52,8 @@ class T5Sharded(Seq2SeqLM): ) tokenizer.bos_token_id = config.decoder_start_token_id - # Only master download weights - if self.master: - download_weights(model_id, revision=revision, extension=".safetensors") - torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - if not filenames: - raise ValueError("No safetensors weights found") with init_empty_weights(): model = AutoModelForSeq2SeqLM.from_config(config) diff --git a/server/text_generation/utils.py b/server/text_generation/utils.py deleted file mode 100644 index 3b3f08c7..00000000 --- a/server/text_generation/utils.py +++ /dev/null @@ -1,283 +0,0 @@ -import concurrent -import os -import re -import torch -import torch.distributed - -from datetime import timedelta - -from concurrent.futures import ThreadPoolExecutor -from functools import partial -from pathlib import Path -from huggingface_hub import HfApi, hf_hub_download, _CACHED_NO_EXIST -from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE -from huggingface_hub.utils import LocalEntryNotFoundError -from tqdm import tqdm -from typing import List, Optional, Tuple -from transformers import PreTrainedTokenizerBase -from transformers.generation.logits_process import ( - LogitsProcessorList, - RepetitionPenaltyLogitsProcessor, - TemperatureLogitsWarper, - TopPLogitsWarper, - TopKLogitsWarper, -) - -from text_generation.pb import generate_pb2 -from text_generation.pb.generate_pb2 import FinishReason - -WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None) - - -class Sampling: - def __init__(self, seed: int, device: str = "cpu"): - self.generator = torch.Generator(device) - self.generator.manual_seed(seed) - self.seed = seed - - def __call__(self, logits): - probs = torch.nn.functional.softmax(logits) - next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator) - return next_tokens - - -class Greedy: - def __call__(self, logits): - return logits.argmax() - - -class NextTokenChooser: - def __init__( - self, - temperature=1.0, - repetition_penalty=1.0, - top_k=None, - top_p=None, - do_sample=False, - seed=0, - device="cpu", - ): - warpers = LogitsProcessorList() - # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files - # all samplers can be found in `generation_utils_samplers.py` - sampling = do_sample - if temperature is not None and temperature != 1.0: - temperature = float(temperature) - warpers.append(TemperatureLogitsWarper(temperature)) - sampling = True - if top_k is not None and top_k != 0: - warpers.append(TopKLogitsWarper(top_k=top_k)) - sampling = True - if top_p is not None and top_p < 1.0: - warpers.append(TopPLogitsWarper(top_p=top_p)) - sampling = True - if repetition_penalty is not None and repetition_penalty != 1.0: - warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)) - - self.warpers = warpers - self.choice = Sampling(seed, device) if sampling else Greedy() - - def __call__(self, input_ids, scores): - # Warp logits - scores = self.warpers(input_ids, scores) - - # Compute logprobs - logprobs = torch.log_softmax(scores, -1) - - # Choose tokens - next_id = self.choice(scores[-1]) - - return next_id.view(1, 1), logprobs - - @classmethod - def from_pb( - cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device - ) -> "NextTokenChooser": - return NextTokenChooser( - temperature=pb.temperature, - repetition_penalty=pb.repetition_penalty, - top_k=pb.top_k, - top_p=pb.top_p, - do_sample=pb.do_sample, - seed=pb.seed, - device=device, - ) - - -class StopSequenceCriteria: - def __init__(self, stop_sequence: str): - self.regex = re.compile(f".*{stop_sequence}$") - - def __call__(self, output: str) -> bool: - if self.regex.findall(output): - return True - return False - - -class StoppingCriteria: - def __init__( - self, - eos_token_id: int, - stop_sequence_criterias: List[StopSequenceCriteria], - max_new_tokens=20, - ): - self.eos_token_id = eos_token_id - self.stop_sequence_criterias = stop_sequence_criterias - self.max_new_tokens = max_new_tokens - self.current_tokens = 0 - self.current_output = "" - - def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]: - self.current_tokens += 1 - if self.current_tokens >= self.max_new_tokens: - return True, FinishReason.FINISH_REASON_LENGTH - - if last_token == self.eos_token_id: - return True, FinishReason.FINISH_REASON_EOS_TOKEN - - self.current_output += last_output - for stop_sequence_criteria in self.stop_sequence_criterias: - if stop_sequence_criteria(self.current_output): - return True, FinishReason.FINISH_REASON_STOP_SEQUENCE - - return False, None - - @classmethod - def from_pb( - cls, - pb: generate_pb2.StoppingCriteriaParameters, - tokenizer: PreTrainedTokenizerBase, - ) -> "StoppingCriteria": - stop_sequence_criterias = [ - StopSequenceCriteria(sequence) for sequence in pb.stop_sequences - ] - return StoppingCriteria( - tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens - ) - - -def initialize_torch_distributed(): - rank = int(os.getenv("RANK", "0")) - world_size = int(os.getenv("WORLD_SIZE", "1")) - - if torch.cuda.is_available(): - from torch.distributed import ProcessGroupNCCL - - # Set the device id. - assert world_size <= torch.cuda.device_count(), "Each process is one gpu" - device = rank % torch.cuda.device_count() - torch.cuda.set_device(device) - backend = "nccl" - options = ProcessGroupNCCL.Options() - options.is_high_priority_stream = True - options._timeout = timedelta(seconds=60) - else: - backend = "gloo" - options = None - - # Call the init process. - torch.distributed.init_process_group( - backend=backend, - world_size=world_size, - rank=rank, - timeout=timedelta(seconds=60), - pg_options=options, - ) - - return torch.distributed.group.WORLD, rank, world_size - - -def weight_hub_files(model_id, revision=None, extension=".safetensors"): - """Get the safetensors filenames on the hub""" - api = HfApi() - info = api.model_info(model_id, revision=revision) - filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)] - return filenames - - -def try_to_load_from_cache(model_id, revision, filename): - """Try to load a file from the Hugging Face cache""" - if revision is None: - revision = "main" - - object_id = model_id.replace("/", "--") - repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}" - - if not repo_cache.is_dir(): - # No cache for this model - return None - - refs_dir = repo_cache / "refs" - snapshots_dir = repo_cache / "snapshots" - no_exist_dir = repo_cache / ".no_exist" - - # Resolve refs (for instance to convert main to the associated commit sha) - if refs_dir.is_dir(): - revision_file = refs_dir / revision - if revision_file.exists(): - with revision_file.open() as f: - revision = f.read() - - # Check if file is cached as "no_exist" - if (no_exist_dir / revision / filename).is_file(): - return _CACHED_NO_EXIST - - # Check if revision folder exists - if not snapshots_dir.exists(): - return None - cached_shas = os.listdir(snapshots_dir) - if revision not in cached_shas: - # No cache for this revision and we won't try to return a random revision - return None - - # Check if file exists in cache - cached_file = snapshots_dir / revision / filename - return str(cached_file) if cached_file.is_file() else None - - -def weight_files(model_id, revision=None, extension=".safetensors"): - """Get the local safetensors filenames""" - if WEIGHTS_CACHE_OVERRIDE is not None: - return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}")) - - filenames = weight_hub_files(model_id, revision, extension) - files = [] - for filename in filenames: - cache_file = try_to_load_from_cache( - model_id, revision=revision, filename=filename - ) - if cache_file is None: - raise LocalEntryNotFoundError( - f"File {filename} of model {model_id} not found in " - f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. " - f"Please run `text-generation-server download-weights {model_id}` first." - ) - files.append(cache_file) - - return files - - -def download_weights(model_id, revision=None, extension=".safetensors"): - """Download the safetensors files from the hub""" - if WEIGHTS_CACHE_OVERRIDE is not None: - return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}")) - - filenames = weight_hub_files(model_id, revision, extension) - - download_function = partial( - hf_hub_download, - repo_id=model_id, - local_files_only=False, - ) - - executor = ThreadPoolExecutor(max_workers=5) - futures = [ - executor.submit(download_function, filename=filename, revision=revision) - for filename in filenames - ] - files = [ - future.result() - for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)) - ] - - return files diff --git a/server/text_generation/utils/__init__.py b/server/text_generation/utils/__init__.py new file mode 100644 index 00000000..b7521b92 --- /dev/null +++ b/server/text_generation/utils/__init__.py @@ -0,0 +1,30 @@ +from text_generation.utils.convert import convert_file, convert_files +from text_generation.utils.dist import initialize_torch_distributed +from text_generation.utils.hub import ( + weight_files, + weight_hub_files, + download_weights, + EntryNotFoundError, + LocalEntryNotFoundError, +) +from text_generation.utils.tokens import ( + Greedy, + NextTokenChooser, + Sampling, + StoppingCriteria, +) + +__all__ = [ + "convert_file", + "convert_files", + "initialize_torch_distributed", + "weight_files", + "weight_hub_files", + "download_weights", + "EntryNotFoundError", + "LocalEntryNotFoundError", + "Greedy", + "NextTokenChooser", + "Sampling", + "StoppingCriteria", +] diff --git a/server/text_generation/utils/convert.py b/server/text_generation/utils/convert.py new file mode 100644 index 00000000..3d429efa --- /dev/null +++ b/server/text_generation/utils/convert.py @@ -0,0 +1,85 @@ +import concurrent +import torch + +from concurrent.futures import ThreadPoolExecutor +from collections import defaultdict +from pathlib import Path +from safetensors.torch import load_file, save_file +from tqdm import tqdm +from typing import Dict, List + + +def check_file_size(source_file: Path, target_file: Path): + """ + Check that two files are close in size + """ + source_file_size = source_file.stat().st_size + target_file_size = target_file.stat().st_size + + if (source_file_size - target_file_size) / source_file_size > 0.01: + raise RuntimeError( + f"""The file size different is more than 1%: + - {source_file}: {source_file_size} + - {target_file}: {target_file_size} + """ + ) + + +def remove_shared_pointers(tensors: Dict[str, torch.Tensor]): + """ + For a Dict of tensors, check if two or more tensors point to the same underlying memory and + remove them + """ + ptrs = defaultdict(list) + for k, v in tensors.items(): + ptrs[v.data_ptr()].append(k) + + # Iterate over all found memory addresses + for ptr, names in ptrs.items(): + if len(names) > 1: + # Multiple tensors are point to the same memory + # Only keep the first tensor + for name in names[1:]: + tensors.pop(name) + + +def convert_file(pt_file: Path, st_file: Path): + """ + Convert a pytorch file to a safetensors file + """ + pt_state = torch.load(pt_file, map_location="cpu") + if "state_dict" in pt_state: + pt_state = pt_state["state_dict"] + + remove_shared_pointers(pt_state) + + # Tensors need to be contiguous + pt_state = {k: v.contiguous() for k, v in pt_state.items()} + + st_file.parent.mkdir(parents=True, exist_ok=True) + save_file(pt_state, str(st_file), metadata={"format": "pt"}) + + # Check that both files are close in size + check_file_size(pt_file, st_file) + + # Load safetensors state + st_state = load_file(str(st_file)) + for k in st_state: + pt_tensor = pt_state[k] + st_tensor = st_state[k] + if not torch.equal(pt_tensor, st_tensor): + raise RuntimeError(f"The output tensors do not match for key {k}") + + +def convert_files(pt_files: List[Path], st_files: List[Path]): + assert len(pt_files) == len(st_files) + + executor = ThreadPoolExecutor(max_workers=5) + futures = [ + executor.submit(convert_file, pt_file=pt_file, st_file=st_file) + for pt_file, st_file in zip(pt_files, st_files) + ] + [ + future.result() + for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)) + ] diff --git a/server/text_generation/utils/dist.py b/server/text_generation/utils/dist.py new file mode 100644 index 00000000..9785493e --- /dev/null +++ b/server/text_generation/utils/dist.py @@ -0,0 +1,35 @@ +import os +import torch + +from datetime import timedelta + + +def initialize_torch_distributed(): + rank = int(os.getenv("RANK", "0")) + world_size = int(os.getenv("WORLD_SIZE", "1")) + + if torch.cuda.is_available(): + from torch.distributed import ProcessGroupNCCL + + # Set the device id. + assert world_size <= torch.cuda.device_count(), "Each process is one gpu" + device = rank % torch.cuda.device_count() + torch.cuda.set_device(device) + backend = "nccl" + options = ProcessGroupNCCL.Options() + options.is_high_priority_stream = True + options._timeout = timedelta(seconds=60) + else: + backend = "gloo" + options = None + + # Call the init process. + torch.distributed.init_process_group( + backend=backend, + world_size=world_size, + rank=rank, + timeout=timedelta(seconds=60), + pg_options=options, + ) + + return torch.distributed.group.WORLD, rank, world_size diff --git a/server/text_generation/utils/hub.py b/server/text_generation/utils/hub.py new file mode 100644 index 00000000..e7f1e518 --- /dev/null +++ b/server/text_generation/utils/hub.py @@ -0,0 +1,137 @@ +import concurrent +import os + +from concurrent.futures import ThreadPoolExecutor +from functools import partial +from pathlib import Path +from tqdm import tqdm +from typing import Optional, List + +from huggingface_hub import HfApi, _CACHED_NO_EXIST, hf_hub_download +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE +from huggingface_hub.utils import LocalEntryNotFoundError, EntryNotFoundError + +WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None) + + +def weight_hub_files( + model_id: str, revision: str = None, extension: str = ".safetensors" +) -> List[str]: + """Get the weights filenames on the hub""" + api = HfApi() + info = api.model_info(model_id, revision=revision) + filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)] + + if not filenames: + raise EntryNotFoundError( + f"No {extension} weights found for model {model_id} and revision {revision}.", + None, + ) + + return filenames + + +def try_to_load_from_cache( + model_id: str, revision: str, filename: str +) -> Optional[Path]: + """Try to load a file from the Hugging Face cache""" + if revision is None: + revision = "main" + + object_id = model_id.replace("/", "--") + repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}" + + if not repo_cache.is_dir(): + # No cache for this model + return None + + refs_dir = repo_cache / "refs" + snapshots_dir = repo_cache / "snapshots" + no_exist_dir = repo_cache / ".no_exist" + + # Resolve refs (for instance to convert main to the associated commit sha) + if refs_dir.is_dir(): + revision_file = refs_dir / revision + if revision_file.exists(): + with revision_file.open() as f: + revision = f.read() + + # Check if file is cached as "no_exist" + if (no_exist_dir / revision / filename).is_file(): + return _CACHED_NO_EXIST + + # Check if revision folder exists + if not snapshots_dir.exists(): + return None + cached_shas = os.listdir(snapshots_dir) + if revision not in cached_shas: + # No cache for this revision and we won't try to return a random revision + return None + + # Check if file exists in cache + cached_file = snapshots_dir / revision / filename + return cached_file if cached_file.is_file() else None + + +def weight_files(model_id: str, revision: str, extension: str) -> List[Path]: + """Get the local files""" + try: + filenames = weight_hub_files(model_id, revision, extension) + except EntryNotFoundError as e: + if extension != ".safetensors": + raise e + # Try to see if there are pytorch weights + pt_filenames = weight_hub_files(model_id, revision, extension=".bin") + # Change pytorch extension to safetensors extension + # It is possible that we have safetensors weights locally even though they are not on the + # hub if we converted weights locally without pushing them + filenames = [ + f"{Path(f).stem.lstrip('pytorch_')}.safetensors" for f in pt_filenames + ] + + if WEIGHTS_CACHE_OVERRIDE is not None: + files = [] + for filename in filenames: + p = Path(WEIGHTS_CACHE_OVERRIDE) / filename + if not p.exists(): + raise LocalEntryNotFoundError( + f"File {p} not found in {WEIGHTS_CACHE_OVERRIDE}." + ) + files.append(p) + return files + + files = [] + for filename in filenames: + cache_file = try_to_load_from_cache( + model_id, revision=revision, filename=filename + ) + if cache_file is None: + raise LocalEntryNotFoundError( + f"File {filename} of model {model_id} not found in " + f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. " + f"Please run `text-generation-server download-weights {model_id}` first." + ) + files.append(cache_file) + + return files + + +def download_weights(model_id: str, revision: str, filenames: List[str]) -> List[Path]: + """Download the safetensors files from the hub""" + download_function = partial( + hf_hub_download, + repo_id=model_id, + local_files_only=False, + ) + + executor = ThreadPoolExecutor(max_workers=5) + futures = [ + executor.submit(download_function, filename=filename, revision=revision) + for filename in filenames + ] + files = [ + future.result() + for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)) + ] + + return [Path(p) for p in files] diff --git a/server/text_generation/utils/tokens.py b/server/text_generation/utils/tokens.py new file mode 100644 index 00000000..ef71ab81 --- /dev/null +++ b/server/text_generation/utils/tokens.py @@ -0,0 +1,141 @@ +import re +import torch + +from transformers import ( + LogitsProcessorList, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, + RepetitionPenaltyLogitsProcessor, + PreTrainedTokenizerBase, +) +from typing import List, Tuple, Optional + +from text_generation.pb import generate_pb2 + + +class Sampling: + def __init__(self, seed: int, device: str = "cpu"): + self.generator = torch.Generator(device) + self.generator.manual_seed(seed) + self.seed = seed + + def __call__(self, logits): + probs = torch.nn.functional.softmax(logits) + next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator) + return next_tokens + + +class Greedy: + def __call__(self, logits): + return logits.argmax() + + +class NextTokenChooser: + def __init__( + self, + temperature=1.0, + repetition_penalty=1.0, + top_k=None, + top_p=None, + do_sample=False, + seed=0, + device="cpu", + ): + warpers = LogitsProcessorList() + # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files + # all samplers can be found in `generation_utils_samplers.py` + sampling = do_sample + if temperature is not None and temperature != 1.0: + temperature = float(temperature) + warpers.append(TemperatureLogitsWarper(temperature)) + sampling = True + if top_k is not None and top_k != 0: + warpers.append(TopKLogitsWarper(top_k=top_k)) + sampling = True + if top_p is not None and top_p < 1.0: + warpers.append(TopPLogitsWarper(top_p=top_p)) + sampling = True + if repetition_penalty is not None and repetition_penalty != 1.0: + warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)) + + self.warpers = warpers + self.choice = Sampling(seed, device) if sampling else Greedy() + + def __call__(self, input_ids, scores): + # Warp logits + scores = self.warpers(input_ids, scores) + + # Compute logprobs + logprobs = torch.log_softmax(scores, -1) + + # Choose tokens + next_id = self.choice(scores[-1]) + + return next_id.view(1, 1), logprobs + + @classmethod + def from_pb( + cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device + ) -> "NextTokenChooser": + return NextTokenChooser( + temperature=pb.temperature, + repetition_penalty=pb.repetition_penalty, + top_k=pb.top_k, + top_p=pb.top_p, + do_sample=pb.do_sample, + seed=pb.seed, + device=device, + ) + + +class StopSequenceCriteria: + def __init__(self, stop_sequence: str): + self.regex = re.compile(f".*{stop_sequence}$") + + def __call__(self, output: str) -> bool: + if self.regex.findall(output): + return True + return False + + +class StoppingCriteria: + def __init__( + self, + eos_token_id: int, + stop_sequence_criterias: List[StopSequenceCriteria], + max_new_tokens=20, + ): + self.eos_token_id = eos_token_id + self.stop_sequence_criterias = stop_sequence_criterias + self.max_new_tokens = max_new_tokens + self.current_tokens = 0 + self.current_output = "" + + def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]: + self.current_tokens += 1 + if self.current_tokens >= self.max_new_tokens: + return True, generate_pb2.FinishReason.FINISH_REASON_LENGTH + + if last_token == self.eos_token_id: + return True, generate_pb2.FinishReason.FINISH_REASON_EOS_TOKEN + + self.current_output += last_output + for stop_sequence_criteria in self.stop_sequence_criterias: + if stop_sequence_criteria(self.current_output): + return True, generate_pb2.FinishReason.FINISH_REASON_STOP_SEQUENCE + + return False, None + + @classmethod + def from_pb( + cls, + pb: generate_pb2.StoppingCriteriaParameters, + tokenizer: PreTrainedTokenizerBase, + ) -> "StoppingCriteria": + stop_sequence_criterias = [ + StopSequenceCriteria(sequence) for sequence in pb.stop_sequences + ] + return StoppingCriteria( + tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens + )