Updating code.

This commit is contained in:
Nicolas Patry 2023-04-26 14:25:55 +02:00
parent e28b5bf460
commit 9d613d0f9b
5 changed files with 49 additions and 30 deletions

View File

@ -67,6 +67,12 @@ jobs:
run: |
pip install pytest
HF_HUB_ENABLE_HF_TRANSFER=1 pytest -sv server/tests
- name: Run Rust fmt
run: |
cargo fmt --check
- name: Run Rust clippy
run: |
cargo clippy
- name: Run Rust tests
run: |
cargo test

View File

@ -20,20 +20,11 @@ pub struct HubModelInfo {
pub pipeline_tag: Option<String>,
}
#[derive(Clone, Debug, Serialize, ToSchema)]
pub struct HealthResponse {}
#[derive(Clone, Debug)]
pub struct Health {
pub client: ShardedClient,
}
impl Health {
pub fn new(client: ShardedClient) -> Self {
Self { client }
}
}
#[derive(Clone, Debug, Serialize, ToSchema)]
pub struct Info {
/// Model info

View File

@ -90,7 +90,7 @@ async fn get_model_info(info: Extension<Info>) -> Json<Info> {
responses(
(status = 200, description = "Everything is working fine"),
(status = 500, description = "Text generation inference is down", body = ErrorResponse,
example = json ! ({"error": "unhealthy"})),
example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})),
)
)]
#[instrument]
@ -98,11 +98,6 @@ async fn get_model_info(info: Extension<Info>) -> Json<Info> {
async fn health(
mut health: Extension<Health>,
) -> Result<Json<()>, (StatusCode, Json<ErrorResponse>)> {
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
// be a bit too slow for a health check.
// What we should do instead is check if the gRPC channels are still healthy.
// Send a small inference request
health.client.health().await.map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
@ -589,7 +584,9 @@ pub async fn run(
max_input_length,
max_total_tokens,
);
let health_ext = Health::new(client.clone());
let health_ext = Health {
client: client.clone(),
};
let infer = Infer::new(
client,
validation,

View File

@ -380,50 +380,75 @@ pub enum ValidationError {
}
#[cfg(test)]
mod tests{
mod tests {
use super::*;
use std::io::Write;
#[tokio::test]
async fn test_validation_max_new_tokens(){
async fn test_validation_max_new_tokens() {
let tokenizer = None;
let max_best_of = 2;
let max_stop_sequence = 3;
let max_input_length = 4;
let max_total_tokens = 5;
let workers = 1;
let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens);
let validation = Validation::new(
workers,
tokenizer,
max_best_of,
max_stop_sequence,
max_input_length,
max_total_tokens,
);
let max_new_tokens = 10;
match validation.validate_input("Hello".to_string(), None, max_new_tokens).await{
match validation
.validate_input("Hello".to_string(), None, max_new_tokens)
.await
{
Err(ValidationError::MaxNewTokens(1, 10)) => (),
_ => panic!("Unexpected not max new tokens")
_ => panic!("Unexpected not max new tokens"),
}
}
async fn get_tokenizer() -> Tokenizer{
if !std::path::Path::new("tokenizer.json").exists(){
let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json").await.unwrap().bytes().await.unwrap();
let mut file = std::fs::File::create("tokenizer.json").unwrap();
async fn get_tokenizer() -> Tokenizer {
if !std::path::Path::new("tokenizer.json").exists() {
let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json")
.await
.unwrap()
.bytes()
.await
.unwrap();
let mut file = std::fs::File::create("tokenizer.json").unwrap();
file.write_all(&content).unwrap();
}
Tokenizer::from_file("tokenizer.json").unwrap()
}
#[tokio::test]
async fn test_validation_input_length(){
async fn test_validation_input_length() {
let tokenizer = Some(get_tokenizer().await);
let max_best_of = 2;
let max_stop_sequence = 3;
let max_input_length = 4;
let max_total_tokens = 5;
let workers = 1;
let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens);
let validation = Validation::new(
workers,
tokenizer,
max_best_of,
max_stop_sequence,
max_input_length,
max_total_tokens,
);
let max_new_tokens = 10;
match validation.validate_input("Hello".to_string(), None, max_new_tokens).await{
match validation
.validate_input("Hello".to_string(), None, max_new_tokens)
.await
{
Err(ValidationError::MaxTotalTokens(5, 1, 10)) => (),
_ => panic!("Unexpected not max new tokens")
_ => panic!("Unexpected not max new tokens"),
}
}
}

View File

@ -31,7 +31,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
async def Health(self, request, context):
if self.model.device.type == "cuda":
torch.zeros((2, 2)).to(device=f"cuda:{os.environ['RANK']}")
torch.zeros((2, 2)).cuda()
return generate_pb2.HealthResponse()
async def ServiceDiscovery(self, request, context):