syntax = "proto3";

package generate.v3;

service TextGenerationService {
    /// Model Info
    rpc Info (InfoRequest) returns (InfoResponse) {}
    /// Service discovery
    rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {}
    /// Empties batch cache
    rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse);
    /// Remove requests from a cached batch
    rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse);
    /// Warmup the model and compute max cache size
    rpc Warmup (WarmupRequest) returns (WarmupResponse);
    /// Prefill batch and decode first token
    rpc Prefill (PrefillRequest) returns (PrefillResponse);
    /// Decode token for a list of prefilled batches
    rpc Decode (DecodeRequest) returns (DecodeResponse);
    /// Health check
    rpc Health (HealthRequest) returns (HealthResponse);
}

message HealthRequest {}
message HealthResponse {}

/// Empty request
message InfoRequest {}

message InfoResponse {
    bool requires_padding = 1;
    string dtype = 2;
    string device_type = 3;
    optional uint32 window_size = 4;
    uint32 speculate = 5;
}

/// Empty request
message ServiceDiscoveryRequest {}

message ServiceDiscoveryResponse {
    /// Other shards urls
    repeated string urls = 1;
}

message ClearCacheRequest {
    /// Optional batch id
    optional uint64 id = 1;
}

/// Empty response
message ClearCacheResponse {}

message Image {
    /// Binary image data.
    bytes data = 1;

    /// Image MIME type.
    string mimetype = 2;
}

message InputChunk {
    oneof chunk {
        /// Plain text data
        string text = 1;
        /// Image data
        Image image = 2;
    }
}

message Input {
    repeated InputChunk chunks = 1;
  }

enum GrammarType {
    GRAMMAR_TYPE_NONE = 0;
    GRAMMAR_TYPE_JSON = 1;
    GRAMMAR_TYPE_REGEX = 2;
}

message NextTokenChooserParameters {
    /// exponential scaling output probability distribution
    float temperature = 1;
    /// restricting to the k highest probability elements
    uint32 top_k = 2;
    /// restricting to top tokens summing to prob_cut_off <= prob_cut_off
    float top_p = 3;
    /// restricting to top tokens summing to prob_cut_off <= prob_cut_off
    float typical_p = 4;
    /// apply sampling on the logits
    bool do_sample = 5;
    /// random seed for sampling
    uint64 seed = 6;
    /// repetition penalty
    float repetition_penalty = 7;
    /// frequency penalty
    float frequency_penalty = 9;
    /// token watermarking using "A Watermark for Large Language Models"
    bool watermark = 8;
    /// grammar (applied if not empty)
    string grammar = 10;
    /// grammar type
    GrammarType grammar_type = 11;
}

message StoppingCriteriaParameters {
    /// Maximum number of generated tokens
    uint32 max_new_tokens = 1;
    /// Optional stopping sequences
    repeated string stop_sequences = 2;
    /// Ignore end of sequence token
    /// used for benchmarking
    bool ignore_eos_token = 3;
}

message Request {
    /// Request ID
    uint64 id = 1;
    /// The generation context as chunks
    Input input_chunks = 8;
    /// The generation context, stringified input_chunks
    string inputs = 2;
    /// Context truncation
    uint32 truncate = 3;
    /// Next Token Chooser Parameters
    NextTokenChooserParameters parameters = 4;
    /// Stopping Criteria Parameters
    StoppingCriteriaParameters stopping_parameters = 5;
    /// Return prefill logprobs
    bool prefill_logprobs = 6;
    /// Return most likely n tokens
    uint32 top_n_tokens = 7;
    /// Paged attention blocks
    repeated uint32 blocks = 9;
    /// Paged attention slots
    repeated uint32  slots = 10;
    /// LORA adapter index
    optional string adapter_id = 11;
}

message Batch {
    /// Batch ID
    uint64 id = 1;
    /// Individual requests
    repeated Request requests = 2;
    /// Batch size (==len(requests))
    uint32 size = 3;
    /// Maximum number of tokens this batch will grow to
    uint32 max_tokens = 4;
    /// Maximum number of Paged Attention blocks
    uint32 max_blocks = 5;
}

message CachedBatch {
    /// Batch ID
    uint64 id = 1;
    /// Individual requests ids
    repeated uint64 request_ids = 2;
    /// Batch size (==len(requests))
    uint32 size = 3;
    /// Maximum number of tokens this batch will grow to
    uint32 max_tokens = 4;
}

enum FinishReason {
    FINISH_REASON_LENGTH = 0;
    FINISH_REASON_EOS_TOKEN = 1;
    FINISH_REASON_STOP_SEQUENCE = 2;
}

message GeneratedText {
    /// Output
    string text = 1;
    /// Number of generated tokens
    uint32 generated_tokens = 2;
    /// Finish reason
    FinishReason finish_reason = 3;
    /// Seed
    optional uint64 seed = 4;
}

message Tokens {
    /// Token IDs
    repeated uint32 ids = 1;
    /// Logprobs
    repeated float logprobs = 2;
    /// tokens
    repeated string texts = 3;
    /// special
    repeated bool is_special = 4;
}

message Generation {
    /// Request ID
    uint64 request_id = 1;
    /// Prefill tokens (optional)
    Tokens prefill_tokens = 2;
    Tokens tokens = 3;
    /// Complete generated text
    optional GeneratedText generated_text = 4;
    /// Top tokens
    repeated Tokens top_tokens = 5;
}

message FilterBatchRequest {
    /// Batch ID
    uint64 batch_id = 1;
    /// Requests to keep
    repeated uint64 request_ids = 2;
}

message FilterBatchResponse {
    /// Filtered Batch (cached)
    CachedBatch batch = 1;
}


message PrefillRequest {
    /// Batch
    Batch batch = 1;
}

message PrefillResponse {
    /// Generation
    repeated Generation generations = 1;
    /// Next batch (cached)
    optional CachedBatch batch = 2;
    /// Forward elapsed time in nanoseconds
    uint64 forward_ns = 3;
    /// Decode elapsed time in nanoseconds
    uint64 decode_ns = 4;
    /// Total elapsed time in nanoseconds
    uint64 total_ns = 5;
}

message DecodeRequest {
    /// Cached batches
    repeated CachedBatch batches = 1;
}

message DecodeResponse {
    /// Decodes
    repeated Generation generations = 1;
    /// Next batch (cached)
    optional CachedBatch batch = 2;
    /// Forward elapsed time in nanoseconds
    uint64 forward_ns = 3;
    /// Decode elapsed time in nanoseconds
    uint64 decode_ns = 4;
    /// Total elapsed time in nanoseconds
    uint64 total_ns = 5;
    /// Concatenate elapsed time in nanoseconds
    optional uint64 concat_ns = 6;
}

message WarmupRequest {
    /// Batch to warmup on
    Batch batch = 1;
    uint32 max_input_length = 2;
    uint32 max_prefill_tokens = 3;
    uint32 max_total_tokens = 4;
}

message WarmupResponse {
    /// Maximum number of tokens supported by the model
    optional uint32 max_supported_total_tokens = 1;
}