mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Bump neuron SDK version (#3260)
* chore(neuron): bump version to 0.2.0 * refactor(neuron): use named parameters in inputs helpers This allows to hide the differences between the two backends in terms of input parameters. * refactor(neuron): remove obsolete code paths * fix(neuron): use neuron_config whenever possible * fix(neuron): use new cache import path * fix(neuron): neuron config is not stored in config anymore * fix(nxd): adapt model retrieval to new APIs * fix(generator): emulate greedy in sampling parameters When on-device sampling is enabled, we need to emulate the greedy behaviour using top-k=1, top-p=1, temperature=1. * test(neuron): update models and expectations * feat(neuron): support on-device sampling * fix(neuron): adapt entrypoint * tests(neuron): remove obsolete models * fix(neuron): adjust test expectations for llama on nxd
This commit is contained in:
parent
1ff9d185d5
commit
79183d1647
@ -5,7 +5,7 @@ RUN mkdir -p /tgi
|
|||||||
# Fetch the optimum-neuron sources directly to avoid relying on pypi deployments
|
# Fetch the optimum-neuron sources directly to avoid relying on pypi deployments
|
||||||
FROM alpine AS optimum-neuron
|
FROM alpine AS optimum-neuron
|
||||||
RUN mkdir -p /optimum-neuron
|
RUN mkdir -p /optimum-neuron
|
||||||
ADD https://github.com/huggingface/optimum-neuron/archive/refs/tags/v0.1.0.tar.gz /optimum-neuron/sources.tar.gz
|
ADD https://github.com/huggingface/optimum-neuron/archive/refs/tags/v0.2.0.tar.gz /optimum-neuron/sources.tar.gz
|
||||||
RUN tar -C /optimum-neuron -xf /optimum-neuron/sources.tar.gz --strip-components=1
|
RUN tar -C /optimum-neuron -xf /optimum-neuron/sources.tar.gz --strip-components=1
|
||||||
|
|
||||||
# Build cargo components (adapted from TGI original Dockerfile)
|
# Build cargo components (adapted from TGI original Dockerfile)
|
||||||
@ -108,10 +108,10 @@ RUN wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEU
|
|||||||
# Install neuronx packages
|
# Install neuronx packages
|
||||||
RUN apt-get update -y \
|
RUN apt-get update -y \
|
||||||
&& apt-get install -y --no-install-recommends \
|
&& apt-get install -y --no-install-recommends \
|
||||||
aws-neuronx-dkms=2.19.64.0 \
|
aws-neuronx-dkms=2.20.28.0 \
|
||||||
aws-neuronx-collectives=2.23.135.0-3e70920f2 \
|
aws-neuronx-collectives=2.24.59.0-838c7fc8b \
|
||||||
aws-neuronx-runtime-lib=2.23.112.0-9b5179492 \
|
aws-neuronx-runtime-lib=2.24.53.0-f239092cc \
|
||||||
aws-neuronx-tools=2.20.204.0 \
|
aws-neuronx-tools=2.22.61.0 \
|
||||||
libxml2 \
|
libxml2 \
|
||||||
&& rm -rf /var/lib/apt/lists/* \
|
&& rm -rf /var/lib/apt/lists/* \
|
||||||
&& apt-get clean
|
&& apt-get clean
|
||||||
@ -125,11 +125,10 @@ RUN pip3 install \
|
|||||||
--index-url https://download.pytorch.org/whl/cpu
|
--index-url https://download.pytorch.org/whl/cpu
|
||||||
|
|
||||||
RUN pip3 install \
|
RUN pip3 install \
|
||||||
neuronx-cc==2.16.372.0 \
|
neuronx-cc==2.17.194.0 \
|
||||||
torch-neuronx==2.5.1.2.4.0 \
|
torch-neuronx==2.5.1.2.6.0 \
|
||||||
transformers-neuronx==0.13.322 \
|
neuronx-distributed==0.11.0 \
|
||||||
neuronx-distributed==0.10.1 \
|
libneuronxla==2.2.1630.0 \
|
||||||
libneuronxla==2.1.681.0 \
|
|
||||||
--extra-index-url=https://pip.repos.neuron.amazonaws.com
|
--extra-index-url=https://pip.repos.neuron.amazonaws.com
|
||||||
|
|
||||||
# Install HuggingFace packages
|
# Install HuggingFace packages
|
||||||
@ -160,7 +159,7 @@ RUN pip install dist/text_generation_server*.tar.gz
|
|||||||
# Final image
|
# Final image
|
||||||
FROM neuron
|
FROM neuron
|
||||||
|
|
||||||
COPY backends/neuron/tgi_env.py /tgi_env.py
|
COPY backends/neuron/tgi_entry_point.py /tgi_entry_point.py
|
||||||
COPY backends/neuron/tgi-entrypoint.sh /tgi-entrypoint.sh
|
COPY backends/neuron/tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||||
RUN chmod +x /tgi-entrypoint.sh
|
RUN chmod +x /tgi-entrypoint.sh
|
||||||
|
|
||||||
|
@ -7,7 +7,8 @@ from typing import List, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase
|
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||||
|
from optimum.neuron.configuration_utils import NeuronConfig
|
||||||
from transformers.generation import GenerationConfig
|
from transformers.generation import GenerationConfig
|
||||||
|
|
||||||
from optimum.neuron import NeuronModelForCausalLM
|
from optimum.neuron import NeuronModelForCausalLM
|
||||||
@ -175,6 +176,12 @@ class Slot:
|
|||||||
self._generation_config.top_p = request.parameters.top_p
|
self._generation_config.top_p = request.parameters.top_p
|
||||||
if request.parameters.typical_p != 0:
|
if request.parameters.typical_p != 0:
|
||||||
self._generation_config.typical_p = request.parameters.typical_p
|
self._generation_config.typical_p = request.parameters.typical_p
|
||||||
|
else:
|
||||||
|
# Set the sampling parameters to emulate greedy decoding when using on-device sampling
|
||||||
|
self._generation_config.temperature = 1.0
|
||||||
|
self._generation_config.top_k = 1
|
||||||
|
self._generation_config.top_p = 1.0
|
||||||
|
self._generation_config.typical_p = 1.0
|
||||||
if request.parameters.repetition_penalty != 0:
|
if request.parameters.repetition_penalty != 0:
|
||||||
self._generation_config.repetition_penalty = (
|
self._generation_config.repetition_penalty = (
|
||||||
request.parameters.repetition_penalty
|
request.parameters.repetition_penalty
|
||||||
@ -211,19 +218,11 @@ class Slot:
|
|||||||
self._mask = attention_mask.clone()
|
self._mask = attention_mask.clone()
|
||||||
self._selector = selector
|
self._selector = selector
|
||||||
|
|
||||||
def pause(self, reset_on_pause: bool):
|
def pause(self):
|
||||||
"""Mark the current slot as paused for generation.
|
"""Mark the current slot as paused for generation.
|
||||||
|
|
||||||
Note that the KV cache for this slot will still be filled.
|
Note that the KV cache for this slot will still be filled.
|
||||||
"""
|
"""
|
||||||
if reset_on_pause:
|
|
||||||
# Drop the last token as it will be added back when resuming the slot
|
|
||||||
self._generated_tokens -= 1
|
|
||||||
# Since generated tokens are now part of the prefill, we need to reevaluate
|
|
||||||
# max_new_tokens for the next generation
|
|
||||||
self._generation_config.max_new_tokens = (
|
|
||||||
self._max_new_tokens - self._generated_tokens
|
|
||||||
)
|
|
||||||
self._state = Slot.State.PAUSE
|
self._state = Slot.State.PAUSE
|
||||||
|
|
||||||
def resume(self):
|
def resume(self):
|
||||||
@ -340,16 +339,27 @@ class NeuronGenerator(Generator):
|
|||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.rebuild_cache_on_prefill = not self.model.continuous_batching
|
if not isinstance(self.model, NeuronModelForCausalLM):
|
||||||
|
raise ValueError("The model must be a NeuronModelForCausalLM.")
|
||||||
|
if not model.neuron_config.continuous_batching:
|
||||||
|
raise ValueError(
|
||||||
|
"The neuron model must be compiled with continuous_batching=True."
|
||||||
|
)
|
||||||
# Specify padding and truncation options for decoder-only architecture
|
# Specify padding and truncation options for decoder-only architecture
|
||||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
tokenizer.padding_side = "left"
|
tokenizer.padding_side = "left"
|
||||||
tokenizer.truncation_side = "left"
|
tokenizer.truncation_side = "left"
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.special_tokens = self.tokenizer.all_special_ids
|
self.special_tokens = self.tokenizer.all_special_ids
|
||||||
self.slots = [Slot(i, tokenizer) for i in range(self.model.batch_size)]
|
self.slots = [
|
||||||
|
Slot(i, tokenizer) for i in range(self.model.neuron_config.batch_size)
|
||||||
|
]
|
||||||
self.batch_id = 0
|
self.batch_id = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def on_device_sampling(self) -> bool:
|
||||||
|
return getattr(self.model.neuron_config, "on_device_sampling", False)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def info(self) -> InfoResponse:
|
def info(self) -> InfoResponse:
|
||||||
"""Returns the expected InfoResponse."""
|
"""Returns the expected InfoResponse."""
|
||||||
@ -371,14 +381,22 @@ class NeuronGenerator(Generator):
|
|||||||
The maximum number of tokens the model supports.
|
The maximum number of tokens the model supports.
|
||||||
"""
|
"""
|
||||||
# Just check that the warmup request parameters match the model capacity
|
# Just check that the warmup request parameters match the model capacity
|
||||||
batch_size = self.model.batch_size
|
batch_size = self.model.neuron_config.batch_size
|
||||||
if len(batch.requests) > batch_size:
|
if len(batch.requests) > batch_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Inconsistent batch_size configuration: Please make sure the batch_size in the compiled model (currently {batch_size}) matches the batch_size passed to TGI. The compiled model batch_size is usually in the neuron section of the model config.json file. You may also have passed it into optimum-cli during the compilation process. The batch size for TGI is usually set in the environment as MAX_BATCH_SIZE."
|
f"Inconsistent batch_size configuration: Please make sure the batch_size in the compiled model (currently {batch_size}) matches the batch_size passed to TGI. The compiled model.neuron_config.batch_size is usually in the neuron section of the model config.json file. You may also have passed it into optimum-cli during the compilation process. The batch size for TGI is usually set in the environment as MAX_BATCH_SIZE."
|
||||||
)
|
)
|
||||||
self.prefill(batch)
|
self.prefill(batch)
|
||||||
self.clear()
|
self.clear()
|
||||||
return self.model.batch_size * self.model.max_length
|
return (
|
||||||
|
self.model.neuron_config.batch_size
|
||||||
|
* self.model.neuron_config.sequence_length
|
||||||
|
)
|
||||||
|
|
||||||
|
def max_prefill_length(self) -> int:
|
||||||
|
if hasattr(self.model.neuron_config, "max_context_length"):
|
||||||
|
return self.model.neuron_config.max_context_length
|
||||||
|
return self.model.neuron_config.sequence_length
|
||||||
|
|
||||||
def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
|
def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
|
||||||
"""Prefill new requests.
|
"""Prefill new requests.
|
||||||
@ -398,7 +416,7 @@ class NeuronGenerator(Generator):
|
|||||||
if len(empty_slots) < len(batch.requests):
|
if len(empty_slots) < len(batch.requests):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Cannot prefill {len(batch.requests)} new request(s) with only {len(empty_slots)} empty slots."
|
f"Cannot prefill {len(batch.requests)} new request(s) with only {len(empty_slots)} empty slots."
|
||||||
f" Please align max_batch_size with the static batch size: {self.model.batch_size}."
|
f" Please align max_batch_size with the static batch size: {self.model.neuron_config.batch_size}."
|
||||||
)
|
)
|
||||||
# Assign each request to an empty slot
|
# Assign each request to an empty slot
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@ -412,14 +430,8 @@ class NeuronGenerator(Generator):
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
f"Request {slot.request_id} assigned to slot {slot.id} with and max_new_tokens {slot.max_new_tokens}"
|
f"Request {slot.request_id} assigned to slot {slot.id} with and max_new_tokens {slot.max_new_tokens}"
|
||||||
)
|
)
|
||||||
if self.rebuild_cache_on_prefill:
|
prefill_slots = new_slots
|
||||||
# We will clear pending slots and prefill all slots
|
seq_ids = torch.tensor([slot.id for slot in prefill_slots])
|
||||||
prefill_slots = self.slots
|
|
||||||
seq_ids = None
|
|
||||||
else:
|
|
||||||
# We only need to pass inputs for the new requests
|
|
||||||
prefill_slots = new_slots
|
|
||||||
seq_ids = torch.tensor([slot.id for slot in prefill_slots])
|
|
||||||
# Reconstruct the full inputs (without padding) as seen by the model.
|
# Reconstruct the full inputs (without padding) as seen by the model.
|
||||||
# This comprises:
|
# This comprises:
|
||||||
# - the inputs for new requests,
|
# - the inputs for new requests,
|
||||||
@ -431,8 +443,10 @@ class NeuronGenerator(Generator):
|
|||||||
inputs.append(slot.cached_text)
|
inputs.append(slot.cached_text)
|
||||||
# Apply truncation, making sure we fit into static dimensions
|
# Apply truncation, making sure we fit into static dimensions
|
||||||
if slot.truncate == 0:
|
if slot.truncate == 0:
|
||||||
max_length = self.model.max_length
|
max_length = self.max_prefill_length()
|
||||||
elif slot.truncate > max_length and slot.truncate < self.model.max_length:
|
elif (
|
||||||
|
slot.truncate > max_length and slot.truncate < self.max_prefill_length()
|
||||||
|
):
|
||||||
max_length = slot.truncate
|
max_length = slot.truncate
|
||||||
# Tokenize with padding and truncation
|
# Tokenize with padding and truncation
|
||||||
padded_inputs = self.tokenizer(
|
padded_inputs = self.tokenizer(
|
||||||
@ -444,13 +458,12 @@ class NeuronGenerator(Generator):
|
|||||||
)
|
)
|
||||||
input_ids = padded_inputs.input_ids
|
input_ids = padded_inputs.input_ids
|
||||||
attention_mask = padded_inputs.attention_mask
|
attention_mask = padded_inputs.attention_mask
|
||||||
|
sampling_params = (
|
||||||
|
torch.zeros(input_ids.shape[0], 3) if self.on_device_sampling else None
|
||||||
|
)
|
||||||
# Pause previously active slots during generation
|
# Pause previously active slots during generation
|
||||||
next_tokens = []
|
|
||||||
for slot in active_slots:
|
for slot in active_slots:
|
||||||
slot.pause(reset_on_pause=self.rebuild_cache_on_prefill)
|
slot.pause()
|
||||||
if self.rebuild_cache_on_prefill:
|
|
||||||
# The slot will be reset, so we need to store its next token
|
|
||||||
next_tokens.append(slot.next_token)
|
|
||||||
# Each slot must be reset with the padded inputs and masks
|
# Each slot must be reset with the padded inputs and masks
|
||||||
for i, slot in enumerate(prefill_slots):
|
for i, slot in enumerate(prefill_slots):
|
||||||
if slot.state != slot.state.EMPTY:
|
if slot.state != slot.state.EMPTY:
|
||||||
@ -464,29 +477,33 @@ class NeuronGenerator(Generator):
|
|||||||
slot_input_ids,
|
slot_input_ids,
|
||||||
slot.generation_config,
|
slot.generation_config,
|
||||||
self.model,
|
self.model,
|
||||||
self.model.max_length,
|
self.model.neuron_config.sequence_length,
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
seed=slot.seed,
|
seed=slot.seed,
|
||||||
)
|
)
|
||||||
slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64)
|
slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64)
|
||||||
slot_attention_mask = attention_mask[i]
|
slot_attention_mask = attention_mask[i]
|
||||||
slot.reset(slot_input_ids, slot_attention_mask, selector)
|
slot.reset(slot_input_ids, slot_attention_mask, selector)
|
||||||
|
if sampling_params is not None:
|
||||||
|
sampling_params[i, 0] = slot.generation_config.top_k
|
||||||
|
sampling_params[i, 1] = slot.generation_config.top_p
|
||||||
|
sampling_params[i, 2] = slot.generation_config.temperature
|
||||||
# Note: when rebuilding cache on prefill, the new tokens on paused slots will be ignored,
|
# Note: when rebuilding cache on prefill, the new tokens on paused slots will be ignored,
|
||||||
# as they have already been generated and sent back in the last decode.
|
# as they have already been generated and sent back in the last decode.
|
||||||
model_inputs = self.model.prepare_inputs_for_prefill(
|
model_inputs = self.model.prepare_inputs_for_prefill(
|
||||||
input_ids, attention_mask, seq_ids
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
seq_ids=seq_ids,
|
||||||
|
sampling_params=sampling_params,
|
||||||
)
|
)
|
||||||
logits = self.model(**model_inputs)[0]
|
tokens_or_logits = self.model(**model_inputs)[0]
|
||||||
generation, next_batch = self._generate_token(
|
generation, next_batch = self._generate_token(
|
||||||
prefill_slots, self.batch_id, logits, input_ids
|
prefill_slots, self.batch_id, tokens_or_logits, input_ids
|
||||||
)
|
)
|
||||||
self.batch_id += 1
|
self.batch_id += 1
|
||||||
# Reactivate previously active slots for the next decode
|
# Reactivate previously active slots for the next decode
|
||||||
for i, slot in enumerate(active_slots):
|
for i, slot in enumerate(active_slots):
|
||||||
slot.resume()
|
slot.resume()
|
||||||
if self.rebuild_cache_on_prefill:
|
|
||||||
# Append back the next token
|
|
||||||
slot.append(next_tokens[i])
|
|
||||||
logger.debug("Model ready for decoding")
|
logger.debug("Model ready for decoding")
|
||||||
if next_batch is not None:
|
if next_batch is not None:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@ -530,12 +547,8 @@ class NeuronGenerator(Generator):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unable to decode tokens for non-prefilled batches (probably due to a previous failure)"
|
"Unable to decode tokens for non-prefilled batches (probably due to a previous failure)"
|
||||||
)
|
)
|
||||||
if self.model.continuous_batching:
|
decode_slots = active_slots
|
||||||
decode_slots = active_slots
|
seq_ids = torch.tensor([slot.id for slot in decode_slots])
|
||||||
seq_ids = torch.tensor([slot.id for slot in decode_slots])
|
|
||||||
else:
|
|
||||||
decode_slots = self.slots
|
|
||||||
seq_ids = None
|
|
||||||
# Reconstruct input_ids and attention_mask from decode slots
|
# Reconstruct input_ids and attention_mask from decode slots
|
||||||
n_slots = len(decode_slots)
|
n_slots = len(decode_slots)
|
||||||
input_ids = torch.full(
|
input_ids = torch.full(
|
||||||
@ -545,22 +558,32 @@ class NeuronGenerator(Generator):
|
|||||||
for slot in decode_slots:
|
for slot in decode_slots:
|
||||||
max_length = max(max_length, slot.attention_mask.size(-1))
|
max_length = max(max_length, slot.attention_mask.size(-1))
|
||||||
attention_mask = torch.zeros([n_slots, max_length], dtype=torch.int64)
|
attention_mask = torch.zeros([n_slots, max_length], dtype=torch.int64)
|
||||||
|
sampling_params = torch.zeros(n_slots, 3) if self.on_device_sampling else None
|
||||||
for i, slot in enumerate(decode_slots):
|
for i, slot in enumerate(decode_slots):
|
||||||
if slot.state != Slot.State.EMPTY:
|
if slot.state != Slot.State.EMPTY:
|
||||||
# input_ids are simply the tokens generated by the last decode or prefill requests (other tokens are cached)
|
# input_ids are simply the tokens generated by the last decode or prefill requests (other tokens are cached)
|
||||||
input_ids[i, 0] = slot.next_token
|
input_ids[i, 0] = slot.next_token
|
||||||
attention_mask[i, : slot.attention_mask.size(-1)] = slot.attention_mask
|
attention_mask[i, : slot.attention_mask.size(-1)] = slot.attention_mask
|
||||||
|
if sampling_params is not None:
|
||||||
|
sampling_params[i, 0] = slot.generation_config.top_k
|
||||||
|
sampling_params[i, 1] = slot.generation_config.top_p
|
||||||
|
sampling_params[i, 2] = slot.generation_config.temperature
|
||||||
model_inputs = self.model.prepare_inputs_for_decode(
|
model_inputs = self.model.prepare_inputs_for_decode(
|
||||||
input_ids, attention_mask, seq_ids
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
seq_ids=seq_ids,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
)
|
||||||
|
tokens_or_logits = self.model(**model_inputs)[0]
|
||||||
|
return self._generate_token(
|
||||||
|
decode_slots, next_batch_id, tokens_or_logits, input_ids
|
||||||
)
|
)
|
||||||
logits = self.model(**model_inputs)[0]
|
|
||||||
return self._generate_token(decode_slots, next_batch_id, logits, input_ids)
|
|
||||||
|
|
||||||
def _generate_token(
|
def _generate_token(
|
||||||
self,
|
self,
|
||||||
slots: List[Slot],
|
slots: List[Slot],
|
||||||
next_batch_id: int,
|
next_batch_id: int,
|
||||||
logits: torch.Tensor,
|
tokens_or_logits: torch.Tensor,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
) -> Tuple[List[Generation], CachedBatch]:
|
) -> Tuple[List[Generation], CachedBatch]:
|
||||||
generations = []
|
generations = []
|
||||||
@ -569,9 +592,12 @@ class NeuronGenerator(Generator):
|
|||||||
if slot.state != Slot.State.READY:
|
if slot.state != Slot.State.READY:
|
||||||
continue
|
continue
|
||||||
request_id = slot.request_id
|
request_id = slot.request_id
|
||||||
next_token_logits = logits[i : i + 1, -1, :]
|
|
||||||
slot_input_ids = input_ids[i : i + 1, :]
|
slot_input_ids = input_ids[i : i + 1, :]
|
||||||
next_token = slot.select(slot_input_ids, next_token_logits)
|
if self.on_device_sampling:
|
||||||
|
next_token = tokens_or_logits[i]
|
||||||
|
else:
|
||||||
|
next_token_logits = tokens_or_logits[i : i + 1, -1, :]
|
||||||
|
next_token = slot.select(slot_input_ids, next_token_logits)
|
||||||
next_token_text = slot.append(next_token)
|
next_token_text = slot.append(next_token)
|
||||||
generated_text = None
|
generated_text = None
|
||||||
finish_reason = None
|
finish_reason = None
|
||||||
@ -622,7 +648,7 @@ class NeuronGenerator(Generator):
|
|||||||
|
|
||||||
def _cached_batch(self, batch_id: int, request_ids: List):
|
def _cached_batch(self, batch_id: int, request_ids: List):
|
||||||
size = len(request_ids)
|
size = len(request_ids)
|
||||||
max_tokens = size * self.model.max_length
|
max_tokens = size * self.model.neuron_config.sequence_length
|
||||||
return CachedBatch(
|
return CachedBatch(
|
||||||
id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens
|
id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens
|
||||||
)
|
)
|
||||||
@ -671,8 +697,16 @@ class NeuronGenerator(Generator):
|
|||||||
Returns:
|
Returns:
|
||||||
A NeuronGenerator.
|
A NeuronGenerator.
|
||||||
"""
|
"""
|
||||||
config = AutoConfig.from_pretrained(model_id)
|
try:
|
||||||
neuron_config = getattr(config, "neuron", None)
|
neuron_config = NeuronConfig.from_pretrained(model_id, revision=revision)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(
|
||||||
|
"NeuronConfig.from_pretrained failed for model %s, revision %s: %s",
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
neuron_config = None
|
||||||
start = time.time()
|
start = time.time()
|
||||||
if neuron_config is None:
|
if neuron_config is None:
|
||||||
export_kwargs = get_export_kwargs_from_env()
|
export_kwargs = get_export_kwargs_from_env()
|
||||||
|
@ -6,10 +6,12 @@ from typing import Optional
|
|||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from huggingface_hub.constants import HF_HUB_CACHE
|
from huggingface_hub.constants import HF_HUB_CACHE
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from transformers import AutoConfig
|
|
||||||
|
|
||||||
from optimum.neuron import NeuronModelForCausalLM
|
from optimum.neuron.cache import get_hub_cached_entries
|
||||||
from optimum.neuron.utils import get_hub_cached_entries
|
from optimum.neuron.configuration_utils import NeuronConfig
|
||||||
|
|
||||||
|
|
||||||
|
from .tgi_env import check_env_and_neuron_config_compatibility
|
||||||
|
|
||||||
|
|
||||||
def get_export_kwargs_from_env():
|
def get_export_kwargs_from_env():
|
||||||
@ -24,7 +26,6 @@ def get_export_kwargs_from_env():
|
|||||||
num_cores = int(num_cores)
|
num_cores = int(num_cores)
|
||||||
auto_cast_type = os.environ.get("HF_AUTO_CAST_TYPE", None)
|
auto_cast_type = os.environ.get("HF_AUTO_CAST_TYPE", None)
|
||||||
return {
|
return {
|
||||||
"task": "text-generation",
|
|
||||||
"batch_size": batch_size,
|
"batch_size": batch_size,
|
||||||
"sequence_length": sequence_length,
|
"sequence_length": sequence_length,
|
||||||
"num_cores": num_cores,
|
"num_cores": num_cores,
|
||||||
@ -32,20 +33,15 @@ def get_export_kwargs_from_env():
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def is_cached(model_id, neuron_config):
|
def is_cached(model_id):
|
||||||
# Look for cached entries for the specified model
|
# Look for cached entries for the specified model
|
||||||
in_cache = False
|
in_cache = False
|
||||||
entries = get_hub_cached_entries(model_id, "inference")
|
entries = get_hub_cached_entries(model_id)
|
||||||
# Look for compatible entries
|
# Look for compatible entries
|
||||||
for entry in entries:
|
for entry in entries:
|
||||||
compatible = True
|
if check_env_and_neuron_config_compatibility(
|
||||||
for key, value in neuron_config.items():
|
entry, check_compiler_version=True
|
||||||
# Only weights can be different
|
):
|
||||||
if key in ["checkpoint_id", "checkpoint_revision"]:
|
|
||||||
continue
|
|
||||||
if entry[key] != value:
|
|
||||||
compatible = False
|
|
||||||
if compatible:
|
|
||||||
in_cache = True
|
in_cache = True
|
||||||
break
|
break
|
||||||
return in_cache
|
return in_cache
|
||||||
@ -87,8 +83,16 @@ def fetch_model(
|
|||||||
revision = None
|
revision = None
|
||||||
# Download the model from the Hub (HUGGING_FACE_HUB_TOKEN must be set for a private or gated model)
|
# Download the model from the Hub (HUGGING_FACE_HUB_TOKEN must be set for a private or gated model)
|
||||||
# Note that the model may already be present in the cache.
|
# Note that the model may already be present in the cache.
|
||||||
config = AutoConfig.from_pretrained(model_id, revision=revision)
|
try:
|
||||||
neuron_config = getattr(config, "neuron", None)
|
neuron_config = NeuronConfig.from_pretrained(model_id, revision=revision)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(
|
||||||
|
"NeuronConfig.from_pretrained failed for model %s, revision %s: %s",
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
neuron_config = None
|
||||||
if neuron_config is not None:
|
if neuron_config is not None:
|
||||||
if os.path.isdir(model_id):
|
if os.path.isdir(model_id):
|
||||||
return model_id
|
return model_id
|
||||||
@ -99,16 +103,11 @@ def fetch_model(
|
|||||||
log_cache_size()
|
log_cache_size()
|
||||||
return snapshot_download(model_id, revision=revision, ignore_patterns="*.bin")
|
return snapshot_download(model_id, revision=revision, ignore_patterns="*.bin")
|
||||||
# Model needs to be exported: look for compatible cached entries on the hub
|
# Model needs to be exported: look for compatible cached entries on the hub
|
||||||
export_kwargs = get_export_kwargs_from_env()
|
if not is_cached(model_id):
|
||||||
export_config = NeuronModelForCausalLM.get_export_config(
|
|
||||||
model_id, config, revision=revision, **export_kwargs
|
|
||||||
)
|
|
||||||
neuron_config = export_config.neuron
|
|
||||||
if not is_cached(model_id, neuron_config):
|
|
||||||
hub_cache_url = "https://huggingface.co/aws-neuron/optimum-neuron-cache"
|
hub_cache_url = "https://huggingface.co/aws-neuron/optimum-neuron-cache"
|
||||||
neuron_export_url = "https://huggingface.co/docs/optimum-neuron/main/en/guides/export_model#exporting-neuron-models-using-neuronx-tgi"
|
neuron_export_url = "https://huggingface.co/docs/optimum-neuron/main/en/guides/export_model#exporting-neuron-models-using-neuronx-tgi"
|
||||||
error_msg = (
|
error_msg = (
|
||||||
f"No cached version found for {model_id} with {neuron_config}."
|
f"No cached version found for {model_id} with {get_export_kwargs_from_env()}."
|
||||||
f"You can start a discussion to request it on {hub_cache_url}"
|
f"You can start a discussion to request it on {hub_cache_url}"
|
||||||
f"Alternatively, you can export your own neuron model as explained in {neuron_export_url}"
|
f"Alternatively, you can export your own neuron model as explained in {neuron_export_url}"
|
||||||
)
|
)
|
||||||
@ -121,8 +120,10 @@ def fetch_model(
|
|||||||
# Prefetch weights, tokenizer and generation config so that they are in cache
|
# Prefetch weights, tokenizer and generation config so that they are in cache
|
||||||
log_cache_size()
|
log_cache_size()
|
||||||
start = time.time()
|
start = time.time()
|
||||||
snapshot_download(model_id, revision=revision, ignore_patterns="*.bin")
|
snapshot_path = snapshot_download(
|
||||||
|
model_id, revision=revision, ignore_patterns="*.bin"
|
||||||
|
)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
logger.info(f"Model weights fetched in {end - start:.2f} s.")
|
logger.info(f"Model weights fetched in {end - start:.2f} s.")
|
||||||
log_cache_size()
|
log_cache_size()
|
||||||
return model_id
|
return snapshot_path
|
||||||
|
145
backends/neuron/tgi_env.py → backends/neuron/server/text_generation_server/tgi_env.py
Executable file → Normal file
145
backends/neuron/tgi_env.py → backends/neuron/server/text_generation_server/tgi_env.py
Executable file → Normal file
@ -6,12 +6,11 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from huggingface_hub import constants
|
|
||||||
from transformers import AutoConfig
|
|
||||||
|
|
||||||
from optimum.neuron.modeling_decoder import get_available_cores
|
from optimum.neuron.modeling_decoder import get_available_cores
|
||||||
from optimum.neuron.utils import get_hub_cached_entries
|
from optimum.neuron.cache import get_hub_cached_entries
|
||||||
|
from optimum.neuron.configuration_utils import NeuronConfig
|
||||||
from optimum.neuron.utils.version_utils import get_neuronxcc_version
|
from optimum.neuron.utils.version_utils import get_neuronxcc_version
|
||||||
|
from optimum.neuron.utils import map_torch_dtype
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -24,15 +23,9 @@ tgi_router_env_vars = [
|
|||||||
]
|
]
|
||||||
tgi_server_env_vars = ["HF_NUM_CORES", "HF_AUTO_CAST_TYPE"]
|
tgi_server_env_vars = ["HF_NUM_CORES", "HF_AUTO_CAST_TYPE"]
|
||||||
|
|
||||||
env_config_peering = [
|
|
||||||
("MAX_BATCH_SIZE", "batch_size"),
|
|
||||||
("MAX_TOTAL_TOKENS", "sequence_length"),
|
|
||||||
("HF_AUTO_CAST_TYPE", "auto_cast_type"),
|
|
||||||
("HF_NUM_CORES", "num_cores"),
|
|
||||||
]
|
|
||||||
|
|
||||||
# By the end of this script all env var should be specified properly
|
# By the end of this script all env var should be specified properly
|
||||||
env_vars = tgi_server_env_vars + tgi_router_env_vars
|
tgi_env_vars = tgi_server_env_vars + tgi_router_env_vars
|
||||||
|
|
||||||
available_cores = get_available_cores()
|
available_cores = get_available_cores()
|
||||||
neuronxcc_version = get_neuronxcc_version()
|
neuronxcc_version = get_neuronxcc_version()
|
||||||
@ -93,9 +86,17 @@ def parse_cmdline_and_set_env(argv: List[str] = None) -> argparse.Namespace:
|
|||||||
|
|
||||||
|
|
||||||
def neuron_config_to_env(neuron_config):
|
def neuron_config_to_env(neuron_config):
|
||||||
|
if isinstance(neuron_config, NeuronConfig):
|
||||||
|
neuron_config = neuron_config.to_dict()
|
||||||
with open(os.environ["ENV_FILEPATH"], "w") as f:
|
with open(os.environ["ENV_FILEPATH"], "w") as f:
|
||||||
for env_var, config_key in env_config_peering:
|
f.write("export MAX_BATCH_SIZE={}\n".format(neuron_config["batch_size"]))
|
||||||
f.write("export {}={}\n".format(env_var, neuron_config[config_key]))
|
f.write("export MAX_TOTAL_TOKENS={}\n".format(neuron_config["sequence_length"]))
|
||||||
|
f.write("export HF_NUM_CORES={}\n".format(neuron_config["tp_degree"]))
|
||||||
|
config_key = (
|
||||||
|
"auto_cast_type" if "auto_cast_type" in neuron_config else "torch_dtype"
|
||||||
|
)
|
||||||
|
auto_cast_type = neuron_config[config_key]
|
||||||
|
f.write("export HF_AUTO_CAST_TYPE={}\n".format(auto_cast_type))
|
||||||
max_input_tokens = os.getenv("MAX_INPUT_TOKENS")
|
max_input_tokens = os.getenv("MAX_INPUT_TOKENS")
|
||||||
if not max_input_tokens:
|
if not max_input_tokens:
|
||||||
max_input_tokens = int(neuron_config["sequence_length"]) // 2
|
max_input_tokens = int(neuron_config["sequence_length"]) // 2
|
||||||
@ -111,7 +112,7 @@ def neuron_config_to_env(neuron_config):
|
|||||||
|
|
||||||
|
|
||||||
def sort_neuron_configs(dictionary):
|
def sort_neuron_configs(dictionary):
|
||||||
return -dictionary["num_cores"], -dictionary["batch_size"]
|
return -dictionary["tp_degree"], -dictionary["batch_size"]
|
||||||
|
|
||||||
|
|
||||||
def lookup_compatible_cached_model(
|
def lookup_compatible_cached_model(
|
||||||
@ -119,7 +120,7 @@ def lookup_compatible_cached_model(
|
|||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[Dict[str, Any]]:
|
||||||
# Reuse the same mechanic as the one in use to configure the tgi server part
|
# Reuse the same mechanic as the one in use to configure the tgi server part
|
||||||
# The only difference here is that we stay as flexible as possible on the compatibility part
|
# The only difference here is that we stay as flexible as possible on the compatibility part
|
||||||
entries = get_hub_cached_entries(model_id, "inference")
|
entries = get_hub_cached_entries(model_id)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Found %d cached entries for model %s, revision %s",
|
"Found %d cached entries for model %s, revision %s",
|
||||||
@ -155,15 +156,15 @@ def lookup_compatible_cached_model(
|
|||||||
|
|
||||||
|
|
||||||
def check_env_and_neuron_config_compatibility(
|
def check_env_and_neuron_config_compatibility(
|
||||||
neuron_config: Dict[str, Any], check_compiler_version: bool
|
neuron_config_dict: Dict[str, Any], check_compiler_version: bool
|
||||||
) -> bool:
|
) -> bool:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Checking the provided neuron config %s is compatible with the local setup and provided environment",
|
"Checking the provided neuron config %s is compatible with the local setup and provided environment",
|
||||||
neuron_config,
|
neuron_config_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Local setup compat checks
|
# Local setup compat checks
|
||||||
if neuron_config["num_cores"] > available_cores:
|
if neuron_config_dict["tp_degree"] > available_cores:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Not enough neuron cores available to run the provided neuron config"
|
"Not enough neuron cores available to run the provided neuron config"
|
||||||
)
|
)
|
||||||
@ -171,33 +172,65 @@ def check_env_and_neuron_config_compatibility(
|
|||||||
|
|
||||||
if (
|
if (
|
||||||
check_compiler_version
|
check_compiler_version
|
||||||
and neuron_config["compiler_version"] != neuronxcc_version
|
and neuron_config_dict["neuronxcc_version"] != neuronxcc_version
|
||||||
):
|
):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Compiler version conflict, the local one (%s) differs from the one used to compile the model (%s)",
|
"Compiler version conflict, the local one (%s) differs from the one used to compile the model (%s)",
|
||||||
neuronxcc_version,
|
neuronxcc_version,
|
||||||
neuron_config["compiler_version"],
|
neuron_config_dict["neuronxcc_version"],
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
for env_var, config_key in env_config_peering:
|
batch_size = os.getenv("MAX_BATCH_SIZE", None)
|
||||||
neuron_config_value = str(neuron_config[config_key])
|
if batch_size is not None and neuron_config_dict["batch_size"] < int(batch_size):
|
||||||
env_value = os.getenv(env_var, str(neuron_config_value))
|
logger.debug(
|
||||||
|
"The provided MAX_BATCH_SIZE (%s) is higher than the neuron config batch size (%s)",
|
||||||
|
os.getenv("MAX_BATCH_SIZE"),
|
||||||
|
neuron_config_dict["batch_size"],
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
max_total_tokens = os.getenv("MAX_TOTAL_TOKENS", None)
|
||||||
|
if max_total_tokens is not None and neuron_config_dict["sequence_length"] < int(
|
||||||
|
max_total_tokens
|
||||||
|
):
|
||||||
|
logger.debug(
|
||||||
|
"The provided MAX_TOTAL_TOKENS (%s) is higher than the neuron config sequence length (%s)",
|
||||||
|
max_total_tokens,
|
||||||
|
neuron_config_dict["sequence_length"],
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
num_cores = os.getenv("HF_NUM_CORES", None)
|
||||||
|
if num_cores is not None and neuron_config_dict["tp_degree"] < int(num_cores):
|
||||||
|
logger.debug(
|
||||||
|
"The provided HF_NUM_CORES (%s) is higher than the neuron config tp degree (%s)",
|
||||||
|
num_cores,
|
||||||
|
neuron_config_dict["tp_degree"],
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
auto_cast_type = os.getenv("HF_AUTO_CAST_TYPE", None)
|
||||||
|
if auto_cast_type is not None:
|
||||||
|
config_key = (
|
||||||
|
"auto_cast_type"
|
||||||
|
if "auto_cast_type" in neuron_config_dict
|
||||||
|
else "torch_dtype"
|
||||||
|
)
|
||||||
|
neuron_config_value = map_torch_dtype(str(neuron_config_dict[config_key]))
|
||||||
|
env_value = map_torch_dtype(auto_cast_type)
|
||||||
if env_value != neuron_config_value:
|
if env_value != neuron_config_value:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"The provided env var '%s' and the neuron config '%s' param differ (%s != %s)",
|
"The provided auto cast type and the neuron config param differ (%s != %s)",
|
||||||
env_var,
|
|
||||||
config_key,
|
|
||||||
env_value,
|
env_value,
|
||||||
neuron_config_value,
|
neuron_config_value,
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
max_input_tokens = int(
|
max_input_tokens = int(
|
||||||
os.getenv("MAX_INPUT_TOKENS", os.getenv("MAX_INPUT_LENGTH", 0))
|
os.getenv("MAX_INPUT_TOKENS", os.getenv("MAX_INPUT_LENGTH", 0))
|
||||||
)
|
)
|
||||||
if max_input_tokens > 0:
|
if max_input_tokens > 0:
|
||||||
sequence_length = neuron_config["sequence_length"]
|
if hasattr(neuron_config_dict, "max_context_length"):
|
||||||
|
sequence_length = neuron_config_dict["max_context_length"]
|
||||||
|
else:
|
||||||
|
sequence_length = neuron_config_dict["sequence_length"]
|
||||||
if max_input_tokens >= sequence_length:
|
if max_input_tokens >= sequence_length:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Specified max input tokens is not compatible with config sequence length ( %s >= %s)",
|
"Specified max input tokens is not compatible with config sequence length ( %s >= %s)",
|
||||||
@ -211,38 +244,29 @@ def check_env_and_neuron_config_compatibility(
|
|||||||
|
|
||||||
def get_env_dict() -> Dict[str, str]:
|
def get_env_dict() -> Dict[str, str]:
|
||||||
d = {}
|
d = {}
|
||||||
for k in env_vars:
|
for k in tgi_env_vars:
|
||||||
d[k] = os.getenv(k)
|
d[k] = os.getenv(k)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def get_neuron_config_for_model(
|
||||||
"""
|
model_name_or_path: str, revision: Optional[str] = None
|
||||||
This script determines proper default TGI env variables for the neuron precompiled models to
|
) -> NeuronConfig:
|
||||||
work properly
|
try:
|
||||||
:return:
|
neuron_config = NeuronConfig.from_pretrained(
|
||||||
"""
|
model_name_or_path, revision=revision
|
||||||
args = parse_cmdline_and_set_env()
|
|
||||||
|
|
||||||
for env_var in env_vars:
|
|
||||||
if not os.getenv(env_var):
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
"All env vars %s already set, skipping, user know what they are doing",
|
|
||||||
env_vars,
|
|
||||||
)
|
)
|
||||||
sys.exit(0)
|
except Exception as e:
|
||||||
|
logger.debug(
|
||||||
cache_dir = constants.HF_HUB_CACHE
|
"NeuronConfig.from_pretrained failed for model %s, revision %s: %s",
|
||||||
|
model_name_or_path,
|
||||||
logger.info("Cache dir %s, model %s", cache_dir, args.model_id)
|
revision,
|
||||||
|
e,
|
||||||
config = AutoConfig.from_pretrained(args.model_id, revision=args.revision)
|
)
|
||||||
neuron_config = getattr(config, "neuron", None)
|
neuron_config = None
|
||||||
if neuron_config is not None:
|
if neuron_config is not None:
|
||||||
compatible = check_env_and_neuron_config_compatibility(
|
compatible = check_env_and_neuron_config_compatibility(
|
||||||
neuron_config, check_compiler_version=False
|
neuron_config.to_dict(), check_compiler_version=False
|
||||||
)
|
)
|
||||||
if not compatible:
|
if not compatible:
|
||||||
env_dict = get_env_dict()
|
env_dict = get_env_dict()
|
||||||
@ -252,17 +276,6 @@ def main():
|
|||||||
logger.error(msg)
|
logger.error(msg)
|
||||||
raise Exception(msg)
|
raise Exception(msg)
|
||||||
else:
|
else:
|
||||||
neuron_config = lookup_compatible_cached_model(args.model_id, args.revision)
|
neuron_config = lookup_compatible_cached_model(model_name_or_path, revision)
|
||||||
|
|
||||||
if not neuron_config:
|
return neuron_config
|
||||||
msg = (
|
|
||||||
"No compatible neuron config found. Provided env {}, available cores {}, neuronxcc version {}"
|
|
||||||
).format(get_env_dict(), available_cores, neuronxcc_version)
|
|
||||||
logger.error(msg)
|
|
||||||
raise Exception(msg)
|
|
||||||
|
|
||||||
neuron_config_to_env(neuron_config)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
74
backends/neuron/tests/fixtures/model.py
vendored
74
backends/neuron/tests/fixtures/model.py
vendored
@ -4,14 +4,12 @@ import subprocess
|
|||||||
import sys
|
import sys
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
|
|
||||||
import huggingface_hub
|
import os
|
||||||
import pytest
|
import pytest
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from optimum.neuron import NeuronModelForCausalLM
|
|
||||||
from optimum.neuron.utils import synchronize_hub_cache
|
from optimum.neuron.cache import synchronize_hub_cache
|
||||||
from optimum.neuron.version import __sdk_version__ as sdk_version
|
|
||||||
from optimum.neuron.version import __version__ as version
|
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@ -21,30 +19,14 @@ logging.basicConfig(
|
|||||||
)
|
)
|
||||||
logger = logging.getLogger(__file__)
|
logger = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
|
||||||
OPTIMUM_CACHE_REPO_ID = "optimum-internal-testing/neuron-testing-cache"
|
OPTIMUM_CACHE_REPO_ID = "optimum-internal-testing/neuron-testing-cache"
|
||||||
|
|
||||||
|
|
||||||
# All model configurations below will be added to the neuron_model_config fixture
|
# All model configurations below will be added to the neuron_model_config fixture
|
||||||
MODEL_CONFIGURATIONS = {
|
MODEL_CONFIGURATIONS = {
|
||||||
"gpt2": {
|
|
||||||
"model_id": "gpt2",
|
|
||||||
"export_kwargs": {
|
|
||||||
"batch_size": 4,
|
|
||||||
"sequence_length": 1024,
|
|
||||||
"num_cores": 2,
|
|
||||||
"auto_cast_type": "fp16",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"llama": {
|
"llama": {
|
||||||
"model_id": "NousResearch/Hermes-2-Theta-Llama-3-8B",
|
"model_id": "unsloth/Llama-3.2-1B-Instruct",
|
||||||
"export_kwargs": {
|
|
||||||
"batch_size": 4,
|
|
||||||
"sequence_length": 2048,
|
|
||||||
"num_cores": 2,
|
|
||||||
"auto_cast_type": "fp16",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"mistral": {
|
|
||||||
"model_id": "optimum/mistral-1.1b-testing",
|
|
||||||
"export_kwargs": {
|
"export_kwargs": {
|
||||||
"batch_size": 4,
|
"batch_size": 4,
|
||||||
"sequence_length": 4096,
|
"sequence_length": 4096,
|
||||||
@ -58,7 +40,7 @@ MODEL_CONFIGURATIONS = {
|
|||||||
"batch_size": 4,
|
"batch_size": 4,
|
||||||
"sequence_length": 4096,
|
"sequence_length": 4096,
|
||||||
"num_cores": 2,
|
"num_cores": 2,
|
||||||
"auto_cast_type": "fp16",
|
"auto_cast_type": "bf16",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"granite": {
|
"granite": {
|
||||||
@ -73,12 +55,6 @@ MODEL_CONFIGURATIONS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_hub_neuron_model_id(config_name: str):
|
|
||||||
return (
|
|
||||||
f"optimum-internal-testing/neuron-testing-{version}-{sdk_version}-{config_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def export_model(model_id, export_kwargs, neuron_model_path):
|
def export_model(model_id, export_kwargs, neuron_model_path):
|
||||||
export_command = [
|
export_command = [
|
||||||
"optimum-cli",
|
"optimum-cli",
|
||||||
@ -104,57 +80,35 @@ def export_model(model_id, export_kwargs, neuron_model_path):
|
|||||||
def neuron_model_config(request):
|
def neuron_model_config(request):
|
||||||
"""Expose a pre-trained neuron model
|
"""Expose a pre-trained neuron model
|
||||||
|
|
||||||
The fixture first makes sure the following model artifacts are present on the hub:
|
The fixture exports a model locally and returns a dictionary containing:
|
||||||
- exported neuron model under optimum-internal-testing/neuron-testing-<version>-<name>,
|
|
||||||
- cached artifacts under optimum-internal-testing/neuron-testing-cache.
|
|
||||||
If not, it will export the model and push it to the hub.
|
|
||||||
|
|
||||||
It then fetches the model locally and return a dictionary containing:
|
|
||||||
- a configuration name,
|
- a configuration name,
|
||||||
- the original model id,
|
- the original model id,
|
||||||
- the export parameters,
|
- the export parameters,
|
||||||
- the neuron model id,
|
|
||||||
- the neuron model local path.
|
- the neuron model local path.
|
||||||
|
|
||||||
For each exposed model, the local directory is maintained for the duration of the
|
For each exposed model, the local directory is maintained for the duration of the
|
||||||
test session and cleaned up afterwards.
|
test session and cleaned up afterwards.
|
||||||
The hub model artifacts are never cleaned up and persist accross sessions.
|
|
||||||
They must be cleaned up manually when the optimum-neuron version changes.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
config_name = request.param
|
config_name = request.param
|
||||||
model_config = copy.deepcopy(MODEL_CONFIGURATIONS[request.param])
|
model_config = copy.deepcopy(MODEL_CONFIGURATIONS[request.param])
|
||||||
model_id = model_config["model_id"]
|
model_id = model_config["model_id"]
|
||||||
export_kwargs = model_config["export_kwargs"]
|
export_kwargs = model_config["export_kwargs"]
|
||||||
neuron_model_id = get_hub_neuron_model_id(config_name)
|
|
||||||
with TemporaryDirectory() as neuron_model_path:
|
with TemporaryDirectory() as neuron_model_path:
|
||||||
hub = huggingface_hub.HfApi()
|
export_model(model_id, export_kwargs, neuron_model_path)
|
||||||
if hub.repo_exists(neuron_model_id):
|
synchronize_hub_cache(cache_repo_id=OPTIMUM_CACHE_REPO_ID)
|
||||||
logger.info(f"Fetching {neuron_model_id} from the HuggingFace hub")
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
hub.snapshot_download(neuron_model_id, local_dir=neuron_model_path)
|
tokenizer.save_pretrained(neuron_model_path)
|
||||||
else:
|
del tokenizer
|
||||||
export_model(model_id, export_kwargs, neuron_model_path)
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
||||||
tokenizer.save_pretrained(neuron_model_path)
|
|
||||||
del tokenizer
|
|
||||||
# Create the test model on the hub
|
|
||||||
hub.create_repo(neuron_model_id, private=True)
|
|
||||||
hub.upload_folder(
|
|
||||||
folder_path=neuron_model_path,
|
|
||||||
repo_id=neuron_model_id,
|
|
||||||
ignore_patterns=[NeuronModelForCausalLM.CHECKPOINT_DIR + "/*"],
|
|
||||||
)
|
|
||||||
# Make sure it is cached
|
|
||||||
synchronize_hub_cache(cache_repo_id=OPTIMUM_CACHE_REPO_ID)
|
|
||||||
# Add dynamic parameters to the model configuration
|
# Add dynamic parameters to the model configuration
|
||||||
model_config["neuron_model_path"] = neuron_model_path
|
model_config["neuron_model_path"] = neuron_model_path
|
||||||
model_config["neuron_model_id"] = neuron_model_id
|
|
||||||
# Also add model configuration name to allow tests to adapt their expectations
|
# Also add model configuration name to allow tests to adapt their expectations
|
||||||
model_config["name"] = config_name
|
model_config["name"] = config_name
|
||||||
# Yield instead of returning to keep a reference to the temporary directory.
|
# Yield instead of returning to keep a reference to the temporary directory.
|
||||||
# It will go out of scope and be released only once all tests needing the fixture
|
# It will go out of scope and be released only once all tests needing the fixture
|
||||||
# have been completed.
|
# have been completed.
|
||||||
logger.info(f"{config_name} ready for testing ...")
|
logger.info(f"{config_name} ready for testing ...")
|
||||||
|
os.environ["CUSTOM_CACHE_REPO"] = OPTIMUM_CACHE_REPO_ID
|
||||||
yield model_config
|
yield model_config
|
||||||
logger.info(f"Done with {config_name}")
|
logger.info(f"Done with {config_name}")
|
||||||
|
|
||||||
|
42
backends/neuron/tests/server/test_cached_model.py
Normal file
42
backends/neuron/tests/server/test_cached_model.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from text_generation_server.generator import NeuronGenerator
|
||||||
|
from text_generation_server.model import fetch_model, is_cached
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def cached_model_id(neuron_model_config) -> str:
|
||||||
|
"""
|
||||||
|
Fixture to provide a cached model ID for testing.
|
||||||
|
This assumes the model is already cached in the local environment.
|
||||||
|
"""
|
||||||
|
export_kwargs = neuron_model_config["export_kwargs"]
|
||||||
|
os.environ["MAX_BATCH_SIZE"] = str(export_kwargs["batch_size"])
|
||||||
|
os.environ["MAX_TOTAL_TOKENS"] = str(export_kwargs["sequence_length"])
|
||||||
|
os.environ["HF_AUTO_CAST_TYPE"] = export_kwargs["auto_cast_type"]
|
||||||
|
os.environ["HF_NUM_CORES"] = str(export_kwargs["num_cores"])
|
||||||
|
yield neuron_model_config["model_id"]
|
||||||
|
os.environ.pop("MAX_BATCH_SIZE", None)
|
||||||
|
os.environ.pop("MAX_TOTAL_TOKENS", None)
|
||||||
|
os.environ.pop("HF_AUTO_CAST_TYPE", None)
|
||||||
|
os.environ.pop("HF_NUM_CORES", None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_is_cached(cached_model_id):
|
||||||
|
assert is_cached(cached_model_id), f"Model {cached_model_id} is not cached"
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_cached_model(cached_model_id: str):
|
||||||
|
model_path = fetch_model(cached_model_id)
|
||||||
|
assert os.path.exists(
|
||||||
|
model_path
|
||||||
|
), f"Model {cached_model_id} was not fetched successfully"
|
||||||
|
assert os.path.isdir(model_path), f"Model {cached_model_id} is not a directory"
|
||||||
|
|
||||||
|
|
||||||
|
def test_generator_from_cached_model(cached_model_id: str):
|
||||||
|
generator = NeuronGenerator.from_pretrained(model_id=cached_model_id)
|
||||||
|
assert generator is not None, "Generator could not be created from cached model"
|
||||||
|
assert generator.model is not None, "Generator model is not initialized"
|
||||||
|
assert generator.tokenizer is not None, "Generator tokenizer is not initialized"
|
@ -9,13 +9,13 @@ def test_continuous_batching_two_requests(neuron_model_config):
|
|||||||
"""
|
"""
|
||||||
neuron_model_path = neuron_model_config["neuron_model_path"]
|
neuron_model_path = neuron_model_config["neuron_model_path"]
|
||||||
generator = NeuronGenerator.from_pretrained(neuron_model_path)
|
generator = NeuronGenerator.from_pretrained(neuron_model_path)
|
||||||
assert generator.model.batch_size > 1
|
assert generator.model.neuron_config.batch_size > 1
|
||||||
input_text = "Once upon a time"
|
input_text = "Once upon a time"
|
||||||
max_new_tokens = 20
|
max_new_tokens = 20
|
||||||
# Prefill a single request, remembering the generated token
|
# Prefill a single request, remembering the generated token
|
||||||
tokens = {0: [], 1: []}
|
tokens = {0: [], 1: []}
|
||||||
request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens)
|
request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens)
|
||||||
max_length = generator.model.max_length
|
max_length = generator.model.neuron_config.sequence_length
|
||||||
batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
|
batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
|
||||||
generations, next_batch = generator.prefill(batch)
|
generations, next_batch = generator.prefill(batch)
|
||||||
assert next_batch.size == 1
|
assert next_batch.size == 1
|
||||||
|
@ -23,7 +23,7 @@ def _test_decode(config_name, generator, do_sample):
|
|||||||
request = create_request(
|
request = create_request(
|
||||||
id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=do_sample
|
id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=do_sample
|
||||||
)
|
)
|
||||||
max_length = generator.model.max_length
|
max_length = generator.model.neuron_config.sequence_length
|
||||||
batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
|
batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
|
||||||
generations, next_batch = generator.prefill(batch)
|
generations, next_batch = generator.prefill(batch)
|
||||||
# We already generated one token: call decode max_new_tokens - 1 times
|
# We already generated one token: call decode max_new_tokens - 1 times
|
||||||
@ -40,19 +40,15 @@ def _test_decode(config_name, generator, do_sample):
|
|||||||
assert output.finish_reason == 0
|
assert output.finish_reason == 0
|
||||||
if do_sample:
|
if do_sample:
|
||||||
expected_text = {
|
expected_text = {
|
||||||
"gpt2": " The sun was set",
|
"llama": " I sat alone in the café",
|
||||||
"llama": "George Orwell, 1984",
|
"qwen2": " The air was so still",
|
||||||
"mistral": "The sky was",
|
|
||||||
"qwen2": " A young woman with",
|
|
||||||
"granite": "1984, George Orwell",
|
"granite": "1984, George Orwell",
|
||||||
}[config_name]
|
}[config_name]
|
||||||
assert expected_text in output.text
|
assert expected_text in output.text
|
||||||
else:
|
else:
|
||||||
print(output.text)
|
print(output.text)
|
||||||
expected_text = {
|
expected_text = {
|
||||||
"gpt2": '\n\n"I\'m going to go to bed," I said.\n\n"I\'m going',
|
"llama": " The world was holding its breath as the world's top scientists and engineers gathered at the secret underground facility",
|
||||||
"llama": " George Orwell’s classic dystopian novel, 1984, begins with this ominous sentence. The story",
|
|
||||||
"mistral": "\nThe clocks were striking thirteen.\nThe clocks were striking thirteen.",
|
|
||||||
"qwen2": " I was sitting in my room, staring at the ceiling, when the door opened and in came a",
|
"qwen2": " I was sitting in my room, staring at the ceiling, when the door opened and in came a",
|
||||||
"granite": "\n\nThis opening line from George Orwell's dystopian novel \"198",
|
"granite": "\n\nThis opening line from George Orwell's dystopian novel \"198",
|
||||||
}[config_name]
|
}[config_name]
|
||||||
|
@ -9,7 +9,7 @@ def test_prefill(neuron_model_config):
|
|||||||
neuron_model_path = neuron_model_config["neuron_model_path"]
|
neuron_model_path = neuron_model_config["neuron_model_path"]
|
||||||
generator = NeuronGenerator.from_pretrained(neuron_model_path)
|
generator = NeuronGenerator.from_pretrained(neuron_model_path)
|
||||||
max_batch_size = 4
|
max_batch_size = 4
|
||||||
assert generator.model.batch_size >= max_batch_size
|
assert generator.model.neuron_config.batch_size >= max_batch_size
|
||||||
for num_requests in [1, max_batch_size]:
|
for num_requests in [1, max_batch_size]:
|
||||||
for do_sample in [True, False]:
|
for do_sample in [True, False]:
|
||||||
mode = "sample" if do_sample else "greedy"
|
mode = "sample" if do_sample else "greedy"
|
||||||
@ -34,7 +34,7 @@ def _test_prefill(config_name, generator, batch_size, do_sample):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
# Let's be pessimistic when estimating max_tokens
|
# Let's be pessimistic when estimating max_tokens
|
||||||
max_length = generator.model.max_length
|
max_length = generator.max_prefill_length()
|
||||||
batch = Batch(
|
batch = Batch(
|
||||||
id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length
|
id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length
|
||||||
)
|
)
|
||||||
@ -46,17 +46,13 @@ def _test_prefill(config_name, generator, batch_size, do_sample):
|
|||||||
assert len(generations) == batch_size
|
assert len(generations) == batch_size
|
||||||
if do_sample:
|
if do_sample:
|
||||||
expectations = {
|
expectations = {
|
||||||
"gpt2": [383, " The"],
|
"llama": [358, " I"],
|
||||||
"llama": [10058, " George"],
|
"qwen2": [576, " The"],
|
||||||
"mistral": [450, " The"],
|
|
||||||
"qwen2": [362, " A"],
|
|
||||||
"granite": [308, " ("],
|
"granite": [308, " ("],
|
||||||
}[config_name]
|
}[config_name]
|
||||||
else:
|
else:
|
||||||
expectations = {
|
expectations = {
|
||||||
"gpt2": [198, "\n"],
|
"llama": [578, " The"],
|
||||||
"llama": [10058, " George"],
|
|
||||||
"mistral": [13, "\n"],
|
|
||||||
"qwen2": [358, " I"],
|
"qwen2": [358, " I"],
|
||||||
"granite": [203, "\n"],
|
"granite": [203, "\n"],
|
||||||
}[config_name]
|
}[config_name]
|
||||||
@ -70,7 +66,7 @@ def test_prefill_truncate(neuron_model_config):
|
|||||||
config_name = neuron_model_config["name"]
|
config_name = neuron_model_config["name"]
|
||||||
neuron_model_path = neuron_model_config["neuron_model_path"]
|
neuron_model_path = neuron_model_config["neuron_model_path"]
|
||||||
generator = NeuronGenerator.from_pretrained(neuron_model_path)
|
generator = NeuronGenerator.from_pretrained(neuron_model_path)
|
||||||
batch_size = generator.model.batch_size
|
batch_size = generator.model.neuron_config.batch_size
|
||||||
# We apply truncation to all requests but the first one
|
# We apply truncation to all requests but the first one
|
||||||
truncate = [
|
truncate = [
|
||||||
None,
|
None,
|
||||||
@ -83,7 +79,7 @@ def test_prefill_truncate(neuron_model_config):
|
|||||||
requests = []
|
requests = []
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
requests.append(create_request(id=i, inputs=input_text, truncate=truncate[i]))
|
requests.append(create_request(id=i, inputs=input_text, truncate=truncate[i]))
|
||||||
max_length = generator.model.max_length
|
max_length = generator.max_prefill_length()
|
||||||
batch = Batch(
|
batch = Batch(
|
||||||
id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length
|
id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length
|
||||||
)
|
)
|
||||||
@ -91,12 +87,12 @@ def test_prefill_truncate(neuron_model_config):
|
|||||||
# Even if the input text is identical for all requests, the first generated token might
|
# Even if the input text is identical for all requests, the first generated token might
|
||||||
# be different because of the truncation
|
# be different because of the truncation
|
||||||
expectations = {
|
expectations = {
|
||||||
"gpt2": [" He", " He", "\n", " He"],
|
"llama": [" He", "iens", "\x08", " He"],
|
||||||
"llama": [" —", " The", " He", " He"],
|
|
||||||
"mistral": [" He", "\n", " He", " He"],
|
|
||||||
"qwen2": [" He", " The", " He", " He"],
|
"qwen2": [" He", " The", " He", " He"],
|
||||||
"granite": ["\n", "\n", " I", " He"],
|
"granite": ["\n", "\n", " I", " He"],
|
||||||
}[config_name]
|
}[config_name]
|
||||||
for i, g in enumerate(generations):
|
for i, g in enumerate(generations):
|
||||||
tokens = g.tokens
|
tokens = g.tokens
|
||||||
assert tokens.texts[0] == expectations[i]
|
assert (
|
||||||
|
tokens.texts[0] == expectations[i]
|
||||||
|
), f"Request {i} expected [{expectations[i]}], got [{tokens.texts[0]}]"
|
||||||
|
63
backends/neuron/tests/test_entry_point.py
Normal file
63
backends/neuron/tests/test_entry_point.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
|
|
||||||
|
from optimum.neuron.models.inference.nxd.backend.config import NxDNeuronConfig
|
||||||
|
from optimum.neuron.utils import map_torch_dtype
|
||||||
|
|
||||||
|
from text_generation_server.tgi_env import (
|
||||||
|
get_neuron_config_for_model,
|
||||||
|
lookup_compatible_cached_model,
|
||||||
|
neuron_config_to_env,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_neuron_config_for_model(neuron_model_config):
|
||||||
|
neuron_model_path = neuron_model_config["neuron_model_path"]
|
||||||
|
export_kwargs = neuron_model_config["export_kwargs"]
|
||||||
|
os.environ["MAX_BATCH_SIZE"] = str(export_kwargs["batch_size"])
|
||||||
|
os.environ["MAX_TOTAL_TOKENS"] = str(export_kwargs["sequence_length"])
|
||||||
|
os.environ["HF_AUTO_CAST_TYPE"] = export_kwargs["auto_cast_type"]
|
||||||
|
os.environ["HF_NUM_CORES"] = str(export_kwargs["num_cores"])
|
||||||
|
neuron_config = get_neuron_config_for_model(neuron_model_path)
|
||||||
|
assert neuron_config is not None
|
||||||
|
assert neuron_config.batch_size == export_kwargs["batch_size"]
|
||||||
|
assert neuron_config.sequence_length == export_kwargs["sequence_length"]
|
||||||
|
assert neuron_config.tp_degree == export_kwargs["num_cores"]
|
||||||
|
if isinstance(neuron_config, NxDNeuronConfig):
|
||||||
|
assert map_torch_dtype(neuron_config.torch_dtype) == map_torch_dtype(
|
||||||
|
export_kwargs["auto_cast_type"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert map_torch_dtype(neuron_config.auto_cast_type) == map_torch_dtype(
|
||||||
|
export_kwargs["auto_cast_type"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_id", ["unsloth/Llama-3.2-1B-Instruct"])
|
||||||
|
def test_lookup_compatible_cached_model(model_id: str):
|
||||||
|
neuron_config = lookup_compatible_cached_model(model_id, None)
|
||||||
|
assert neuron_config is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_neuron_config_to_env(neuron_model_config) -> None:
|
||||||
|
neuron_model_path = neuron_model_config["neuron_model_path"]
|
||||||
|
neuron_config = get_neuron_config_for_model(neuron_model_path)
|
||||||
|
with TemporaryDirectory() as temp_dir:
|
||||||
|
os.environ["ENV_FILEPATH"] = os.path.join(temp_dir, "env.sh")
|
||||||
|
neuron_config_to_env(neuron_config)
|
||||||
|
with open(os.environ["ENV_FILEPATH"], "r") as env_file:
|
||||||
|
env_content = env_file.read()
|
||||||
|
assert f"export MAX_BATCH_SIZE={neuron_config.batch_size}" in env_content
|
||||||
|
assert (
|
||||||
|
f"export MAX_TOTAL_TOKENS={neuron_config.sequence_length}"
|
||||||
|
in env_content
|
||||||
|
)
|
||||||
|
assert f"export HF_NUM_CORES={neuron_config.tp_degree}" in env_content
|
||||||
|
if hasattr(neuron_config, "torch_dtype"):
|
||||||
|
auto_cast_type = str(map_torch_dtype(neuron_config.torch_dtype)).split(
|
||||||
|
"."
|
||||||
|
)[-1]
|
||||||
|
else:
|
||||||
|
auto_cast_type = neuron_config.auto_cast_type
|
||||||
|
assert f"export HF_AUTO_CAST_TYPE={auto_cast_type}" in env_content
|
@ -9,7 +9,7 @@ touch $ENV_FILEPATH
|
|||||||
|
|
||||||
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
||||||
|
|
||||||
${SCRIPT_DIR}/tgi_env.py $@
|
${SCRIPT_DIR}/tgi_entry_point.py $@
|
||||||
|
|
||||||
source $ENV_FILEPATH
|
source $ENV_FILEPATH
|
||||||
|
|
||||||
|
53
backends/neuron/tgi_entry_point.py
Executable file
53
backends/neuron/tgi_entry_point.py
Executable file
@ -0,0 +1,53 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
from text_generation_server.tgi_env import (
|
||||||
|
available_cores,
|
||||||
|
get_env_dict,
|
||||||
|
get_neuron_config_for_model,
|
||||||
|
neuron_config_to_env,
|
||||||
|
neuronxcc_version,
|
||||||
|
parse_cmdline_and_set_env,
|
||||||
|
tgi_env_vars,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""
|
||||||
|
This script determines proper default TGI env variables for the neuron precompiled models to
|
||||||
|
work properly
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
args = parse_cmdline_and_set_env()
|
||||||
|
|
||||||
|
for env_var in tgi_env_vars:
|
||||||
|
if not os.getenv(env_var):
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"All env vars %s already set, skipping, user know what they are doing",
|
||||||
|
tgi_env_vars,
|
||||||
|
)
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
neuron_config = get_neuron_config_for_model(args.model_id, args.revision)
|
||||||
|
|
||||||
|
if not neuron_config:
|
||||||
|
msg = (
|
||||||
|
"No compatible neuron config found. Provided env {}, available cores {}, neuronxcc version {}"
|
||||||
|
).format(get_env_dict(), available_cores, neuronxcc_version)
|
||||||
|
logger.error(msg)
|
||||||
|
raise Exception(msg)
|
||||||
|
|
||||||
|
neuron_config_to_env(neuron_config)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -28,15 +28,6 @@ logger = logging.getLogger(__file__)
|
|||||||
|
|
||||||
# All model configurations below will be added to the neuron_model_config fixture
|
# All model configurations below will be added to the neuron_model_config fixture
|
||||||
MODEL_CONFIGURATIONS = {
|
MODEL_CONFIGURATIONS = {
|
||||||
"gpt2": {
|
|
||||||
"model_id": "gpt2",
|
|
||||||
"export_kwargs": {
|
|
||||||
"batch_size": 4,
|
|
||||||
"sequence_length": 1024,
|
|
||||||
"num_cores": 2,
|
|
||||||
"auto_cast_type": "fp16",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"llama": {
|
"llama": {
|
||||||
"model_id": "unsloth/Llama-3.2-1B-Instruct",
|
"model_id": "unsloth/Llama-3.2-1B-Instruct",
|
||||||
"export_kwargs": {
|
"export_kwargs": {
|
||||||
@ -46,15 +37,6 @@ MODEL_CONFIGURATIONS = {
|
|||||||
"auto_cast_type": "fp16",
|
"auto_cast_type": "fp16",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"mistral": {
|
|
||||||
"model_id": "optimum/mistral-1.1b-testing",
|
|
||||||
"export_kwargs": {
|
|
||||||
"batch_size": 4,
|
|
||||||
"sequence_length": 4096,
|
|
||||||
"num_cores": 2,
|
|
||||||
"auto_cast_type": "bf16",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"qwen2": {
|
"qwen2": {
|
||||||
"model_id": "Qwen/Qwen2.5-0.5B",
|
"model_id": "Qwen/Qwen2.5-0.5B",
|
||||||
"export_kwargs": {
|
"export_kwargs": {
|
||||||
|
@ -20,9 +20,7 @@ async def test_model_single_request(tgi_service):
|
|||||||
)
|
)
|
||||||
assert response.details.generated_tokens == 17
|
assert response.details.generated_tokens == 17
|
||||||
greedy_expectations = {
|
greedy_expectations = {
|
||||||
"gpt2": "\n\nDeep learning is a new field of research that has been around for a while",
|
"llama": " and how does it work?\nDeep learning is a subset of machine learning that uses artificial",
|
||||||
"llama": " and How Does it Work?\nDeep learning is a subset of machine learning that uses artificial",
|
|
||||||
"mistral": "\nWhat is Deep Learning?\nDeep Learning is a type of machine learning that",
|
|
||||||
"qwen2": " - Part 1\n\nDeep Learning is a subset of Machine Learning that is based on",
|
"qwen2": " - Part 1\n\nDeep Learning is a subset of Machine Learning that is based on",
|
||||||
"granite": "\n\nDeep Learning is a subset of Machine Learning, which is a branch of Art",
|
"granite": "\n\nDeep Learning is a subset of Machine Learning, which is a branch of Art",
|
||||||
}
|
}
|
||||||
@ -79,9 +77,7 @@ async def test_model_multiple_requests(tgi_service, neuron_generate_load):
|
|||||||
|
|
||||||
assert len(responses) == 4
|
assert len(responses) == 4
|
||||||
expectations = {
|
expectations = {
|
||||||
"gpt2": "Deep learning is a new field of research that has been around for a while",
|
|
||||||
"llama": "Deep learning is a subset of machine learning that uses artificial",
|
"llama": "Deep learning is a subset of machine learning that uses artificial",
|
||||||
"mistral": "Deep Learning is a type of machine learning that",
|
|
||||||
"qwen2": "Deep Learning is a subset of Machine Learning that is based on",
|
"qwen2": "Deep Learning is a subset of Machine Learning that is based on",
|
||||||
"granite": "Deep Learning is a subset of Machine Learning, which is a branch of Art",
|
"granite": "Deep Learning is a subset of Machine Learning, which is a branch of Art",
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user