mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Fmt
This commit is contained in:
parent
e19d6f3589
commit
a7a98c0253
@ -982,14 +982,16 @@ fn spawn_shards(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn compute_type(num_shard: usize) -> Option<String>{
|
||||
let output = Command::new("nvidia-smi").args(["--query-gpu=gpu_name", "--format=csv"]).output().ok()?;
|
||||
fn compute_type(num_shard: usize) -> Option<String> {
|
||||
let output = Command::new("nvidia-smi")
|
||||
.args(["--query-gpu=gpu_name", "--format=csv"])
|
||||
.output()
|
||||
.ok()?;
|
||||
let output = String::from_utf8(output.stdout).ok()?;
|
||||
let fullname = output.split("\n").nth(1)?;
|
||||
let cardname = fullname.replace(" ", "-").to_lowercase();
|
||||
let compute_type = format!("{num_shard}-{cardname}");
|
||||
Some(compute_type)
|
||||
|
||||
}
|
||||
|
||||
fn spawn_webserver(
|
||||
@ -1086,7 +1088,7 @@ fn spawn_webserver(
|
||||
// Parse Compute type
|
||||
if let Ok(compute_type) = env::var("COMPUTE_TYPE") {
|
||||
envs.push(("COMPUTE_TYPE".into(), compute_type.into()))
|
||||
}else if let Some(compute_type) = compute_type(num_shard){
|
||||
} else if let Some(compute_type) = compute_type(num_shard) {
|
||||
envs.push(("COMPUTE_TYPE".into(), compute_type.into()))
|
||||
}
|
||||
|
||||
@ -1283,8 +1285,8 @@ fn main() -> Result<(), LauncherError> {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut webserver =
|
||||
spawn_webserver(num_shard, args, shutdown.clone(), &shutdown_receiver).map_err(|err| {
|
||||
let mut webserver = spawn_webserver(num_shard, args, shutdown.clone(), &shutdown_receiver)
|
||||
.map_err(|err| {
|
||||
shutdown_shards(shutdown.clone(), &shutdown_receiver);
|
||||
err
|
||||
})?;
|
||||
|
@ -67,7 +67,7 @@ async fn compat_generate(
|
||||
|
||||
// switch on stream
|
||||
if req.stream {
|
||||
Ok(generate_stream(infer,compute_type, Json(req.into()))
|
||||
Ok(generate_stream(infer, compute_type, Json(req.into()))
|
||||
.await
|
||||
.into_response())
|
||||
} else {
|
||||
@ -372,7 +372,7 @@ async fn generate_stream_internal(
|
||||
let compute_characters = req.inputs.chars().count();
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("x-compute-type",compute_type.parse().unwrap());
|
||||
headers.insert("x-compute-type", compute_type.parse().unwrap());
|
||||
headers.insert(
|
||||
"x-compute-characters",
|
||||
compute_characters.to_string().parse().unwrap(),
|
||||
@ -649,13 +649,22 @@ async fn chat_completions(
|
||||
)
|
||||
};
|
||||
|
||||
let (headers, response_stream) =
|
||||
generate_stream_internal(infer, compute_type, Json(generate_request), on_message_callback).await;
|
||||
let (headers, response_stream) = generate_stream_internal(
|
||||
infer,
|
||||
compute_type,
|
||||
Json(generate_request),
|
||||
on_message_callback,
|
||||
)
|
||||
.await;
|
||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||
Ok((headers, sse).into_response())
|
||||
} else {
|
||||
let (headers, Json(generation)) =
|
||||
generate(Extension(infer), Extension(compute_type), Json(generate_request)).await?;
|
||||
let (headers, Json(generation)) = generate(
|
||||
Extension(infer),
|
||||
Extension(compute_type),
|
||||
Json(generate_request),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
@ -943,7 +952,8 @@ pub async fn run(
|
||||
Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise
|
||||
};
|
||||
|
||||
let compute_type = ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string()));
|
||||
let compute_type =
|
||||
ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string()));
|
||||
|
||||
// Combine routes and layers
|
||||
let app = Router::new()
|
||||
|
Loading…
Reference in New Issue
Block a user