mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Fmt
This commit is contained in:
parent
e19d6f3589
commit
a7a98c0253
@ -982,14 +982,16 @@ fn spawn_shards(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn compute_type(num_shard: usize) -> Option<String>{
|
fn compute_type(num_shard: usize) -> Option<String> {
|
||||||
let output = Command::new("nvidia-smi").args(["--query-gpu=gpu_name", "--format=csv"]).output().ok()?;
|
let output = Command::new("nvidia-smi")
|
||||||
|
.args(["--query-gpu=gpu_name", "--format=csv"])
|
||||||
|
.output()
|
||||||
|
.ok()?;
|
||||||
let output = String::from_utf8(output.stdout).ok()?;
|
let output = String::from_utf8(output.stdout).ok()?;
|
||||||
let fullname = output.split("\n").nth(1)?;
|
let fullname = output.split("\n").nth(1)?;
|
||||||
let cardname = fullname.replace(" ", "-").to_lowercase();
|
let cardname = fullname.replace(" ", "-").to_lowercase();
|
||||||
let compute_type = format!("{num_shard}-{cardname}");
|
let compute_type = format!("{num_shard}-{cardname}");
|
||||||
Some(compute_type)
|
Some(compute_type)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn spawn_webserver(
|
fn spawn_webserver(
|
||||||
@ -1086,7 +1088,7 @@ fn spawn_webserver(
|
|||||||
// Parse Compute type
|
// Parse Compute type
|
||||||
if let Ok(compute_type) = env::var("COMPUTE_TYPE") {
|
if let Ok(compute_type) = env::var("COMPUTE_TYPE") {
|
||||||
envs.push(("COMPUTE_TYPE".into(), compute_type.into()))
|
envs.push(("COMPUTE_TYPE".into(), compute_type.into()))
|
||||||
}else if let Some(compute_type) = compute_type(num_shard){
|
} else if let Some(compute_type) = compute_type(num_shard) {
|
||||||
envs.push(("COMPUTE_TYPE".into(), compute_type.into()))
|
envs.push(("COMPUTE_TYPE".into(), compute_type.into()))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1283,8 +1285,8 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut webserver =
|
let mut webserver = spawn_webserver(num_shard, args, shutdown.clone(), &shutdown_receiver)
|
||||||
spawn_webserver(num_shard, args, shutdown.clone(), &shutdown_receiver).map_err(|err| {
|
.map_err(|err| {
|
||||||
shutdown_shards(shutdown.clone(), &shutdown_receiver);
|
shutdown_shards(shutdown.clone(), &shutdown_receiver);
|
||||||
err
|
err
|
||||||
})?;
|
})?;
|
||||||
|
@ -67,7 +67,7 @@ async fn compat_generate(
|
|||||||
|
|
||||||
// switch on stream
|
// switch on stream
|
||||||
if req.stream {
|
if req.stream {
|
||||||
Ok(generate_stream(infer,compute_type, Json(req.into()))
|
Ok(generate_stream(infer, compute_type, Json(req.into()))
|
||||||
.await
|
.await
|
||||||
.into_response())
|
.into_response())
|
||||||
} else {
|
} else {
|
||||||
@ -372,7 +372,7 @@ async fn generate_stream_internal(
|
|||||||
let compute_characters = req.inputs.chars().count();
|
let compute_characters = req.inputs.chars().count();
|
||||||
|
|
||||||
let mut headers = HeaderMap::new();
|
let mut headers = HeaderMap::new();
|
||||||
headers.insert("x-compute-type",compute_type.parse().unwrap());
|
headers.insert("x-compute-type", compute_type.parse().unwrap());
|
||||||
headers.insert(
|
headers.insert(
|
||||||
"x-compute-characters",
|
"x-compute-characters",
|
||||||
compute_characters.to_string().parse().unwrap(),
|
compute_characters.to_string().parse().unwrap(),
|
||||||
@ -649,13 +649,22 @@ async fn chat_completions(
|
|||||||
)
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
let (headers, response_stream) =
|
let (headers, response_stream) = generate_stream_internal(
|
||||||
generate_stream_internal(infer, compute_type, Json(generate_request), on_message_callback).await;
|
infer,
|
||||||
|
compute_type,
|
||||||
|
Json(generate_request),
|
||||||
|
on_message_callback,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||||
Ok((headers, sse).into_response())
|
Ok((headers, sse).into_response())
|
||||||
} else {
|
} else {
|
||||||
let (headers, Json(generation)) =
|
let (headers, Json(generation)) = generate(
|
||||||
generate(Extension(infer), Extension(compute_type), Json(generate_request)).await?;
|
Extension(infer),
|
||||||
|
Extension(compute_type),
|
||||||
|
Json(generate_request),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
let current_time = std::time::SystemTime::now()
|
let current_time = std::time::SystemTime::now()
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
@ -943,7 +952,8 @@ pub async fn run(
|
|||||||
Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise
|
Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise
|
||||||
};
|
};
|
||||||
|
|
||||||
let compute_type = ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string()));
|
let compute_type =
|
||||||
|
ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string()));
|
||||||
|
|
||||||
// Combine routes and layers
|
// Combine routes and layers
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
|
Loading…
Reference in New Issue
Block a user