mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
add warmup
This commit is contained in:
parent
d649cd8e02
commit
ddfc02f2a4
@ -128,7 +128,10 @@ struct Args {
|
|||||||
#[clap(default_value = "1.2", long, env)]
|
#[clap(default_value = "1.2", long, env)]
|
||||||
waiting_served_ratio: f32,
|
waiting_served_ratio: f32,
|
||||||
|
|
||||||
#[clap(default_value = "32000", long, env)]
|
/// Limits the number of tokens for the prefill operation.
|
||||||
|
/// Since this operation take the most memory and is compute bound, it is interesting
|
||||||
|
/// to limit the number of requests that can be sent.
|
||||||
|
#[clap(default_value = "4096", long, env)]
|
||||||
max_batch_prefill_tokens: u32,
|
max_batch_prefill_tokens: u32,
|
||||||
|
|
||||||
/// **IMPORTANT** This is one critical control to allow maximum usage
|
/// **IMPORTANT** This is one critical control to allow maximum usage
|
||||||
@ -143,13 +146,6 @@ struct Args {
|
|||||||
/// For `max_batch_total_tokens=1000`, you could fit `10` queries of `total_tokens=100`
|
/// For `max_batch_total_tokens=1000`, you could fit `10` queries of `total_tokens=100`
|
||||||
/// or a single query of `1000` tokens.
|
/// or a single query of `1000` tokens.
|
||||||
///
|
///
|
||||||
/// So you don't have to control that finely
|
|
||||||
/// `max_batch_size` or `max_total_tokens`. In fact you could mostly relax them if you
|
|
||||||
/// want maximum flexibility. However, for your users if they are asking for the full amount of
|
|
||||||
/// total tokens, they are likely to wait for a very long time to get a spot
|
|
||||||
/// in the batch (since they are going to be alone) so setting `max_batch_size`
|
|
||||||
/// and `max_total_tokens` can still be useful to prevent those long waiting times.
|
|
||||||
///
|
|
||||||
/// Overall this number should be the largest possible amount that fits the
|
/// Overall this number should be the largest possible amount that fits the
|
||||||
/// remaining memory (after the model is loaded). Since the actual memory overhead
|
/// remaining memory (after the model is loaded). Since the actual memory overhead
|
||||||
/// depends on other parameters like if you're using quantization, flash attention
|
/// depends on other parameters like if you're using quantization, flash attention
|
||||||
@ -448,7 +444,7 @@ fn shard_manager(
|
|||||||
|
|
||||||
// We received a shutdown signal
|
// We received a shutdown signal
|
||||||
if *shutdown.lock().unwrap() {
|
if *shutdown.lock().unwrap() {
|
||||||
p.terminate().unwrap();
|
p.kill().unwrap();
|
||||||
let _ = p.wait_timeout(Duration::from_secs(90));
|
let _ = p.wait_timeout(Duration::from_secs(90));
|
||||||
tracing::info!("Shard {rank} terminated");
|
tracing::info!("Shard {rank} terminated");
|
||||||
return;
|
return;
|
||||||
|
@ -11,6 +11,8 @@ service TextGenerationService {
|
|||||||
rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse);
|
rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse);
|
||||||
/// Remove requests from a cached batch
|
/// Remove requests from a cached batch
|
||||||
rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse);
|
rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse);
|
||||||
|
/// Warmup the model and compute max cache size
|
||||||
|
rpc Warmup (WarmupRequest) returns (WarmupResponse);
|
||||||
/// Prefill batch and decode first token
|
/// Prefill batch and decode first token
|
||||||
rpc Prefill (PrefillRequest) returns (PrefillResponse);
|
rpc Prefill (PrefillRequest) returns (PrefillResponse);
|
||||||
/// Decode token for a list of prefilled batches
|
/// Decode token for a list of prefilled batches
|
||||||
@ -192,3 +194,13 @@ message DecodeResponse {
|
|||||||
/// Next batch (cached)
|
/// Next batch (cached)
|
||||||
optional CachedBatch batch = 2;
|
optional CachedBatch batch = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message WarmupRequest {
|
||||||
|
/// Batch to warmup on
|
||||||
|
Batch batch = 1;
|
||||||
|
/// Maximum number of tokens that the client will send
|
||||||
|
uint32 max_total_tokens = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Empty response
|
||||||
|
message WarmupResponse {}
|
||||||
|
@ -3,6 +3,7 @@ use crate::pb::generate::v1::text_generation_service_client::TextGenerationServi
|
|||||||
use crate::pb::generate::v1::*;
|
use crate::pb::generate::v1::*;
|
||||||
use crate::Result;
|
use crate::Result;
|
||||||
use grpc_metadata::InjectTelemetryContext;
|
use grpc_metadata::InjectTelemetryContext;
|
||||||
|
use std::cmp::min;
|
||||||
use tonic::transport::{Channel, Uri};
|
use tonic::transport::{Channel, Uri};
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
|
||||||
@ -94,6 +95,63 @@ impl Client {
|
|||||||
Ok(filtered_batch.batch)
|
Ok(filtered_batch.batch)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Warmup on a max size batch
|
||||||
|
///
|
||||||
|
/// Returns the maximum amount of tokens supported by the hardware
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn warmup(
|
||||||
|
&mut self,
|
||||||
|
max_input_length: u32,
|
||||||
|
max_prefill_tokens: u32,
|
||||||
|
max_total_tokens: u32,
|
||||||
|
) -> Result<()> {
|
||||||
|
let mut n_tokens = 0;
|
||||||
|
let mut requests = Vec::new();
|
||||||
|
|
||||||
|
// Create requests
|
||||||
|
while n_tokens < max_prefill_tokens {
|
||||||
|
requests.push(Request {
|
||||||
|
id: 0,
|
||||||
|
// We truncate the input on the server side to be sure that it has the correct size
|
||||||
|
inputs: "test".to_string().repeat(max_input_length as usize),
|
||||||
|
truncate: min(max_input_length, max_prefill_tokens - n_tokens),
|
||||||
|
// Set sampling parameters to also take these ops into account in the max memory
|
||||||
|
parameters: Some(NextTokenChooserParameters {
|
||||||
|
temperature: 0.9,
|
||||||
|
top_k: 10,
|
||||||
|
top_p: 0.9,
|
||||||
|
typical_p: 0.9,
|
||||||
|
do_sample: false,
|
||||||
|
seed: 0,
|
||||||
|
repetition_penalty: 1.2,
|
||||||
|
watermark: true,
|
||||||
|
}),
|
||||||
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
|
max_new_tokens: 2,
|
||||||
|
stop_sequences: vec![],
|
||||||
|
ignore_eos_token: false,
|
||||||
|
}),
|
||||||
|
prefill_logprobs: true,
|
||||||
|
});
|
||||||
|
n_tokens += max_input_length;
|
||||||
|
}
|
||||||
|
|
||||||
|
let batch = Batch {
|
||||||
|
id: 0,
|
||||||
|
size: requests.len() as u32,
|
||||||
|
requests,
|
||||||
|
max_tokens: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
let request = tonic::Request::new(WarmupRequest {
|
||||||
|
batch: Some(batch),
|
||||||
|
max_total_tokens,
|
||||||
|
})
|
||||||
|
.inject_context();
|
||||||
|
self.stub.warmup(request).await?.into_inner();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// Generate one token for each request in the given batch
|
/// Generate one token for each request in the given batch
|
||||||
///
|
///
|
||||||
/// Returns Generation for each request in batch
|
/// Returns Generation for each request in batch
|
||||||
|
@ -87,6 +87,27 @@ impl ShardedClient {
|
|||||||
join_all(futures).await.pop().unwrap()
|
join_all(futures).await.pop().unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Warmup on a max size batch
|
||||||
|
///
|
||||||
|
/// Returns the maximum amount of tokens supported by the hardware
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn warmup(
|
||||||
|
&mut self,
|
||||||
|
max_input_length: u32,
|
||||||
|
max_prefill_tokens: u32,
|
||||||
|
max_total_tokens: u32,
|
||||||
|
) -> Result<()> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| {
|
||||||
|
Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens))
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
// all shards return the same message
|
||||||
|
join_all(futures).await.pop().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
/// Generate one token for each request in the given batch
|
/// Generate one token for each request in the given batch
|
||||||
///
|
///
|
||||||
/// Returns Generation for each request in batch
|
/// Returns Generation for each request in batch
|
||||||
|
@ -242,6 +242,7 @@ impl Infer {
|
|||||||
/// Will be launched in a background Tokio task
|
/// Will be launched in a background Tokio task
|
||||||
///
|
///
|
||||||
/// Batches requests and sends them to the inference server
|
/// Batches requests and sends them to the inference server
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
async fn batching_task(
|
async fn batching_task(
|
||||||
mut client: ShardedClient,
|
mut client: ShardedClient,
|
||||||
waiting_served_ratio: f32,
|
waiting_served_ratio: f32,
|
||||||
@ -288,7 +289,8 @@ async fn batching_task(
|
|||||||
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
||||||
};
|
};
|
||||||
|
|
||||||
let token_budget = max_batch_total_tokens - batch_max_tokens;
|
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||||
|
tracing::info!("{token_budget} {batch_max_tokens}");
|
||||||
|
|
||||||
// Try to get a new batch
|
// Try to get a new batch
|
||||||
if let Some((mut new_entries, new_batch, span)) = queue
|
if let Some((mut new_entries, new_batch, span)) = queue
|
||||||
|
@ -34,7 +34,7 @@ struct Args {
|
|||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
#[clap(default_value = "1.2", long, env)]
|
#[clap(default_value = "1.2", long, env)]
|
||||||
waiting_served_ratio: f32,
|
waiting_served_ratio: f32,
|
||||||
#[clap(default_value = "32000", long, env)]
|
#[clap(default_value = "4096", long, env)]
|
||||||
max_batch_prefill_tokens: u32,
|
max_batch_prefill_tokens: u32,
|
||||||
#[clap(default_value = "32000", long, env)]
|
#[clap(default_value = "32000", long, env)]
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
@ -180,16 +180,23 @@ fn main() -> Result<(), std::io::Error> {
|
|||||||
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
||||||
.await
|
.await
|
||||||
.expect("Could not connect to server");
|
.expect("Could not connect to server");
|
||||||
// Clear the cache; useful if the webserver rebooted
|
|
||||||
sharded_client
|
|
||||||
.clear_cache(None)
|
|
||||||
.await
|
|
||||||
.expect("Unable to clear cache");
|
|
||||||
// Get info from the shard
|
// Get info from the shard
|
||||||
let shard_info = sharded_client
|
let shard_info = sharded_client
|
||||||
.info()
|
.info()
|
||||||
.await
|
.await
|
||||||
.expect("Unable to get shard info");
|
.expect("Unable to get shard info");
|
||||||
|
|
||||||
|
// Warmup model
|
||||||
|
tracing::info!("Warming up model");
|
||||||
|
sharded_client
|
||||||
|
.warmup(
|
||||||
|
max_input_length as u32,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("Unable to warmup model");
|
||||||
tracing::info!("Connected");
|
tracing::info!("Connected");
|
||||||
|
|
||||||
// Binds on localhost
|
// Binds on localhost
|
||||||
|
@ -19,11 +19,12 @@ class Cache:
|
|||||||
def delete(self, batch_id: int):
|
def delete(self, batch_id: int):
|
||||||
batch = self.pop(batch_id)
|
batch = self.pop(batch_id)
|
||||||
if batch is not None:
|
if batch is not None:
|
||||||
batch.cleanup()
|
batch.free()
|
||||||
del batch
|
del batch
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
for k in self.cache.keys():
|
keys = list(self.cache.keys())
|
||||||
|
for k in keys:
|
||||||
self.delete(k)
|
self.delete(k)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
@ -122,7 +122,7 @@ class CausalLMBatch(Batch):
|
|||||||
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
||||||
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
|
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
|
||||||
|
|
||||||
max_tokens = len(inputs) * max_input_length + max_decode_tokens
|
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=pb.id,
|
batch_id=pb.id,
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import math
|
||||||
import itertools
|
import itertools
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
@ -5,6 +6,7 @@ import torch.distributed
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from loguru import logger
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from typing import Optional, Tuple, List, Type, Union, Dict
|
from typing import Optional, Tuple, List, Type, Union, Dict
|
||||||
@ -21,6 +23,7 @@ from text_generation_server.utils import StoppingCriteria, HeterogeneousNextToke
|
|||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
BLOCK_SIZE = 16
|
||||||
# Will be set in warmup
|
# Will be set in warmup
|
||||||
CACHE_MANAGER: Optional["CacheManager"] = None
|
CACHE_MANAGER: Optional["CacheManager"] = None
|
||||||
|
|
||||||
@ -35,7 +38,7 @@ class CacheManager:
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
):
|
):
|
||||||
self.block_size = 16
|
self.block_size = BLOCK_SIZE
|
||||||
|
|
||||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||||
x = self.block_size // element_size
|
x = self.block_size // element_size
|
||||||
@ -60,26 +63,30 @@ class CacheManager:
|
|||||||
0, num_blocks * self.block_size, dtype=torch.int32
|
0, num_blocks * self.block_size, dtype=torch.int32
|
||||||
).view(num_blocks, self.block_size)
|
).view(num_blocks, self.block_size)
|
||||||
|
|
||||||
def allocate(self, n_tokens: int) -> Tuple[List[int], torch.Tensor]:
|
def allocate(self, num_blocks: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Number of needed block to cover all tokens
|
|
||||||
needed_blocks = (n_tokens // self.block_size) + 1
|
|
||||||
|
|
||||||
# Get free blocks indices by finding values in mask that are not set to 0
|
# Get free blocks indices by finding values in mask that are not set to 0
|
||||||
free_block_indices = self.free_block_mask.nonzero()
|
free_block_indices = self.free_block_mask.nonzero()
|
||||||
assert len(free_block_indices) >= needed_blocks, "Out of available cache blocks"
|
logger.info(f"Free blocks: {len(free_block_indices)}")
|
||||||
|
assert (
|
||||||
|
len(free_block_indices) >= num_blocks
|
||||||
|
), f"Out of available cache blocks: asked {num_blocks}, only {len(free_block_indices)} free blocks"
|
||||||
|
|
||||||
# Allocate the required number of blocks by setting the mask to 0
|
# Allocate the required number of blocks by setting the mask to 0
|
||||||
block_indices = free_block_indices[:needed_blocks]
|
block_indices = free_block_indices[:num_blocks]
|
||||||
self.free_block_mask[block_indices] = 0
|
self.free_block_mask[block_indices] = 0
|
||||||
|
|
||||||
# Get slots for the allocated blocks
|
# Get slots for the allocated blocks
|
||||||
slots = self.slots[block_indices].flatten()[:n_tokens]
|
slots = self.slots[block_indices].flatten()
|
||||||
|
|
||||||
return block_indices.flatten().tolist(), slots
|
logger.info(f"allocate {num_blocks} blocks")
|
||||||
|
|
||||||
def free(self, block_indices: List[int]):
|
return block_indices.flatten(), slots
|
||||||
# Reset mask
|
|
||||||
self.free_block_mask[block_indices] = 1
|
def free(self, block_indices: Optional[List[int]]):
|
||||||
|
if block_indices is not None:
|
||||||
|
# Reset mask
|
||||||
|
logger.info(f"free {len(block_indices)} blocks")
|
||||||
|
self.free_block_mask[block_indices] = 1
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -97,16 +104,25 @@ class FlashCausalLMBatch(Batch):
|
|||||||
start_seq_prefill: Optional[torch.Tensor]
|
start_seq_prefill: Optional[torch.Tensor]
|
||||||
# tensor of length b holding ending offset of each sequence, only used in prefill
|
# tensor of length b holding ending offset of each sequence, only used in prefill
|
||||||
end_seq_prefill: Optional[torch.Tensor]
|
end_seq_prefill: Optional[torch.Tensor]
|
||||||
# list of length b of list of length s_i // block_size
|
|
||||||
block_tables: List[List[int]]
|
# Paged Attention values
|
||||||
# tensor of size [b, max_seqlen // block_size] holding the paged attention block tables for all sequences
|
|
||||||
block_tables_tensor: torch.Tensor
|
# Set when creating the batch
|
||||||
# CPU tensor of length b indicating the start of each sequence in slots
|
# CPU tensor of length b indicating the start of each sequence in slots
|
||||||
start_slots: torch.Tensor
|
start_slots: torch.Tensor
|
||||||
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
|
|
||||||
slots: torch.Tensor
|
|
||||||
# tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
|
# tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
|
||||||
slot_indices: torch.Tensor
|
slot_indices: torch.Tensor
|
||||||
|
# List of tuple of ints representing the number of blocks and slots needed by each sequence
|
||||||
|
needed_blocks_slots: Optional[List[Tuple[int, int]]]
|
||||||
|
|
||||||
|
# Set in prefill by the CacheManager
|
||||||
|
# list of length b of list of length s_i // block_size
|
||||||
|
block_tables: Optional[List[List[int]]]
|
||||||
|
# tensor of size [b, max_seqlen // block_size] holding the paged attention block tables for all sequences
|
||||||
|
block_tables_tensor: Optional[torch.Tensor]
|
||||||
|
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
|
||||||
|
slots: Optional[torch.Tensor]
|
||||||
|
|
||||||
max_seqlen: int
|
max_seqlen: int
|
||||||
|
|
||||||
# Prefill metadata tensors to efficiently compute logprobs
|
# Prefill metadata tensors to efficiently compute logprobs
|
||||||
@ -128,16 +144,17 @@ class FlashCausalLMBatch(Batch):
|
|||||||
next_token_chooser: HeterogeneousNextTokenChooser
|
next_token_chooser: HeterogeneousNextTokenChooser
|
||||||
stopping_criterias: List[StoppingCriteria]
|
stopping_criterias: List[StoppingCriteria]
|
||||||
|
|
||||||
|
# Number of blocks in this batch
|
||||||
|
blocks: int
|
||||||
# Maximum number of blocks
|
# Maximum number of blocks
|
||||||
max_blocks: int
|
max_blocks: int
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.CachedBatch:
|
def to_pb(self) -> generate_pb2.CachedBatch:
|
||||||
global CACHE_MANAGER
|
|
||||||
return generate_pb2.CachedBatch(
|
return generate_pb2.CachedBatch(
|
||||||
id=self.batch_id,
|
id=self.batch_id,
|
||||||
request_ids=[r.id for r in self.requests],
|
request_ids=[r.id for r in self.requests],
|
||||||
size=len(self),
|
size=len(self),
|
||||||
max_tokens=len(self.slots),
|
max_tokens=self.blocks * BLOCK_SIZE,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -148,8 +165,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "FlashCausalLMBatch":
|
) -> "FlashCausalLMBatch":
|
||||||
global CACHE_MANAGER
|
|
||||||
|
|
||||||
batch_inputs = []
|
batch_inputs = []
|
||||||
max_truncation = 0
|
max_truncation = 0
|
||||||
for r in pb.requests:
|
for r in pb.requests:
|
||||||
@ -163,9 +178,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
position_ids = []
|
position_ids = []
|
||||||
start_seq_prefill = []
|
start_seq_prefill = []
|
||||||
end_seq_prefill = []
|
end_seq_prefill = []
|
||||||
block_tables = []
|
needed_blocks_slots = []
|
||||||
start_slots = []
|
start_slots = []
|
||||||
slots = []
|
|
||||||
slot_indices = []
|
slot_indices = []
|
||||||
|
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
@ -188,6 +202,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
cumulative_max_length = 0
|
cumulative_max_length = 0
|
||||||
prefill_out_cumulative_length = 0
|
prefill_out_cumulative_length = 0
|
||||||
|
|
||||||
|
blocks = 0
|
||||||
max_seqlen = 0
|
max_seqlen = 0
|
||||||
max_length = 0
|
max_length = 0
|
||||||
max_blocks = 0
|
max_blocks = 0
|
||||||
@ -228,9 +243,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
# Paged attention
|
# Paged attention
|
||||||
# Remove one as the first token des not have a past
|
# Remove one as the first token des not have a past
|
||||||
total_tokens = input_length + max_new_tokens - 1
|
total_tokens = input_length + max_new_tokens - 1
|
||||||
request_blocks, request_slots = CACHE_MANAGER.allocate(total_tokens)
|
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
||||||
block_tables.append(request_blocks)
|
blocks += needed_blocks
|
||||||
slots.extend(request_slots)
|
needed_blocks_slots.append((needed_blocks, total_tokens))
|
||||||
start_slots.append(cumulative_max_length)
|
start_slots.append(cumulative_max_length)
|
||||||
|
|
||||||
request_slot_indices = torch.arange(
|
request_slot_indices = torch.arange(
|
||||||
@ -264,7 +279,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
cumulative_length += input_length
|
cumulative_length += input_length
|
||||||
cumulative_max_length += total_tokens
|
cumulative_max_length += total_tokens
|
||||||
max_seqlen = max(max_seqlen, input_length)
|
max_seqlen = max(max_seqlen, input_length)
|
||||||
max_blocks = max(max_blocks, len(request_blocks))
|
max_blocks = max(max_blocks, needed_blocks)
|
||||||
max_length = max(max_length, input_length + max_new_tokens)
|
max_length = max(max_length, input_length + max_new_tokens)
|
||||||
|
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
@ -272,15 +287,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||||
|
|
||||||
# Padded block tables
|
|
||||||
block_tables_tensor = torch.zeros(
|
|
||||||
(len(pb.requests), max_blocks), dtype=torch.int32
|
|
||||||
)
|
|
||||||
for i, request_blocks in enumerate(block_tables):
|
|
||||||
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
|
|
||||||
|
|
||||||
block_tables_tensor = block_tables_tensor.to(device)
|
|
||||||
|
|
||||||
# Padded all_input_ids_tensor
|
# Padded all_input_ids_tensor
|
||||||
all_input_ids_tensor = np.zeros(
|
all_input_ids_tensor = np.zeros(
|
||||||
(len(all_input_ids), max_length), dtype=np.int64
|
(len(all_input_ids), max_length), dtype=np.int64
|
||||||
@ -312,7 +318,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
position_ids = position_ids.to(device)
|
position_ids = position_ids.to(device)
|
||||||
slot_indices = slot_indices.to(device)
|
slot_indices = slot_indices.to(device)
|
||||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||||
slots = torch.tensor(slots, dtype=torch.int32, device=device)
|
|
||||||
input_lengths_tensor = torch.tensor(
|
input_lengths_tensor = torch.tensor(
|
||||||
input_lengths, dtype=torch.int32, device=device
|
input_lengths, dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
@ -339,11 +344,12 @@ class FlashCausalLMBatch(Batch):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
start_seq_prefill=start_seq_prefill,
|
start_seq_prefill=start_seq_prefill,
|
||||||
end_seq_prefill=end_seq_prefill,
|
end_seq_prefill=end_seq_prefill,
|
||||||
block_tables=block_tables,
|
|
||||||
block_tables_tensor=block_tables_tensor,
|
|
||||||
start_slots=start_slots,
|
start_slots=start_slots,
|
||||||
slots=slots,
|
|
||||||
slot_indices=slot_indices,
|
slot_indices=slot_indices,
|
||||||
|
needed_blocks_slots=needed_blocks_slots,
|
||||||
|
block_tables=None,
|
||||||
|
block_tables_tensor=None,
|
||||||
|
slots=None,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
prefill_head_indices=prefill_head_indices,
|
prefill_head_indices=prefill_head_indices,
|
||||||
prefill_next_token_indices=prefill_next_token_indices,
|
prefill_next_token_indices=prefill_next_token_indices,
|
||||||
@ -356,12 +362,12 @@ class FlashCausalLMBatch(Batch):
|
|||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
|
blocks=blocks,
|
||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
)
|
)
|
||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
|
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
|
||||||
global CACHE_MANAGER
|
|
||||||
if len(request_ids) == 0:
|
if len(request_ids) == 0:
|
||||||
raise ValueError("Batch must have at least one request")
|
raise ValueError("Batch must have at least one request")
|
||||||
# We assume that if len(requests) == len(self) then the requests are the same
|
# We assume that if len(requests) == len(self) then the requests are the same
|
||||||
@ -396,6 +402,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
|
||||||
|
blocks = 0
|
||||||
max_blocks = 0
|
max_blocks = 0
|
||||||
# Cumulative length
|
# Cumulative length
|
||||||
cumulative_max_length = 0
|
cumulative_max_length = 0
|
||||||
@ -425,6 +432,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
request_block_table = self.block_tables[idx]
|
request_block_table = self.block_tables[idx]
|
||||||
|
blocks += len(request_block_table)
|
||||||
block_tables.append(request_block_table)
|
block_tables.append(request_block_table)
|
||||||
start_slots.append(cumulative_max_length)
|
start_slots.append(cumulative_max_length)
|
||||||
|
|
||||||
@ -443,6 +451,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
max_blocks = max(max_blocks, len(request_block_table))
|
max_blocks = max(max_blocks, len(request_block_table))
|
||||||
|
|
||||||
|
global CACHE_MANAGER
|
||||||
# Iterate on all requests
|
# Iterate on all requests
|
||||||
for i, r in enumerate(self.requests):
|
for i, r in enumerate(self.requests):
|
||||||
# Filter requests that are not part of the new batch
|
# Filter requests that are not part of the new batch
|
||||||
@ -472,11 +481,12 @@ class FlashCausalLMBatch(Batch):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
start_seq_prefill=None,
|
start_seq_prefill=None,
|
||||||
end_seq_prefill=None,
|
end_seq_prefill=None,
|
||||||
|
start_slots=start_slots,
|
||||||
|
slot_indices=slot_indices,
|
||||||
|
needed_blocks_slots=None,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
block_tables_tensor=block_tables_tensor,
|
block_tables_tensor=block_tables_tensor,
|
||||||
start_slots=start_slots,
|
|
||||||
slots=slots,
|
slots=slots,
|
||||||
slot_indices=slot_indices,
|
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
prefill_head_indices=None,
|
prefill_head_indices=None,
|
||||||
prefill_next_token_indices=None,
|
prefill_next_token_indices=None,
|
||||||
@ -489,17 +499,18 @@ class FlashCausalLMBatch(Batch):
|
|||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
|
blocks=blocks,
|
||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
|
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
|
||||||
global CACHE_MANAGER
|
|
||||||
# Batch attributes
|
# Batch attributes
|
||||||
requests = []
|
requests = []
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
|
|
||||||
|
blocks = 0
|
||||||
total_batch_size = 0
|
total_batch_size = 0
|
||||||
total_slots = 0
|
total_slots = 0
|
||||||
max_blocks = 0
|
max_blocks = 0
|
||||||
@ -508,6 +519,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
for b in batches:
|
for b in batches:
|
||||||
total_batch_size += len(b)
|
total_batch_size += len(b)
|
||||||
total_slots += len(b.slots)
|
total_slots += len(b.slots)
|
||||||
|
blocks += b.blocks
|
||||||
max_blocks = max(max_blocks, b.max_blocks)
|
max_blocks = max(max_blocks, b.max_blocks)
|
||||||
max_seqlen = max(max_seqlen, b.max_seqlen)
|
max_seqlen = max(max_seqlen, b.max_seqlen)
|
||||||
max_length = max(
|
max_length = max(
|
||||||
@ -613,11 +625,12 @@ class FlashCausalLMBatch(Batch):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
start_seq_prefill=None,
|
start_seq_prefill=None,
|
||||||
end_seq_prefill=None,
|
end_seq_prefill=None,
|
||||||
|
start_slots=start_slots,
|
||||||
|
slot_indices=slot_indices,
|
||||||
|
needed_blocks_slots=None,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
block_tables_tensor=block_tables_tensor,
|
block_tables_tensor=block_tables_tensor,
|
||||||
start_slots=start_slots,
|
|
||||||
slots=slots,
|
slots=slots,
|
||||||
slot_indices=slot_indices,
|
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
prefill_head_indices=None,
|
prefill_head_indices=None,
|
||||||
prefill_next_token_indices=None,
|
prefill_next_token_indices=None,
|
||||||
@ -630,13 +643,15 @@ class FlashCausalLMBatch(Batch):
|
|||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
|
blocks=blocks,
|
||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
)
|
)
|
||||||
|
|
||||||
def cleanup(self):
|
def free(self):
|
||||||
global CACHE_MANAGER
|
if self.block_tables is not None:
|
||||||
# Free blocks
|
global CACHE_MANAGER
|
||||||
CACHE_MANAGER.free(list(itertools.chain.from_iterable(self.block_tables)))
|
# Free blocks
|
||||||
|
CACHE_MANAGER.free(list(itertools.chain.from_iterable(self.block_tables)))
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.requests)
|
return len(self.requests)
|
||||||
@ -648,22 +663,17 @@ class FlashCausalLM(Model):
|
|||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
num_layers: int,
|
num_layers: int,
|
||||||
num_heads: int,
|
num_kv_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
):
|
):
|
||||||
self.num_heads = num_heads
|
self.num_layers = num_layers
|
||||||
|
self.num_kv_heads = num_kv_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
|
|
||||||
global CACHE_MANAGER
|
|
||||||
torch.cuda.set_per_process_memory_fraction(1.0)
|
|
||||||
CACHE_MANAGER = CacheManager(
|
|
||||||
1000, num_layers, num_heads, head_size, dtype, device
|
|
||||||
)
|
|
||||||
|
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
@ -678,6 +688,30 @@ class FlashCausalLM(Model):
|
|||||||
def batch_type(self) -> Type[FlashCausalLMBatch]:
|
def batch_type(self) -> Type[FlashCausalLMBatch]:
|
||||||
return FlashCausalLMBatch
|
return FlashCausalLMBatch
|
||||||
|
|
||||||
|
def warmup(self, batch: FlashCausalLMBatch, max_total_tokens: int):
|
||||||
|
global CACHE_MANAGER
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
try:
|
||||||
|
CACHE_MANAGER = CacheManager(
|
||||||
|
# Adds some wiggle room
|
||||||
|
math.ceil(max_total_tokens / BLOCK_SIZE) + 10,
|
||||||
|
self.num_layers,
|
||||||
|
self.num_kv_heads,
|
||||||
|
self.head_size,
|
||||||
|
self.dtype,
|
||||||
|
self.device,
|
||||||
|
)
|
||||||
|
_, batch = self.generate_token(batch)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} "
|
||||||
|
f"prefill tokens. "
|
||||||
|
f"You need to decrease `--max-batch-total-tokens` and `--max-batch-prefill-tokens`"
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
batch.free()
|
||||||
|
|
||||||
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
|
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
|
||||||
return self.tokenizer.decode(
|
return self.tokenizer.decode(
|
||||||
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||||
@ -718,6 +752,35 @@ class FlashCausalLM(Model):
|
|||||||
prefill = batch.start_seq_prefill is not None
|
prefill = batch.start_seq_prefill is not None
|
||||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||||
|
|
||||||
|
if batch.needed_blocks_slots:
|
||||||
|
# Padded block tables
|
||||||
|
block_tables_tensor = torch.zeros(
|
||||||
|
(len(batch), batch.max_blocks), dtype=torch.int32
|
||||||
|
)
|
||||||
|
|
||||||
|
# Allocate paged attention blocks
|
||||||
|
slots = []
|
||||||
|
block_tables = []
|
||||||
|
try:
|
||||||
|
for i, (needed_blocks, needed_slots) in enumerate(
|
||||||
|
batch.needed_blocks_slots
|
||||||
|
):
|
||||||
|
allocated_blocks, allocated_slots = CACHE_MANAGER.allocate(
|
||||||
|
needed_blocks
|
||||||
|
)
|
||||||
|
slots.append(allocated_slots[:needed_slots])
|
||||||
|
block_tables.append(allocated_blocks.tolist())
|
||||||
|
block_tables_tensor[i, :needed_blocks] = allocated_blocks
|
||||||
|
except Exception as e:
|
||||||
|
for blocks in block_tables:
|
||||||
|
CACHE_MANAGER.free(blocks)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
batch.needed_blocks_slots = None
|
||||||
|
batch.block_tables = block_tables
|
||||||
|
batch.block_tables_tensor = block_tables_tensor.to(self.device)
|
||||||
|
batch.slots = torch.concat(slots).to(self.device)
|
||||||
|
|
||||||
out = self.forward(
|
out = self.forward(
|
||||||
batch.input_ids,
|
batch.input_ids,
|
||||||
batch.position_ids,
|
batch.position_ids,
|
||||||
@ -931,7 +994,7 @@ class FlashCausalLM(Model):
|
|||||||
batch.all_input_ids[i] = all_input_ids
|
batch.all_input_ids[i] = all_input_ids
|
||||||
|
|
||||||
if stopped:
|
if stopped:
|
||||||
batch.cleanup()
|
batch.free()
|
||||||
# No need to return a batch if we know that all requests stopped
|
# No need to return a batch if we know that all requests stopped
|
||||||
return generations, None
|
return generations, None
|
||||||
|
|
||||||
|
@ -68,7 +68,7 @@ class FlashLlama(FlashCausalLM):
|
|||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
num_layers=len(model.model.layers),
|
num_layers=len(model.model.layers),
|
||||||
num_heads=model.model.num_heads,
|
num_kv_heads=model.model.num_heads,
|
||||||
head_size=model.model.head_size,
|
head_size=model.model.head_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
|
@ -22,6 +22,9 @@ class Model(ABC):
|
|||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
):
|
):
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.set_per_process_memory_fraction(1.0)
|
||||||
|
|
||||||
self.model = model.eval()
|
self.model = model.eval()
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.all_special_ids = set(tokenizer.all_special_ids)
|
self.all_special_ids = set(tokenizer.all_special_ids)
|
||||||
@ -55,6 +58,9 @@ class Model(ABC):
|
|||||||
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
|
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def warmup(self, batch: B, max_total_tokens: int):
|
||||||
|
self.generate_token(batch)
|
||||||
|
|
||||||
def decode_token(
|
def decode_token(
|
||||||
self,
|
self,
|
||||||
all_input_ids: List[int],
|
all_input_ids: List[int],
|
||||||
|
@ -127,7 +127,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
read_offsets.append(1)
|
read_offsets.append(1)
|
||||||
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
|
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
|
||||||
|
|
||||||
max_tokens = len(inputs) * max_input_length + max_decode_tokens
|
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=pb.id,
|
batch_id=pb.id,
|
||||||
|
@ -35,7 +35,7 @@ class Batch(ABC):
|
|||||||
def concatenate(cls, batches: List["Batch"]) -> "Batch":
|
def concatenate(cls, batches: List["Batch"]) -> "Batch":
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def cleanup(self):
|
def free(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -53,12 +53,24 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
|
|
||||||
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
||||||
|
|
||||||
|
async def Warmup(self, request, context):
|
||||||
|
batch = self.model.batch_type.from_pb(
|
||||||
|
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||||
|
)
|
||||||
|
self.model.warmup(batch, request.max_total_tokens)
|
||||||
|
return generate_pb2.WarmupResponse()
|
||||||
|
|
||||||
async def Prefill(self, request, context):
|
async def Prefill(self, request, context):
|
||||||
batch = self.model.batch_type.from_pb(
|
batch = self.model.batch_type.from_pb(
|
||||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||||
)
|
)
|
||||||
|
|
||||||
generations, next_batch = self.model.generate_token(batch)
|
try:
|
||||||
|
generations, next_batch = self.model.generate_token(batch)
|
||||||
|
except Exception as e:
|
||||||
|
batch.free()
|
||||||
|
raise e
|
||||||
|
|
||||||
self.cache.set(next_batch)
|
self.cache.set(next_batch)
|
||||||
|
|
||||||
return generate_pb2.PrefillResponse(
|
return generate_pb2.PrefillResponse(
|
||||||
@ -81,11 +93,20 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
raise ValueError("All batches are empty")
|
raise ValueError("All batches are empty")
|
||||||
|
|
||||||
if len(batches) > 1:
|
if len(batches) > 1:
|
||||||
batch = self.model.batch_type.concatenate(batches)
|
try:
|
||||||
|
batch = self.model.batch_type.concatenate(batches)
|
||||||
|
except Exception as e:
|
||||||
|
[batch.free() for batch in batches]
|
||||||
|
raise e
|
||||||
else:
|
else:
|
||||||
batch = batches[0]
|
batch = batches[0]
|
||||||
|
|
||||||
generations, next_batch = self.model.generate_token(batch)
|
try:
|
||||||
|
generations, next_batch = self.model.generate_token(batch)
|
||||||
|
except Exception as e:
|
||||||
|
batch.free()
|
||||||
|
raise e
|
||||||
|
|
||||||
self.cache.set(next_batch)
|
self.cache.set(next_batch)
|
||||||
|
|
||||||
return generate_pb2.DecodeResponse(
|
return generate_pb2.DecodeResponse(
|
||||||
|
Loading…
Reference in New Issue
Block a user