mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
Updating code.
This commit is contained in:
parent
e28b5bf460
commit
9d613d0f9b
6
.github/workflows/tests.yaml
vendored
6
.github/workflows/tests.yaml
vendored
@ -67,6 +67,12 @@ jobs:
|
||||
run: |
|
||||
pip install pytest
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 pytest -sv server/tests
|
||||
- name: Run Rust fmt
|
||||
run: |
|
||||
cargo fmt --check
|
||||
- name: Run Rust clippy
|
||||
run: |
|
||||
cargo clippy
|
||||
- name: Run Rust tests
|
||||
run: |
|
||||
cargo test
|
||||
|
@ -20,20 +20,11 @@ pub struct HubModelInfo {
|
||||
pub pipeline_tag: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||
pub struct HealthResponse {}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Health {
|
||||
pub client: ShardedClient,
|
||||
}
|
||||
|
||||
impl Health {
|
||||
pub fn new(client: ShardedClient) -> Self {
|
||||
Self { client }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||
pub struct Info {
|
||||
/// Model info
|
||||
|
@ -90,7 +90,7 @@ async fn get_model_info(info: Extension<Info>) -> Json<Info> {
|
||||
responses(
|
||||
(status = 200, description = "Everything is working fine"),
|
||||
(status = 500, description = "Text generation inference is down", body = ErrorResponse,
|
||||
example = json ! ({"error": "unhealthy"})),
|
||||
example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})),
|
||||
)
|
||||
)]
|
||||
#[instrument]
|
||||
@ -98,11 +98,6 @@ async fn get_model_info(info: Extension<Info>) -> Json<Info> {
|
||||
async fn health(
|
||||
mut health: Extension<Health>,
|
||||
) -> Result<Json<()>, (StatusCode, Json<ErrorResponse>)> {
|
||||
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
|
||||
// be a bit too slow for a health check.
|
||||
// What we should do instead is check if the gRPC channels are still healthy.
|
||||
|
||||
// Send a small inference request
|
||||
health.client.health().await.map_err(|_| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
@ -589,7 +584,9 @@ pub async fn run(
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
);
|
||||
let health_ext = Health::new(client.clone());
|
||||
let health_ext = Health {
|
||||
client: client.clone(),
|
||||
};
|
||||
let infer = Infer::new(
|
||||
client,
|
||||
validation,
|
||||
|
@ -380,50 +380,75 @@ pub enum ValidationError {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests{
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::io::Write;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validation_max_new_tokens(){
|
||||
async fn test_validation_max_new_tokens() {
|
||||
let tokenizer = None;
|
||||
let max_best_of = 2;
|
||||
let max_stop_sequence = 3;
|
||||
let max_input_length = 4;
|
||||
let max_total_tokens = 5;
|
||||
let workers = 1;
|
||||
let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens);
|
||||
let validation = Validation::new(
|
||||
workers,
|
||||
tokenizer,
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
);
|
||||
|
||||
let max_new_tokens = 10;
|
||||
match validation.validate_input("Hello".to_string(), None, max_new_tokens).await{
|
||||
match validation
|
||||
.validate_input("Hello".to_string(), None, max_new_tokens)
|
||||
.await
|
||||
{
|
||||
Err(ValidationError::MaxNewTokens(1, 10)) => (),
|
||||
_ => panic!("Unexpected not max new tokens")
|
||||
_ => panic!("Unexpected not max new tokens"),
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_tokenizer() -> Tokenizer{
|
||||
if !std::path::Path::new("tokenizer.json").exists(){
|
||||
let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json").await.unwrap().bytes().await.unwrap();
|
||||
let mut file = std::fs::File::create("tokenizer.json").unwrap();
|
||||
async fn get_tokenizer() -> Tokenizer {
|
||||
if !std::path::Path::new("tokenizer.json").exists() {
|
||||
let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json")
|
||||
.await
|
||||
.unwrap()
|
||||
.bytes()
|
||||
.await
|
||||
.unwrap();
|
||||
let mut file = std::fs::File::create("tokenizer.json").unwrap();
|
||||
file.write_all(&content).unwrap();
|
||||
}
|
||||
Tokenizer::from_file("tokenizer.json").unwrap()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validation_input_length(){
|
||||
async fn test_validation_input_length() {
|
||||
let tokenizer = Some(get_tokenizer().await);
|
||||
let max_best_of = 2;
|
||||
let max_stop_sequence = 3;
|
||||
let max_input_length = 4;
|
||||
let max_total_tokens = 5;
|
||||
let workers = 1;
|
||||
let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens);
|
||||
let validation = Validation::new(
|
||||
workers,
|
||||
tokenizer,
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
);
|
||||
|
||||
let max_new_tokens = 10;
|
||||
match validation.validate_input("Hello".to_string(), None, max_new_tokens).await{
|
||||
match validation
|
||||
.validate_input("Hello".to_string(), None, max_new_tokens)
|
||||
.await
|
||||
{
|
||||
Err(ValidationError::MaxTotalTokens(5, 1, 10)) => (),
|
||||
_ => panic!("Unexpected not max new tokens")
|
||||
_ => panic!("Unexpected not max new tokens"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -31,7 +31,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
|
||||
async def Health(self, request, context):
|
||||
if self.model.device.type == "cuda":
|
||||
torch.zeros((2, 2)).to(device=f"cuda:{os.environ['RANK']}")
|
||||
torch.zeros((2, 2)).cuda()
|
||||
return generate_pb2.HealthResponse()
|
||||
|
||||
async def ServiceDiscovery(self, request, context):
|
||||
|
Loading…
Reference in New Issue
Block a user