mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-22 02:02:07 +00:00
Integrate flash attention for starcoder2 tgi through habana and some fixes, enabling (#198)
This commit is contained in:
parent
0ca54b55f8
commit
9b71343328
@ -8,16 +8,19 @@ pub(crate) struct Env {
|
|||||||
docker_label: &'static str,
|
docker_label: &'static str,
|
||||||
nvidia_env: String,
|
nvidia_env: String,
|
||||||
xpu_env: String,
|
xpu_env: String,
|
||||||
|
hpu_env: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Env {
|
impl Env {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
let nvidia_env = nvidia_smi();
|
let nvidia_env = nvidia_smi();
|
||||||
let xpu_env = xpu_smi();
|
let xpu_env = xpu_smi();
|
||||||
|
let hpu_env = hl_smi();
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
nvidia_env: nvidia_env.unwrap_or("N/A".to_string()),
|
nvidia_env: nvidia_env.unwrap_or("N/A".to_string()),
|
||||||
xpu_env: xpu_env.unwrap_or("N/A".to_string()),
|
xpu_env: xpu_env.unwrap_or("N/A".to_string()),
|
||||||
|
hpu_env: hpu_env.unwrap_or("N/A".to_string()),
|
||||||
cargo_target: env!("VERGEN_CARGO_TARGET_TRIPLE"),
|
cargo_target: env!("VERGEN_CARGO_TARGET_TRIPLE"),
|
||||||
cargo_version: env!("VERGEN_RUSTC_SEMVER"),
|
cargo_version: env!("VERGEN_RUSTC_SEMVER"),
|
||||||
git_sha: option_env!("VERGEN_GIT_SHA").unwrap_or("N/A"),
|
git_sha: option_env!("VERGEN_GIT_SHA").unwrap_or("N/A"),
|
||||||
@ -35,7 +38,8 @@ impl fmt::Display for Env {
|
|||||||
writeln!(f, "Commit sha: {}", self.git_sha)?;
|
writeln!(f, "Commit sha: {}", self.git_sha)?;
|
||||||
writeln!(f, "Docker label: {}", self.docker_label)?;
|
writeln!(f, "Docker label: {}", self.docker_label)?;
|
||||||
writeln!(f, "nvidia-smi:\n{}", self.nvidia_env)?;
|
writeln!(f, "nvidia-smi:\n{}", self.nvidia_env)?;
|
||||||
write!(f, "xpu-smi:\n{}", self.xpu_env)?;
|
writeln!(f, "xpu-smi:\n{}", self.xpu_env)?;
|
||||||
|
write!(f, "hpu-smi:\n{}", self.hpu_env)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -54,3 +58,10 @@ fn xpu_smi() -> Option<String> {
|
|||||||
let output = xpu_smi.replace('\n', "\n ");
|
let output = xpu_smi.replace('\n', "\n ");
|
||||||
Some(output.trim().to_string())
|
Some(output.trim().to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn hl_smi() -> Option<String> {
|
||||||
|
let output = Command::new("hl-smi").output().ok()?;
|
||||||
|
let hl_smi = String::from_utf8(output.stdout).ok()?;
|
||||||
|
let output = hl_smi.replace('\n', "\n ");
|
||||||
|
Some(output.trim().to_string())
|
||||||
|
}
|
||||||
|
@ -684,9 +684,11 @@ class CausalLM(Model):
|
|||||||
"return_dict": True,
|
"return_dict": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
if model.config.model_type in ["llama", "mistral"]:
|
if model.config.model_type in ["llama", "mistral", "starcoder2"]:
|
||||||
kwargs["attn_softmax_bf16"] = True
|
|
||||||
kwargs["trim_logits"] = True
|
if model.config.model_type in ["llama", "mistral"]:
|
||||||
|
kwargs["attn_softmax_bf16"] = True
|
||||||
|
kwargs["trim_logits"] = True
|
||||||
|
|
||||||
if os.getenv("USE_FLASH_ATTENTION", "false").lower() == "true":
|
if os.getenv("USE_FLASH_ATTENTION", "false").lower() == "true":
|
||||||
kwargs["use_flash_attention"] = True
|
kwargs["use_flash_attention"] = True
|
||||||
|
@ -10,7 +10,7 @@ from transformers import PreTrainedTokenizerBase, AutoTokenizer, AutoConfig
|
|||||||
from typing import Optional, Tuple, Type
|
from typing import Optional, Tuple, Type
|
||||||
|
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.models import FlashCausalLM
|
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||||
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch, BLOCK_SIZE
|
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch, BLOCK_SIZE
|
||||||
from text_generation_server.models.cache_manager import (
|
from text_generation_server.models.cache_manager import (
|
||||||
get_cache_manager,
|
get_cache_manager,
|
||||||
|
Loading…
Reference in New Issue
Block a user