text-generation-inference/server/text_generation_server/models/flash_causal_lm.py

750 lines
25 KiB
Python
Raw Normal View History

2023-04-03 17:06:42 +00:00
import torch
import torch.distributed
import numpy as np
2023-04-03 17:06:42 +00:00
from torch.nn import functional as F
from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel
from typing import Optional, Tuple, List, Type, Union, Dict
2023-04-03 17:06:42 +00:00
from text_generation_server.models import Model
from text_generation_server.models.types import (
Batch,
PrefillTokens,
Generation,
GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import (
NextTokenChooser,
StoppingCriteria,
Sampling,
)
tracer = trace.get_tracer(__name__)
@dataclass
class FlashCausalLMBatch(Batch):
batch_id: int
requests: List[generate_pb2.Request]
# request id -> idx in list mapping
requests_idx_mapping: Dict[int, int]
2023-04-03 17:06:42 +00:00
# Decoder values
input_ids: torch.Tensor
position_ids: torch.Tensor
2023-04-03 17:06:42 +00:00
# cumulative sequence lengths
cu_seqlens: torch.Tensor
# cumulative query sequence lengths, only used in decode
cu_seqlens_q: Optional[torch.Tensor]
# past key values, only used in decode
past_key_values: Optional[torch.Tensor]
2023-04-03 17:06:42 +00:00
max_seqlen: int
# All tokens
all_input_ids: List[List[int]]
all_input_ids_tensor: List[torch.Tensor]
# Lengths of all generations present in the batch
input_lengths: List[int]
prefix_offsets: List[Optional[int]]
read_offsets: List[Optional[int]]
2023-04-03 17:06:42 +00:00
# Generation helpers
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
# Maximum number of tokens this batch will grow to
max_tokens: int
2023-04-03 17:06:42 +00:00
def to_pb(self) -> generate_pb2.Batch:
return generate_pb2.Batch(
id=self.batch_id,
requests=self.requests,
size=len(self),
max_tokens=self.max_tokens,
2023-04-03 17:06:42 +00:00
)
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
device: torch.device,
) -> "FlashCausalLMBatch":
2023-04-03 17:06:42 +00:00
position_ids = []
cu_seqlens = [0]
max_seqlen = 0
input_lengths = []
prefix_offsets = []
read_offsets = []
2023-04-03 17:06:42 +00:00
all_input_ids = []
requests_idx_mapping = {}
2023-04-03 17:06:42 +00:00
next_token_choosers = []
stopping_criterias = []
# Cumulative length
cumulative_length = 0
max_tokens = 0
2023-04-03 17:06:42 +00:00
# Parse batch
for i, r in enumerate(pb.requests):
# request id -> idx in list mapping
requests_idx_mapping[r.id] = i
tokenized_input = tokenizer(
r.inputs, truncation=True, max_length=r.truncate
)["input_ids"]
2023-04-03 17:06:42 +00:00
input_length = len(tokenized_input)
max_seqlen = max(max_seqlen, input_length)
input_lengths.append(input_length)
prefix_offsets.append(0)
read_offsets.append(input_length)
2023-04-03 17:06:42 +00:00
all_input_ids.append(tokenized_input)
2023-04-03 17:06:42 +00:00
# Position ids
position_ids.append(np.arange(0, input_length))
2023-04-03 17:06:42 +00:00
# Add cumulative lengths of all previous inputs
cu_seqlens.append(cumulative_length + input_length)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
2023-04-03 17:06:42 +00:00
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
max_new_tokens = stopping_criteria.max_new_tokens
2023-04-03 17:06:42 +00:00
stopping_criterias.append(stopping_criteria)
2023-04-03 17:06:42 +00:00
# Update
cumulative_length += input_length
max_tokens += input_length + max_new_tokens
2023-04-03 17:06:42 +00:00
# Create tensors on device
input_ids = torch.tensor(
np.concatenate(all_input_ids), dtype=torch.int64, device=device
)
position_ids = torch.tensor(
np.concatenate(position_ids), dtype=torch.int32, device=device
)
cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32)
2023-04-03 17:06:42 +00:00
return cls(
batch_id=pb.id,
requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
2023-04-03 17:06:42 +00:00
input_ids=input_ids,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
cu_seqlens_q=None,
2023-04-03 17:06:42 +00:00
max_seqlen=max_seqlen,
past_key_values=None,
input_lengths=input_lengths,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
2023-04-03 17:06:42 +00:00
all_input_ids=all_input_ids,
all_input_ids_tensor=[],
2023-04-03 17:06:42 +00:00
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
max_tokens=max_tokens,
2023-04-03 17:06:42 +00:00
)
@tracer.start_as_current_span("filter")
def filter(self, requests: List[generate_pb2.Request]) -> "FlashCausalLMBatch":
if len(requests) == 0:
raise ValueError("Batch must have at least one request")
# We assume that if len(requests) == len(self) then the requests are the same
if len(requests) == len(self):
return self
single_request = len(requests) == 1
# Cumulative length
cumulative_length = 0
# New values after filtering
requests_idx_mapping = {}
input_ids = self.input_ids.new_empty(len(requests))
position_ids = self.position_ids.new_empty(len(requests))
# Create on CPU to only move to GPU once instead of at every copy
cu_seqlens = torch.zeros(len(requests) + 1, dtype=torch.int32)
cu_seqlens_q = torch.arange(
0, len(requests) + 1, device=self.cu_seqlens_q.device, dtype=torch.int32
)
max_seqlen = 0
past_key_values = []
all_input_ids = []
all_input_ids_tensor = []
2023-04-03 17:06:42 +00:00
input_lengths = []
prefix_offsets = []
read_offsets = []
2023-04-03 17:06:42 +00:00
next_token_choosers = []
stopping_criterias = []
max_tokens = 0
for i, r in enumerate(requests):
idx = self.requests_idx_mapping[r.id]
requests_idx_mapping[r.id] = i
# Get length
request_input_length = self.input_lengths[idx]
# Copy tensors (GPU)
input_ids[i] = self.input_ids[idx]
position_ids[i] = self.position_ids[idx]
# Copy to tensor (CPU)
cu_seqlens[i + 1] = cumulative_length + request_input_length
max_seqlen = max(max_seqlen, request_input_length)
# Slice from past
past_key_values.append(
self.past_key_values[:, self.cu_seqlens[idx] : self.cu_seqlens[idx + 1]]
)
all_input_ids.append(self.all_input_ids[idx])
all_input_ids_tensor.append(self.all_input_ids_tensor[idx])
input_lengths.append(request_input_length)
prefix_offsets.append(self.prefix_offsets[idx])
read_offsets.append(self.read_offsets[idx])
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
cumulative_length += request_input_length
max_tokens += request_input_length + (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
if single_request:
# Preallocate tensor for bs = 1 case
past_key_values = F.pad(
past_key_values[0],
(
0,
0,
0,
0,
0,
0,
0,
stopping_criterias[0].max_new_tokens
- stopping_criterias[0].current_tokens,
),
)
else:
# Cat all past
past_key_values = torch.cat(past_key_values, dim=1)
# Move to GPU now that we have the whole tensor
cu_seqlens = cu_seqlens.to(self.cu_seqlens.device)
return FlashCausalLMBatch(
batch_id=self.batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
cu_seqlens_q=cu_seqlens_q,
max_seqlen=max_seqlen,
past_key_values=past_key_values,
input_lengths=input_lengths,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
max_tokens=max_tokens,
)
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
# Batch attributes
requests = []
requests_idx_mapping = {}
total_batch_size = sum([len(b) for b in batches])
device = batches[0].input_ids.device
input_ids = batches[0].input_ids.new_empty(total_batch_size)
position_ids = batches[0].position_ids.new_empty(total_batch_size)
cu_seqlens = [0]
cu_seqlens_q = torch.arange(
0, total_batch_size + 1, device=device, dtype=torch.int32
)
2023-04-03 17:06:42 +00:00
max_seqlen = 0
past_key_values = []
all_input_ids = []
all_input_ids_tensor = []
input_lengths = []
prefix_offsets = []
read_offsets = []
next_token_choosers = []
stopping_criterias = []
2023-04-03 17:06:42 +00:00
# Cumulative length
cumulative_batch_size = 0
cumulative_length = 0
max_tokens = 0
2023-04-03 17:06:42 +00:00
for i, batch in enumerate(batches):
requests.extend(batch.requests)
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
else:
# We need to offset the mapping for each batch by the cumulative batch size
for k, v in batch.requests_idx_mapping.items():
requests_idx_mapping[k] = v + cumulative_batch_size
start_index = cumulative_batch_size
end_index = cumulative_batch_size + len(batch)
# Copy tensors (GPU)
input_ids[start_index:end_index] = batch.input_ids
position_ids[start_index:end_index] = batch.position_ids
# Add cumulative lengths of all previous inputs
cu_seqlens.extend([l + cumulative_length for l in batch.cu_seqlens[1:]])
max_seqlen = max(max_seqlen, batch.max_seqlen)
if len(batch) != 1:
past_key_values.append(batch.past_key_values)
else:
# past was pre-allocated for this batch
# We need to slice to remove the padding
past_key_values.append(
batch.past_key_values[:, : batch.input_lengths[0]]
)
all_input_ids.extend(batch.all_input_ids)
all_input_ids_tensor.extend(batch.all_input_ids_tensor)
2023-04-03 17:06:42 +00:00
input_lengths.extend(batch.input_lengths)
prefix_offsets.extend(batch.prefix_offsets)
read_offsets.extend(batch.read_offsets)
2023-04-03 17:06:42 +00:00
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
# Update
cumulative_length += batch.cu_seqlens[-1]
cumulative_batch_size += len(batch)
max_tokens += batch.max_tokens
2023-04-03 17:06:42 +00:00
# Cat past
past_key_values = torch.cat(past_key_values, dim=1)
# Create final tensor on GPU
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
2023-04-03 17:06:42 +00:00
return FlashCausalLMBatch(
batch_id=batches[0].batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
2023-04-03 17:06:42 +00:00
input_ids=input_ids,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
cu_seqlens_q=cu_seqlens_q,
2023-04-03 17:06:42 +00:00
max_seqlen=max_seqlen,
past_key_values=past_key_values,
input_lengths=input_lengths,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
2023-04-03 17:06:42 +00:00
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
max_tokens=max_tokens,
2023-04-03 17:06:42 +00:00
)
def __len__(self):
return len(self.requests)
class FlashCausalLM(Model):
def __init__(
self,
model_cls: Type[PreTrainedModel],
model_id: str,
revision: Optional[str] = None,
feat(server): GPTQ quantization (step1) (#277) Changes only the type from `bool` to `Option<Enum>` pretty much everywhere. - Use `Optional[str]` in Python (easier to manage than importing type everywhere). Except for the cli to get proper validation - Updated all models to handle gracefully new values. (Error out if unknown value, or gptq since not implemented). <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
2023-05-12 12:46:41 +00:00
quantize: Optional[str] = None,
2023-04-03 17:06:42 +00:00
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16
2023-04-03 17:06:42 +00:00
else:
raise NotImplementedError("FlashCausalLM is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
2023-04-03 17:06:42 +00:00
)
model = model_cls.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
load_in_8bit=quantize == "bitsandbytes",
).to(device)
2023-04-03 17:06:42 +00:00
super(FlashCausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
2023-04-03 17:06:42 +00:00
)
@property
def batch_type(self) -> Type[FlashCausalLMBatch]:
return FlashCausalLMBatch
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
return self.tokenizer.decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
2023-04-03 17:06:42 +00:00
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlens: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor],
2023-04-03 17:06:42 +00:00
max_s: int,
past_key_values: Optional = None,
pre_allocate_past_size: Optional[int] = None,
2023-04-03 17:06:42 +00:00
) -> Tuple[torch.Tensor, torch.Tensor]:
# Model Forward
return self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
cu_seqlens_q=cu_seqlens_q,
2023-04-03 17:06:42 +00:00
max_s=max_s,
past_key_values=past_key_values,
pre_allocate_past_size=pre_allocate_past_size,
2023-04-03 17:06:42 +00:00
)
@tracer.start_as_current_span("generate_token")
def generate_token(
self, batch: FlashCausalLMBatch
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
prefill = batch.past_key_values is None
if prefill and len(batch) == 1:
# Ask to pre-allocate kv to its max size
# == number of tokens + max_new_tokens
pre_allocate_past_size = (
batch.input_lengths[0] + batch.stopping_criterias[0].max_new_tokens
)
else:
pre_allocate_past_size = None
2023-04-03 17:06:42 +00:00
out, present = self.forward(
batch.input_ids,
batch.position_ids,
batch.cu_seqlens,
batch.cu_seqlens_q,
2023-04-03 17:06:42 +00:00
batch.max_seqlen,
batch.past_key_values,
pre_allocate_past_size,
2023-04-03 17:06:42 +00:00
)
if prefill:
if len(batch) > 1:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
# When batch == 1, we will just use the batch.input_ids values directly
prefill_tokens_indices = batch.input_ids.new_zeros(len(batch.input_ids))
# Create batch.cu_seqlens_q for decode
batch.cu_seqlens_q = torch.arange(
0, len(batch) + 1, device=self.device, dtype=torch.int32
)
next_input_ids = batch.input_ids.new_empty(len(batch))
next_position_ids = batch.position_ids.new_empty(len(batch))
else:
prefill_logprobs = None
next_input_ids = batch.input_ids
next_position_ids = batch.position_ids
next_token_logprobs = out.new_empty(len(batch))
# Prepare past for next decode
if len(batch) > 1:
# Used to slice next batch past
past_indices = torch.empty(
present.shape[1], dtype=torch.int64, device=self.device
)
batch.past_key_values = present.new_empty(
(
present.shape[0],
present.shape[1] + len(batch.requests),
*present.shape[2:],
)
)
# It is actually faster to do a whole other for loop here as the copy from present to past is fairly slow
# and will run asynchronously while we do the next for loop
cumulative_length = 0
for i, input_length in enumerate(batch.input_lengths):
# Indexing metadata
start_index = cumulative_length
end_index = cumulative_length + input_length
# Indices to copy present at the correct place in past_key_values
torch.arange(
start_index + i,
end_index + i,
dtype=torch.int64,
device=self.device,
out=past_indices[start_index:end_index],
)
cumulative_length += input_length
# Copy from present to past_key_values
batch.past_key_values[:, past_indices] = present
# Initialize past_key_values in prefill for len(batch) == 1
elif prefill:
# present is already pre-padded
batch.past_key_values = present
2023-04-03 17:06:42 +00:00
# Cumulative length
cumulative_length = 0
# Results
generations: List[Generation] = []
stopped = True
2023-04-03 17:06:42 +00:00
# Zipped iterator
iterator = zip(
batch.input_lengths,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
)
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second
# one, we need to first do a GPU <-> CPU sync
# It is faster if we delay this sync for the maximum amount of time
2023-04-03 17:06:42 +00:00
# For each member of the batch
for i, (
input_length,
next_token_chooser,
stopping_criteria,
all_input_ids,
) in enumerate(iterator):
# Indexing metadata
start_index = cumulative_length
end_index = cumulative_length + input_length
if prefill:
2023-04-03 17:06:42 +00:00
# Prefill mode
# out is of shape [cumulative_sequence_lengths, vocab_size]
# only take last token logit
logits = out[end_index - 1 : end_index]
# Create all_input_ids_tensor that will be used by token warpers (for example, RepetitionPenalty)
all_input_ids_tensor = batch.input_ids.new_empty(
input_length + stopping_criteria.max_new_tokens
)
# Copy from batch.input_ids to all_input_ids_tensor
all_input_ids_tensor[:input_length] = batch.input_ids[
start_index:end_index
]
batch.all_input_ids_tensor.append(all_input_ids_tensor)
# Initialize position_ids
# In decode, we do not need this as we can just increment position ids
next_position_ids[i] = batch.position_ids[end_index - 1]
# Used to gather prefill logprobs
# Copy batch.input_ids to prefill_token_indices
if len(batch) > 1:
prefill_tokens_indices[
start_index : end_index - 1
] = batch.input_ids[start_index + 1 : end_index]
else:
# Set prefill_tokens_indices to the correct slice
prefill_tokens_indices = batch.input_ids[
start_index + 1 : end_index
]
2023-04-03 17:06:42 +00:00
else:
# Decode mode
# out is of shape [batch_size, vocab_size]
logits = out[i].view(1, -1)
all_input_ids_tensor = batch.all_input_ids_tensor[i]
2023-04-03 17:06:42 +00:00
# Select next token
next_token_id, logprob = next_token_chooser(
2023-04-03 17:06:42 +00:00
all_input_ids_tensor[None, :input_length], logits
)
# Add to all_input_ids_tensor
next_token_id_squeezed = next_token_id.view(1)
all_input_ids_tensor[input_length] = next_token_id_squeezed
# Set values
next_input_ids[i] = next_token_id_squeezed
next_token_logprobs[i] = logprob[-1, next_token_id].view(1)
cumulative_length += input_length
# Set values in batch
batch.input_ids = next_input_ids
batch.position_ids = next_position_ids + 1
batch.cu_seqlens = batch.cu_seqlens + batch.cu_seqlens_q
if prefill:
# Get prefill logprobs
prefill_logprobs_tensor = torch.log_softmax(out, -1)
prefill_logprobs = torch.gather(
prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
)
# GPU <-> CPU sync
prefill_logprobs = prefill_logprobs.view(-1).tolist()
# GPU <-> CPU sync
next_token_logprobs = next_token_logprobs.tolist()
next_token_ids = batch.input_ids.tolist()
cumulative_length = 0
# Zipped iterator
iterator = zip(
batch.requests,
batch.input_lengths,
batch.prefix_offsets,
batch.read_offsets,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
batch.all_input_ids_tensor,
next_token_ids,
next_token_logprobs,
)
# For each member of the batch
for i, (
request,
input_length,
prefix_offset,
read_offset,
next_token_chooser,
stopping_criteria,
all_input_ids,
all_input_ids_tensor,
next_token_id,
next_token_logprob,
) in enumerate(iterator):
start_index = cumulative_length
end_index = cumulative_length + input_length
2023-04-03 17:06:42 +00:00
# Append next token to all tokens
all_input_ids.append(next_token_id)
2023-04-03 17:06:42 +00:00
# Generated token
next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids,
prefix_offset,
read_offset,
2023-04-03 17:06:42 +00:00
)
# Evaluate stopping criteria
stop, reason = stopping_criteria(
next_token_id,
2023-04-03 17:06:42 +00:00
next_token_text,
)
if not stop:
stopped = False
2023-04-03 17:06:42 +00:00
# Shard generations
# All generations will be appended in the rust sharded client
if i % self.world_size == self.rank:
if stop:
# Decode generated tokens
output_text = self.decode(
all_input_ids[-stopping_criteria.current_tokens :]
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
)
else:
generated_text = None
# Prefill
if prefill:
# Remove generated token to only have prefill and add nan for first prompt token
request_prefill_logprobs = [float("nan")] + prefill_logprobs[
start_index : end_index - 1
]
prefill_token_ids = all_input_ids[:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens(
prefill_token_ids, request_prefill_logprobs, prefill_texts
)
else:
prefill_tokens = None
generation = Generation(
request.id,
prefill_tokens,
next_token_id,
next_token_logprob,
next_token_text,
next_token_id in self.all_special_ids,
generated_text,
2023-04-03 17:06:42 +00:00
)
generations.append(generation)
2023-04-03 17:06:42 +00:00
new_input_length = input_length + 1
2023-04-03 17:06:42 +00:00
# Update values
batch.input_lengths[i] = new_input_length
batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset
batch.all_input_ids[i] = all_input_ids
batch.max_seqlen = batch.max_seqlen + 1
cumulative_length += input_length
# No need to return a batch if we know that all requests stopped
return generations, batch if not stopped else None