text-generation-inference/router/client/src/lib.rs
drbh cef0553d59
Outlines guided generation (#1539)
This WIP PR starts to add grammar support via outlines, currently this
PR supports very simple regex grammars and does not optimize for
precompiling or caching grammar fsm's.

todo:
- [X] add simple outlines guidance to `NextTokenChooser`
- [X] update protos for grammar
- [X] update generation params API
- [X] constrain simple grammar
- [ ] support parsing more complex grammar into fsm
- [ ] support all outline support grammar types
- [ ] explore optimizations to avoid recompiling grammars

guided request
```bash
curl -s 'http://localhost:3000/generate' \
--header 'Content-Type: application/json' \
--data-raw '{
    "inputs": "make an email for david: \n",
    "parameters": {
        "max_new_tokens": 6,
        "grammar": "[\\w-]+@([\\w-]+\\.)+[\\w-]+"
    }
}' | jq
```
response
```json
{
  "generated_text": "david@example.com"
}
```

unguided request
```bash
curl -s 'http://localhost:3000/generate' \
--header 'Content-Type: application/json' \
--data '{
    "inputs": "make an email for david: \n",
    "parameters": {
        "max_new_tokens": 6
    }
}' | jq
```
response
```json
{
  "generated_text": "    email = 'david"
}
```
2024-02-15 10:28:10 +01:00

47 lines
1.2 KiB
Rust

//! Text Generation gRPC client library
mod client;
#[allow(clippy::derive_partial_eq_without_eq)]
mod pb;
mod sharded_client;
pub use client::Client;
pub use pb::generate::v2::HealthResponse;
pub use pb::generate::v2::InfoResponse as ShardInfo;
pub use pb::generate::v2::{
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
};
pub use sharded_client::ShardedClient;
use thiserror::Error;
use tonic::transport;
use tonic::Status;
#[derive(Error, Debug, Clone)]
pub enum ClientError {
#[error("Could not connect to Text Generation server: {0}")]
Connection(String),
#[error("Server error: {0}")]
Generation(String),
#[error("Sharded results are empty")]
EmptyResults,
}
impl From<Status> for ClientError {
fn from(err: Status) -> Self {
let err = Self::Generation(err.message().to_string());
tracing::error!("{err}");
err
}
}
impl From<transport::Error> for ClientError {
fn from(err: transport::Error) -> Self {
let err = Self::Connection(err.to_string());
tracing::error!("{err}");
err
}
}
pub type Result<T> = std::result::Result<T, ClientError>;