text-generation-inference/router/src/health.rs
drbh 55acb86f42 Outlines guided generation (#1539)
This WIP PR starts to add grammar support via outlines, currently this
PR supports very simple regex grammars and does not optimize for
precompiling or caching grammar fsm's.

todo:
- [X] add simple outlines guidance to `NextTokenChooser`
- [X] update protos for grammar
- [X] update generation params API
- [X] constrain simple grammar
- [ ] support parsing more complex grammar into fsm
- [ ] support all outline support grammar types
- [ ] explore optimizations to avoid recompiling grammars

guided request
```bash
curl -s 'http://localhost:3000/generate' \
--header 'Content-Type: application/json' \
--data-raw '{
    "inputs": "make an email for david: \n",
    "parameters": {
        "max_new_tokens": 6,
        "grammar": "[\\w-]+@([\\w-]+\\.)+[\\w-]+"
    }
}' | jq
```
response
```json
{
  "generated_text": "david@example.com"
}
```

unguided request
```bash
curl -s 'http://localhost:3000/generate' \
--header 'Content-Type: application/json' \
--data '{
    "inputs": "make an email for david: \n",
    "parameters": {
        "max_new_tokens": 6
    }
}' | jq
```
response
```json
{
  "generated_text": "    email = 'david"
}
```
2024-04-24 14:57:37 +03:00

73 lines
2.4 KiB
Rust

use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use text_generation_client::GrammarType as ProtoGrammarType;
use text_generation_client::{
Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
};
// Note: Request ids and batch ids cannot collide.
const LIVENESS_ID: u64 = u64::MAX;
const BATCH_ID: u64 = u64::MAX;
#[derive(Clone, Debug)]
pub(crate) struct Health {
client: ShardedClient,
generation_health: Arc<AtomicBool>,
}
impl Health {
pub(crate) fn new(client: ShardedClient, generation_health: Arc<AtomicBool>) -> Self {
Self {
client,
generation_health,
}
}
pub(crate) async fn check(&mut self) -> bool {
if self.generation_health.load(Ordering::SeqCst) {
// Generation is healthy, we only check that the shards are answering gRPC calls
self.client.health().await.is_ok()
} else {
// Generation is unhealthy or have not sent any generation request yet
// Dummy batch of 1 token and 1 generated token
let liveness_request = Request {
id: LIVENESS_ID,
inputs: "liveness".to_string(),
truncate: 10,
prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters {
temperature: 1.0,
top_k: 0,
top_p: 1.0,
typical_p: 1.0,
do_sample: false,
seed: 0,
repetition_penalty: 1.0,
frequency_penalty: 0.0,
watermark: false,
grammar: String::new(),
grammar_type: ProtoGrammarType::None as i32,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1,
stop_sequences: vec![],
ignore_eos_token: false,
}),
top_n_tokens: 0,
};
let batch = Batch {
id: BATCH_ID,
requests: vec![liveness_request],
size: 1,
max_tokens: 2,
};
// Skips the queue
let value = self.client.prefill(batch).await.is_ok();
// Update generation health
self.generation_health.store(value, Ordering::SeqCst);
value
}
}
}