mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
Updating code.
This commit is contained in:
parent
e28b5bf460
commit
9d613d0f9b
6
.github/workflows/tests.yaml
vendored
6
.github/workflows/tests.yaml
vendored
@ -67,6 +67,12 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
pip install pytest
|
pip install pytest
|
||||||
HF_HUB_ENABLE_HF_TRANSFER=1 pytest -sv server/tests
|
HF_HUB_ENABLE_HF_TRANSFER=1 pytest -sv server/tests
|
||||||
|
- name: Run Rust fmt
|
||||||
|
run: |
|
||||||
|
cargo fmt --check
|
||||||
|
- name: Run Rust clippy
|
||||||
|
run: |
|
||||||
|
cargo clippy
|
||||||
- name: Run Rust tests
|
- name: Run Rust tests
|
||||||
run: |
|
run: |
|
||||||
cargo test
|
cargo test
|
||||||
|
@ -20,20 +20,11 @@ pub struct HubModelInfo {
|
|||||||
pub pipeline_tag: Option<String>,
|
pub pipeline_tag: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, ToSchema)]
|
|
||||||
pub struct HealthResponse {}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct Health {
|
pub struct Health {
|
||||||
pub client: ShardedClient,
|
pub client: ShardedClient,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Health {
|
|
||||||
pub fn new(client: ShardedClient) -> Self {
|
|
||||||
Self { client }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, ToSchema)]
|
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||||
pub struct Info {
|
pub struct Info {
|
||||||
/// Model info
|
/// Model info
|
||||||
|
@ -90,7 +90,7 @@ async fn get_model_info(info: Extension<Info>) -> Json<Info> {
|
|||||||
responses(
|
responses(
|
||||||
(status = 200, description = "Everything is working fine"),
|
(status = 200, description = "Everything is working fine"),
|
||||||
(status = 500, description = "Text generation inference is down", body = ErrorResponse,
|
(status = 500, description = "Text generation inference is down", body = ErrorResponse,
|
||||||
example = json ! ({"error": "unhealthy"})),
|
example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})),
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
#[instrument]
|
#[instrument]
|
||||||
@ -98,11 +98,6 @@ async fn get_model_info(info: Extension<Info>) -> Json<Info> {
|
|||||||
async fn health(
|
async fn health(
|
||||||
mut health: Extension<Health>,
|
mut health: Extension<Health>,
|
||||||
) -> Result<Json<()>, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Json<()>, (StatusCode, Json<ErrorResponse>)> {
|
||||||
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
|
|
||||||
// be a bit too slow for a health check.
|
|
||||||
// What we should do instead is check if the gRPC channels are still healthy.
|
|
||||||
|
|
||||||
// Send a small inference request
|
|
||||||
health.client.health().await.map_err(|_| {
|
health.client.health().await.map_err(|_| {
|
||||||
(
|
(
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
@ -589,7 +584,9 @@ pub async fn run(
|
|||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
);
|
);
|
||||||
let health_ext = Health::new(client.clone());
|
let health_ext = Health {
|
||||||
|
client: client.clone(),
|
||||||
|
};
|
||||||
let infer = Infer::new(
|
let infer = Infer::new(
|
||||||
client,
|
client,
|
||||||
validation,
|
validation,
|
||||||
|
@ -392,18 +392,33 @@ mod tests{
|
|||||||
let max_input_length = 4;
|
let max_input_length = 4;
|
||||||
let max_total_tokens = 5;
|
let max_total_tokens = 5;
|
||||||
let workers = 1;
|
let workers = 1;
|
||||||
let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens);
|
let validation = Validation::new(
|
||||||
|
workers,
|
||||||
|
tokenizer,
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequence,
|
||||||
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
|
);
|
||||||
|
|
||||||
let max_new_tokens = 10;
|
let max_new_tokens = 10;
|
||||||
match validation.validate_input("Hello".to_string(), None, max_new_tokens).await{
|
match validation
|
||||||
|
.validate_input("Hello".to_string(), None, max_new_tokens)
|
||||||
|
.await
|
||||||
|
{
|
||||||
Err(ValidationError::MaxNewTokens(1, 10)) => (),
|
Err(ValidationError::MaxNewTokens(1, 10)) => (),
|
||||||
_ => panic!("Unexpected not max new tokens")
|
_ => panic!("Unexpected not max new tokens"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_tokenizer() -> Tokenizer {
|
async fn get_tokenizer() -> Tokenizer {
|
||||||
if !std::path::Path::new("tokenizer.json").exists() {
|
if !std::path::Path::new("tokenizer.json").exists() {
|
||||||
let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json").await.unwrap().bytes().await.unwrap();
|
let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json")
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.bytes()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
let mut file = std::fs::File::create("tokenizer.json").unwrap();
|
let mut file = std::fs::File::create("tokenizer.json").unwrap();
|
||||||
file.write_all(&content).unwrap();
|
file.write_all(&content).unwrap();
|
||||||
}
|
}
|
||||||
@ -418,12 +433,22 @@ mod tests{
|
|||||||
let max_input_length = 4;
|
let max_input_length = 4;
|
||||||
let max_total_tokens = 5;
|
let max_total_tokens = 5;
|
||||||
let workers = 1;
|
let workers = 1;
|
||||||
let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens);
|
let validation = Validation::new(
|
||||||
|
workers,
|
||||||
|
tokenizer,
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequence,
|
||||||
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
|
);
|
||||||
|
|
||||||
let max_new_tokens = 10;
|
let max_new_tokens = 10;
|
||||||
match validation.validate_input("Hello".to_string(), None, max_new_tokens).await{
|
match validation
|
||||||
|
.validate_input("Hello".to_string(), None, max_new_tokens)
|
||||||
|
.await
|
||||||
|
{
|
||||||
Err(ValidationError::MaxTotalTokens(5, 1, 10)) => (),
|
Err(ValidationError::MaxTotalTokens(5, 1, 10)) => (),
|
||||||
_ => panic!("Unexpected not max new tokens")
|
_ => panic!("Unexpected not max new tokens"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -31,7 +31,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
|
|
||||||
async def Health(self, request, context):
|
async def Health(self, request, context):
|
||||||
if self.model.device.type == "cuda":
|
if self.model.device.type == "cuda":
|
||||||
torch.zeros((2, 2)).to(device=f"cuda:{os.environ['RANK']}")
|
torch.zeros((2, 2)).cuda()
|
||||||
return generate_pb2.HealthResponse()
|
return generate_pb2.HealthResponse()
|
||||||
|
|
||||||
async def ServiceDiscovery(self, request, context):
|
async def ServiceDiscovery(self, request, context):
|
||||||
|
Loading…
Reference in New Issue
Block a user