syntax = "proto3";

package generate.v3;

service TextGenerationService {
  /// Model Info
  rpc Info(InfoRequest) returns (InfoResponse) {}
  /// Service discovery
  rpc ServiceDiscovery(ServiceDiscoveryRequest)
      returns (ServiceDiscoveryResponse) {}
  /// Empties batch cache
  rpc ClearCache(ClearCacheRequest) returns (ClearCacheResponse);
  /// Remove requests from a cached batch
  rpc FilterBatch(FilterBatchRequest) returns (FilterBatchResponse);
  /// Warmup the model and compute max cache size
  rpc Warmup(WarmupRequest) returns (WarmupResponse);
  /// Prefill batch and decode first token
  rpc Prefill(PrefillRequest) returns (PrefillResponse);
  /// Decode token for a list of prefilled batches
  rpc Decode(DecodeRequest) returns (DecodeResponse);
  /// Health check
  rpc Health(HealthRequest) returns (HealthResponse);
}

message HealthRequest {}
message HealthResponse {}

/// Empty request
message InfoRequest {}

message InfoResponse {
  bool requires_padding = 1;
  string dtype = 2;
  string device_type = 3;
  optional uint32 window_size = 4;
  uint32 speculate = 5;
  bool support_chunking = 6;
  bool use_prefix_caching = 7;
  string attention_impl = 8;
  uint32 block_size = 9;
}

/// Empty request
message ServiceDiscoveryRequest {}

message ServiceDiscoveryResponse {
  /// Other shards urls
  repeated string urls = 1;
}

message ClearCacheRequest {
  /// Optional batch id
  optional uint64 id = 1;
}

/// Empty response
message ClearCacheResponse {}

message Image {
  /// Binary image data.
  bytes data = 1;

  /// Image MIME type.
  string mimetype = 2;
}

message InputChunk {
  oneof chunk {
    /// Plain text data
    string text = 1;
    /// Image data
    Image image = 2;
  }
}

message Input { repeated InputChunk chunks = 1; }

enum GrammarType {
  GRAMMAR_TYPE_NONE = 0;
  GRAMMAR_TYPE_JSON = 1;
  GRAMMAR_TYPE_REGEX = 2;
}

message NextTokenChooserParameters {
  /// exponential scaling output probability distribution
  float temperature = 1;
  /// restricting to the k highest probability elements
  uint32 top_k = 2;
  /// restricting to top tokens summing to prob_cut_off <= prob_cut_off
  float top_p = 3;
  /// restricting to top tokens summing to prob_cut_off <= prob_cut_off
  float typical_p = 4;
  /// apply sampling on the logits
  bool do_sample = 5;
  /// random seed for sampling
  uint64 seed = 6;
  /// repetition penalty
  float repetition_penalty = 7;
  /// frequency penalty
  float frequency_penalty = 9;
  /// token watermarking using "A Watermark for Large Language Models"
  bool watermark = 8;
  /// grammar (applied if not empty)
  string grammar = 10;
  /// grammar type
  GrammarType grammar_type = 11;
}

message StoppingCriteriaParameters {
  /// Maximum number of generated tokens
  uint32 max_new_tokens = 1;
  /// Optional stopping sequences
  repeated string stop_sequences = 2;
  /// Ignore end of sequence token
  /// used for benchmarking
  bool ignore_eos_token = 3;
}

message Request {
  /// Request ID
  uint64 id = 1;
  /// The generation context as chunks
  Input input_chunks = 8;
  /// The generation context, stringified input_chunks
  string inputs = 2;
  /// Context truncation
  uint32 truncate = 3;
  /// Next Token Chooser Parameters
  NextTokenChooserParameters parameters = 4;
  /// Stopping Criteria Parameters
  StoppingCriteriaParameters stopping_parameters = 5;
  /// Return prefill logprobs
  bool prefill_logprobs = 6;
  /// Return most likely n tokens
  uint32 top_n_tokens = 7;
  /// Paged attention blocks
  repeated uint32 blocks = 9;
  /// Paged attention slots
  repeated uint32 slots = 10;
  /// LORA adapter index
  optional string adapter_id = 11;
  /// Tokens that can be retrieved from the KV cache.
  /// This value is set for the first prefill and never reset
  uint32 cache_len = 12;
  /// Context truncation
  bool add_special_tokens = 13;
  /// Chunk of tokens that must be computed for the first prefill
  /// This value is set for the first prefill and never reset
  optional uint32 chunk_len = 14;
}

message Batch {
  /// Batch ID
  uint64 id = 1;
  /// Individual requests
  repeated Request requests = 2;
  /// Batch size (==len(requests))
  uint32 size = 3;
  /// Maximum number of tokens this batch will grow to
  uint32 max_tokens = 4;
  /// Maximum number of Paged Attention blocks
  uint32 max_blocks = 5;
}

message CachedBatch {
  /// Batch ID
  uint64 id = 1;
  /// Individual requests ids
  repeated uint64 request_ids = 2;
  /// Batch size (==len(requests))
  uint32 size = 3;
  /// Maximum number of tokens this batch will grow to
  uint32 max_tokens = 4;
  /// Number of tokens in the next forward
  uint32 current_tokens = 5;
}

enum FinishReason {
  FINISH_REASON_LENGTH = 0;
  FINISH_REASON_EOS_TOKEN = 1;
  FINISH_REASON_STOP_SEQUENCE = 2;
}

message GeneratedText {
  /// Output
  string text = 1;
  /// Number of generated tokens
  uint32 generated_tokens = 2;
  /// Finish reason
  FinishReason finish_reason = 3;
  /// Seed
  optional uint64 seed = 4;
}

message Tokens {
  /// Token IDs
  repeated uint32 ids = 1;
  /// Logprobs
  repeated float logprobs = 2;
  /// tokens
  repeated string texts = 3;
  /// special
  repeated bool is_special = 4;
}

message Generation {
  /// Request ID
  uint64 request_id = 1;
  /// Prefill tokens (optional)
  Tokens prefill_tokens = 2;
  Tokens tokens = 3;
  /// Complete generated text
  optional GeneratedText generated_text = 4;
  /// Top tokens
  repeated Tokens top_tokens = 5;
}

message FilterBatchRequest {
  /// Batch ID
  uint64 batch_id = 1;
  /// Requests to keep
  repeated uint64 request_ids = 2;
}

message FilterBatchResponse {
  /// Filtered Batch (cached)
  CachedBatch batch = 1;
}

message PrefillRequest {
  /// Batch
  Batch batch = 1;
  /// Optional cached batch
  CachedBatch cached_batch = 2;
}

message PrefillResponse {
  /// Generation
  repeated Generation generations = 1;
  /// Next batch (cached)
  optional CachedBatch batch = 2;
  /// Forward elapsed time in nanoseconds
  uint64 forward_ns = 3;
  /// Decode elapsed time in nanoseconds
  uint64 decode_ns = 4;
  /// Total elapsed time in nanoseconds
  uint64 total_ns = 5;
  /// Concatenate elapsed time in nanoseconds
  optional uint64 concat_ns = 6;
}

message DecodeRequest {
  /// Cached batches
  repeated CachedBatch batches = 1;
}

message DecodeResponse {
  /// Decodes
  repeated Generation generations = 1;
  /// Next batch (cached)
  optional CachedBatch batch = 2;
  /// Forward elapsed time in nanoseconds
  uint64 forward_ns = 3;
  /// Decode elapsed time in nanoseconds
  uint64 decode_ns = 4;
  /// Total elapsed time in nanoseconds
  uint64 total_ns = 5;
  /// Concatenate elapsed time in nanoseconds
  optional uint64 concat_ns = 6;
}

message WarmupRequest {
  /// Batch to warmup on
  Batch batch = 1;
  optional uint32 max_input_tokens = 2;
  uint32 max_prefill_tokens = 3;
  optional uint32 max_total_tokens = 4;
}

message WarmupResponse {
  /// Maximum number of tokens supported by the model
  optional uint32 max_supported_total_tokens = 1;
  /// Maximum input tokens by clients should be equal to request value if it's set
  /// Otherwise warmup automatically allocates a value here
  uint32 max_input_tokens = 2;
  /// Maximum total tokens by clients should be equal to request value if it's set
  /// Otherwise warmup automatically allocates a value here
  uint32 max_total_tokens = 3;
}