text-generation-inference/server/text_generation/utils.py

277 lines
8.7 KiB
Python
Raw Normal View History

import concurrent
2022-10-08 10:30:12 +00:00
import os
2022-12-16 15:03:39 +00:00
import re
2022-10-08 10:30:12 +00:00
import torch
import torch.distributed
2022-10-17 12:59:00 +00:00
from datetime import timedelta
from concurrent.futures import ThreadPoolExecutor
from functools import partial
2023-01-31 17:53:56 +00:00
from pathlib import Path
from huggingface_hub import HfApi, hf_hub_download, _CACHED_NO_EXIST
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from huggingface_hub.utils import LocalEntryNotFoundError
from tqdm import tqdm
2022-12-12 17:25:22 +00:00
from typing import List, Optional, Tuple
from transformers import PreTrainedTokenizerBase
2022-12-01 18:31:54 +00:00
from transformers.generation.logits_process import (
2022-10-08 10:30:12 +00:00
LogitsProcessorList,
RepetitionPenaltyLogitsProcessor,
2022-10-08 10:30:12 +00:00
TemperatureLogitsWarper,
TopPLogitsWarper,
TopKLogitsWarper,
)
2022-12-12 17:25:22 +00:00
from text_generation.pb import generate_pb2
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
2022-10-08 10:30:12 +00:00
class Sampling:
def __init__(self, seed: int, device: str = "cpu"):
2023-01-31 13:30:33 +00:00
self.generator = torch.Generator(device)
self.generator.manual_seed(seed)
self.seed = seed
2022-10-08 10:30:12 +00:00
def __call__(self, logits):
probs = torch.nn.functional.softmax(logits, dim=-1)
next_tokens = torch.multinomial(
probs, num_samples=1, generator=self.generator
).squeeze(1)
2022-10-08 10:30:12 +00:00
return next_tokens
class Greedy:
def __call__(self, logits):
return logits.argmax(dim=-1)
class NextTokenChooser:
def __init__(
2023-01-31 13:30:33 +00:00
self,
temperature=1.0,
repetition_penalty=1.0,
2023-01-31 13:30:33 +00:00
top_k=None,
top_p=None,
do_sample=False,
seed=0,
2023-01-31 13:30:33 +00:00
device="cpu",
):
2022-10-08 10:30:12 +00:00
warpers = LogitsProcessorList()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
sampling = do_sample
if temperature is not None and temperature != 1.0:
temperature = float(temperature)
warpers.append(TemperatureLogitsWarper(temperature))
sampling = True
if top_k is not None and top_k != 0:
warpers.append(TopKLogitsWarper(top_k=top_k))
sampling = True
if top_p is not None and top_p < 1.0:
warpers.append(TopPLogitsWarper(top_p=top_p))
sampling = True
if repetition_penalty is not None and repetition_penalty != 1.0:
warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
2022-10-08 10:30:12 +00:00
self.warpers = warpers
2023-01-31 13:30:33 +00:00
self.choice = Sampling(seed, device) if sampling else Greedy()
2022-10-08 10:30:12 +00:00
def __call__(self, input_ids, scores):
2022-12-15 16:03:56 +00:00
# Warp logits
2022-10-08 10:30:12 +00:00
scores = self.warpers(input_ids, scores)
2022-12-15 16:03:56 +00:00
# Compute logprobs
logprobs = torch.log_softmax(scores, -1)
2022-12-15 16:03:56 +00:00
# Choose tokens
2022-10-08 10:30:12 +00:00
next_ids = self.choice(scores)
2022-12-15 16:03:56 +00:00
return next_ids, logprobs
2022-10-08 10:30:12 +00:00
2022-12-12 17:25:22 +00:00
@classmethod
2023-01-31 13:30:33 +00:00
def from_pb(
cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device
) -> "NextTokenChooser":
2022-12-12 17:25:22 +00:00
return NextTokenChooser(
temperature=pb.temperature,
repetition_penalty=pb.repetition_penalty,
2022-12-12 17:25:22 +00:00
top_k=pb.top_k,
top_p=pb.top_p,
do_sample=pb.do_sample,
seed=pb.seed,
device=device,
2022-12-12 17:25:22 +00:00
)
class StopSequenceCriteria:
2022-12-16 15:03:39 +00:00
def __init__(self, stop_sequence: str):
self.regex = re.compile(f".*{stop_sequence}$")
def __call__(self, output: str) -> bool:
if self.regex.findall(output):
2022-12-12 17:25:22 +00:00
return True
return False
2022-10-08 10:30:12 +00:00
class StoppingCriteria:
2022-12-12 17:25:22 +00:00
def __init__(
2022-12-16 15:03:39 +00:00
self,
eos_token_id: int,
stop_sequence_criterias: List[StopSequenceCriteria],
max_new_tokens=20,
2022-12-12 17:25:22 +00:00
):
2022-12-16 15:03:39 +00:00
self.eos_token_id = eos_token_id
2022-12-12 17:25:22 +00:00
self.stop_sequence_criterias = stop_sequence_criterias
2022-10-08 10:30:12 +00:00
self.max_new_tokens = max_new_tokens
self.current_tokens = 0
2022-12-16 15:03:39 +00:00
self.current_output = ""
2022-10-08 10:30:12 +00:00
2022-12-16 15:03:39 +00:00
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
2022-10-08 10:30:12 +00:00
self.current_tokens += 1
if self.current_tokens >= self.max_new_tokens:
2022-12-12 17:25:22 +00:00
return True, "length"
2022-12-16 15:03:39 +00:00
if last_token == self.eos_token_id:
return True, "eos_token"
self.current_output += last_output
2022-12-12 17:25:22 +00:00
for stop_sequence_criteria in self.stop_sequence_criterias:
2022-12-16 15:03:39 +00:00
if stop_sequence_criteria(self.current_output):
2022-12-12 17:25:22 +00:00
return True, "stop_sequence"
return False, None
@classmethod
def from_pb(
2023-01-20 11:24:39 +00:00
cls,
pb: generate_pb2.StoppingCriteriaParameters,
tokenizer: PreTrainedTokenizerBase,
2022-12-12 17:25:22 +00:00
) -> "StoppingCriteria":
2022-12-16 15:03:39 +00:00
stop_sequence_criterias = [
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
]
return StoppingCriteria(
tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens
)
2022-10-08 10:30:12 +00:00
def initialize_torch_distributed():
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
if torch.cuda.is_available():
# initialized `torch.distributed`
# Set the device id.
assert world_size <= torch.cuda.device_count(), "Each process is one gpu"
device = rank % torch.cuda.device_count()
torch.cuda.set_device(device)
backend = "nccl"
else:
backend = "gloo"
# Call the init process.
torch.distributed.init_process_group(
backend=backend,
world_size=world_size,
rank=rank,
2022-10-17 12:59:00 +00:00
timeout=timedelta(seconds=60),
2022-10-08 10:30:12 +00:00
)
return torch.distributed.distributed_c10d._get_default_group(), rank, world_size
2023-01-31 17:53:56 +00:00
def weight_hub_files(model_name, revision=None, extension=".safetensors"):
"""Get the safetensors filenames on the hub"""
api = HfApi()
2023-01-31 17:53:56 +00:00
info = api.model_info(model_name, revision=revision)
filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)]
return filenames
2023-01-31 17:53:56 +00:00
def try_to_load_from_cache(model_name, revision, filename):
"""Try to load a file from the Hugging Face cache"""
if revision is None:
revision = "main"
object_id = model_name.replace("/", "--")
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}"
if not repo_cache.is_dir():
# No cache for this model
return None
refs_dir = repo_cache / "refs"
snapshots_dir = repo_cache / "snapshots"
no_exist_dir = repo_cache / ".no_exist"
# Resolve refs (for instance to convert main to the associated commit sha)
if refs_dir.is_dir():
revision_file = refs_dir / revision
if revision_file.exists():
with revision_file.open() as f:
revision = f.read()
# Check if file is cached as "no_exist"
if (no_exist_dir / revision / filename).is_file():
return _CACHED_NO_EXIST
# Check if revision folder exists
if not snapshots_dir.exists():
return None
cached_shas = os.listdir(snapshots_dir)
if revision not in cached_shas:
# No cache for this revision and we won't try to return a random revision
return None
# Check if file exists in cache
cached_file = snapshots_dir / revision / filename
return str(cached_file) if cached_file.is_file() else None
def weight_files(model_name, revision=None, extension=".safetensors"):
"""Get the local safetensors filenames"""
if WEIGHTS_CACHE_OVERRIDE is not None:
return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))
2023-01-31 17:53:56 +00:00
filenames = weight_hub_files(model_name, revision, extension)
files = []
for filename in filenames:
2023-01-31 17:53:56 +00:00
cache_file = try_to_load_from_cache(
model_name, revision=revision, filename=filename
)
if cache_file is None:
raise LocalEntryNotFoundError(
f"File {filename} of model {model_name} not found in "
f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
f"Please run `text-generation-server download-weights {model_name}` first."
)
files.append(cache_file)
return files
2023-01-31 17:53:56 +00:00
def download_weights(model_name, revision=None, extension=".safetensors"):
"""Download the safetensors files from the hub"""
if WEIGHTS_CACHE_OVERRIDE is not None:
return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))
2023-01-31 17:53:56 +00:00
filenames = weight_hub_files(model_name, revision, extension)
download_function = partial(
hf_hub_download,
repo_id=model_name,
local_files_only=False,
)
executor = ThreadPoolExecutor(max_workers=5)
futures = [
2023-01-31 17:53:56 +00:00
executor.submit(download_function, filename=filename, revision=revision)
for filename in filenames
]
files = [
2022-12-08 17:49:33 +00:00
future.result()
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures))
]
return files