mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
feat(router): ask hf.co for pipeline to make informed desicion on compat_return_full_text
This commit is contained in:
parent
21340f24ba
commit
9579bf165a
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -2268,6 +2268,7 @@ dependencies = [
|
||||
"opentelemetry-otlp",
|
||||
"parking_lot",
|
||||
"rand",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"text-generation-client",
|
||||
|
@ -26,6 +26,7 @@ opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
|
||||
opentelemetry-otlp = "0.11.0"
|
||||
parking_lot = "0.12.1"
|
||||
rand = "0.8.5"
|
||||
reqwest = { version = "0.11.14", features = [] }
|
||||
serde = "1.0.152"
|
||||
serde_json = "1.0.93"
|
||||
thiserror = "1.0.38"
|
||||
|
@ -87,7 +87,7 @@ fn main() -> Result<(), std::io::Error> {
|
||||
// This will only be used to validate payloads
|
||||
//
|
||||
// We need to download it outside of the Tokio runtime
|
||||
let tokenizer = Tokenizer::from_pretrained(tokenizer_name, None).unwrap();
|
||||
let tokenizer = Tokenizer::from_pretrained(tokenizer_name.clone(), None).unwrap();
|
||||
|
||||
// Launch Tokio runtime
|
||||
tokio::runtime::Builder::new_multi_thread()
|
||||
@ -97,6 +97,36 @@ fn main() -> Result<(), std::io::Error> {
|
||||
.block_on(async {
|
||||
init_logging(otlp_endpoint, json_output);
|
||||
|
||||
// Get pipeline tag
|
||||
let model_info = reqwest::get(format!(
|
||||
"https://api-inference.huggingface.co/models/{tokenizer_name}"
|
||||
))
|
||||
.await
|
||||
.expect("Could not connect to hf.co")
|
||||
.text()
|
||||
.await
|
||||
.expect("error when retrieving model info from hf.co");
|
||||
let model_info: serde_json::Value =
|
||||
serde_json::from_str(&model_info).expect("unable to parse model info");
|
||||
|
||||
// if pipeline-tag == text-generation we return prompt + generated_text from the / route
|
||||
let compat_return_full_text = match model_info["pipeline_tag"].as_str() {
|
||||
None => {
|
||||
tracing::warn!("no pipeline tag found for model {tokenizer_name}");
|
||||
tracing::warn!("returning only generated_text from the compat route");
|
||||
false
|
||||
}
|
||||
Some(pipeline_tag) => {
|
||||
if pipeline_tag == "text-generation" {
|
||||
tracing::info!("returning prompt + generated_text from the compat route");
|
||||
true
|
||||
} else {
|
||||
tracing::info!("returning only generated_text from the compat route");
|
||||
false
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Instantiate sharded client from the master unix socket
|
||||
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
||||
.await
|
||||
@ -113,6 +143,7 @@ fn main() -> Result<(), std::io::Error> {
|
||||
|
||||
// Run server
|
||||
server::run(
|
||||
compat_return_full_text,
|
||||
max_concurrent_requests,
|
||||
max_stop_sequences,
|
||||
max_input_length,
|
||||
|
@ -29,6 +29,7 @@ use utoipa_swagger_ui::SwaggerUi;
|
||||
/// Compatibility route with api-inference and AzureML
|
||||
#[instrument(skip(infer))]
|
||||
async fn compat_generate(
|
||||
return_full_text: Extension<bool>,
|
||||
infer: Extension<Infer>,
|
||||
req: Json<CompatGenerateRequest>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
|
||||
@ -39,9 +40,20 @@ async fn compat_generate(
|
||||
.await
|
||||
.into_response())
|
||||
} else {
|
||||
let mut add_prompt = None;
|
||||
if return_full_text.0 {
|
||||
add_prompt = Some(req.inputs.clone());
|
||||
}
|
||||
|
||||
let (headers, generation) = generate(infer, Json(req.into())).await?;
|
||||
|
||||
let mut generation = generation.0;
|
||||
if let Some(prompt) = add_prompt {
|
||||
generation.generated_text = prompt + &generation.generated_text;
|
||||
};
|
||||
|
||||
// wrap generation inside a Vec to match api-inference
|
||||
Ok((headers, Json(vec![generation.0])).into_response())
|
||||
Ok((headers, Json(vec![generation])).into_response())
|
||||
}
|
||||
}
|
||||
|
||||
@ -345,6 +357,7 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
|
||||
/// Serving method
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn run(
|
||||
compat_return_full_text: bool,
|
||||
max_concurrent_requests: usize,
|
||||
max_stop_sequences: usize,
|
||||
max_input_length: usize,
|
||||
@ -429,8 +442,9 @@ pub async fn run(
|
||||
.route("/generate_stream", post(generate_stream))
|
||||
.route("/", get(health))
|
||||
.route("/health", get(health))
|
||||
.layer(Extension(infer))
|
||||
.route("/metrics", get(metrics))
|
||||
.layer(Extension(compat_return_full_text))
|
||||
.layer(Extension(infer))
|
||||
.layer(Extension(prom_handle))
|
||||
.layer(opentelemetry_tracing_layer())
|
||||
.layer(cors_layer);
|
||||
|
Loading…
Reference in New Issue
Block a user