mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
Add support for rope_scaling and remove is_optimized_for_gaudi (#112)
Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
This commit is contained in:
parent
bf5263b88b
commit
7342baa2eb
@ -17,14 +17,12 @@ class BloomCausalLMBatch(CausalLMBatch):
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
is_optimized_for_gaudi: bool = False,
|
||||
) -> "CausalLMBatch":
|
||||
batch = super().from_pb(
|
||||
pb=pb,
|
||||
tokenizer=tokenizer,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
is_optimized_for_gaudi=is_optimized_for_gaudi,
|
||||
)
|
||||
batch.keys_head_dim_last = False
|
||||
return batch
|
||||
|
@ -1,28 +1,33 @@
|
||||
import bisect
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
import tempfile
|
||||
import itertools
|
||||
import bisect
|
||||
import math
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
from dataclasses import dataclass
|
||||
from loguru import logger
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, AutoConfig
|
||||
from typing import Optional, Tuple, List, Type, Dict
|
||||
|
||||
import text_generation_server.habana_quantization_env as hq_env
|
||||
import habana_frameworks.torch as htorch
|
||||
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
||||
from contextlib import nullcontext
|
||||
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
|
||||
from optimum.habana.utils import HabanaProfile
|
||||
|
||||
from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES
|
||||
from optimum.habana.checkpoint_utils import (
|
||||
get_repo_root,
|
||||
model_on_meta,
|
||||
write_checkpoints_json,
|
||||
)
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForCausalLM,
|
||||
PreTrainedTokenizerBase,
|
||||
AutoConfig,
|
||||
)
|
||||
|
||||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
from text_generation_server.models import Model
|
||||
@ -34,11 +39,13 @@ from text_generation_server.models.types import (
|
||||
TopTokens,
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import HeterogeneousNextTokenChooser, StoppingCriteria, Sampling, make_tokenizer_optional, is_tokenizer_transparent
|
||||
from text_generation_server.utils import (
|
||||
HeterogeneousNextTokenChooser,
|
||||
StoppingCriteria,
|
||||
make_tokenizer_optional,
|
||||
is_tokenizer_transparent,
|
||||
)
|
||||
from text_generation_server.utils.debug import dbg_trace
|
||||
from loguru import logger
|
||||
from functools import wraps
|
||||
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -384,7 +391,7 @@ class CausalLMBatch(Batch):
|
||||
parameters = [r.data.parameters for r in flat_requests]
|
||||
if len(flat_requests) < new_bs:
|
||||
for i in range(new_bs-len(flat_requests)) :
|
||||
#append the dummy parameters for dummy request
|
||||
# append the dummy parameters for dummy request
|
||||
parameters.append(parameters[0])
|
||||
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||
@ -423,7 +430,6 @@ class CausalLMBatch(Batch):
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
is_optimized_for_gaudi: bool = False,
|
||||
) -> "CausalLMBatch":
|
||||
dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}')
|
||||
requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)]
|
||||
@ -474,19 +480,16 @@ class CausalLMBatch(Batch):
|
||||
input_ids = tokenized_inputs["input_ids"]
|
||||
attention_mask = tokenized_inputs["attention_mask"]
|
||||
|
||||
if is_optimized_for_gaudi:
|
||||
# Allocate space for first token
|
||||
input_ids = torch.nn.functional.pad(
|
||||
input_ids, (left_padding, 1), value=tokenizer.pad_token_id
|
||||
)
|
||||
attention_mask = torch.nn.functional.pad(
|
||||
attention_mask, (left_padding, 1), value=0
|
||||
)
|
||||
all_input_ids = torch.nn.functional.pad(
|
||||
input_ids, (0, max_new_tokens), value=tokenizer.pad_token_id
|
||||
).T.split(1, dim=1)
|
||||
else:
|
||||
all_input_ids = input_ids.clone().T.split(1, dim=1)
|
||||
# Allocate space for first token
|
||||
input_ids = torch.nn.functional.pad(
|
||||
input_ids, (left_padding, 1), value=tokenizer.pad_token_id
|
||||
)
|
||||
attention_mask = torch.nn.functional.pad(
|
||||
attention_mask, (left_padding, 1), value=0
|
||||
)
|
||||
all_input_ids = torch.nn.functional.pad(
|
||||
input_ids, (0, max_new_tokens), value=tokenizer.pad_token_id
|
||||
).T.split(1, dim=1)
|
||||
|
||||
# New input length after left padding
|
||||
input_len = bucket_size
|
||||
@ -562,16 +565,9 @@ class CausalLM(Model):
|
||||
revision: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = torch.device("hpu")
|
||||
if hq_env.is_quantization_enabled:
|
||||
htorch.core.hpu_set_env()
|
||||
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
|
||||
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
|
||||
|
||||
adapt_transformers_to_gaudi()
|
||||
|
||||
# Create tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
@ -580,79 +576,42 @@ class CausalLM(Model):
|
||||
)
|
||||
make_tokenizer_optional(tokenizer)
|
||||
|
||||
model_kwargs = {
|
||||
"revision": revision,
|
||||
}
|
||||
|
||||
# Create model
|
||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||
rank = int(os.getenv("RANK", "0"))
|
||||
self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true"
|
||||
self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
device = torch.device("hpu")
|
||||
|
||||
if hq_env.is_quantization_enabled:
|
||||
htorch.core.hpu_set_env()
|
||||
|
||||
if world_size > 1:
|
||||
import habana_frameworks.torch.hpu as torch_hpu
|
||||
|
||||
# Get world size, rank and local rank
|
||||
from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu
|
||||
|
||||
world_size, rank, local_rank = initialize_distributed_hpu()
|
||||
import deepspeed
|
||||
|
||||
# Initialize process(es) for DeepSpeed
|
||||
deepspeed.init_distributed(dist_backend="hccl")
|
||||
logger.info(
|
||||
"DeepSpeed is enabled. world_size {} rank {} local_rank {}".format(world_size, rank, local_rank)
|
||||
model = self.get_deepspeed_model(
|
||||
model_id, dtype, revision
|
||||
)
|
||||
config = AutoConfig.from_pretrained(model_id, **model_kwargs)
|
||||
load_to_meta = model_on_meta(config)
|
||||
|
||||
if load_to_meta:
|
||||
# Construct model with fake meta tensors, later will be replaced on devices during ds-inference ckpt load
|
||||
with deepspeed.OnDevice(dtype=dtype, device="meta"):
|
||||
model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype)
|
||||
else:
|
||||
get_repo_root(model_id, local_rank=os.getenv("LOCAL_RANK"))
|
||||
# TODO: revisit placement on CPU when auto-injection is possible
|
||||
with deepspeed.OnDevice(dtype=dtype, device="cpu"):
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, **model_kwargs)
|
||||
model = model.eval()
|
||||
|
||||
# Initialize the model
|
||||
ds_inference_kwargs = {"dtype": dtype}
|
||||
ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size}
|
||||
ds_inference_kwargs["enable_cuda_graph"] = False
|
||||
|
||||
if load_to_meta:
|
||||
# model loaded to meta is managed differently
|
||||
checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w")
|
||||
write_checkpoints_json(model_id, local_rank, checkpoints_json)
|
||||
ds_inference_kwargs["checkpoint"] = checkpoints_json.name
|
||||
model = deepspeed.init_inference(model, **ds_inference_kwargs)
|
||||
model = model.module
|
||||
model = self.prepare_model_for_quantization(model)
|
||||
model = remove_kv_cache_from_output(model)
|
||||
if self.enable_hpu_graph:
|
||||
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
||||
|
||||
else:
|
||||
get_repo_root(model_id)
|
||||
rope_scaling = self.get_rope_scaling()
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
rope_scaling=rope_scaling
|
||||
)
|
||||
model = self.prepare_model_for_quantization(model)
|
||||
model = model.eval().to(device)
|
||||
# wrap in hpu_graph only if self.enable_hpu_graph is set
|
||||
model = remove_kv_cache_from_output(model)
|
||||
if self.enable_hpu_graph:
|
||||
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
||||
|
||||
self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true"
|
||||
self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
|
||||
model = remove_kv_cache_from_output(model)
|
||||
if self.enable_hpu_graph:
|
||||
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
||||
|
||||
model = self.setup_quantization(model)
|
||||
|
||||
if model.config.model_type in MODELS_OPTIMIZED_WITH_STATIC_SHAPES:
|
||||
self.is_optimized_for_gaudi = True
|
||||
else:
|
||||
self.is_optimized_for_gaudi = False
|
||||
if model.config.model_type not in MODELS_OPTIMIZED_WITH_STATIC_SHAPES:
|
||||
raise ValueError(f"Model type {model.config.model_type} is not supported!")
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
if model.config.pad_token_id is not None:
|
||||
@ -682,24 +641,85 @@ class CausalLM(Model):
|
||||
rank=rank,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
prof_ranks = [int(val) for val in os.getenv("PROF_RANKS", "0").split(',')]
|
||||
self.profiling_warmup_steps = int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in prof_ranks else 0
|
||||
self.profiling_steps = int(os.getenv("PROF_STEP", "0")) if rank in prof_ranks else 0
|
||||
self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0"))
|
||||
|
||||
# Create profiler
|
||||
ranks_to_profile = [int(val) for val in os.getenv("PROF_RANKS", "0").split(',')]
|
||||
record_shapes = os.getenv("PROF_RECORD_SHAPES", "false").lower() == "true"
|
||||
output_dir = os.getenv("PROF_PATH", "/tmp/hpu_profile")
|
||||
self.profiling_warmup_steps = int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in ranks_to_profile else 0
|
||||
self.profiling_steps = int(os.getenv("PROF_STEP", "0")) if rank in ranks_to_profile else 0
|
||||
self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0"))
|
||||
if self.profiling_steps > 0:
|
||||
self.hb_profiler = HabanaProfile(
|
||||
wait=self.profiling_wait_steps,
|
||||
warmup=self.profiling_warmup_steps,
|
||||
active=self.profiling_steps,
|
||||
output_dir=output_dir, record_shapes=record_shapes
|
||||
output_dir=output_dir,
|
||||
record_shapes=record_shapes
|
||||
)
|
||||
self.hb_profiler.start()
|
||||
else:
|
||||
self.hb_profiler = None
|
||||
self.step = 0
|
||||
|
||||
def get_deepspeed_model(
|
||||
self,
|
||||
model_id: str,
|
||||
dtype: torch.dtype,
|
||||
revision: Optional[str] = None
|
||||
) -> torch.nn.Module:
|
||||
import deepspeed
|
||||
from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu
|
||||
|
||||
world_size, rank, local_rank = initialize_distributed_hpu()
|
||||
model_kwargs = {
|
||||
"revision": revision,
|
||||
'rope_scaling': self.get_rope_scaling()
|
||||
}
|
||||
|
||||
# Initialize process(es) for DeepSpeed
|
||||
deepspeed.init_distributed(dist_backend="hccl")
|
||||
logger.info(
|
||||
"DeepSpeed is enabled. world_size {} rank {} local_rank {}".format(world_size, rank, local_rank)
|
||||
)
|
||||
config = AutoConfig.from_pretrained(model_id, **model_kwargs)
|
||||
load_to_meta = model_on_meta(config)
|
||||
|
||||
if load_to_meta:
|
||||
# Construct model with fake meta tensors, later will be replaced on devices during ds-inference ckpt load
|
||||
with deepspeed.OnDevice(dtype=dtype, device="meta"):
|
||||
model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype)
|
||||
else:
|
||||
get_repo_root(model_id, local_rank=os.getenv("LOCAL_RANK"))
|
||||
# TODO: revisit placement on CPU when auto-injection is possible
|
||||
with deepspeed.OnDevice(dtype=dtype, device="cpu"):
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, **model_kwargs)
|
||||
model = model.eval()
|
||||
|
||||
# Initialize the model
|
||||
ds_inference_kwargs = {"dtype": dtype}
|
||||
ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size}
|
||||
ds_inference_kwargs["enable_cuda_graph"] = False
|
||||
|
||||
if load_to_meta:
|
||||
# model loaded to meta is managed differently
|
||||
checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w")
|
||||
write_checkpoints_json(model_id, local_rank, checkpoints_json)
|
||||
ds_inference_kwargs["checkpoint"] = checkpoints_json.name
|
||||
model = deepspeed.init_inference(model, **ds_inference_kwargs)
|
||||
|
||||
return model.module
|
||||
|
||||
def get_rope_scaling(self) -> Optional[Dict]:
|
||||
rope_scaling = os.getenv("ROPE_SCALING", None)
|
||||
if rope_scaling is None:
|
||||
return None
|
||||
|
||||
rope_factor = float(os.getenv("ROPE_FACTOR", 1.0))
|
||||
return {
|
||||
'type': rope_scaling, 'factor': float(rope_factor)
|
||||
}
|
||||
|
||||
def setup_quantization(self, model):
|
||||
if hq_env.is_quantization_enabled:
|
||||
htorch.core.quantization._mark_params_as_const(model)
|
||||
@ -754,7 +774,7 @@ class CausalLM(Model):
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
token_idx: Optional = None,
|
||||
token_idx,
|
||||
past_key_values: Optional = None,
|
||||
bypass_hpu_graph: Optional = None,
|
||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
@ -763,11 +783,9 @@ class CausalLM(Model):
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past_key_values,
|
||||
"token_idx": token_idx
|
||||
}
|
||||
|
||||
if self.is_optimized_for_gaudi:
|
||||
kwargs["token_idx"] = token_idx
|
||||
|
||||
if self.has_position_ids:
|
||||
kwargs["position_ids"] = position_ids
|
||||
|
||||
@ -794,20 +812,17 @@ class CausalLM(Model):
|
||||
logits = batch.logits
|
||||
past = batch.past
|
||||
prefill = batch.past_key_values is None
|
||||
if self.is_optimized_for_gaudi:
|
||||
if prefill:
|
||||
# no right padding for prefill
|
||||
token_idx_scalar = batch.attention_mask.shape[-1] - 1
|
||||
token_idx = torch.tensor(token_idx_scalar).to(self.device)
|
||||
else:
|
||||
token_idx_scalar = batch.attention_mask.shape[-1] - batch.right_padding
|
||||
token_idx = torch.tensor(token_idx_scalar).to(self.device)
|
||||
if prefill:
|
||||
# no right padding for prefill
|
||||
token_idx_scalar = batch.attention_mask.shape[-1] - 1
|
||||
token_idx = torch.tensor(token_idx_scalar).to(self.device)
|
||||
else:
|
||||
token_idx = None
|
||||
token_idx_scalar = batch.attention_mask.shape[-1] - batch.right_padding
|
||||
token_idx = torch.tensor(token_idx_scalar).to(self.device)
|
||||
|
||||
# Select next token
|
||||
input_length = batch.input_length
|
||||
if self.is_optimized_for_gaudi and logits.shape[-2] > 1:
|
||||
if logits.shape[-2] > 1:
|
||||
next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
|
||||
batch.input_ids, logits[:, input_length - 1: input_length, :].squeeze(-2)
|
||||
)
|
||||
@ -840,20 +855,11 @@ class CausalLM(Model):
|
||||
|
||||
htorch.core.mark_step()
|
||||
|
||||
if token_idx is None:
|
||||
batch.input_ids[:, 0] = next_token_ids[:, 0]
|
||||
else:
|
||||
batch.input_ids.index_copy_(1, token_idx, next_token_ids.unsqueeze(1))
|
||||
|
||||
# Slice unused values from prefill, use it to store next token
|
||||
if token_idx is None:
|
||||
batch.input_ids = batch.input_ids[:, :1]
|
||||
# Add new token into input_ids
|
||||
batch.input_ids.index_copy_(1, token_idx, next_token_ids.unsqueeze(1))
|
||||
|
||||
# Update attention_mask as we added a new token to input_ids
|
||||
if self.is_optimized_for_gaudi:
|
||||
batch.attention_mask.index_fill_(1, token_idx, 1)
|
||||
else:
|
||||
batch.attention_mask[:, -batch.padding_right_offset] = 1
|
||||
batch.attention_mask.index_fill_(1, token_idx, 1)
|
||||
|
||||
# Adjust lengths
|
||||
batch.input_length += 1
|
||||
@ -886,37 +892,24 @@ class CausalLM(Model):
|
||||
scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}')
|
||||
assert batch.right_padding > 0, 'No more room for next token!'
|
||||
|
||||
if self.is_optimized_for_gaudi:
|
||||
if prefill:
|
||||
# no right padding for prefill
|
||||
token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device)
|
||||
else:
|
||||
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
|
||||
attention_mask = batch.attention_mask
|
||||
else:
|
||||
token_idx = None
|
||||
# slice the attention mask to the correct shape
|
||||
# TODO fix me!
|
||||
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
||||
|
||||
if not prefill and token_idx is not None:
|
||||
input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1)
|
||||
else:
|
||||
input_ids = batch.input_ids
|
||||
|
||||
# Execute batch
|
||||
if prefill:
|
||||
# no right padding for prefill
|
||||
token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device)
|
||||
batch.logits, batch.past = self.forward(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
batch.input_ids,
|
||||
batch.attention_mask,
|
||||
batch.position_ids,
|
||||
token_idx,
|
||||
batch.past_key_values,
|
||||
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None
|
||||
)
|
||||
else:
|
||||
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
|
||||
input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1)
|
||||
batch.logits = self.forward(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
batch.attention_mask,
|
||||
batch.position_ids,
|
||||
token_idx,
|
||||
batch.past_key_values,
|
||||
@ -955,10 +948,7 @@ class CausalLM(Model):
|
||||
top_token_logprobs = req_data['top_token_logprobs']
|
||||
|
||||
# Append next token to all tokens
|
||||
if self.is_optimized_for_gaudi:
|
||||
all_input_ids[input_length] = next_token_id
|
||||
else:
|
||||
all_input_ids = torch.cat([all_input_ids, next_token_id])
|
||||
all_input_ids[input_length] = next_token_id
|
||||
new_input_length = input_length + 1
|
||||
|
||||
# Generated token
|
||||
|
@ -59,7 +59,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
async def Warmup(self, request, context):
|
||||
def batch_from_pb(batch):
|
||||
return self.model.batch_type.from_pb(
|
||||
batch, self.model.tokenizer, self.model.dtype, self.model.device, self.model.is_optimized_for_gaudi
|
||||
batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||
)
|
||||
|
||||
batches = [batch_from_pb(batch) for batch in request.batches]
|
||||
@ -69,7 +69,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
|
||||
async def Prefill(self, request, context):
|
||||
batch = self.model.batch_type.from_pb(
|
||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device, self.model.is_optimized_for_gaudi
|
||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||
)
|
||||
generations, next_batch = self.model.generate_token([batch])
|
||||
self.cache.set(next_batch)
|
||||
|
Loading…
Reference in New Issue
Block a user