mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
feat(router): ask hf.co for pipeline to make informed desicion on compat_return_full_text
This commit is contained in:
parent
21340f24ba
commit
9579bf165a
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -2268,6 +2268,7 @@ dependencies = [
|
|||||||
"opentelemetry-otlp",
|
"opentelemetry-otlp",
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
"rand",
|
"rand",
|
||||||
|
"reqwest",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"text-generation-client",
|
"text-generation-client",
|
||||||
|
@ -26,6 +26,7 @@ opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
|
|||||||
opentelemetry-otlp = "0.11.0"
|
opentelemetry-otlp = "0.11.0"
|
||||||
parking_lot = "0.12.1"
|
parking_lot = "0.12.1"
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
|
reqwest = { version = "0.11.14", features = [] }
|
||||||
serde = "1.0.152"
|
serde = "1.0.152"
|
||||||
serde_json = "1.0.93"
|
serde_json = "1.0.93"
|
||||||
thiserror = "1.0.38"
|
thiserror = "1.0.38"
|
||||||
|
@ -87,7 +87,7 @@ fn main() -> Result<(), std::io::Error> {
|
|||||||
// This will only be used to validate payloads
|
// This will only be used to validate payloads
|
||||||
//
|
//
|
||||||
// We need to download it outside of the Tokio runtime
|
// We need to download it outside of the Tokio runtime
|
||||||
let tokenizer = Tokenizer::from_pretrained(tokenizer_name, None).unwrap();
|
let tokenizer = Tokenizer::from_pretrained(tokenizer_name.clone(), None).unwrap();
|
||||||
|
|
||||||
// Launch Tokio runtime
|
// Launch Tokio runtime
|
||||||
tokio::runtime::Builder::new_multi_thread()
|
tokio::runtime::Builder::new_multi_thread()
|
||||||
@ -97,6 +97,36 @@ fn main() -> Result<(), std::io::Error> {
|
|||||||
.block_on(async {
|
.block_on(async {
|
||||||
init_logging(otlp_endpoint, json_output);
|
init_logging(otlp_endpoint, json_output);
|
||||||
|
|
||||||
|
// Get pipeline tag
|
||||||
|
let model_info = reqwest::get(format!(
|
||||||
|
"https://api-inference.huggingface.co/models/{tokenizer_name}"
|
||||||
|
))
|
||||||
|
.await
|
||||||
|
.expect("Could not connect to hf.co")
|
||||||
|
.text()
|
||||||
|
.await
|
||||||
|
.expect("error when retrieving model info from hf.co");
|
||||||
|
let model_info: serde_json::Value =
|
||||||
|
serde_json::from_str(&model_info).expect("unable to parse model info");
|
||||||
|
|
||||||
|
// if pipeline-tag == text-generation we return prompt + generated_text from the / route
|
||||||
|
let compat_return_full_text = match model_info["pipeline_tag"].as_str() {
|
||||||
|
None => {
|
||||||
|
tracing::warn!("no pipeline tag found for model {tokenizer_name}");
|
||||||
|
tracing::warn!("returning only generated_text from the compat route");
|
||||||
|
false
|
||||||
|
}
|
||||||
|
Some(pipeline_tag) => {
|
||||||
|
if pipeline_tag == "text-generation" {
|
||||||
|
tracing::info!("returning prompt + generated_text from the compat route");
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
tracing::info!("returning only generated_text from the compat route");
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Instantiate sharded client from the master unix socket
|
// Instantiate sharded client from the master unix socket
|
||||||
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
||||||
.await
|
.await
|
||||||
@ -113,6 +143,7 @@ fn main() -> Result<(), std::io::Error> {
|
|||||||
|
|
||||||
// Run server
|
// Run server
|
||||||
server::run(
|
server::run(
|
||||||
|
compat_return_full_text,
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
max_stop_sequences,
|
max_stop_sequences,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
|
@ -29,6 +29,7 @@ use utoipa_swagger_ui::SwaggerUi;
|
|||||||
/// Compatibility route with api-inference and AzureML
|
/// Compatibility route with api-inference and AzureML
|
||||||
#[instrument(skip(infer))]
|
#[instrument(skip(infer))]
|
||||||
async fn compat_generate(
|
async fn compat_generate(
|
||||||
|
return_full_text: Extension<bool>,
|
||||||
infer: Extension<Infer>,
|
infer: Extension<Infer>,
|
||||||
req: Json<CompatGenerateRequest>,
|
req: Json<CompatGenerateRequest>,
|
||||||
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
|
||||||
@ -39,9 +40,20 @@ async fn compat_generate(
|
|||||||
.await
|
.await
|
||||||
.into_response())
|
.into_response())
|
||||||
} else {
|
} else {
|
||||||
|
let mut add_prompt = None;
|
||||||
|
if return_full_text.0 {
|
||||||
|
add_prompt = Some(req.inputs.clone());
|
||||||
|
}
|
||||||
|
|
||||||
let (headers, generation) = generate(infer, Json(req.into())).await?;
|
let (headers, generation) = generate(infer, Json(req.into())).await?;
|
||||||
|
|
||||||
|
let mut generation = generation.0;
|
||||||
|
if let Some(prompt) = add_prompt {
|
||||||
|
generation.generated_text = prompt + &generation.generated_text;
|
||||||
|
};
|
||||||
|
|
||||||
// wrap generation inside a Vec to match api-inference
|
// wrap generation inside a Vec to match api-inference
|
||||||
Ok((headers, Json(vec![generation.0])).into_response())
|
Ok((headers, Json(vec![generation])).into_response())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -345,6 +357,7 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
|
|||||||
/// Serving method
|
/// Serving method
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub async fn run(
|
pub async fn run(
|
||||||
|
compat_return_full_text: bool,
|
||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
max_stop_sequences: usize,
|
max_stop_sequences: usize,
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
@ -429,8 +442,9 @@ pub async fn run(
|
|||||||
.route("/generate_stream", post(generate_stream))
|
.route("/generate_stream", post(generate_stream))
|
||||||
.route("/", get(health))
|
.route("/", get(health))
|
||||||
.route("/health", get(health))
|
.route("/health", get(health))
|
||||||
.layer(Extension(infer))
|
|
||||||
.route("/metrics", get(metrics))
|
.route("/metrics", get(metrics))
|
||||||
|
.layer(Extension(compat_return_full_text))
|
||||||
|
.layer(Extension(infer))
|
||||||
.layer(Extension(prom_handle))
|
.layer(Extension(prom_handle))
|
||||||
.layer(opentelemetry_tracing_layer())
|
.layer(opentelemetry_tracing_layer())
|
||||||
.layer(cors_layer);
|
.layer(cors_layer);
|
||||||
|
Loading…
Reference in New Issue
Block a user