add warmup

This commit is contained in:
OlivierDehaene 2023-06-29 15:50:44 +02:00
parent d649cd8e02
commit ddfc02f2a4
14 changed files with 272 additions and 85 deletions

View File

@ -128,7 +128,10 @@ struct Args {
#[clap(default_value = "1.2", long, env)]
waiting_served_ratio: f32,
#[clap(default_value = "32000", long, env)]
/// Limits the number of tokens for the prefill operation.
/// Since this operation take the most memory and is compute bound, it is interesting
/// to limit the number of requests that can be sent.
#[clap(default_value = "4096", long, env)]
max_batch_prefill_tokens: u32,
/// **IMPORTANT** This is one critical control to allow maximum usage
@ -143,13 +146,6 @@ struct Args {
/// For `max_batch_total_tokens=1000`, you could fit `10` queries of `total_tokens=100`
/// or a single query of `1000` tokens.
///
/// So you don't have to control that finely
/// `max_batch_size` or `max_total_tokens`. In fact you could mostly relax them if you
/// want maximum flexibility. However, for your users if they are asking for the full amount of
/// total tokens, they are likely to wait for a very long time to get a spot
/// in the batch (since they are going to be alone) so setting `max_batch_size`
/// and `max_total_tokens` can still be useful to prevent those long waiting times.
///
/// Overall this number should be the largest possible amount that fits the
/// remaining memory (after the model is loaded). Since the actual memory overhead
/// depends on other parameters like if you're using quantization, flash attention
@ -448,7 +444,7 @@ fn shard_manager(
// We received a shutdown signal
if *shutdown.lock().unwrap() {
p.terminate().unwrap();
p.kill().unwrap();
let _ = p.wait_timeout(Duration::from_secs(90));
tracing::info!("Shard {rank} terminated");
return;

View File

@ -11,6 +11,8 @@ service TextGenerationService {
rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse);
/// Remove requests from a cached batch
rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse);
/// Warmup the model and compute max cache size
rpc Warmup (WarmupRequest) returns (WarmupResponse);
/// Prefill batch and decode first token
rpc Prefill (PrefillRequest) returns (PrefillResponse);
/// Decode token for a list of prefilled batches
@ -192,3 +194,13 @@ message DecodeResponse {
/// Next batch (cached)
optional CachedBatch batch = 2;
}
message WarmupRequest {
/// Batch to warmup on
Batch batch = 1;
/// Maximum number of tokens that the client will send
uint32 max_total_tokens = 2;
}
/// Empty response
message WarmupResponse {}

View File

@ -3,6 +3,7 @@ use crate::pb::generate::v1::text_generation_service_client::TextGenerationServi
use crate::pb::generate::v1::*;
use crate::Result;
use grpc_metadata::InjectTelemetryContext;
use std::cmp::min;
use tonic::transport::{Channel, Uri};
use tracing::instrument;
@ -94,6 +95,63 @@ impl Client {
Ok(filtered_batch.batch)
}
/// Warmup on a max size batch
///
/// Returns the maximum amount of tokens supported by the hardware
#[instrument(skip(self))]
pub async fn warmup(
&mut self,
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
) -> Result<()> {
let mut n_tokens = 0;
let mut requests = Vec::new();
// Create requests
while n_tokens < max_prefill_tokens {
requests.push(Request {
id: 0,
// We truncate the input on the server side to be sure that it has the correct size
inputs: "test".to_string().repeat(max_input_length as usize),
truncate: min(max_input_length, max_prefill_tokens - n_tokens),
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {
temperature: 0.9,
top_k: 10,
top_p: 0.9,
typical_p: 0.9,
do_sample: false,
seed: 0,
repetition_penalty: 1.2,
watermark: true,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 2,
stop_sequences: vec![],
ignore_eos_token: false,
}),
prefill_logprobs: true,
});
n_tokens += max_input_length;
}
let batch = Batch {
id: 0,
size: requests.len() as u32,
requests,
max_tokens: 0,
};
let request = tonic::Request::new(WarmupRequest {
batch: Some(batch),
max_total_tokens,
})
.inject_context();
self.stub.warmup(request).await?.into_inner();
Ok(())
}
/// Generate one token for each request in the given batch
///
/// Returns Generation for each request in batch

View File

@ -87,6 +87,27 @@ impl ShardedClient {
join_all(futures).await.pop().unwrap()
}
/// Warmup on a max size batch
///
/// Returns the maximum amount of tokens supported by the hardware
#[instrument(skip(self))]
pub async fn warmup(
&mut self,
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
) -> Result<()> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| {
Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens))
})
.collect();
// all shards return the same message
join_all(futures).await.pop().unwrap()
}
/// Generate one token for each request in the given batch
///
/// Returns Generation for each request in batch

View File

@ -242,6 +242,7 @@ impl Infer {
/// Will be launched in a background Tokio task
///
/// Batches requests and sends them to the inference server
#[allow(clippy::too_many_arguments)]
async fn batching_task(
mut client: ShardedClient,
waiting_served_ratio: f32,
@ -288,7 +289,8 @@ async fn batching_task(
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
};
let token_budget = max_batch_total_tokens - batch_max_tokens;
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
tracing::info!("{token_budget} {batch_max_tokens}");
// Try to get a new batch
if let Some((mut new_entries, new_batch, span)) = queue

View File

@ -34,7 +34,7 @@ struct Args {
max_total_tokens: usize,
#[clap(default_value = "1.2", long, env)]
waiting_served_ratio: f32,
#[clap(default_value = "32000", long, env)]
#[clap(default_value = "4096", long, env)]
max_batch_prefill_tokens: u32,
#[clap(default_value = "32000", long, env)]
max_batch_total_tokens: u32,
@ -180,16 +180,23 @@ fn main() -> Result<(), std::io::Error> {
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
.await
.expect("Could not connect to server");
// Clear the cache; useful if the webserver rebooted
sharded_client
.clear_cache(None)
.await
.expect("Unable to clear cache");
// Get info from the shard
let shard_info = sharded_client
.info()
.await
.expect("Unable to get shard info");
// Warmup model
tracing::info!("Warming up model");
sharded_client
.warmup(
max_input_length as u32,
max_batch_prefill_tokens,
max_batch_total_tokens,
)
.await
.expect("Unable to warmup model");
tracing::info!("Connected");
// Binds on localhost

View File

@ -19,11 +19,12 @@ class Cache:
def delete(self, batch_id: int):
batch = self.pop(batch_id)
if batch is not None:
batch.cleanup()
batch.free()
del batch
def clear(self):
for k in self.cache.keys():
keys = list(self.cache.keys())
for k in keys:
self.delete(k)
def __len__(self):

View File

@ -122,7 +122,7 @@ class CausalLMBatch(Batch):
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
max_tokens = len(inputs) * max_input_length + max_decode_tokens
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
return cls(
batch_id=pb.id,

View File

@ -1,3 +1,4 @@
import math
import itertools
import torch
import torch.distributed
@ -5,6 +6,7 @@ import torch.distributed
import numpy as np
from dataclasses import dataclass
from loguru import logger
from opentelemetry import trace
from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Union, Dict
@ -21,6 +23,7 @@ from text_generation_server.utils import StoppingCriteria, HeterogeneousNextToke
tracer = trace.get_tracer(__name__)
BLOCK_SIZE = 16
# Will be set in warmup
CACHE_MANAGER: Optional["CacheManager"] = None
@ -35,7 +38,7 @@ class CacheManager:
dtype: torch.dtype,
device: torch.device,
):
self.block_size = 16
self.block_size = BLOCK_SIZE
element_size = torch.tensor([], dtype=dtype).element_size()
x = self.block_size // element_size
@ -60,26 +63,30 @@ class CacheManager:
0, num_blocks * self.block_size, dtype=torch.int32
).view(num_blocks, self.block_size)
def allocate(self, n_tokens: int) -> Tuple[List[int], torch.Tensor]:
# Number of needed block to cover all tokens
needed_blocks = (n_tokens // self.block_size) + 1
def allocate(self, num_blocks: int) -> Tuple[torch.Tensor, torch.Tensor]:
# Get free blocks indices by finding values in mask that are not set to 0
free_block_indices = self.free_block_mask.nonzero()
assert len(free_block_indices) >= needed_blocks, "Out of available cache blocks"
logger.info(f"Free blocks: {len(free_block_indices)}")
assert (
len(free_block_indices) >= num_blocks
), f"Out of available cache blocks: asked {num_blocks}, only {len(free_block_indices)} free blocks"
# Allocate the required number of blocks by setting the mask to 0
block_indices = free_block_indices[:needed_blocks]
block_indices = free_block_indices[:num_blocks]
self.free_block_mask[block_indices] = 0
# Get slots for the allocated blocks
slots = self.slots[block_indices].flatten()[:n_tokens]
slots = self.slots[block_indices].flatten()
return block_indices.flatten().tolist(), slots
logger.info(f"allocate {num_blocks} blocks")
def free(self, block_indices: List[int]):
# Reset mask
self.free_block_mask[block_indices] = 1
return block_indices.flatten(), slots
def free(self, block_indices: Optional[List[int]]):
if block_indices is not None:
# Reset mask
logger.info(f"free {len(block_indices)} blocks")
self.free_block_mask[block_indices] = 1
@dataclass
@ -97,16 +104,25 @@ class FlashCausalLMBatch(Batch):
start_seq_prefill: Optional[torch.Tensor]
# tensor of length b holding ending offset of each sequence, only used in prefill
end_seq_prefill: Optional[torch.Tensor]
# list of length b of list of length s_i // block_size
block_tables: List[List[int]]
# tensor of size [b, max_seqlen // block_size] holding the paged attention block tables for all sequences
block_tables_tensor: torch.Tensor
# Paged Attention values
# Set when creating the batch
# CPU tensor of length b indicating the start of each sequence in slots
start_slots: torch.Tensor
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
slots: torch.Tensor
# tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
slot_indices: torch.Tensor
# List of tuple of ints representing the number of blocks and slots needed by each sequence
needed_blocks_slots: Optional[List[Tuple[int, int]]]
# Set in prefill by the CacheManager
# list of length b of list of length s_i // block_size
block_tables: Optional[List[List[int]]]
# tensor of size [b, max_seqlen // block_size] holding the paged attention block tables for all sequences
block_tables_tensor: Optional[torch.Tensor]
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
slots: Optional[torch.Tensor]
max_seqlen: int
# Prefill metadata tensors to efficiently compute logprobs
@ -128,16 +144,17 @@ class FlashCausalLMBatch(Batch):
next_token_chooser: HeterogeneousNextTokenChooser
stopping_criterias: List[StoppingCriteria]
# Number of blocks in this batch
blocks: int
# Maximum number of blocks
max_blocks: int
def to_pb(self) -> generate_pb2.CachedBatch:
global CACHE_MANAGER
return generate_pb2.CachedBatch(
id=self.batch_id,
request_ids=[r.id for r in self.requests],
size=len(self),
max_tokens=len(self.slots),
max_tokens=self.blocks * BLOCK_SIZE,
)
@classmethod
@ -148,8 +165,6 @@ class FlashCausalLMBatch(Batch):
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
global CACHE_MANAGER
batch_inputs = []
max_truncation = 0
for r in pb.requests:
@ -163,9 +178,8 @@ class FlashCausalLMBatch(Batch):
position_ids = []
start_seq_prefill = []
end_seq_prefill = []
block_tables = []
needed_blocks_slots = []
start_slots = []
slots = []
slot_indices = []
input_lengths = []
@ -188,6 +202,7 @@ class FlashCausalLMBatch(Batch):
cumulative_max_length = 0
prefill_out_cumulative_length = 0
blocks = 0
max_seqlen = 0
max_length = 0
max_blocks = 0
@ -228,9 +243,9 @@ class FlashCausalLMBatch(Batch):
# Paged attention
# Remove one as the first token des not have a past
total_tokens = input_length + max_new_tokens - 1
request_blocks, request_slots = CACHE_MANAGER.allocate(total_tokens)
block_tables.append(request_blocks)
slots.extend(request_slots)
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
blocks += needed_blocks
needed_blocks_slots.append((needed_blocks, total_tokens))
start_slots.append(cumulative_max_length)
request_slot_indices = torch.arange(
@ -264,7 +279,7 @@ class FlashCausalLMBatch(Batch):
cumulative_length += input_length
cumulative_max_length += total_tokens
max_seqlen = max(max_seqlen, input_length)
max_blocks = max(max_blocks, len(request_blocks))
max_blocks = max(max_blocks, needed_blocks)
max_length = max(max_length, input_length + max_new_tokens)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
@ -272,15 +287,6 @@ class FlashCausalLMBatch(Batch):
)
start_slots = torch.tensor(start_slots, dtype=torch.int64)
# Padded block tables
block_tables_tensor = torch.zeros(
(len(pb.requests), max_blocks), dtype=torch.int32
)
for i, request_blocks in enumerate(block_tables):
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
block_tables_tensor = block_tables_tensor.to(device)
# Padded all_input_ids_tensor
all_input_ids_tensor = np.zeros(
(len(all_input_ids), max_length), dtype=np.int64
@ -312,7 +318,6 @@ class FlashCausalLMBatch(Batch):
position_ids = position_ids.to(device)
slot_indices = slot_indices.to(device)
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
slots = torch.tensor(slots, dtype=torch.int32, device=device)
input_lengths_tensor = torch.tensor(
input_lengths, dtype=torch.int32, device=device
)
@ -339,11 +344,12 @@ class FlashCausalLMBatch(Batch):
position_ids=position_ids,
start_seq_prefill=start_seq_prefill,
end_seq_prefill=end_seq_prefill,
block_tables=block_tables,
block_tables_tensor=block_tables_tensor,
start_slots=start_slots,
slots=slots,
slot_indices=slot_indices,
needed_blocks_slots=needed_blocks_slots,
block_tables=None,
block_tables_tensor=None,
slots=None,
max_seqlen=max_seqlen,
prefill_head_indices=prefill_head_indices,
prefill_next_token_indices=prefill_next_token_indices,
@ -356,12 +362,12 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
blocks=blocks,
max_blocks=max_blocks,
)
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
global CACHE_MANAGER
if len(request_ids) == 0:
raise ValueError("Batch must have at least one request")
# We assume that if len(requests) == len(self) then the requests are the same
@ -396,6 +402,7 @@ class FlashCausalLMBatch(Batch):
stopping_criterias = []
blocks = 0
max_blocks = 0
# Cumulative length
cumulative_max_length = 0
@ -425,6 +432,7 @@ class FlashCausalLMBatch(Batch):
)
request_block_table = self.block_tables[idx]
blocks += len(request_block_table)
block_tables.append(request_block_table)
start_slots.append(cumulative_max_length)
@ -443,6 +451,7 @@ class FlashCausalLMBatch(Batch):
max_blocks = max(max_blocks, len(request_block_table))
global CACHE_MANAGER
# Iterate on all requests
for i, r in enumerate(self.requests):
# Filter requests that are not part of the new batch
@ -472,11 +481,12 @@ class FlashCausalLMBatch(Batch):
position_ids=position_ids,
start_seq_prefill=None,
end_seq_prefill=None,
start_slots=start_slots,
slot_indices=slot_indices,
needed_blocks_slots=None,
block_tables=block_tables,
block_tables_tensor=block_tables_tensor,
start_slots=start_slots,
slots=slots,
slot_indices=slot_indices,
max_seqlen=max_seqlen,
prefill_head_indices=None,
prefill_next_token_indices=None,
@ -489,17 +499,18 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
blocks=blocks,
max_blocks=max_blocks,
)
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
global CACHE_MANAGER
# Batch attributes
requests = []
requests_idx_mapping = {}
blocks = 0
total_batch_size = 0
total_slots = 0
max_blocks = 0
@ -508,6 +519,7 @@ class FlashCausalLMBatch(Batch):
for b in batches:
total_batch_size += len(b)
total_slots += len(b.slots)
blocks += b.blocks
max_blocks = max(max_blocks, b.max_blocks)
max_seqlen = max(max_seqlen, b.max_seqlen)
max_length = max(
@ -613,11 +625,12 @@ class FlashCausalLMBatch(Batch):
position_ids=position_ids,
start_seq_prefill=None,
end_seq_prefill=None,
start_slots=start_slots,
slot_indices=slot_indices,
needed_blocks_slots=None,
block_tables=block_tables,
block_tables_tensor=block_tables_tensor,
start_slots=start_slots,
slots=slots,
slot_indices=slot_indices,
max_seqlen=max_seqlen,
prefill_head_indices=None,
prefill_next_token_indices=None,
@ -630,13 +643,15 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
blocks=blocks,
max_blocks=max_blocks,
)
def cleanup(self):
global CACHE_MANAGER
# Free blocks
CACHE_MANAGER.free(list(itertools.chain.from_iterable(self.block_tables)))
def free(self):
if self.block_tables is not None:
global CACHE_MANAGER
# Free blocks
CACHE_MANAGER.free(list(itertools.chain.from_iterable(self.block_tables)))
def __len__(self):
return len(self.requests)
@ -648,22 +663,17 @@ class FlashCausalLM(Model):
model: torch.nn.Module,
tokenizer: PreTrainedTokenizerBase,
num_layers: int,
num_heads: int,
num_kv_heads: int,
head_size: int,
dtype: torch.dtype,
device: torch.device,
rank: int = 0,
world_size: int = 1,
):
self.num_heads = num_heads
self.num_layers = num_layers
self.num_kv_heads = num_kv_heads
self.head_size = head_size
global CACHE_MANAGER
torch.cuda.set_per_process_memory_fraction(1.0)
CACHE_MANAGER = CacheManager(
1000, num_layers, num_heads, head_size, dtype, device
)
super(FlashCausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
@ -678,6 +688,30 @@ class FlashCausalLM(Model):
def batch_type(self) -> Type[FlashCausalLMBatch]:
return FlashCausalLMBatch
def warmup(self, batch: FlashCausalLMBatch, max_total_tokens: int):
global CACHE_MANAGER
torch.cuda.empty_cache()
try:
CACHE_MANAGER = CacheManager(
# Adds some wiggle room
math.ceil(max_total_tokens / BLOCK_SIZE) + 10,
self.num_layers,
self.num_kv_heads,
self.head_size,
self.dtype,
self.device,
)
_, batch = self.generate_token(batch)
except Exception as e:
logger.error(
f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} "
f"prefill tokens. "
f"You need to decrease `--max-batch-total-tokens` and `--max-batch-prefill-tokens`"
)
raise e
batch.free()
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
return self.tokenizer.decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
@ -718,6 +752,35 @@ class FlashCausalLM(Model):
prefill = batch.start_seq_prefill is not None
prefill_logprobs = batch.prefill_next_token_indices is not None
if batch.needed_blocks_slots:
# Padded block tables
block_tables_tensor = torch.zeros(
(len(batch), batch.max_blocks), dtype=torch.int32
)
# Allocate paged attention blocks
slots = []
block_tables = []
try:
for i, (needed_blocks, needed_slots) in enumerate(
batch.needed_blocks_slots
):
allocated_blocks, allocated_slots = CACHE_MANAGER.allocate(
needed_blocks
)
slots.append(allocated_slots[:needed_slots])
block_tables.append(allocated_blocks.tolist())
block_tables_tensor[i, :needed_blocks] = allocated_blocks
except Exception as e:
for blocks in block_tables:
CACHE_MANAGER.free(blocks)
raise e
batch.needed_blocks_slots = None
batch.block_tables = block_tables
batch.block_tables_tensor = block_tables_tensor.to(self.device)
batch.slots = torch.concat(slots).to(self.device)
out = self.forward(
batch.input_ids,
batch.position_ids,
@ -931,7 +994,7 @@ class FlashCausalLM(Model):
batch.all_input_ids[i] = all_input_ids
if stopped:
batch.cleanup()
batch.free()
# No need to return a batch if we know that all requests stopped
return generations, None

View File

@ -68,7 +68,7 @@ class FlashLlama(FlashCausalLM):
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_heads=model.model.num_heads,
num_kv_heads=model.model.num_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,

View File

@ -22,6 +22,9 @@ class Model(ABC):
rank: int = 0,
world_size: int = 1,
):
if torch.cuda.is_available():
torch.cuda.set_per_process_memory_fraction(1.0)
self.model = model.eval()
self.tokenizer = tokenizer
self.all_special_ids = set(tokenizer.all_special_ids)
@ -55,6 +58,9 @@ class Model(ABC):
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
raise NotImplementedError
def warmup(self, batch: B, max_total_tokens: int):
self.generate_token(batch)
def decode_token(
self,
all_input_ids: List[int],

View File

@ -127,7 +127,7 @@ class Seq2SeqLMBatch(Batch):
read_offsets.append(1)
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
max_tokens = len(inputs) * max_input_length + max_decode_tokens
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
return cls(
batch_id=pb.id,

View File

@ -35,7 +35,7 @@ class Batch(ABC):
def concatenate(cls, batches: List["Batch"]) -> "Batch":
raise NotImplementedError
def cleanup(self):
def free(self):
pass
@abstractmethod

View File

@ -53,12 +53,24 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
async def Warmup(self, request, context):
batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
)
self.model.warmup(batch, request.max_total_tokens)
return generate_pb2.WarmupResponse()
async def Prefill(self, request, context):
batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
)
generations, next_batch = self.model.generate_token(batch)
try:
generations, next_batch = self.model.generate_token(batch)
except Exception as e:
batch.free()
raise e
self.cache.set(next_batch)
return generate_pb2.PrefillResponse(
@ -81,11 +93,20 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
raise ValueError("All batches are empty")
if len(batches) > 1:
batch = self.model.batch_type.concatenate(batches)
try:
batch = self.model.batch_type.concatenate(batches)
except Exception as e:
[batch.free() for batch in batches]
raise e
else:
batch = batches[0]
generations, next_batch = self.model.generate_token(batch)
try:
generations, next_batch = self.model.generate_token(batch)
except Exception as e:
batch.free()
raise e
self.cache.set(next_batch)
return generate_pb2.DecodeResponse(