mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 23:42:06 +00:00
Add support for rope_scaling and remove is_optimized_for_gaudi (#112)
Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
This commit is contained in:
parent
bf5263b88b
commit
7342baa2eb
@ -17,14 +17,12 @@ class BloomCausalLMBatch(CausalLMBatch):
|
|||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
is_optimized_for_gaudi: bool = False,
|
|
||||||
) -> "CausalLMBatch":
|
) -> "CausalLMBatch":
|
||||||
batch = super().from_pb(
|
batch = super().from_pb(
|
||||||
pb=pb,
|
pb=pb,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
is_optimized_for_gaudi=is_optimized_for_gaudi,
|
|
||||||
)
|
)
|
||||||
batch.keys_head_dim_last = False
|
batch.keys_head_dim_last = False
|
||||||
return batch
|
return batch
|
||||||
|
@ -1,28 +1,33 @@
|
|||||||
|
import bisect
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import wraps
|
||||||
|
import itertools
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import itertools
|
from typing import Dict, List, Optional, Tuple, Type
|
||||||
import bisect
|
|
||||||
import math
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from loguru import logger
|
||||||
from dataclasses import dataclass
|
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, AutoConfig
|
|
||||||
from typing import Optional, Tuple, List, Type, Dict
|
|
||||||
|
|
||||||
import text_generation_server.habana_quantization_env as hq_env
|
import text_generation_server.habana_quantization_env as hq_env
|
||||||
import habana_frameworks.torch as htorch
|
import habana_frameworks.torch as htorch
|
||||||
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
||||||
from contextlib import nullcontext
|
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
|
||||||
from optimum.habana.utils import HabanaProfile
|
from optimum.habana.utils import HabanaProfile
|
||||||
|
|
||||||
from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES
|
from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES
|
||||||
from optimum.habana.checkpoint_utils import (
|
from optimum.habana.checkpoint_utils import (
|
||||||
get_repo_root,
|
get_repo_root,
|
||||||
model_on_meta,
|
model_on_meta,
|
||||||
write_checkpoints_json,
|
write_checkpoints_json,
|
||||||
)
|
)
|
||||||
|
from transformers import (
|
||||||
|
AutoTokenizer,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
PreTrainedTokenizerBase,
|
||||||
|
AutoConfig,
|
||||||
|
)
|
||||||
|
|
||||||
from text_generation_server.utils.tokens import batch_top_tokens
|
from text_generation_server.utils.tokens import batch_top_tokens
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
@ -34,11 +39,13 @@ from text_generation_server.models.types import (
|
|||||||
TopTokens,
|
TopTokens,
|
||||||
)
|
)
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.utils import HeterogeneousNextTokenChooser, StoppingCriteria, Sampling, make_tokenizer_optional, is_tokenizer_transparent
|
from text_generation_server.utils import (
|
||||||
|
HeterogeneousNextTokenChooser,
|
||||||
|
StoppingCriteria,
|
||||||
|
make_tokenizer_optional,
|
||||||
|
is_tokenizer_transparent,
|
||||||
|
)
|
||||||
from text_generation_server.utils.debug import dbg_trace
|
from text_generation_server.utils.debug import dbg_trace
|
||||||
from loguru import logger
|
|
||||||
from functools import wraps
|
|
||||||
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -384,7 +391,7 @@ class CausalLMBatch(Batch):
|
|||||||
parameters = [r.data.parameters for r in flat_requests]
|
parameters = [r.data.parameters for r in flat_requests]
|
||||||
if len(flat_requests) < new_bs:
|
if len(flat_requests) < new_bs:
|
||||||
for i in range(new_bs-len(flat_requests)) :
|
for i in range(new_bs-len(flat_requests)) :
|
||||||
#append the dummy parameters for dummy request
|
# append the dummy parameters for dummy request
|
||||||
parameters.append(parameters[0])
|
parameters.append(parameters[0])
|
||||||
|
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
@ -423,7 +430,6 @@ class CausalLMBatch(Batch):
|
|||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
is_optimized_for_gaudi: bool = False,
|
|
||||||
) -> "CausalLMBatch":
|
) -> "CausalLMBatch":
|
||||||
dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}')
|
dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}')
|
||||||
requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)]
|
requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)]
|
||||||
@ -474,7 +480,6 @@ class CausalLMBatch(Batch):
|
|||||||
input_ids = tokenized_inputs["input_ids"]
|
input_ids = tokenized_inputs["input_ids"]
|
||||||
attention_mask = tokenized_inputs["attention_mask"]
|
attention_mask = tokenized_inputs["attention_mask"]
|
||||||
|
|
||||||
if is_optimized_for_gaudi:
|
|
||||||
# Allocate space for first token
|
# Allocate space for first token
|
||||||
input_ids = torch.nn.functional.pad(
|
input_ids = torch.nn.functional.pad(
|
||||||
input_ids, (left_padding, 1), value=tokenizer.pad_token_id
|
input_ids, (left_padding, 1), value=tokenizer.pad_token_id
|
||||||
@ -485,8 +490,6 @@ class CausalLMBatch(Batch):
|
|||||||
all_input_ids = torch.nn.functional.pad(
|
all_input_ids = torch.nn.functional.pad(
|
||||||
input_ids, (0, max_new_tokens), value=tokenizer.pad_token_id
|
input_ids, (0, max_new_tokens), value=tokenizer.pad_token_id
|
||||||
).T.split(1, dim=1)
|
).T.split(1, dim=1)
|
||||||
else:
|
|
||||||
all_input_ids = input_ids.clone().T.split(1, dim=1)
|
|
||||||
|
|
||||||
# New input length after left padding
|
# New input length after left padding
|
||||||
input_len = bucket_size
|
input_len = bucket_size
|
||||||
@ -562,16 +565,9 @@ class CausalLM(Model):
|
|||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
):
|
):
|
||||||
device = torch.device("hpu")
|
|
||||||
if hq_env.is_quantization_enabled:
|
|
||||||
htorch.core.hpu_set_env()
|
|
||||||
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
|
|
||||||
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
|
|
||||||
|
|
||||||
adapt_transformers_to_gaudi()
|
adapt_transformers_to_gaudi()
|
||||||
|
|
||||||
|
# Create tokenizer
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
@ -580,23 +576,106 @@ class CausalLM(Model):
|
|||||||
)
|
)
|
||||||
make_tokenizer_optional(tokenizer)
|
make_tokenizer_optional(tokenizer)
|
||||||
|
|
||||||
model_kwargs = {
|
# Create model
|
||||||
"revision": revision,
|
|
||||||
}
|
|
||||||
|
|
||||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||||
rank = int(os.getenv("RANK", "0"))
|
rank = int(os.getenv("RANK", "0"))
|
||||||
self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true"
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
|
device = torch.device("hpu")
|
||||||
|
|
||||||
|
if hq_env.is_quantization_enabled:
|
||||||
|
htorch.core.hpu_set_env()
|
||||||
|
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
import habana_frameworks.torch.hpu as torch_hpu
|
model = self.get_deepspeed_model(
|
||||||
|
model_id, dtype, revision
|
||||||
|
)
|
||||||
|
model = self.prepare_model_for_quantization(model)
|
||||||
|
else:
|
||||||
|
get_repo_root(model_id)
|
||||||
|
rope_scaling = self.get_rope_scaling()
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
torch_dtype=dtype,
|
||||||
|
rope_scaling=rope_scaling
|
||||||
|
)
|
||||||
|
model = self.prepare_model_for_quantization(model)
|
||||||
|
model = model.eval().to(device)
|
||||||
|
|
||||||
# Get world size, rank and local rank
|
self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true"
|
||||||
|
self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
|
||||||
|
model = remove_kv_cache_from_output(model)
|
||||||
|
if self.enable_hpu_graph:
|
||||||
|
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
||||||
|
|
||||||
|
model = self.setup_quantization(model)
|
||||||
|
|
||||||
|
if model.config.model_type not in MODELS_OPTIMIZED_WITH_STATIC_SHAPES:
|
||||||
|
raise ValueError(f"Model type {model.config.model_type} is not supported!")
|
||||||
|
|
||||||
|
if tokenizer.pad_token_id is None:
|
||||||
|
if model.config.pad_token_id is not None:
|
||||||
|
tokenizer.pad_token_id = model.config.pad_token_id
|
||||||
|
elif model.config.eos_token_id is not None:
|
||||||
|
tokenizer.pad_token_id = model.config.eos_token_id
|
||||||
|
elif tokenizer.eos_token_id is not None:
|
||||||
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
|
else:
|
||||||
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||||
|
|
||||||
|
kwargs = {
|
||||||
|
"use_cache": True,
|
||||||
|
"return_dict": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
if model.config.model_type == "llama":
|
||||||
|
kwargs["attn_softmax_bf16"] = True
|
||||||
|
kwargs["trim_logits"] = True
|
||||||
|
|
||||||
|
super(CausalLM, self).__init__(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
requires_padding=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
rank=rank,
|
||||||
|
kwargs=kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create profiler
|
||||||
|
ranks_to_profile = [int(val) for val in os.getenv("PROF_RANKS", "0").split(',')]
|
||||||
|
record_shapes = os.getenv("PROF_RECORD_SHAPES", "false").lower() == "true"
|
||||||
|
output_dir = os.getenv("PROF_PATH", "/tmp/hpu_profile")
|
||||||
|
self.profiling_warmup_steps = int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in ranks_to_profile else 0
|
||||||
|
self.profiling_steps = int(os.getenv("PROF_STEP", "0")) if rank in ranks_to_profile else 0
|
||||||
|
self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0"))
|
||||||
|
if self.profiling_steps > 0:
|
||||||
|
self.hb_profiler = HabanaProfile(
|
||||||
|
wait=self.profiling_wait_steps,
|
||||||
|
warmup=self.profiling_warmup_steps,
|
||||||
|
active=self.profiling_steps,
|
||||||
|
output_dir=output_dir,
|
||||||
|
record_shapes=record_shapes
|
||||||
|
)
|
||||||
|
self.hb_profiler.start()
|
||||||
|
else:
|
||||||
|
self.hb_profiler = None
|
||||||
|
self.step = 0
|
||||||
|
|
||||||
|
def get_deepspeed_model(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
revision: Optional[str] = None
|
||||||
|
) -> torch.nn.Module:
|
||||||
|
import deepspeed
|
||||||
from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu
|
from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu
|
||||||
|
|
||||||
world_size, rank, local_rank = initialize_distributed_hpu()
|
world_size, rank, local_rank = initialize_distributed_hpu()
|
||||||
import deepspeed
|
model_kwargs = {
|
||||||
|
"revision": revision,
|
||||||
|
'rope_scaling': self.get_rope_scaling()
|
||||||
|
}
|
||||||
|
|
||||||
# Initialize process(es) for DeepSpeed
|
# Initialize process(es) for DeepSpeed
|
||||||
deepspeed.init_distributed(dist_backend="hccl")
|
deepspeed.init_distributed(dist_backend="hccl")
|
||||||
@ -628,78 +707,19 @@ class CausalLM(Model):
|
|||||||
write_checkpoints_json(model_id, local_rank, checkpoints_json)
|
write_checkpoints_json(model_id, local_rank, checkpoints_json)
|
||||||
ds_inference_kwargs["checkpoint"] = checkpoints_json.name
|
ds_inference_kwargs["checkpoint"] = checkpoints_json.name
|
||||||
model = deepspeed.init_inference(model, **ds_inference_kwargs)
|
model = deepspeed.init_inference(model, **ds_inference_kwargs)
|
||||||
model = model.module
|
|
||||||
model = self.prepare_model_for_quantization(model)
|
|
||||||
model = remove_kv_cache_from_output(model)
|
|
||||||
if self.enable_hpu_graph:
|
|
||||||
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
|
||||||
|
|
||||||
else:
|
return model.module
|
||||||
get_repo_root(model_id)
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
torch_dtype=dtype,
|
|
||||||
)
|
|
||||||
model = self.prepare_model_for_quantization(model)
|
|
||||||
model = model.eval().to(device)
|
|
||||||
# wrap in hpu_graph only if self.enable_hpu_graph is set
|
|
||||||
model = remove_kv_cache_from_output(model)
|
|
||||||
if self.enable_hpu_graph:
|
|
||||||
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
|
||||||
model = self.setup_quantization(model)
|
|
||||||
|
|
||||||
if model.config.model_type in MODELS_OPTIMIZED_WITH_STATIC_SHAPES:
|
def get_rope_scaling(self) -> Optional[Dict]:
|
||||||
self.is_optimized_for_gaudi = True
|
rope_scaling = os.getenv("ROPE_SCALING", None)
|
||||||
else:
|
if rope_scaling is None:
|
||||||
self.is_optimized_for_gaudi = False
|
return None
|
||||||
|
|
||||||
if tokenizer.pad_token_id is None:
|
rope_factor = float(os.getenv("ROPE_FACTOR", 1.0))
|
||||||
if model.config.pad_token_id is not None:
|
return {
|
||||||
tokenizer.pad_token_id = model.config.pad_token_id
|
'type': rope_scaling, 'factor': float(rope_factor)
|
||||||
elif model.config.eos_token_id is not None:
|
|
||||||
tokenizer.pad_token_id = model.config.eos_token_id
|
|
||||||
elif tokenizer.eos_token_id is not None:
|
|
||||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
|
||||||
else:
|
|
||||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
|
||||||
|
|
||||||
kwargs = {
|
|
||||||
"use_cache": True,
|
|
||||||
"return_dict": True,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if model.config.model_type == "llama":
|
|
||||||
kwargs["attn_softmax_bf16"] = True
|
|
||||||
kwargs["trim_logits"] = True
|
|
||||||
|
|
||||||
super(CausalLM, self).__init__(
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
requires_padding=True,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
kwargs=kwargs,
|
|
||||||
)
|
|
||||||
prof_ranks = [int(val) for val in os.getenv("PROF_RANKS", "0").split(',')]
|
|
||||||
self.profiling_warmup_steps = int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in prof_ranks else 0
|
|
||||||
self.profiling_steps = int(os.getenv("PROF_STEP", "0")) if rank in prof_ranks else 0
|
|
||||||
self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0"))
|
|
||||||
record_shapes = os.getenv("PROF_RECORD_SHAPES", "false").lower() == "true"
|
|
||||||
output_dir = os.getenv("PROF_PATH", "/tmp/hpu_profile")
|
|
||||||
if self.profiling_steps > 0:
|
|
||||||
self.hb_profiler = HabanaProfile(
|
|
||||||
wait=self.profiling_wait_steps,
|
|
||||||
warmup=self.profiling_warmup_steps,
|
|
||||||
active=self.profiling_steps,
|
|
||||||
output_dir=output_dir, record_shapes=record_shapes
|
|
||||||
)
|
|
||||||
self.hb_profiler.start()
|
|
||||||
else:
|
|
||||||
self.hb_profiler = None
|
|
||||||
self.step = 0
|
|
||||||
|
|
||||||
def setup_quantization(self, model):
|
def setup_quantization(self, model):
|
||||||
if hq_env.is_quantization_enabled:
|
if hq_env.is_quantization_enabled:
|
||||||
htorch.core.quantization._mark_params_as_const(model)
|
htorch.core.quantization._mark_params_as_const(model)
|
||||||
@ -754,7 +774,7 @@ class CausalLM(Model):
|
|||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
token_idx: Optional = None,
|
token_idx,
|
||||||
past_key_values: Optional = None,
|
past_key_values: Optional = None,
|
||||||
bypass_hpu_graph: Optional = None,
|
bypass_hpu_graph: Optional = None,
|
||||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
@ -763,11 +783,9 @@ class CausalLM(Model):
|
|||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"past_key_values": past_key_values,
|
"past_key_values": past_key_values,
|
||||||
|
"token_idx": token_idx
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.is_optimized_for_gaudi:
|
|
||||||
kwargs["token_idx"] = token_idx
|
|
||||||
|
|
||||||
if self.has_position_ids:
|
if self.has_position_ids:
|
||||||
kwargs["position_ids"] = position_ids
|
kwargs["position_ids"] = position_ids
|
||||||
|
|
||||||
@ -794,7 +812,6 @@ class CausalLM(Model):
|
|||||||
logits = batch.logits
|
logits = batch.logits
|
||||||
past = batch.past
|
past = batch.past
|
||||||
prefill = batch.past_key_values is None
|
prefill = batch.past_key_values is None
|
||||||
if self.is_optimized_for_gaudi:
|
|
||||||
if prefill:
|
if prefill:
|
||||||
# no right padding for prefill
|
# no right padding for prefill
|
||||||
token_idx_scalar = batch.attention_mask.shape[-1] - 1
|
token_idx_scalar = batch.attention_mask.shape[-1] - 1
|
||||||
@ -802,12 +819,10 @@ class CausalLM(Model):
|
|||||||
else:
|
else:
|
||||||
token_idx_scalar = batch.attention_mask.shape[-1] - batch.right_padding
|
token_idx_scalar = batch.attention_mask.shape[-1] - batch.right_padding
|
||||||
token_idx = torch.tensor(token_idx_scalar).to(self.device)
|
token_idx = torch.tensor(token_idx_scalar).to(self.device)
|
||||||
else:
|
|
||||||
token_idx = None
|
|
||||||
|
|
||||||
# Select next token
|
# Select next token
|
||||||
input_length = batch.input_length
|
input_length = batch.input_length
|
||||||
if self.is_optimized_for_gaudi and logits.shape[-2] > 1:
|
if logits.shape[-2] > 1:
|
||||||
next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
|
next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
|
||||||
batch.input_ids, logits[:, input_length - 1: input_length, :].squeeze(-2)
|
batch.input_ids, logits[:, input_length - 1: input_length, :].squeeze(-2)
|
||||||
)
|
)
|
||||||
@ -840,20 +855,11 @@ class CausalLM(Model):
|
|||||||
|
|
||||||
htorch.core.mark_step()
|
htorch.core.mark_step()
|
||||||
|
|
||||||
if token_idx is None:
|
# Add new token into input_ids
|
||||||
batch.input_ids[:, 0] = next_token_ids[:, 0]
|
|
||||||
else:
|
|
||||||
batch.input_ids.index_copy_(1, token_idx, next_token_ids.unsqueeze(1))
|
batch.input_ids.index_copy_(1, token_idx, next_token_ids.unsqueeze(1))
|
||||||
|
|
||||||
# Slice unused values from prefill, use it to store next token
|
|
||||||
if token_idx is None:
|
|
||||||
batch.input_ids = batch.input_ids[:, :1]
|
|
||||||
|
|
||||||
# Update attention_mask as we added a new token to input_ids
|
# Update attention_mask as we added a new token to input_ids
|
||||||
if self.is_optimized_for_gaudi:
|
|
||||||
batch.attention_mask.index_fill_(1, token_idx, 1)
|
batch.attention_mask.index_fill_(1, token_idx, 1)
|
||||||
else:
|
|
||||||
batch.attention_mask[:, -batch.padding_right_offset] = 1
|
|
||||||
|
|
||||||
# Adjust lengths
|
# Adjust lengths
|
||||||
batch.input_length += 1
|
batch.input_length += 1
|
||||||
@ -886,37 +892,24 @@ class CausalLM(Model):
|
|||||||
scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}')
|
scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}')
|
||||||
assert batch.right_padding > 0, 'No more room for next token!'
|
assert batch.right_padding > 0, 'No more room for next token!'
|
||||||
|
|
||||||
if self.is_optimized_for_gaudi:
|
# Execute batch
|
||||||
if prefill:
|
if prefill:
|
||||||
# no right padding for prefill
|
# no right padding for prefill
|
||||||
token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device)
|
token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device)
|
||||||
else:
|
|
||||||
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
|
|
||||||
attention_mask = batch.attention_mask
|
|
||||||
else:
|
|
||||||
token_idx = None
|
|
||||||
# slice the attention mask to the correct shape
|
|
||||||
# TODO fix me!
|
|
||||||
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
|
||||||
|
|
||||||
if not prefill and token_idx is not None:
|
|
||||||
input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1)
|
|
||||||
else:
|
|
||||||
input_ids = batch.input_ids
|
|
||||||
|
|
||||||
if prefill:
|
|
||||||
batch.logits, batch.past = self.forward(
|
batch.logits, batch.past = self.forward(
|
||||||
input_ids,
|
batch.input_ids,
|
||||||
attention_mask,
|
batch.attention_mask,
|
||||||
batch.position_ids,
|
batch.position_ids,
|
||||||
token_idx,
|
token_idx,
|
||||||
batch.past_key_values,
|
batch.past_key_values,
|
||||||
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None
|
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
|
||||||
|
input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1)
|
||||||
batch.logits = self.forward(
|
batch.logits = self.forward(
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
batch.attention_mask,
|
||||||
batch.position_ids,
|
batch.position_ids,
|
||||||
token_idx,
|
token_idx,
|
||||||
batch.past_key_values,
|
batch.past_key_values,
|
||||||
@ -955,10 +948,7 @@ class CausalLM(Model):
|
|||||||
top_token_logprobs = req_data['top_token_logprobs']
|
top_token_logprobs = req_data['top_token_logprobs']
|
||||||
|
|
||||||
# Append next token to all tokens
|
# Append next token to all tokens
|
||||||
if self.is_optimized_for_gaudi:
|
|
||||||
all_input_ids[input_length] = next_token_id
|
all_input_ids[input_length] = next_token_id
|
||||||
else:
|
|
||||||
all_input_ids = torch.cat([all_input_ids, next_token_id])
|
|
||||||
new_input_length = input_length + 1
|
new_input_length = input_length + 1
|
||||||
|
|
||||||
# Generated token
|
# Generated token
|
||||||
|
@ -59,7 +59,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
async def Warmup(self, request, context):
|
async def Warmup(self, request, context):
|
||||||
def batch_from_pb(batch):
|
def batch_from_pb(batch):
|
||||||
return self.model.batch_type.from_pb(
|
return self.model.batch_type.from_pb(
|
||||||
batch, self.model.tokenizer, self.model.dtype, self.model.device, self.model.is_optimized_for_gaudi
|
batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||||
)
|
)
|
||||||
|
|
||||||
batches = [batch_from_pb(batch) for batch in request.batches]
|
batches = [batch_from_pb(batch) for batch in request.batches]
|
||||||
@ -69,7 +69,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
|
|
||||||
async def Prefill(self, request, context):
|
async def Prefill(self, request, context):
|
||||||
batch = self.model.batch_type.from_pb(
|
batch = self.model.batch_type.from_pb(
|
||||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device, self.model.is_optimized_for_gaudi
|
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||||
)
|
)
|
||||||
generations, next_batch = self.model.generate_token([batch])
|
generations, next_batch = self.model.generate_token([batch])
|
||||||
self.cache.set(next_batch)
|
self.cache.set(next_batch)
|
||||||
|
Loading…
Reference in New Issue
Block a user