feat(router): ask hf.co for pipeline to make informed desicion on compat_return_full_text

This commit is contained in:
OlivierDehaene 2023-02-27 18:11:49 +01:00
parent 21340f24ba
commit 9579bf165a
4 changed files with 50 additions and 3 deletions

1
Cargo.lock generated
View File

@ -2268,6 +2268,7 @@ dependencies = [
"opentelemetry-otlp", "opentelemetry-otlp",
"parking_lot", "parking_lot",
"rand", "rand",
"reqwest",
"serde", "serde",
"serde_json", "serde_json",
"text-generation-client", "text-generation-client",

View File

@ -26,6 +26,7 @@ opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.11.0" opentelemetry-otlp = "0.11.0"
parking_lot = "0.12.1" parking_lot = "0.12.1"
rand = "0.8.5" rand = "0.8.5"
reqwest = { version = "0.11.14", features = [] }
serde = "1.0.152" serde = "1.0.152"
serde_json = "1.0.93" serde_json = "1.0.93"
thiserror = "1.0.38" thiserror = "1.0.38"

View File

@ -87,7 +87,7 @@ fn main() -> Result<(), std::io::Error> {
// This will only be used to validate payloads // This will only be used to validate payloads
// //
// We need to download it outside of the Tokio runtime // We need to download it outside of the Tokio runtime
let tokenizer = Tokenizer::from_pretrained(tokenizer_name, None).unwrap(); let tokenizer = Tokenizer::from_pretrained(tokenizer_name.clone(), None).unwrap();
// Launch Tokio runtime // Launch Tokio runtime
tokio::runtime::Builder::new_multi_thread() tokio::runtime::Builder::new_multi_thread()
@ -97,6 +97,36 @@ fn main() -> Result<(), std::io::Error> {
.block_on(async { .block_on(async {
init_logging(otlp_endpoint, json_output); init_logging(otlp_endpoint, json_output);
// Get pipeline tag
let model_info = reqwest::get(format!(
"https://api-inference.huggingface.co/models/{tokenizer_name}"
))
.await
.expect("Could not connect to hf.co")
.text()
.await
.expect("error when retrieving model info from hf.co");
let model_info: serde_json::Value =
serde_json::from_str(&model_info).expect("unable to parse model info");
// if pipeline-tag == text-generation we return prompt + generated_text from the / route
let compat_return_full_text = match model_info["pipeline_tag"].as_str() {
None => {
tracing::warn!("no pipeline tag found for model {tokenizer_name}");
tracing::warn!("returning only generated_text from the compat route");
false
}
Some(pipeline_tag) => {
if pipeline_tag == "text-generation" {
tracing::info!("returning prompt + generated_text from the compat route");
true
} else {
tracing::info!("returning only generated_text from the compat route");
false
}
}
};
// Instantiate sharded client from the master unix socket // Instantiate sharded client from the master unix socket
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
.await .await
@ -113,6 +143,7 @@ fn main() -> Result<(), std::io::Error> {
// Run server // Run server
server::run( server::run(
compat_return_full_text,
max_concurrent_requests, max_concurrent_requests,
max_stop_sequences, max_stop_sequences,
max_input_length, max_input_length,

View File

@ -29,6 +29,7 @@ use utoipa_swagger_ui::SwaggerUi;
/// Compatibility route with api-inference and AzureML /// Compatibility route with api-inference and AzureML
#[instrument(skip(infer))] #[instrument(skip(infer))]
async fn compat_generate( async fn compat_generate(
return_full_text: Extension<bool>,
infer: Extension<Infer>, infer: Extension<Infer>,
req: Json<CompatGenerateRequest>, req: Json<CompatGenerateRequest>,
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> { ) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
@ -39,9 +40,20 @@ async fn compat_generate(
.await .await
.into_response()) .into_response())
} else { } else {
let mut add_prompt = None;
if return_full_text.0 {
add_prompt = Some(req.inputs.clone());
}
let (headers, generation) = generate(infer, Json(req.into())).await?; let (headers, generation) = generate(infer, Json(req.into())).await?;
let mut generation = generation.0;
if let Some(prompt) = add_prompt {
generation.generated_text = prompt + &generation.generated_text;
};
// wrap generation inside a Vec to match api-inference // wrap generation inside a Vec to match api-inference
Ok((headers, Json(vec![generation.0])).into_response()) Ok((headers, Json(vec![generation])).into_response())
} }
} }
@ -345,6 +357,7 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
/// Serving method /// Serving method
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub async fn run( pub async fn run(
compat_return_full_text: bool,
max_concurrent_requests: usize, max_concurrent_requests: usize,
max_stop_sequences: usize, max_stop_sequences: usize,
max_input_length: usize, max_input_length: usize,
@ -429,8 +442,9 @@ pub async fn run(
.route("/generate_stream", post(generate_stream)) .route("/generate_stream", post(generate_stream))
.route("/", get(health)) .route("/", get(health))
.route("/health", get(health)) .route("/health", get(health))
.layer(Extension(infer))
.route("/metrics", get(metrics)) .route("/metrics", get(metrics))
.layer(Extension(compat_return_full_text))
.layer(Extension(infer))
.layer(Extension(prom_handle)) .layer(Extension(prom_handle))
.layer(opentelemetry_tracing_layer()) .layer(opentelemetry_tracing_layer())
.layer(cors_layer); .layer(cors_layer);