This commit is contained in:
Joel Lamy-Poirier 2023-05-05 18:48:57 -04:00
parent 0e648a71f9
commit 87b5f03958
No known key found for this signature in database
GPG Key ID: 82EE2141E842DFCF
3 changed files with 315 additions and 169 deletions

View File

@ -114,7 +114,9 @@ def get_model(
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
return santacoder_cls(model_id, revision, quantize=quantize) return santacoder_cls(model_id, revision, quantize=quantize)
config = AutoConfig.from_pretrained(model_id, revision=revision, trust_remote_code=True) config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=True
)
model_type = config.model_type model_type = config.model_type
if model_type == "bloom": if model_type == "bloom":

View File

@ -18,7 +18,9 @@ from text_generation_server.models.types import (
) )
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import StoppingCriteria from text_generation_server.utils import StoppingCriteria
from text_generation_server.utils.tokens_heterogeneous import HeterogeneousNextTokenChooser from text_generation_server.utils.tokens_heterogeneous import (
HeterogeneousNextTokenChooser,
)
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -32,7 +34,9 @@ class VectorizedCausalLMBatch(Batch):
# Decoder values # Decoder values
attention_mask: torch.Tensor attention_mask: torch.Tensor
position_ids: torch.Tensor position_ids: torch.Tensor
past_key_values: Optional[List[Union[torch.Tensor, Tuple[torch.Tensor], List[torch.Tensor]]]] past_key_values: Optional[
List[Union[torch.Tensor, Tuple[torch.Tensor], List[torch.Tensor]]]
]
# All tokens # All tokens
input_ids: torch.Tensor input_ids: torch.Tensor
@ -52,11 +56,11 @@ class VectorizedCausalLMBatch(Batch):
# Maximum number of tokens this batch will grow to # Maximum number of tokens this batch will grow to
max_tokens: int max_tokens: int
kv_cache_seq_dim:int=2 kv_cache_seq_dim: int = 2
# TODO: Get from requests (should these be lists?) # TODO: Get from requests (should these be lists?)
details:bool=os.environ.get("RETURN_DETAILS") is not None details: bool = os.environ.get("RETURN_DETAILS") is not None
generate_stream:bool=os.environ.get("GENERATE_STREAM") is not None generate_stream: bool = os.environ.get("GENERATE_STREAM") is not None
def to_pb(self) -> generate_pb2.Batch: def to_pb(self) -> generate_pb2.Batch:
return generate_pb2.Batch( return generate_pb2.Batch(
@ -74,15 +78,22 @@ class VectorizedCausalLMBatch(Batch):
device: torch.device, device: torch.device,
) -> "VectorizedCausalLMBatch": ) -> "VectorizedCausalLMBatch":
inputs = [r.inputs for r in pb.requests] inputs = [r.inputs for r in pb.requests]
offsets = [None]*len(inputs) offsets = [None] * len(inputs)
token_offsets = [None]*len(inputs) token_offsets = [None] * len(inputs)
requests_idx_mapping = {r.id:i for i, r in enumerate(pb.requests)} requests_idx_mapping = {r.id: i for i, r in enumerate(pb.requests)}
# Parse batch # Parse batch
stopping_criterias = [StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) for r in pb.requests] stopping_criterias = [
max_new_tokens=(stopping_criteria.max_new_tokens for stopping_criteria in stopping_criterias) StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
for r in pb.requests
]
max_new_tokens = (
stopping_criteria.max_new_tokens for stopping_criteria in stopping_criterias
)
next_token_chooser= HeterogeneousNextTokenChooser.from_pb([r.parameters for r in pb.requests], device) next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
[r.parameters for r in pb.requests], device
)
tokenized_inputs = tokenizer( tokenized_inputs = tokenizer(
inputs, inputs,
@ -96,7 +107,7 @@ class VectorizedCausalLMBatch(Batch):
input_lengths = tokenized_inputs["attention_mask"].sum(1) input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max().item() max_input_length = input_lengths.max().item()
input_shape=(pb.size, max_input_length + max(max_new_tokens)) input_shape = (pb.size, max_input_length + max(max_new_tokens))
# Allocate maximum attention_mask # Allocate maximum attention_mask
attention_mask = torch.empty(input_shape, dtype=torch.bool, device=device) attention_mask = torch.empty(input_shape, dtype=torch.bool, device=device)
@ -112,7 +123,10 @@ class VectorizedCausalLMBatch(Batch):
max_tokens = len(inputs) * max_input_length + sum(max_new_tokens) max_tokens = len(inputs) * max_input_length + sum(max_new_tokens)
generate_stream=cls.generate_stream or any(stopping_criteria.stop_sequence_criterias for stopping_criteria in stopping_criterias) generate_stream = cls.generate_stream or any(
stopping_criteria.stop_sequence_criterias
for stopping_criteria in stopping_criterias
)
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
@ -133,7 +147,9 @@ class VectorizedCausalLMBatch(Batch):
) )
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
def filter(self, requests: List[generate_pb2.Request]) -> Optional["VectorizedCausalLMBatch"]: def filter(
self, requests: List[generate_pb2.Request]
) -> Optional["VectorizedCausalLMBatch"]:
if len(requests) == 0: if len(requests) == 0:
raise ValueError("Batch must have at least one request") raise ValueError("Batch must have at least one request")
if len(requests) == len(self): if len(requests) == len(self):
@ -143,70 +159,108 @@ class VectorizedCausalLMBatch(Batch):
keep_indices = [self.requests_idx_mapping[r.id] for r in self.requests] keep_indices = [self.requests_idx_mapping[r.id] for r in self.requests]
# New values after filtering # New values after filtering
self.requests_idx_mapping={r.id:i for i, r in enumerate(self.requests)} self.requests_idx_mapping = {r.id: i for i, r in enumerate(self.requests)}
self.input_lengths=[self.input_lengths[i] for i in keep_indices] self.input_lengths = [self.input_lengths[i] for i in keep_indices]
self.offsets = [self.offsets[i] for i in keep_indices] self.offsets = [self.offsets[i] for i in keep_indices]
self.token_offsets = [self.token_offsets[i] for i in keep_indices] self.token_offsets = [self.token_offsets[i] for i in keep_indices]
self.next_token_chooser=HeterogeneousNextTokenChooser.from_pb([r.parameters for r in self.requests], self.input_ids.device) self.next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
[r.parameters for r in self.requests], self.input_ids.device
)
self.stopping_criterias = [self.stopping_criterias[i] for i in keep_indices] self.stopping_criterias = [self.stopping_criterias[i] for i in keep_indices]
remaining_decode_tokens=[stopping_criteria.max_new_tokens - stopping_criteria.current_tokens for stopping_criteria in self.stopping_criterias] remaining_decode_tokens = [
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
for stopping_criteria in self.stopping_criterias
]
# Select the remaining indices and remove unnecessary padding # Select the remaining indices and remove unnecessary padding
max_input_length=max(self.input_lengths) max_input_length = max(self.input_lengths)
sequence_slice=slice(self.max_input_length-max_input_length, self.max_input_length+max(remaining_decode_tokens)) sequence_slice = slice(
self.max_input_length=max_input_length self.max_input_length - max_input_length,
self.max_tokens = len(self.requests) * self.max_input_length + sum(remaining_decode_tokens) self.max_input_length + max(remaining_decode_tokens),
)
self.max_input_length = max_input_length
self.max_tokens = len(self.requests) * self.max_input_length + sum(
remaining_decode_tokens
)
self.input_ids = self.input_ids[keep_indices,sequence_slice] self.input_ids = self.input_ids[keep_indices, sequence_slice]
self.position_ids = self.position_ids[keep_indices,sequence_slice] self.position_ids = self.position_ids[keep_indices, sequence_slice]
self.attention_mask = self.attention_mask[keep_indices,sequence_slice] self.attention_mask = self.attention_mask[keep_indices, sequence_slice]
tensors_to_update = [] tensors_to_update = []
if self.past_key_values is not None: if self.past_key_values is not None:
if not isinstance(self.past_key_values,(list, tuple)): if not isinstance(self.past_key_values, (list, tuple)):
raise NotImplementedError(f"Unsupported kv cache type: {type(self.past_key_values)}") raise NotImplementedError(
f"Unsupported kv cache type: {type(self.past_key_values)}"
)
for layer_kv in self.past_key_values: for layer_kv in self.past_key_values:
if isinstance(layer_kv, torch.Tensor): if isinstance(layer_kv, torch.Tensor):
tensors_to_update.append(layer_kv) tensors_to_update.append(layer_kv)
elif isinstance(layer_kv,(list, tuple)): elif isinstance(layer_kv, (list, tuple)):
tensors_to_update.extend(layer_kv) tensors_to_update.extend(layer_kv)
else: else:
raise NotImplementedError(f"Unsupported layer kv cache type: {type(layer_kv)}") raise NotImplementedError(
f"Unsupported layer kv cache type: {type(layer_kv)}"
)
kv_cache_slice=[keep_indices, *(slice(None) for _ in range(1, self.kv_cache_seq_dim)), sequence_slice] kv_cache_slice = [
keep_indices,
*(slice(None) for _ in range(1, self.kv_cache_seq_dim)),
sequence_slice,
]
for tensor in tensors_to_update: for tensor in tensors_to_update:
# Update tensors in-place to allow incremental garbage collection # Update tensors in-place to allow incremental garbage collection
tensors_to_update.data=tensor[kv_cache_slice] tensors_to_update.data = tensor[kv_cache_slice]
return self return self
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["VectorizedCausalLMBatch"]) -> "VectorizedCausalLMBatch": def concatenate(
if len(batches)==0: cls, batches: List["VectorizedCausalLMBatch"]
) -> "VectorizedCausalLMBatch":
if len(batches) == 0:
raise ValueError("Cannot concatenate empty list.") raise ValueError("Cannot concatenate empty list.")
requests=[request for batch in batches for request in batch.requests] requests = [request for batch in batches for request in batch.requests]
batch_sizes=[len(batch.requests) for batch in batches] batch_sizes = [len(batch.requests) for batch in batches]
batch_size=sum(batch_sizes) batch_size = sum(batch_sizes)
end_indices=torch.tensor(batch_sizes).cumsum(0).tolist() end_indices = torch.tensor(batch_sizes).cumsum(0).tolist()
start_indices=[0]+end_indices[:-1] start_indices = [0] + end_indices[:-1]
input_lengths = [length for batch in batches for length in batch.input_lengths] input_lengths = [length for batch in batches for length in batch.input_lengths]
offsets = [offset for batch in batches for offset in batch.offsets] offsets = [offset for batch in batches for offset in batch.offsets]
token_offsets = [token_offset for batch in batches for token_offset in batch.token_offsets] token_offsets = [
next_token_chooser=HeterogeneousNextTokenChooser.from_pb([r.parameters for r in requests], batches[0].input_ids.device) token_offset for batch in batches for token_offset in batch.token_offsets
stopping_criterias = [stopping_criteria for batch in batches for stopping_criteria in batch.stopping_criterias] ]
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
[r.parameters for r in requests], batches[0].input_ids.device
)
stopping_criterias = [
stopping_criteria
for batch in batches
for stopping_criteria in batch.stopping_criterias
]
requests_idx_mapping = {k: v + start_index for batch, start_index in zip(batches, start_indices) for k, v in batch.requests_idx_mapping.items()} requests_idx_mapping = {
k: v + start_index
for batch, start_index in zip(batches, start_indices)
for k, v in batch.requests_idx_mapping.items()
}
max_input_length=max(input_lengths) max_input_length = max(input_lengths)
left_indices=[max_input_length-batch.max_input_length for batch in batches] left_indices = [max_input_length - batch.max_input_length for batch in batches]
input_shape=(batch_size, max_input_length + max(batch.input_ids.size(1)-batch.max_input_length for batch in batches)) input_shape = (
device=batches[0].input_ids.device batch_size,
max_input_length
+ max(
batch.input_ids.size(1) - batch.max_input_length for batch in batches
),
)
device = batches[0].input_ids.device
# Allocate maximum attention_mask # Allocate maximum attention_mask
attention_mask = torch.empty(input_shape, dtype=torch.bool, device=device) attention_mask = torch.empty(input_shape, dtype=torch.bool, device=device)
@ -217,56 +271,84 @@ class VectorizedCausalLMBatch(Batch):
# TODO : only needed for prefill # TODO : only needed for prefill
input_ids[:, :max_input_length].fill_(0) input_ids[:, :max_input_length].fill_(0)
for batch,start_index, end_index, left_index in zip(batches, start_indices, end_indices, left_indices): for batch, start_index, end_index, left_index in zip(
attention_mask[start_index:end_index, left_index:max_input_length].copy_(batch.attention_mask[:, :batch.max_input_length]) batches, start_indices, end_indices, left_indices
input_ids[start_index:end_index, left_index:max_input_length].copy_(batch.input_ids[:, :batch.max_input_length]) ):
attention_mask[start_index:end_index, left_index:max_input_length].copy_(
batch.attention_mask[:, : batch.max_input_length]
)
input_ids[start_index:end_index, left_index:max_input_length].copy_(
batch.input_ids[:, : batch.max_input_length]
)
position_ids = attention_mask.cumsum(-1).sub_(1) position_ids = attention_mask.cumsum(-1).sub_(1)
position_ids[:, :max_input_length].relu_() position_ids[:, :max_input_length].relu_()
max_tokens = sum(batch.max_tokens + (max_input_length - batch.max_input_length) * len(batch) for batch in batches) max_tokens = sum(
batch.max_tokens + (max_input_length - batch.max_input_length) * len(batch)
for batch in batches
)
kv_formats=None kv_formats = None
for batch in batches: for batch in batches:
if batch.past_key_values is None: if batch.past_key_values is None:
raise ValueError("Only concatenate prefilled batches") raise ValueError("Only concatenate prefilled batches")
if not isinstance(batch.past_key_values, (list, tuple)): if not isinstance(batch.past_key_values, (list, tuple)):
raise NotImplementedError(f"Unsupported kv cache type: {type(batch.past_key_values)}") raise NotImplementedError(
f"Unsupported kv cache type: {type(batch.past_key_values)}"
)
if kv_formats is None: if kv_formats is None:
num_layers=len(batch.past_key_values) num_layers = len(batch.past_key_values)
if num_layers==0: if num_layers == 0:
raise ValueError("Empty KV cache") raise ValueError("Empty KV cache")
kv_formats = [0]*num_layers kv_formats = [0] * num_layers
elif len(batch.past_key_values)!=len(kv_formats): elif len(batch.past_key_values) != len(kv_formats):
raise ValueError("Num layers is not constant") raise ValueError("Num layers is not constant")
for i, layer_kv in enumerate(batch.past_key_values): for i, layer_kv in enumerate(batch.past_key_values):
if isinstance(layer_kv, (list, tuple)): if isinstance(layer_kv, (list, tuple)):
kv_format = len(layer_kv) kv_format = len(layer_kv)
else: else:
kv_format=None kv_format = None
if kv_formats[i]==0: if kv_formats[i] == 0:
if kv_format==0: if kv_format == 0:
raise ValueError("Empty KV cache") raise ValueError("Empty KV cache")
kv_formats[i]=kv_format kv_formats[i] = kv_format
elif kv_formats[i]!=kv_format: elif kv_formats[i] != kv_format:
raise ValueError("Incompatible KV cache format.") raise ValueError("Incompatible KV cache format.")
kv_cache_seq_dim=batches[0].kv_cache_seq_dim kv_cache_seq_dim = batches[0].kv_cache_seq_dim
past_key_values=[] past_key_values = []
for i, kv_format in enumerate(kv_formats): for i, kv_format in enumerate(kv_formats):
for j in range(1 if kv_format is None else kv_format): for j in range(1 if kv_format is None else kv_format):
tensors_to_merge=[batch.past_key_values[i] for batch in batches] tensors_to_merge = [batch.past_key_values[i] for batch in batches]
# Generally `max_input_length`, unless the model allocates more than needed. # Generally `max_input_length`, unless the model allocates more than needed.
right_indices=[left_index+tensor.size(kv_cache_seq_dim) for tensor, left_index in zip(tensors_to_merge, left_indices)] right_indices = [
combined_shape=[batch_size]+list(tensors_to_merge[0].shape[1:]) left_index + tensor.size(kv_cache_seq_dim)
combined_shape[kv_cache_seq_dim]=max(right_indices) for tensor, left_index in zip(tensors_to_merge, left_indices)
]
combined_shape = [batch_size] + list(tensors_to_merge[0].shape[1:])
combined_shape[kv_cache_seq_dim] = max(right_indices)
# Set to zero to avoid propagating nans in padded values. # Set to zero to avoid propagating nans in padded values.
kv_cache = torch.zeros(combined_shape, dtype=tensors_to_merge[0].dtype, device=device) kv_cache = torch.zeros(
for tensor, start_index, end_index, left_index, right_index in zip(tensors_to_merge, start_indices, end_indices, left_indices, right_indices): combined_shape, dtype=tensors_to_merge[0].dtype, device=device
kv_cache[[slice(start_index, end_index), *(slice(None) for _ in range(1, kv_cache_seq_dim)), slice(left_index,right_index)]].copy_(tensor) )
for tensor, start_index, end_index, left_index, right_index in zip(
tensors_to_merge,
start_indices,
end_indices,
left_indices,
right_indices,
):
kv_cache[
[
slice(start_index, end_index),
*(slice(None) for _ in range(1, kv_cache_seq_dim)),
slice(left_index, right_index),
]
].copy_(tensor)
if kv_format is None: if kv_format is None:
past_key_values.append(kv_cache) past_key_values.append(kv_cache)
elif j==0: elif j == 0:
past_key_values.append([kv_cache]) past_key_values.append([kv_cache])
else: else:
past_key_values[-1].append(kv_cache) past_key_values[-1].append(kv_cache)
@ -350,58 +432,75 @@ class VectorizedCausalLM(Model):
def generate_token( def generate_token(
self, batch: VectorizedCausalLMBatch self, batch: VectorizedCausalLMBatch
) -> Tuple[List[Generation], Optional[VectorizedCausalLMBatch]]: ) -> Tuple[List[Generation], Optional[VectorizedCausalLMBatch]]:
key_length=batch.max_input_length key_length = batch.max_input_length
if key_length>batch.input_ids.size(1): if key_length > batch.input_ids.size(1):
raise RuntimeError("Cannot generate more than `max_tokens`.") raise RuntimeError("Cannot generate more than `max_tokens`.")
query_length=key_length if batch.past_key_values is None else 1 query_length = key_length if batch.past_key_values is None else 1
input_ids=batch.input_ids[:, key_length-query_length: key_length] input_ids = batch.input_ids[:, key_length - query_length : key_length]
outputs = self.model.forward( outputs = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
attention_mask=batch.attention_mask[:, : key_length], attention_mask=batch.attention_mask[:, :key_length],
position_ids=batch.position_ids[:, key_length-query_length: key_length], position_ids=batch.position_ids[:, key_length - query_length : key_length],
past_key_values=batch.past_key_values, past_key_values=batch.past_key_values,
) )
# TODO: Post-processing # TODO: Post-processing
next_token_ids, logprobs = batch.next_token_chooser(input_ids, outputs.logits, batch.details) next_token_ids, logprobs = batch.next_token_chooser(
input_ids, outputs.logits, batch.details
)
if batch.generate_stream: if batch.generate_stream:
# TODO: self.decode_token, offsets? # TODO: self.decode_token, offsets?
next_token_texts=self.tokenizer.batch_decode(next_token_ids.tolist()) next_token_texts = self.tokenizer.batch_decode(next_token_ids.tolist())
if batch.details: if batch.details:
token_logprobs=logprobs[:, -1, :].gather(1, next_token_ids.unsqueeze(1)).squeeze(1).tolist() token_logprobs = (
if query_length>1: logprobs[:, -1, :]
prefill_token_ids=batch.input_ids[:, :key_length].tolist() .gather(1, next_token_ids.unsqueeze(1))
prefill_logprobs=logprobs.gather(2, batch.input_ids[:, 1:key_length, None]).squeeze(2).tolist() .squeeze(1)
prefill_tokens=[] .tolist()
for prefill_token_ids_, prefill_logprobs_, input_length in zip(prefill_token_ids, prefill_logprobs, batch.input_lengths): )
prefill_token_ids_=prefill_token_ids_[-input_length:] if query_length > 1:
prefill_token_ids = batch.input_ids[:, :key_length].tolist()
prefill_logprobs = (
logprobs.gather(2, batch.input_ids[:, 1:key_length, None])
.squeeze(2)
.tolist()
)
prefill_tokens = []
for prefill_token_ids_, prefill_logprobs_, input_length in zip(
prefill_token_ids, prefill_logprobs, batch.input_lengths
):
prefill_token_ids_ = prefill_token_ids_[-input_length:]
prefill_texts = self.tokenizer.batch_decode( prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids_, prefill_token_ids_,
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
skip_special_tokens=False, skip_special_tokens=False,
) )
prefill_tokens.append(PrefillTokens( prefill_tokens.append(
prefill_token_ids_, [math.nan, *prefill_logprobs_], prefill_texts PrefillTokens(
)) prefill_token_ids_,
[math.nan, *prefill_logprobs_],
prefill_texts,
)
)
# Update batch # Update batch
# TODO: Why do we need all input ids? # TODO: Why do we need all input ids?
batch.input_ids[:, key_length].copy_(next_token_ids) batch.input_ids[:, key_length].copy_(next_token_ids)
batch.past_key_values=outputs.past_key_values batch.past_key_values = outputs.past_key_values
batch.input_lengths=[length+1 for length in batch.input_lengths] batch.input_lengths = [length + 1 for length in batch.input_lengths]
batch.max_input_length+=1 batch.max_input_length += 1
# TODO: Vectorize some of this? # TODO: Vectorize some of this?
generations: List[Generation] = [] generations: List[Generation] = []
next_batch=None next_batch = None
for i, next_token_id in enumerate(next_token_ids): for i, next_token_id in enumerate(next_token_ids):
next_token_text=next_token_texts[i] if batch.generate_stream else "" next_token_text = next_token_texts[i] if batch.generate_stream else ""
stopping_criterias=batch.stopping_criterias[i] stopping_criterias = batch.stopping_criterias[i]
stop, reason = stopping_criterias( stop, reason = stopping_criterias(
next_token_id, next_token_id,
next_token_text, next_token_text,
@ -421,10 +520,9 @@ class VectorizedCausalLM(Model):
generated_text = None generated_text = None
next_batch = batch next_batch = batch
generation = Generation( generation = Generation(
batch.requests[i].id, batch.requests[i].id,
prefill_tokens[i] if batch.details and query_length>1 else None, prefill_tokens[i] if batch.details and query_length > 1 else None,
next_token_id, next_token_id,
token_logprobs[i] if batch.details else 0.0, token_logprobs[i] if batch.details else 0.0,
next_token_text, next_token_text,
@ -435,4 +533,3 @@ class VectorizedCausalLM(Model):
generations.append(generation) generations.append(generation)
return generations, next_batch return generations, next_batch

View File

@ -24,8 +24,10 @@ class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
""" """
def __init__(self, penalty: List[float], device:torch.device): def __init__(self, penalty: List[float], device: torch.device):
self.penalty = torch.tensor(penalty, dtype=torch.float32, device=device).unsqueeze(1) self.penalty = torch.tensor(
penalty, dtype=torch.float32, device=device
).unsqueeze(1)
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
score = torch.gather(scores, 1, input_ids) score = torch.gather(scores, 1, input_ids)
@ -36,6 +38,7 @@ class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
scores.scatter_(1, input_ids, score) scores.scatter_(1, input_ids, score)
return scores return scores
class HeterogeneousTemperatureLogitsWarper(LogitsWarper): class HeterogeneousTemperatureLogitsWarper(LogitsWarper):
r""" r"""
[`LogitsWarper`] for temperature (exponential scaling output probability distribution). [`LogitsWarper`] for temperature (exponential scaling output probability distribution).
@ -47,13 +50,16 @@ class HeterogeneousTemperatureLogitsWarper(LogitsWarper):
The value used to module the logits distribution. The value used to module the logits distribution.
""" """
def __init__(self, temperature: List[float], device:torch.device): def __init__(self, temperature: List[float], device: torch.device):
self.temperature = torch.tensor(temperature, dtype=torch.float32, device=device).unsqueeze(1) self.temperature = torch.tensor(
temperature, dtype=torch.float32, device=device
).unsqueeze(1)
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
scores.div_(self.temperature) scores.div_(self.temperature)
return scores return scores
class HeterogeneousTopPLogitsWarper(LogitsWarper): class HeterogeneousTopPLogitsWarper(LogitsWarper):
""" """
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
@ -70,8 +76,16 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
Minimum number of tokens that cannot be filtered. Minimum number of tokens that cannot be filtered.
""" """
def __init__(self, top_p: List[float], device:torch.device, filter_value: float = -math.inf, min_tokens_to_keep: int = 1): def __init__(
self.top_p = torch.tensor(top_p, dtype=torch.float32, device=device).unsqueeze(1) self,
top_p: List[float],
device: torch.device,
filter_value: float = -math.inf,
min_tokens_to_keep: int = 1,
):
self.top_p = torch.tensor(top_p, dtype=torch.float32, device=device).unsqueeze(
1
)
self.filter_value = filter_value self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep self.min_tokens_to_keep = min_tokens_to_keep
@ -86,10 +100,13 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0 sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
# scatter sorted tensors to original indexing # scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
scores.masked_fill_(indices_to_remove, self.filter_value) scores.masked_fill_(indices_to_remove, self.filter_value)
return scores return scores
class HeterogeneousTopKLogitsWarper(LogitsWarper): class HeterogeneousTopKLogitsWarper(LogitsWarper):
r""" r"""
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
@ -105,10 +122,20 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper):
Minimum number of tokens that cannot be filtered. Minimum number of tokens that cannot be filtered.
""" """
def __init__(self, top_k: List[int], device:torch.device, filter_value: float = -math.inf, min_tokens_to_keep: int = 1): def __init__(
self,
top_k: List[int],
device: torch.device,
filter_value: float = -math.inf,
min_tokens_to_keep: int = 1,
):
self.max_top_k = max(top_k) self.max_top_k = max(top_k)
self.top_k = torch.tensor([max(x - 1, min_tokens_to_keep-1) for x in top_k], dtype=torch.int64,device=device).unsqueeze(1) self.top_k = torch.tensor(
zeros=[x == 0 for x in top_k] [max(x - 1, min_tokens_to_keep - 1) for x in top_k],
dtype=torch.int64,
device=device,
).unsqueeze(1)
zeros = [x == 0 for x in top_k]
if any(zeros): if any(zeros):
self.top_k_mask = torch.tensor(zeros, dtype=torch.bool, device=device) self.top_k_mask = torch.tensor(zeros, dtype=torch.bool, device=device)
else: else:
@ -116,13 +143,13 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper):
self.filter_value = filter_value self.filter_value = filter_value
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
if scores.size(-1)>self.max_top_k: # Safety check if scores.size(-1) > self.max_top_k: # Safety check
max_top_k=scores.size(-1) max_top_k = scores.size(-1)
top_k=torch.clamp_max(self.top_k,max_top_k) # Run only if needed. top_k = torch.clamp_max(self.top_k, max_top_k) # Run only if needed.
else: else:
max_top_k=self.max_top_k max_top_k = self.max_top_k
top_k=self.top_k top_k = self.top_k
kth_scores=torch.gather(torch.topk(scores, max_top_k)[0], 1, top_k) kth_scores = torch.gather(torch.topk(scores, max_top_k)[0], 1, top_k)
if self.top_k_mask is not None: if self.top_k_mask is not None:
kth_scores.masked_fill_(self.top_k_mask, self.filter_value) kth_scores.masked_fill_(self.top_k_mask, self.filter_value)
# Remove all tokens with a probability less than the last token of the top-k # Remove all tokens with a probability less than the last token of the top-k
@ -147,7 +174,13 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper):
Minimum number of tokens that cannot be filtered. Minimum number of tokens that cannot be filtered.
""" """
def __init__(self, mass: List[float], device:torch.device, filter_value: float = -math.inf, min_tokens_to_keep: int = 1): def __init__(
self,
mass: List[float],
device: torch.device,
filter_value: float = -math.inf,
min_tokens_to_keep: int = 1,
):
self.filter_value = filter_value self.filter_value = filter_value
self.mass = torch.tensor(mass, dtype=torch.float32, device=device).unsqueeze(1) self.mass = torch.tensor(mass, dtype=torch.float32, device=device).unsqueeze(1)
self.min_tokens_to_keep = min_tokens_to_keep self.min_tokens_to_keep = min_tokens_to_keep
@ -167,11 +200,15 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper):
# Remove tokens with cumulative mass above the threshold # Remove tokens with cumulative mass above the threshold
last_ind = (cumulative_probs < self.mass).sum(dim=1) last_ind = (cumulative_probs < self.mass).sum(dim=1)
last_ind[last_ind < 0] = 0 last_ind[last_ind < 0] = 0
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1)) sorted_indices_to_remove = sorted_scores > sorted_scores.gather(
1, last_ind.view(-1, 1)
)
if self.min_tokens_to_keep > 1: if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
scores = scores.masked_fill_(indices_to_remove, self.filter_value) scores = scores.masked_fill_(indices_to_remove, self.filter_value)
return scores return scores
@ -181,103 +218,113 @@ class HeterogeneousSampling:
r""" r"""
Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample. Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.
""" """
def __init__(self, do_sample:List[bool], seeds: List[int], device: torch.device):
def __init__(self, do_sample: List[bool], seeds: List[int], device: torch.device):
self.seeds = seeds self.seeds = seeds
self.greedy=Greedy() self.greedy = Greedy()
# TODO: Most seeds are ignored # TODO: Most seeds are ignored
self.sampling=Sampling(seeds[0], device) self.sampling = Sampling(seeds[0], device)
self.do_sample=torch.tensor(do_sample, dtype=torch.bool, device=device) self.do_sample = torch.tensor(do_sample, dtype=torch.bool, device=device)
def __call__(self, logits): def __call__(self, logits):
return torch.where(self.do_sample, self.sampling(logits), self.greedy(logits)) return torch.where(self.do_sample, self.sampling(logits), self.greedy(logits))
class HeterogeneousNextTokenChooser: class HeterogeneousNextTokenChooser:
def __init__( def __init__(
self, self,
*, *,
batch_size:int, batch_size: int,
device:torch.device, device: torch.device,
watermark:Optional[Union[bool,List[Optional[bool]]]]=None, watermark: Optional[Union[bool, List[Optional[bool]]]] = None,
temperature:Optional[Union[float,List[Optional[float]]]]=None, temperature: Optional[Union[float, List[Optional[float]]]] = None,
repetition_penalty:Optional[Union[float,List[Optional[float]]]]=None, repetition_penalty: Optional[Union[float, List[Optional[float]]]] = None,
top_k:Optional[Union[int,List[Optional[int]]]]=None, top_k: Optional[Union[int, List[Optional[int]]]] = None,
top_p:Optional[Union[float,List[Optional[float]]]]=None, top_p: Optional[Union[float, List[Optional[float]]]] = None,
typical_p:Optional[Union[float,List[Optional[float]]]]=None, typical_p: Optional[Union[float, List[Optional[float]]]] = None,
do_sample:Optional[Union[bool,List[Optional[bool]]]]=None, do_sample: Optional[Union[bool, List[Optional[bool]]]] = None,
seeds:Optional[Union[int,List[Optional[int]]]]=None, seeds: Optional[Union[int, List[Optional[int]]]] = None,
): ):
# TODO: Most seeds are ignored # TODO: Most seeds are ignored
seeds=self._standardize(seeds, batch_size, 0) seeds = self._standardize(seeds, batch_size, 0)
do_sample=self._standardize(do_sample, batch_size, False) do_sample = self._standardize(do_sample, batch_size, False)
warpers = LogitsProcessorList() warpers = LogitsProcessorList()
watermark=self._standardize(watermark, batch_size, False) watermark = self._standardize(watermark, batch_size, False)
if any(watermark): if any(watermark):
raise NotImplementedError("Watermarking not implemented") raise NotImplementedError("Watermarking not implemented")
repetition_penalty=self._standardize(repetition_penalty, batch_size, 1.0) repetition_penalty = self._standardize(repetition_penalty, batch_size, 1.0)
if any([x!=1.0 for x in repetition_penalty]): if any([x != 1.0 for x in repetition_penalty]):
warpers.append(HeterogeneousRepetitionPenaltyLogitsProcessor(repetition_penalty, device)) warpers.append(
HeterogeneousRepetitionPenaltyLogitsProcessor(
repetition_penalty, device
)
)
temperature=self._standardize(temperature, batch_size, 1.0) temperature = self._standardize(temperature, batch_size, 1.0)
if any([x!=1.0 for x in temperature]): if any([x != 1.0 for x in temperature]):
do_sample=[sample or x!=1.0 for x, sample in zip(temperature, do_sample)] do_sample = [
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
]
warpers.append(HeterogeneousTemperatureLogitsWarper(temperature, device)) warpers.append(HeterogeneousTemperatureLogitsWarper(temperature, device))
top_k=self._standardize(top_k, batch_size, 0) top_k = self._standardize(top_k, batch_size, 0)
n_top_k=sum([x!=0 for x in top_k]) n_top_k = sum([x != 0 for x in top_k])
if n_top_k>0: if n_top_k > 0:
do_sample=[sample or x!=0 for x, sample in zip(top_k, do_sample)] do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]
warpers.append(HeterogeneousTopKLogitsWarper(top_k, device)) warpers.append(HeterogeneousTopKLogitsWarper(top_k, device))
top_p=self._standardize(top_p, batch_size, 1.0) top_p = self._standardize(top_p, batch_size, 1.0)
if any([x<1.0 for x in top_p]): if any([x < 1.0 for x in top_p]):
do_sample=[sample or x<1.0 for x, sample in zip(top_p, do_sample)] do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)]
warpers.append(HeterogeneousTopPLogitsWarper(top_p, device)) warpers.append(HeterogeneousTopPLogitsWarper(top_p, device))
typical_p=self._standardize(typical_p, batch_size, 1.0) typical_p = self._standardize(typical_p, batch_size, 1.0)
if any([x<1.0 for x in typical_p]): if any([x < 1.0 for x in typical_p]):
do_sample=[sample or x<1.0 for x, sample in zip(typical_p, do_sample)] do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)]
warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, device)) warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, device))
self.warpers=warpers self.warpers = warpers
num_do_sample=sum(do_sample) num_do_sample = sum(do_sample)
if num_do_sample==0: if num_do_sample == 0:
self.choice=Greedy() self.choice = Greedy()
elif num_do_sample<batch_size: elif num_do_sample < batch_size:
self.choice=HeterogeneousSampling(do_sample, seeds, device) self.choice = HeterogeneousSampling(do_sample, seeds, device)
else: else:
# TODO: Most seeds are ignored # TODO: Most seeds are ignored
self.choice=Sampling(seeds[0], device) self.choice = Sampling(seeds[0], device)
@staticmethod @staticmethod
def _standardize(values, batch_size, default): def _standardize(values, batch_size, default):
if isinstance(values, list): if isinstance(values, list):
values=values.copy() values = values.copy()
else: else:
values=[values]*batch_size values = [values] * batch_size
assert len(values)==batch_size assert len(values) == batch_size
for i, v in enumerate(values): for i, v in enumerate(values):
if v is None: if v is None:
values[i]=default values[i] = default
return values return values
def __call__(self, input_ids:torch.Tensor, scores:torch.Tensor, return_logprobs:bool): def __call__(
last_token_scores=self.warpers(input_ids, scores[:, -1, :]) self, input_ids: torch.Tensor, scores: torch.Tensor, return_logprobs: bool
next_token_ids=self.choice(last_token_scores) ):
last_token_scores = self.warpers(input_ids, scores[:, -1, :])
next_token_ids = self.choice(last_token_scores)
if return_logprobs: if return_logprobs:
# Compute logprobs # Compute logprobs
if scores.size(1)==1: if scores.size(1) == 1:
scores=last_token_scores.unsqueeze(1) scores = last_token_scores.unsqueeze(1)
else: else:
# TODO: Post-process all the tokens? # TODO: Post-process all the tokens?
scores[:, -1, :]=last_token_scores scores[:, -1, :] = last_token_scores
logprobs = torch.log_softmax(scores, dim=-1) logprobs = torch.log_softmax(scores, dim=-1)
else: else:
logprobs=None logprobs = None
return next_token_ids, logprobs return next_token_ids, logprobs