mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
format
This commit is contained in:
parent
0e648a71f9
commit
87b5f03958
@ -114,7 +114,9 @@ def get_model(
|
|||||||
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
|
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
|
||||||
return santacoder_cls(model_id, revision, quantize=quantize)
|
return santacoder_cls(model_id, revision, quantize=quantize)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(model_id, revision=revision, trust_remote_code=True)
|
config = AutoConfig.from_pretrained(
|
||||||
|
model_id, revision=revision, trust_remote_code=True
|
||||||
|
)
|
||||||
model_type = config.model_type
|
model_type = config.model_type
|
||||||
|
|
||||||
if model_type == "bloom":
|
if model_type == "bloom":
|
||||||
|
@ -18,7 +18,9 @@ from text_generation_server.models.types import (
|
|||||||
)
|
)
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.utils import StoppingCriteria
|
from text_generation_server.utils import StoppingCriteria
|
||||||
from text_generation_server.utils.tokens_heterogeneous import HeterogeneousNextTokenChooser
|
from text_generation_server.utils.tokens_heterogeneous import (
|
||||||
|
HeterogeneousNextTokenChooser,
|
||||||
|
)
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -32,7 +34,9 @@ class VectorizedCausalLMBatch(Batch):
|
|||||||
# Decoder values
|
# Decoder values
|
||||||
attention_mask: torch.Tensor
|
attention_mask: torch.Tensor
|
||||||
position_ids: torch.Tensor
|
position_ids: torch.Tensor
|
||||||
past_key_values: Optional[List[Union[torch.Tensor, Tuple[torch.Tensor], List[torch.Tensor]]]]
|
past_key_values: Optional[
|
||||||
|
List[Union[torch.Tensor, Tuple[torch.Tensor], List[torch.Tensor]]]
|
||||||
|
]
|
||||||
|
|
||||||
# All tokens
|
# All tokens
|
||||||
input_ids: torch.Tensor
|
input_ids: torch.Tensor
|
||||||
@ -52,11 +56,11 @@ class VectorizedCausalLMBatch(Batch):
|
|||||||
# Maximum number of tokens this batch will grow to
|
# Maximum number of tokens this batch will grow to
|
||||||
max_tokens: int
|
max_tokens: int
|
||||||
|
|
||||||
kv_cache_seq_dim:int=2
|
kv_cache_seq_dim: int = 2
|
||||||
|
|
||||||
# TODO: Get from requests (should these be lists?)
|
# TODO: Get from requests (should these be lists?)
|
||||||
details:bool=os.environ.get("RETURN_DETAILS") is not None
|
details: bool = os.environ.get("RETURN_DETAILS") is not None
|
||||||
generate_stream:bool=os.environ.get("GENERATE_STREAM") is not None
|
generate_stream: bool = os.environ.get("GENERATE_STREAM") is not None
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.Batch:
|
def to_pb(self) -> generate_pb2.Batch:
|
||||||
return generate_pb2.Batch(
|
return generate_pb2.Batch(
|
||||||
@ -74,15 +78,22 @@ class VectorizedCausalLMBatch(Batch):
|
|||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "VectorizedCausalLMBatch":
|
) -> "VectorizedCausalLMBatch":
|
||||||
inputs = [r.inputs for r in pb.requests]
|
inputs = [r.inputs for r in pb.requests]
|
||||||
offsets = [None]*len(inputs)
|
offsets = [None] * len(inputs)
|
||||||
token_offsets = [None]*len(inputs)
|
token_offsets = [None] * len(inputs)
|
||||||
requests_idx_mapping = {r.id:i for i, r in enumerate(pb.requests)}
|
requests_idx_mapping = {r.id: i for i, r in enumerate(pb.requests)}
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
stopping_criterias = [StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) for r in pb.requests]
|
stopping_criterias = [
|
||||||
max_new_tokens=(stopping_criteria.max_new_tokens for stopping_criteria in stopping_criterias)
|
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
||||||
|
for r in pb.requests
|
||||||
|
]
|
||||||
|
max_new_tokens = (
|
||||||
|
stopping_criteria.max_new_tokens for stopping_criteria in stopping_criterias
|
||||||
|
)
|
||||||
|
|
||||||
next_token_chooser= HeterogeneousNextTokenChooser.from_pb([r.parameters for r in pb.requests], device)
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
|
[r.parameters for r in pb.requests], device
|
||||||
|
)
|
||||||
|
|
||||||
tokenized_inputs = tokenizer(
|
tokenized_inputs = tokenizer(
|
||||||
inputs,
|
inputs,
|
||||||
@ -96,7 +107,7 @@ class VectorizedCausalLMBatch(Batch):
|
|||||||
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
||||||
max_input_length = input_lengths.max().item()
|
max_input_length = input_lengths.max().item()
|
||||||
|
|
||||||
input_shape=(pb.size, max_input_length + max(max_new_tokens))
|
input_shape = (pb.size, max_input_length + max(max_new_tokens))
|
||||||
|
|
||||||
# Allocate maximum attention_mask
|
# Allocate maximum attention_mask
|
||||||
attention_mask = torch.empty(input_shape, dtype=torch.bool, device=device)
|
attention_mask = torch.empty(input_shape, dtype=torch.bool, device=device)
|
||||||
@ -112,7 +123,10 @@ class VectorizedCausalLMBatch(Batch):
|
|||||||
|
|
||||||
max_tokens = len(inputs) * max_input_length + sum(max_new_tokens)
|
max_tokens = len(inputs) * max_input_length + sum(max_new_tokens)
|
||||||
|
|
||||||
generate_stream=cls.generate_stream or any(stopping_criteria.stop_sequence_criterias for stopping_criteria in stopping_criterias)
|
generate_stream = cls.generate_stream or any(
|
||||||
|
stopping_criteria.stop_sequence_criterias
|
||||||
|
for stopping_criteria in stopping_criterias
|
||||||
|
)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=pb.id,
|
batch_id=pb.id,
|
||||||
@ -133,7 +147,9 @@ class VectorizedCausalLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
def filter(self, requests: List[generate_pb2.Request]) -> Optional["VectorizedCausalLMBatch"]:
|
def filter(
|
||||||
|
self, requests: List[generate_pb2.Request]
|
||||||
|
) -> Optional["VectorizedCausalLMBatch"]:
|
||||||
if len(requests) == 0:
|
if len(requests) == 0:
|
||||||
raise ValueError("Batch must have at least one request")
|
raise ValueError("Batch must have at least one request")
|
||||||
if len(requests) == len(self):
|
if len(requests) == len(self):
|
||||||
@ -143,70 +159,108 @@ class VectorizedCausalLMBatch(Batch):
|
|||||||
keep_indices = [self.requests_idx_mapping[r.id] for r in self.requests]
|
keep_indices = [self.requests_idx_mapping[r.id] for r in self.requests]
|
||||||
|
|
||||||
# New values after filtering
|
# New values after filtering
|
||||||
self.requests_idx_mapping={r.id:i for i, r in enumerate(self.requests)}
|
self.requests_idx_mapping = {r.id: i for i, r in enumerate(self.requests)}
|
||||||
self.input_lengths=[self.input_lengths[i] for i in keep_indices]
|
self.input_lengths = [self.input_lengths[i] for i in keep_indices]
|
||||||
self.offsets = [self.offsets[i] for i in keep_indices]
|
self.offsets = [self.offsets[i] for i in keep_indices]
|
||||||
self.token_offsets = [self.token_offsets[i] for i in keep_indices]
|
self.token_offsets = [self.token_offsets[i] for i in keep_indices]
|
||||||
|
|
||||||
self.next_token_chooser=HeterogeneousNextTokenChooser.from_pb([r.parameters for r in self.requests], self.input_ids.device)
|
self.next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
|
[r.parameters for r in self.requests], self.input_ids.device
|
||||||
|
)
|
||||||
|
|
||||||
self.stopping_criterias = [self.stopping_criterias[i] for i in keep_indices]
|
self.stopping_criterias = [self.stopping_criterias[i] for i in keep_indices]
|
||||||
remaining_decode_tokens=[stopping_criteria.max_new_tokens - stopping_criteria.current_tokens for stopping_criteria in self.stopping_criterias]
|
remaining_decode_tokens = [
|
||||||
|
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||||
|
for stopping_criteria in self.stopping_criterias
|
||||||
|
]
|
||||||
|
|
||||||
# Select the remaining indices and remove unnecessary padding
|
# Select the remaining indices and remove unnecessary padding
|
||||||
max_input_length=max(self.input_lengths)
|
max_input_length = max(self.input_lengths)
|
||||||
sequence_slice=slice(self.max_input_length-max_input_length, self.max_input_length+max(remaining_decode_tokens))
|
sequence_slice = slice(
|
||||||
self.max_input_length=max_input_length
|
self.max_input_length - max_input_length,
|
||||||
self.max_tokens = len(self.requests) * self.max_input_length + sum(remaining_decode_tokens)
|
self.max_input_length + max(remaining_decode_tokens),
|
||||||
|
)
|
||||||
|
self.max_input_length = max_input_length
|
||||||
|
self.max_tokens = len(self.requests) * self.max_input_length + sum(
|
||||||
|
remaining_decode_tokens
|
||||||
|
)
|
||||||
|
|
||||||
self.input_ids = self.input_ids[keep_indices,sequence_slice]
|
self.input_ids = self.input_ids[keep_indices, sequence_slice]
|
||||||
self.position_ids = self.position_ids[keep_indices,sequence_slice]
|
self.position_ids = self.position_ids[keep_indices, sequence_slice]
|
||||||
self.attention_mask = self.attention_mask[keep_indices,sequence_slice]
|
self.attention_mask = self.attention_mask[keep_indices, sequence_slice]
|
||||||
|
|
||||||
tensors_to_update = []
|
tensors_to_update = []
|
||||||
if self.past_key_values is not None:
|
if self.past_key_values is not None:
|
||||||
if not isinstance(self.past_key_values,(list, tuple)):
|
if not isinstance(self.past_key_values, (list, tuple)):
|
||||||
raise NotImplementedError(f"Unsupported kv cache type: {type(self.past_key_values)}")
|
raise NotImplementedError(
|
||||||
|
f"Unsupported kv cache type: {type(self.past_key_values)}"
|
||||||
|
)
|
||||||
for layer_kv in self.past_key_values:
|
for layer_kv in self.past_key_values:
|
||||||
if isinstance(layer_kv, torch.Tensor):
|
if isinstance(layer_kv, torch.Tensor):
|
||||||
tensors_to_update.append(layer_kv)
|
tensors_to_update.append(layer_kv)
|
||||||
elif isinstance(layer_kv,(list, tuple)):
|
elif isinstance(layer_kv, (list, tuple)):
|
||||||
tensors_to_update.extend(layer_kv)
|
tensors_to_update.extend(layer_kv)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unsupported layer kv cache type: {type(layer_kv)}")
|
raise NotImplementedError(
|
||||||
|
f"Unsupported layer kv cache type: {type(layer_kv)}"
|
||||||
|
)
|
||||||
|
|
||||||
kv_cache_slice=[keep_indices, *(slice(None) for _ in range(1, self.kv_cache_seq_dim)), sequence_slice]
|
kv_cache_slice = [
|
||||||
|
keep_indices,
|
||||||
|
*(slice(None) for _ in range(1, self.kv_cache_seq_dim)),
|
||||||
|
sequence_slice,
|
||||||
|
]
|
||||||
for tensor in tensors_to_update:
|
for tensor in tensors_to_update:
|
||||||
# Update tensors in-place to allow incremental garbage collection
|
# Update tensors in-place to allow incremental garbage collection
|
||||||
tensors_to_update.data=tensor[kv_cache_slice]
|
tensors_to_update.data = tensor[kv_cache_slice]
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
def concatenate(cls, batches: List["VectorizedCausalLMBatch"]) -> "VectorizedCausalLMBatch":
|
def concatenate(
|
||||||
if len(batches)==0:
|
cls, batches: List["VectorizedCausalLMBatch"]
|
||||||
|
) -> "VectorizedCausalLMBatch":
|
||||||
|
if len(batches) == 0:
|
||||||
raise ValueError("Cannot concatenate empty list.")
|
raise ValueError("Cannot concatenate empty list.")
|
||||||
requests=[request for batch in batches for request in batch.requests]
|
requests = [request for batch in batches for request in batch.requests]
|
||||||
batch_sizes=[len(batch.requests) for batch in batches]
|
batch_sizes = [len(batch.requests) for batch in batches]
|
||||||
batch_size=sum(batch_sizes)
|
batch_size = sum(batch_sizes)
|
||||||
|
|
||||||
end_indices=torch.tensor(batch_sizes).cumsum(0).tolist()
|
end_indices = torch.tensor(batch_sizes).cumsum(0).tolist()
|
||||||
start_indices=[0]+end_indices[:-1]
|
start_indices = [0] + end_indices[:-1]
|
||||||
|
|
||||||
input_lengths = [length for batch in batches for length in batch.input_lengths]
|
input_lengths = [length for batch in batches for length in batch.input_lengths]
|
||||||
offsets = [offset for batch in batches for offset in batch.offsets]
|
offsets = [offset for batch in batches for offset in batch.offsets]
|
||||||
token_offsets = [token_offset for batch in batches for token_offset in batch.token_offsets]
|
token_offsets = [
|
||||||
next_token_chooser=HeterogeneousNextTokenChooser.from_pb([r.parameters for r in requests], batches[0].input_ids.device)
|
token_offset for batch in batches for token_offset in batch.token_offsets
|
||||||
stopping_criterias = [stopping_criteria for batch in batches for stopping_criteria in batch.stopping_criterias]
|
]
|
||||||
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
|
[r.parameters for r in requests], batches[0].input_ids.device
|
||||||
|
)
|
||||||
|
stopping_criterias = [
|
||||||
|
stopping_criteria
|
||||||
|
for batch in batches
|
||||||
|
for stopping_criteria in batch.stopping_criterias
|
||||||
|
]
|
||||||
|
|
||||||
requests_idx_mapping = {k: v + start_index for batch, start_index in zip(batches, start_indices) for k, v in batch.requests_idx_mapping.items()}
|
requests_idx_mapping = {
|
||||||
|
k: v + start_index
|
||||||
|
for batch, start_index in zip(batches, start_indices)
|
||||||
|
for k, v in batch.requests_idx_mapping.items()
|
||||||
|
}
|
||||||
|
|
||||||
max_input_length=max(input_lengths)
|
max_input_length = max(input_lengths)
|
||||||
left_indices=[max_input_length-batch.max_input_length for batch in batches]
|
left_indices = [max_input_length - batch.max_input_length for batch in batches]
|
||||||
|
|
||||||
input_shape=(batch_size, max_input_length + max(batch.input_ids.size(1)-batch.max_input_length for batch in batches))
|
input_shape = (
|
||||||
device=batches[0].input_ids.device
|
batch_size,
|
||||||
|
max_input_length
|
||||||
|
+ max(
|
||||||
|
batch.input_ids.size(1) - batch.max_input_length for batch in batches
|
||||||
|
),
|
||||||
|
)
|
||||||
|
device = batches[0].input_ids.device
|
||||||
|
|
||||||
# Allocate maximum attention_mask
|
# Allocate maximum attention_mask
|
||||||
attention_mask = torch.empty(input_shape, dtype=torch.bool, device=device)
|
attention_mask = torch.empty(input_shape, dtype=torch.bool, device=device)
|
||||||
@ -217,56 +271,84 @@ class VectorizedCausalLMBatch(Batch):
|
|||||||
# TODO : only needed for prefill
|
# TODO : only needed for prefill
|
||||||
input_ids[:, :max_input_length].fill_(0)
|
input_ids[:, :max_input_length].fill_(0)
|
||||||
|
|
||||||
for batch,start_index, end_index, left_index in zip(batches, start_indices, end_indices, left_indices):
|
for batch, start_index, end_index, left_index in zip(
|
||||||
attention_mask[start_index:end_index, left_index:max_input_length].copy_(batch.attention_mask[:, :batch.max_input_length])
|
batches, start_indices, end_indices, left_indices
|
||||||
input_ids[start_index:end_index, left_index:max_input_length].copy_(batch.input_ids[:, :batch.max_input_length])
|
):
|
||||||
|
attention_mask[start_index:end_index, left_index:max_input_length].copy_(
|
||||||
|
batch.attention_mask[:, : batch.max_input_length]
|
||||||
|
)
|
||||||
|
input_ids[start_index:end_index, left_index:max_input_length].copy_(
|
||||||
|
batch.input_ids[:, : batch.max_input_length]
|
||||||
|
)
|
||||||
|
|
||||||
position_ids = attention_mask.cumsum(-1).sub_(1)
|
position_ids = attention_mask.cumsum(-1).sub_(1)
|
||||||
position_ids[:, :max_input_length].relu_()
|
position_ids[:, :max_input_length].relu_()
|
||||||
|
|
||||||
max_tokens = sum(batch.max_tokens + (max_input_length - batch.max_input_length) * len(batch) for batch in batches)
|
max_tokens = sum(
|
||||||
|
batch.max_tokens + (max_input_length - batch.max_input_length) * len(batch)
|
||||||
|
for batch in batches
|
||||||
|
)
|
||||||
|
|
||||||
kv_formats=None
|
kv_formats = None
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
if batch.past_key_values is None:
|
if batch.past_key_values is None:
|
||||||
raise ValueError("Only concatenate prefilled batches")
|
raise ValueError("Only concatenate prefilled batches")
|
||||||
if not isinstance(batch.past_key_values, (list, tuple)):
|
if not isinstance(batch.past_key_values, (list, tuple)):
|
||||||
raise NotImplementedError(f"Unsupported kv cache type: {type(batch.past_key_values)}")
|
raise NotImplementedError(
|
||||||
|
f"Unsupported kv cache type: {type(batch.past_key_values)}"
|
||||||
|
)
|
||||||
if kv_formats is None:
|
if kv_formats is None:
|
||||||
num_layers=len(batch.past_key_values)
|
num_layers = len(batch.past_key_values)
|
||||||
if num_layers==0:
|
if num_layers == 0:
|
||||||
raise ValueError("Empty KV cache")
|
raise ValueError("Empty KV cache")
|
||||||
kv_formats = [0]*num_layers
|
kv_formats = [0] * num_layers
|
||||||
elif len(batch.past_key_values)!=len(kv_formats):
|
elif len(batch.past_key_values) != len(kv_formats):
|
||||||
raise ValueError("Num layers is not constant")
|
raise ValueError("Num layers is not constant")
|
||||||
for i, layer_kv in enumerate(batch.past_key_values):
|
for i, layer_kv in enumerate(batch.past_key_values):
|
||||||
if isinstance(layer_kv, (list, tuple)):
|
if isinstance(layer_kv, (list, tuple)):
|
||||||
kv_format = len(layer_kv)
|
kv_format = len(layer_kv)
|
||||||
else:
|
else:
|
||||||
kv_format=None
|
kv_format = None
|
||||||
if kv_formats[i]==0:
|
if kv_formats[i] == 0:
|
||||||
if kv_format==0:
|
if kv_format == 0:
|
||||||
raise ValueError("Empty KV cache")
|
raise ValueError("Empty KV cache")
|
||||||
kv_formats[i]=kv_format
|
kv_formats[i] = kv_format
|
||||||
elif kv_formats[i]!=kv_format:
|
elif kv_formats[i] != kv_format:
|
||||||
raise ValueError("Incompatible KV cache format.")
|
raise ValueError("Incompatible KV cache format.")
|
||||||
|
|
||||||
kv_cache_seq_dim=batches[0].kv_cache_seq_dim
|
kv_cache_seq_dim = batches[0].kv_cache_seq_dim
|
||||||
past_key_values=[]
|
past_key_values = []
|
||||||
for i, kv_format in enumerate(kv_formats):
|
for i, kv_format in enumerate(kv_formats):
|
||||||
for j in range(1 if kv_format is None else kv_format):
|
for j in range(1 if kv_format is None else kv_format):
|
||||||
tensors_to_merge=[batch.past_key_values[i] for batch in batches]
|
tensors_to_merge = [batch.past_key_values[i] for batch in batches]
|
||||||
# Generally `max_input_length`, unless the model allocates more than needed.
|
# Generally `max_input_length`, unless the model allocates more than needed.
|
||||||
right_indices=[left_index+tensor.size(kv_cache_seq_dim) for tensor, left_index in zip(tensors_to_merge, left_indices)]
|
right_indices = [
|
||||||
combined_shape=[batch_size]+list(tensors_to_merge[0].shape[1:])
|
left_index + tensor.size(kv_cache_seq_dim)
|
||||||
combined_shape[kv_cache_seq_dim]=max(right_indices)
|
for tensor, left_index in zip(tensors_to_merge, left_indices)
|
||||||
|
]
|
||||||
|
combined_shape = [batch_size] + list(tensors_to_merge[0].shape[1:])
|
||||||
|
combined_shape[kv_cache_seq_dim] = max(right_indices)
|
||||||
# Set to zero to avoid propagating nans in padded values.
|
# Set to zero to avoid propagating nans in padded values.
|
||||||
kv_cache = torch.zeros(combined_shape, dtype=tensors_to_merge[0].dtype, device=device)
|
kv_cache = torch.zeros(
|
||||||
for tensor, start_index, end_index, left_index, right_index in zip(tensors_to_merge, start_indices, end_indices, left_indices, right_indices):
|
combined_shape, dtype=tensors_to_merge[0].dtype, device=device
|
||||||
kv_cache[[slice(start_index, end_index), *(slice(None) for _ in range(1, kv_cache_seq_dim)), slice(left_index,right_index)]].copy_(tensor)
|
)
|
||||||
|
for tensor, start_index, end_index, left_index, right_index in zip(
|
||||||
|
tensors_to_merge,
|
||||||
|
start_indices,
|
||||||
|
end_indices,
|
||||||
|
left_indices,
|
||||||
|
right_indices,
|
||||||
|
):
|
||||||
|
kv_cache[
|
||||||
|
[
|
||||||
|
slice(start_index, end_index),
|
||||||
|
*(slice(None) for _ in range(1, kv_cache_seq_dim)),
|
||||||
|
slice(left_index, right_index),
|
||||||
|
]
|
||||||
|
].copy_(tensor)
|
||||||
if kv_format is None:
|
if kv_format is None:
|
||||||
past_key_values.append(kv_cache)
|
past_key_values.append(kv_cache)
|
||||||
elif j==0:
|
elif j == 0:
|
||||||
past_key_values.append([kv_cache])
|
past_key_values.append([kv_cache])
|
||||||
else:
|
else:
|
||||||
past_key_values[-1].append(kv_cache)
|
past_key_values[-1].append(kv_cache)
|
||||||
@ -350,58 +432,75 @@ class VectorizedCausalLM(Model):
|
|||||||
def generate_token(
|
def generate_token(
|
||||||
self, batch: VectorizedCausalLMBatch
|
self, batch: VectorizedCausalLMBatch
|
||||||
) -> Tuple[List[Generation], Optional[VectorizedCausalLMBatch]]:
|
) -> Tuple[List[Generation], Optional[VectorizedCausalLMBatch]]:
|
||||||
key_length=batch.max_input_length
|
key_length = batch.max_input_length
|
||||||
if key_length>batch.input_ids.size(1):
|
if key_length > batch.input_ids.size(1):
|
||||||
raise RuntimeError("Cannot generate more than `max_tokens`.")
|
raise RuntimeError("Cannot generate more than `max_tokens`.")
|
||||||
|
|
||||||
query_length=key_length if batch.past_key_values is None else 1
|
query_length = key_length if batch.past_key_values is None else 1
|
||||||
input_ids=batch.input_ids[:, key_length-query_length: key_length]
|
input_ids = batch.input_ids[:, key_length - query_length : key_length]
|
||||||
|
|
||||||
outputs = self.model.forward(
|
outputs = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=batch.attention_mask[:, : key_length],
|
attention_mask=batch.attention_mask[:, :key_length],
|
||||||
position_ids=batch.position_ids[:, key_length-query_length: key_length],
|
position_ids=batch.position_ids[:, key_length - query_length : key_length],
|
||||||
past_key_values=batch.past_key_values,
|
past_key_values=batch.past_key_values,
|
||||||
)
|
)
|
||||||
# TODO: Post-processing
|
# TODO: Post-processing
|
||||||
next_token_ids, logprobs = batch.next_token_chooser(input_ids, outputs.logits, batch.details)
|
next_token_ids, logprobs = batch.next_token_chooser(
|
||||||
|
input_ids, outputs.logits, batch.details
|
||||||
|
)
|
||||||
|
|
||||||
if batch.generate_stream:
|
if batch.generate_stream:
|
||||||
# TODO: self.decode_token, offsets?
|
# TODO: self.decode_token, offsets?
|
||||||
next_token_texts=self.tokenizer.batch_decode(next_token_ids.tolist())
|
next_token_texts = self.tokenizer.batch_decode(next_token_ids.tolist())
|
||||||
|
|
||||||
if batch.details:
|
if batch.details:
|
||||||
token_logprobs=logprobs[:, -1, :].gather(1, next_token_ids.unsqueeze(1)).squeeze(1).tolist()
|
token_logprobs = (
|
||||||
if query_length>1:
|
logprobs[:, -1, :]
|
||||||
prefill_token_ids=batch.input_ids[:, :key_length].tolist()
|
.gather(1, next_token_ids.unsqueeze(1))
|
||||||
prefill_logprobs=logprobs.gather(2, batch.input_ids[:, 1:key_length, None]).squeeze(2).tolist()
|
.squeeze(1)
|
||||||
prefill_tokens=[]
|
.tolist()
|
||||||
for prefill_token_ids_, prefill_logprobs_, input_length in zip(prefill_token_ids, prefill_logprobs, batch.input_lengths):
|
)
|
||||||
prefill_token_ids_=prefill_token_ids_[-input_length:]
|
if query_length > 1:
|
||||||
|
prefill_token_ids = batch.input_ids[:, :key_length].tolist()
|
||||||
|
prefill_logprobs = (
|
||||||
|
logprobs.gather(2, batch.input_ids[:, 1:key_length, None])
|
||||||
|
.squeeze(2)
|
||||||
|
.tolist()
|
||||||
|
)
|
||||||
|
prefill_tokens = []
|
||||||
|
for prefill_token_ids_, prefill_logprobs_, input_length in zip(
|
||||||
|
prefill_token_ids, prefill_logprobs, batch.input_lengths
|
||||||
|
):
|
||||||
|
prefill_token_ids_ = prefill_token_ids_[-input_length:]
|
||||||
prefill_texts = self.tokenizer.batch_decode(
|
prefill_texts = self.tokenizer.batch_decode(
|
||||||
prefill_token_ids_,
|
prefill_token_ids_,
|
||||||
clean_up_tokenization_spaces=False,
|
clean_up_tokenization_spaces=False,
|
||||||
skip_special_tokens=False,
|
skip_special_tokens=False,
|
||||||
)
|
)
|
||||||
prefill_tokens.append(PrefillTokens(
|
prefill_tokens.append(
|
||||||
prefill_token_ids_, [math.nan, *prefill_logprobs_], prefill_texts
|
PrefillTokens(
|
||||||
))
|
prefill_token_ids_,
|
||||||
|
[math.nan, *prefill_logprobs_],
|
||||||
|
prefill_texts,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Update batch
|
# Update batch
|
||||||
# TODO: Why do we need all input ids?
|
# TODO: Why do we need all input ids?
|
||||||
batch.input_ids[:, key_length].copy_(next_token_ids)
|
batch.input_ids[:, key_length].copy_(next_token_ids)
|
||||||
batch.past_key_values=outputs.past_key_values
|
batch.past_key_values = outputs.past_key_values
|
||||||
batch.input_lengths=[length+1 for length in batch.input_lengths]
|
batch.input_lengths = [length + 1 for length in batch.input_lengths]
|
||||||
batch.max_input_length+=1
|
batch.max_input_length += 1
|
||||||
|
|
||||||
# TODO: Vectorize some of this?
|
# TODO: Vectorize some of this?
|
||||||
|
|
||||||
generations: List[Generation] = []
|
generations: List[Generation] = []
|
||||||
next_batch=None
|
next_batch = None
|
||||||
|
|
||||||
for i, next_token_id in enumerate(next_token_ids):
|
for i, next_token_id in enumerate(next_token_ids):
|
||||||
next_token_text=next_token_texts[i] if batch.generate_stream else ""
|
next_token_text = next_token_texts[i] if batch.generate_stream else ""
|
||||||
stopping_criterias=batch.stopping_criterias[i]
|
stopping_criterias = batch.stopping_criterias[i]
|
||||||
stop, reason = stopping_criterias(
|
stop, reason = stopping_criterias(
|
||||||
next_token_id,
|
next_token_id,
|
||||||
next_token_text,
|
next_token_text,
|
||||||
@ -421,10 +520,9 @@ class VectorizedCausalLM(Model):
|
|||||||
generated_text = None
|
generated_text = None
|
||||||
next_batch = batch
|
next_batch = batch
|
||||||
|
|
||||||
|
|
||||||
generation = Generation(
|
generation = Generation(
|
||||||
batch.requests[i].id,
|
batch.requests[i].id,
|
||||||
prefill_tokens[i] if batch.details and query_length>1 else None,
|
prefill_tokens[i] if batch.details and query_length > 1 else None,
|
||||||
next_token_id,
|
next_token_id,
|
||||||
token_logprobs[i] if batch.details else 0.0,
|
token_logprobs[i] if batch.details else 0.0,
|
||||||
next_token_text,
|
next_token_text,
|
||||||
@ -435,4 +533,3 @@ class VectorizedCausalLM(Model):
|
|||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
|
||||||
return generations, next_batch
|
return generations, next_batch
|
||||||
|
|
||||||
|
@ -24,8 +24,10 @@ class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
|||||||
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, penalty: List[float], device:torch.device):
|
def __init__(self, penalty: List[float], device: torch.device):
|
||||||
self.penalty = torch.tensor(penalty, dtype=torch.float32, device=device).unsqueeze(1)
|
self.penalty = torch.tensor(
|
||||||
|
penalty, dtype=torch.float32, device=device
|
||||||
|
).unsqueeze(1)
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||||
score = torch.gather(scores, 1, input_ids)
|
score = torch.gather(scores, 1, input_ids)
|
||||||
@ -36,6 +38,7 @@ class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
|||||||
scores.scatter_(1, input_ids, score)
|
scores.scatter_(1, input_ids, score)
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
class HeterogeneousTemperatureLogitsWarper(LogitsWarper):
|
class HeterogeneousTemperatureLogitsWarper(LogitsWarper):
|
||||||
r"""
|
r"""
|
||||||
[`LogitsWarper`] for temperature (exponential scaling output probability distribution).
|
[`LogitsWarper`] for temperature (exponential scaling output probability distribution).
|
||||||
@ -47,13 +50,16 @@ class HeterogeneousTemperatureLogitsWarper(LogitsWarper):
|
|||||||
The value used to module the logits distribution.
|
The value used to module the logits distribution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, temperature: List[float], device:torch.device):
|
def __init__(self, temperature: List[float], device: torch.device):
|
||||||
self.temperature = torch.tensor(temperature, dtype=torch.float32, device=device).unsqueeze(1)
|
self.temperature = torch.tensor(
|
||||||
|
temperature, dtype=torch.float32, device=device
|
||||||
|
).unsqueeze(1)
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||||
scores.div_(self.temperature)
|
scores.div_(self.temperature)
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
||||||
"""
|
"""
|
||||||
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
|
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
|
||||||
@ -70,8 +76,16 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
|||||||
Minimum number of tokens that cannot be filtered.
|
Minimum number of tokens that cannot be filtered.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, top_p: List[float], device:torch.device, filter_value: float = -math.inf, min_tokens_to_keep: int = 1):
|
def __init__(
|
||||||
self.top_p = torch.tensor(top_p, dtype=torch.float32, device=device).unsqueeze(1)
|
self,
|
||||||
|
top_p: List[float],
|
||||||
|
device: torch.device,
|
||||||
|
filter_value: float = -math.inf,
|
||||||
|
min_tokens_to_keep: int = 1,
|
||||||
|
):
|
||||||
|
self.top_p = torch.tensor(top_p, dtype=torch.float32, device=device).unsqueeze(
|
||||||
|
1
|
||||||
|
)
|
||||||
self.filter_value = filter_value
|
self.filter_value = filter_value
|
||||||
self.min_tokens_to_keep = min_tokens_to_keep
|
self.min_tokens_to_keep = min_tokens_to_keep
|
||||||
|
|
||||||
@ -86,10 +100,13 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
|||||||
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
|
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
|
||||||
|
|
||||||
# scatter sorted tensors to original indexing
|
# scatter sorted tensors to original indexing
|
||||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||||
|
1, sorted_indices, sorted_indices_to_remove
|
||||||
|
)
|
||||||
scores.masked_fill_(indices_to_remove, self.filter_value)
|
scores.masked_fill_(indices_to_remove, self.filter_value)
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
class HeterogeneousTopKLogitsWarper(LogitsWarper):
|
class HeterogeneousTopKLogitsWarper(LogitsWarper):
|
||||||
r"""
|
r"""
|
||||||
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
|
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
|
||||||
@ -105,10 +122,20 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper):
|
|||||||
Minimum number of tokens that cannot be filtered.
|
Minimum number of tokens that cannot be filtered.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, top_k: List[int], device:torch.device, filter_value: float = -math.inf, min_tokens_to_keep: int = 1):
|
def __init__(
|
||||||
|
self,
|
||||||
|
top_k: List[int],
|
||||||
|
device: torch.device,
|
||||||
|
filter_value: float = -math.inf,
|
||||||
|
min_tokens_to_keep: int = 1,
|
||||||
|
):
|
||||||
self.max_top_k = max(top_k)
|
self.max_top_k = max(top_k)
|
||||||
self.top_k = torch.tensor([max(x - 1, min_tokens_to_keep-1) for x in top_k], dtype=torch.int64,device=device).unsqueeze(1)
|
self.top_k = torch.tensor(
|
||||||
zeros=[x == 0 for x in top_k]
|
[max(x - 1, min_tokens_to_keep - 1) for x in top_k],
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=device,
|
||||||
|
).unsqueeze(1)
|
||||||
|
zeros = [x == 0 for x in top_k]
|
||||||
if any(zeros):
|
if any(zeros):
|
||||||
self.top_k_mask = torch.tensor(zeros, dtype=torch.bool, device=device)
|
self.top_k_mask = torch.tensor(zeros, dtype=torch.bool, device=device)
|
||||||
else:
|
else:
|
||||||
@ -116,13 +143,13 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper):
|
|||||||
self.filter_value = filter_value
|
self.filter_value = filter_value
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||||
if scores.size(-1)>self.max_top_k: # Safety check
|
if scores.size(-1) > self.max_top_k: # Safety check
|
||||||
max_top_k=scores.size(-1)
|
max_top_k = scores.size(-1)
|
||||||
top_k=torch.clamp_max(self.top_k,max_top_k) # Run only if needed.
|
top_k = torch.clamp_max(self.top_k, max_top_k) # Run only if needed.
|
||||||
else:
|
else:
|
||||||
max_top_k=self.max_top_k
|
max_top_k = self.max_top_k
|
||||||
top_k=self.top_k
|
top_k = self.top_k
|
||||||
kth_scores=torch.gather(torch.topk(scores, max_top_k)[0], 1, top_k)
|
kth_scores = torch.gather(torch.topk(scores, max_top_k)[0], 1, top_k)
|
||||||
if self.top_k_mask is not None:
|
if self.top_k_mask is not None:
|
||||||
kth_scores.masked_fill_(self.top_k_mask, self.filter_value)
|
kth_scores.masked_fill_(self.top_k_mask, self.filter_value)
|
||||||
# Remove all tokens with a probability less than the last token of the top-k
|
# Remove all tokens with a probability less than the last token of the top-k
|
||||||
@ -147,7 +174,13 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper):
|
|||||||
Minimum number of tokens that cannot be filtered.
|
Minimum number of tokens that cannot be filtered.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, mass: List[float], device:torch.device, filter_value: float = -math.inf, min_tokens_to_keep: int = 1):
|
def __init__(
|
||||||
|
self,
|
||||||
|
mass: List[float],
|
||||||
|
device: torch.device,
|
||||||
|
filter_value: float = -math.inf,
|
||||||
|
min_tokens_to_keep: int = 1,
|
||||||
|
):
|
||||||
self.filter_value = filter_value
|
self.filter_value = filter_value
|
||||||
self.mass = torch.tensor(mass, dtype=torch.float32, device=device).unsqueeze(1)
|
self.mass = torch.tensor(mass, dtype=torch.float32, device=device).unsqueeze(1)
|
||||||
self.min_tokens_to_keep = min_tokens_to_keep
|
self.min_tokens_to_keep = min_tokens_to_keep
|
||||||
@ -167,11 +200,15 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper):
|
|||||||
# Remove tokens with cumulative mass above the threshold
|
# Remove tokens with cumulative mass above the threshold
|
||||||
last_ind = (cumulative_probs < self.mass).sum(dim=1)
|
last_ind = (cumulative_probs < self.mass).sum(dim=1)
|
||||||
last_ind[last_ind < 0] = 0
|
last_ind[last_ind < 0] = 0
|
||||||
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
|
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(
|
||||||
|
1, last_ind.view(-1, 1)
|
||||||
|
)
|
||||||
if self.min_tokens_to_keep > 1:
|
if self.min_tokens_to_keep > 1:
|
||||||
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
||||||
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
|
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
|
||||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||||
|
1, sorted_indices, sorted_indices_to_remove
|
||||||
|
)
|
||||||
|
|
||||||
scores = scores.masked_fill_(indices_to_remove, self.filter_value)
|
scores = scores.masked_fill_(indices_to_remove, self.filter_value)
|
||||||
return scores
|
return scores
|
||||||
@ -181,103 +218,113 @@ class HeterogeneousSampling:
|
|||||||
r"""
|
r"""
|
||||||
Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.
|
Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.
|
||||||
"""
|
"""
|
||||||
def __init__(self, do_sample:List[bool], seeds: List[int], device: torch.device):
|
|
||||||
|
def __init__(self, do_sample: List[bool], seeds: List[int], device: torch.device):
|
||||||
self.seeds = seeds
|
self.seeds = seeds
|
||||||
self.greedy=Greedy()
|
self.greedy = Greedy()
|
||||||
# TODO: Most seeds are ignored
|
# TODO: Most seeds are ignored
|
||||||
self.sampling=Sampling(seeds[0], device)
|
self.sampling = Sampling(seeds[0], device)
|
||||||
self.do_sample=torch.tensor(do_sample, dtype=torch.bool, device=device)
|
self.do_sample = torch.tensor(do_sample, dtype=torch.bool, device=device)
|
||||||
|
|
||||||
def __call__(self, logits):
|
def __call__(self, logits):
|
||||||
return torch.where(self.do_sample, self.sampling(logits), self.greedy(logits))
|
return torch.where(self.do_sample, self.sampling(logits), self.greedy(logits))
|
||||||
|
|
||||||
|
|
||||||
class HeterogeneousNextTokenChooser:
|
class HeterogeneousNextTokenChooser:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
batch_size:int,
|
batch_size: int,
|
||||||
device:torch.device,
|
device: torch.device,
|
||||||
watermark:Optional[Union[bool,List[Optional[bool]]]]=None,
|
watermark: Optional[Union[bool, List[Optional[bool]]]] = None,
|
||||||
temperature:Optional[Union[float,List[Optional[float]]]]=None,
|
temperature: Optional[Union[float, List[Optional[float]]]] = None,
|
||||||
repetition_penalty:Optional[Union[float,List[Optional[float]]]]=None,
|
repetition_penalty: Optional[Union[float, List[Optional[float]]]] = None,
|
||||||
top_k:Optional[Union[int,List[Optional[int]]]]=None,
|
top_k: Optional[Union[int, List[Optional[int]]]] = None,
|
||||||
top_p:Optional[Union[float,List[Optional[float]]]]=None,
|
top_p: Optional[Union[float, List[Optional[float]]]] = None,
|
||||||
typical_p:Optional[Union[float,List[Optional[float]]]]=None,
|
typical_p: Optional[Union[float, List[Optional[float]]]] = None,
|
||||||
do_sample:Optional[Union[bool,List[Optional[bool]]]]=None,
|
do_sample: Optional[Union[bool, List[Optional[bool]]]] = None,
|
||||||
seeds:Optional[Union[int,List[Optional[int]]]]=None,
|
seeds: Optional[Union[int, List[Optional[int]]]] = None,
|
||||||
):
|
):
|
||||||
# TODO: Most seeds are ignored
|
# TODO: Most seeds are ignored
|
||||||
seeds=self._standardize(seeds, batch_size, 0)
|
seeds = self._standardize(seeds, batch_size, 0)
|
||||||
do_sample=self._standardize(do_sample, batch_size, False)
|
do_sample = self._standardize(do_sample, batch_size, False)
|
||||||
|
|
||||||
warpers = LogitsProcessorList()
|
warpers = LogitsProcessorList()
|
||||||
|
|
||||||
watermark=self._standardize(watermark, batch_size, False)
|
watermark = self._standardize(watermark, batch_size, False)
|
||||||
if any(watermark):
|
if any(watermark):
|
||||||
raise NotImplementedError("Watermarking not implemented")
|
raise NotImplementedError("Watermarking not implemented")
|
||||||
|
|
||||||
repetition_penalty=self._standardize(repetition_penalty, batch_size, 1.0)
|
repetition_penalty = self._standardize(repetition_penalty, batch_size, 1.0)
|
||||||
if any([x!=1.0 for x in repetition_penalty]):
|
if any([x != 1.0 for x in repetition_penalty]):
|
||||||
warpers.append(HeterogeneousRepetitionPenaltyLogitsProcessor(repetition_penalty, device))
|
warpers.append(
|
||||||
|
HeterogeneousRepetitionPenaltyLogitsProcessor(
|
||||||
|
repetition_penalty, device
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
temperature=self._standardize(temperature, batch_size, 1.0)
|
temperature = self._standardize(temperature, batch_size, 1.0)
|
||||||
if any([x!=1.0 for x in temperature]):
|
if any([x != 1.0 for x in temperature]):
|
||||||
do_sample=[sample or x!=1.0 for x, sample in zip(temperature, do_sample)]
|
do_sample = [
|
||||||
|
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
|
||||||
|
]
|
||||||
warpers.append(HeterogeneousTemperatureLogitsWarper(temperature, device))
|
warpers.append(HeterogeneousTemperatureLogitsWarper(temperature, device))
|
||||||
|
|
||||||
top_k=self._standardize(top_k, batch_size, 0)
|
top_k = self._standardize(top_k, batch_size, 0)
|
||||||
n_top_k=sum([x!=0 for x in top_k])
|
n_top_k = sum([x != 0 for x in top_k])
|
||||||
if n_top_k>0:
|
if n_top_k > 0:
|
||||||
do_sample=[sample or x!=0 for x, sample in zip(top_k, do_sample)]
|
do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]
|
||||||
warpers.append(HeterogeneousTopKLogitsWarper(top_k, device))
|
warpers.append(HeterogeneousTopKLogitsWarper(top_k, device))
|
||||||
|
|
||||||
top_p=self._standardize(top_p, batch_size, 1.0)
|
top_p = self._standardize(top_p, batch_size, 1.0)
|
||||||
if any([x<1.0 for x in top_p]):
|
if any([x < 1.0 for x in top_p]):
|
||||||
do_sample=[sample or x<1.0 for x, sample in zip(top_p, do_sample)]
|
do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)]
|
||||||
warpers.append(HeterogeneousTopPLogitsWarper(top_p, device))
|
warpers.append(HeterogeneousTopPLogitsWarper(top_p, device))
|
||||||
|
|
||||||
typical_p=self._standardize(typical_p, batch_size, 1.0)
|
typical_p = self._standardize(typical_p, batch_size, 1.0)
|
||||||
if any([x<1.0 for x in typical_p]):
|
if any([x < 1.0 for x in typical_p]):
|
||||||
do_sample=[sample or x<1.0 for x, sample in zip(typical_p, do_sample)]
|
do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)]
|
||||||
warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, device))
|
warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, device))
|
||||||
|
|
||||||
self.warpers=warpers
|
self.warpers = warpers
|
||||||
|
|
||||||
num_do_sample=sum(do_sample)
|
num_do_sample = sum(do_sample)
|
||||||
if num_do_sample==0:
|
if num_do_sample == 0:
|
||||||
self.choice=Greedy()
|
self.choice = Greedy()
|
||||||
elif num_do_sample<batch_size:
|
elif num_do_sample < batch_size:
|
||||||
self.choice=HeterogeneousSampling(do_sample, seeds, device)
|
self.choice = HeterogeneousSampling(do_sample, seeds, device)
|
||||||
else:
|
else:
|
||||||
# TODO: Most seeds are ignored
|
# TODO: Most seeds are ignored
|
||||||
self.choice=Sampling(seeds[0], device)
|
self.choice = Sampling(seeds[0], device)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _standardize(values, batch_size, default):
|
def _standardize(values, batch_size, default):
|
||||||
if isinstance(values, list):
|
if isinstance(values, list):
|
||||||
values=values.copy()
|
values = values.copy()
|
||||||
else:
|
else:
|
||||||
values=[values]*batch_size
|
values = [values] * batch_size
|
||||||
assert len(values)==batch_size
|
assert len(values) == batch_size
|
||||||
for i, v in enumerate(values):
|
for i, v in enumerate(values):
|
||||||
if v is None:
|
if v is None:
|
||||||
values[i]=default
|
values[i] = default
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def __call__(self, input_ids:torch.Tensor, scores:torch.Tensor, return_logprobs:bool):
|
def __call__(
|
||||||
last_token_scores=self.warpers(input_ids, scores[:, -1, :])
|
self, input_ids: torch.Tensor, scores: torch.Tensor, return_logprobs: bool
|
||||||
next_token_ids=self.choice(last_token_scores)
|
):
|
||||||
|
last_token_scores = self.warpers(input_ids, scores[:, -1, :])
|
||||||
|
next_token_ids = self.choice(last_token_scores)
|
||||||
|
|
||||||
if return_logprobs:
|
if return_logprobs:
|
||||||
# Compute logprobs
|
# Compute logprobs
|
||||||
if scores.size(1)==1:
|
if scores.size(1) == 1:
|
||||||
scores=last_token_scores.unsqueeze(1)
|
scores = last_token_scores.unsqueeze(1)
|
||||||
else:
|
else:
|
||||||
# TODO: Post-process all the tokens?
|
# TODO: Post-process all the tokens?
|
||||||
scores[:, -1, :]=last_token_scores
|
scores[:, -1, :] = last_token_scores
|
||||||
logprobs = torch.log_softmax(scores, dim=-1)
|
logprobs = torch.log_softmax(scores, dim=-1)
|
||||||
else:
|
else:
|
||||||
logprobs=None
|
logprobs = None
|
||||||
|
|
||||||
return next_token_ids, logprobs
|
return next_token_ids, logprobs
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user