mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
This change adds support for prefix caching to the v3 router. This is broken up from the backend support to ease reviewing. For now prefix caching is only enabled with `USE_PREFIX_CACHING=1` in this case, the router will switch to `RadixAllocator`. This allocator uses a radix trie to keep track of prefills that were seen prior. If a new prefill is a prefix of a previously-seen prefil, the router will send a request with `prefix_len>0`, which can be used by the backend to decide to reuse KV blocks from the cache, rather than recomputing them. Even though backend support is not added in this PR, the backend will still work with prefix caching enabled. The prefix lengths are just ignored and not used.
268 lines
6.3 KiB
Protocol Buffer
268 lines
6.3 KiB
Protocol Buffer
syntax = "proto3";
|
|
|
|
package generate.v3;
|
|
|
|
service TextGenerationService {
|
|
/// Model Info
|
|
rpc Info(InfoRequest) returns (InfoResponse) {}
|
|
/// Service discovery
|
|
rpc ServiceDiscovery(ServiceDiscoveryRequest)
|
|
returns (ServiceDiscoveryResponse) {}
|
|
/// Empties batch cache
|
|
rpc ClearCache(ClearCacheRequest) returns (ClearCacheResponse);
|
|
/// Remove requests from a cached batch
|
|
rpc FilterBatch(FilterBatchRequest) returns (FilterBatchResponse);
|
|
/// Warmup the model and compute max cache size
|
|
rpc Warmup(WarmupRequest) returns (WarmupResponse);
|
|
/// Prefill batch and decode first token
|
|
rpc Prefill(PrefillRequest) returns (PrefillResponse);
|
|
/// Decode token for a list of prefilled batches
|
|
rpc Decode(DecodeRequest) returns (DecodeResponse);
|
|
/// Health check
|
|
rpc Health(HealthRequest) returns (HealthResponse);
|
|
}
|
|
|
|
message HealthRequest {}
|
|
message HealthResponse {}
|
|
|
|
/// Empty request
|
|
message InfoRequest {}
|
|
|
|
message InfoResponse {
|
|
bool requires_padding = 1;
|
|
string dtype = 2;
|
|
string device_type = 3;
|
|
optional uint32 window_size = 4;
|
|
uint32 speculate = 5;
|
|
}
|
|
|
|
/// Empty request
|
|
message ServiceDiscoveryRequest {}
|
|
|
|
message ServiceDiscoveryResponse {
|
|
/// Other shards urls
|
|
repeated string urls = 1;
|
|
}
|
|
|
|
message ClearCacheRequest {
|
|
/// Optional batch id
|
|
optional uint64 id = 1;
|
|
}
|
|
|
|
/// Empty response
|
|
message ClearCacheResponse {}
|
|
|
|
message Image {
|
|
/// Binary image data.
|
|
bytes data = 1;
|
|
|
|
/// Image MIME type.
|
|
string mimetype = 2;
|
|
}
|
|
|
|
message InputChunk {
|
|
oneof chunk {
|
|
/// Plain text data
|
|
string text = 1;
|
|
/// Image data
|
|
Image image = 2;
|
|
}
|
|
}
|
|
|
|
message Input { repeated InputChunk chunks = 1; }
|
|
|
|
enum GrammarType {
|
|
GRAMMAR_TYPE_NONE = 0;
|
|
GRAMMAR_TYPE_JSON = 1;
|
|
GRAMMAR_TYPE_REGEX = 2;
|
|
}
|
|
|
|
message NextTokenChooserParameters {
|
|
/// exponential scaling output probability distribution
|
|
float temperature = 1;
|
|
/// restricting to the k highest probability elements
|
|
uint32 top_k = 2;
|
|
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
|
float top_p = 3;
|
|
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
|
float typical_p = 4;
|
|
/// apply sampling on the logits
|
|
bool do_sample = 5;
|
|
/// random seed for sampling
|
|
uint64 seed = 6;
|
|
/// repetition penalty
|
|
float repetition_penalty = 7;
|
|
/// frequency penalty
|
|
float frequency_penalty = 9;
|
|
/// token watermarking using "A Watermark for Large Language Models"
|
|
bool watermark = 8;
|
|
/// grammar (applied if not empty)
|
|
string grammar = 10;
|
|
/// grammar type
|
|
GrammarType grammar_type = 11;
|
|
}
|
|
|
|
message StoppingCriteriaParameters {
|
|
/// Maximum number of generated tokens
|
|
uint32 max_new_tokens = 1;
|
|
/// Optional stopping sequences
|
|
repeated string stop_sequences = 2;
|
|
/// Ignore end of sequence token
|
|
/// used for benchmarking
|
|
bool ignore_eos_token = 3;
|
|
}
|
|
|
|
message Request {
|
|
/// Request ID
|
|
uint64 id = 1;
|
|
/// The generation context as chunks
|
|
Input input_chunks = 8;
|
|
/// The generation context, stringified input_chunks
|
|
string inputs = 2;
|
|
/// Context truncation
|
|
uint32 truncate = 3;
|
|
/// Next Token Chooser Parameters
|
|
NextTokenChooserParameters parameters = 4;
|
|
/// Stopping Criteria Parameters
|
|
StoppingCriteriaParameters stopping_parameters = 5;
|
|
/// Return prefill logprobs
|
|
bool prefill_logprobs = 6;
|
|
/// Return most likely n tokens
|
|
uint32 top_n_tokens = 7;
|
|
/// Paged attention blocks
|
|
repeated uint32 blocks = 9;
|
|
/// Paged attention slots
|
|
repeated uint32 slots = 10;
|
|
/// LORA adapter index
|
|
optional string adapter_id = 11;
|
|
/// Prefix length that can be retrieved from the KV cache.
|
|
uint32 prefix_len = 12;
|
|
}
|
|
|
|
message Batch {
|
|
/// Batch ID
|
|
uint64 id = 1;
|
|
/// Individual requests
|
|
repeated Request requests = 2;
|
|
/// Batch size (==len(requests))
|
|
uint32 size = 3;
|
|
/// Maximum number of tokens this batch will grow to
|
|
uint32 max_tokens = 4;
|
|
/// Maximum number of Paged Attention blocks
|
|
uint32 max_blocks = 5;
|
|
}
|
|
|
|
message CachedBatch {
|
|
/// Batch ID
|
|
uint64 id = 1;
|
|
/// Individual requests ids
|
|
repeated uint64 request_ids = 2;
|
|
/// Batch size (==len(requests))
|
|
uint32 size = 3;
|
|
/// Maximum number of tokens this batch will grow to
|
|
uint32 max_tokens = 4;
|
|
}
|
|
|
|
enum FinishReason {
|
|
FINISH_REASON_LENGTH = 0;
|
|
FINISH_REASON_EOS_TOKEN = 1;
|
|
FINISH_REASON_STOP_SEQUENCE = 2;
|
|
}
|
|
|
|
message GeneratedText {
|
|
/// Output
|
|
string text = 1;
|
|
/// Number of generated tokens
|
|
uint32 generated_tokens = 2;
|
|
/// Finish reason
|
|
FinishReason finish_reason = 3;
|
|
/// Seed
|
|
optional uint64 seed = 4;
|
|
}
|
|
|
|
message Tokens {
|
|
/// Token IDs
|
|
repeated uint32 ids = 1;
|
|
/// Logprobs
|
|
repeated float logprobs = 2;
|
|
/// tokens
|
|
repeated string texts = 3;
|
|
/// special
|
|
repeated bool is_special = 4;
|
|
}
|
|
|
|
message Generation {
|
|
/// Request ID
|
|
uint64 request_id = 1;
|
|
/// Prefill tokens (optional)
|
|
Tokens prefill_tokens = 2;
|
|
Tokens tokens = 3;
|
|
/// Complete generated text
|
|
optional GeneratedText generated_text = 4;
|
|
/// Top tokens
|
|
repeated Tokens top_tokens = 5;
|
|
}
|
|
|
|
message FilterBatchRequest {
|
|
/// Batch ID
|
|
uint64 batch_id = 1;
|
|
/// Requests to keep
|
|
repeated uint64 request_ids = 2;
|
|
}
|
|
|
|
message FilterBatchResponse {
|
|
/// Filtered Batch (cached)
|
|
CachedBatch batch = 1;
|
|
}
|
|
|
|
message PrefillRequest {
|
|
/// Batch
|
|
Batch batch = 1;
|
|
}
|
|
|
|
message PrefillResponse {
|
|
/// Generation
|
|
repeated Generation generations = 1;
|
|
/// Next batch (cached)
|
|
optional CachedBatch batch = 2;
|
|
/// Forward elapsed time in nanoseconds
|
|
uint64 forward_ns = 3;
|
|
/// Decode elapsed time in nanoseconds
|
|
uint64 decode_ns = 4;
|
|
/// Total elapsed time in nanoseconds
|
|
uint64 total_ns = 5;
|
|
}
|
|
|
|
message DecodeRequest {
|
|
/// Cached batches
|
|
repeated CachedBatch batches = 1;
|
|
}
|
|
|
|
message DecodeResponse {
|
|
/// Decodes
|
|
repeated Generation generations = 1;
|
|
/// Next batch (cached)
|
|
optional CachedBatch batch = 2;
|
|
/// Forward elapsed time in nanoseconds
|
|
uint64 forward_ns = 3;
|
|
/// Decode elapsed time in nanoseconds
|
|
uint64 decode_ns = 4;
|
|
/// Total elapsed time in nanoseconds
|
|
uint64 total_ns = 5;
|
|
/// Concatenate elapsed time in nanoseconds
|
|
optional uint64 concat_ns = 6;
|
|
}
|
|
|
|
message WarmupRequest {
|
|
/// Batch to warmup on
|
|
Batch batch = 1;
|
|
uint32 max_input_length = 2;
|
|
uint32 max_prefill_tokens = 3;
|
|
uint32 max_total_tokens = 4;
|
|
}
|
|
|
|
message WarmupResponse {
|
|
/// Maximum number of tokens supported by the model
|
|
optional uint32 max_supported_total_tokens = 1;
|
|
}
|