Skip to content

Commit

Permalink
Address several NeuronModelForCausalLM and TGI fixes issues (#454)
Browse files Browse the repository at this point in the history
* feat(decoder): use only the required number of cores

This will allow to load multiple models on the same host.

* fix(tgi): give explicit calculation for max-prefill-tokens

* feat(decoder): support custom stopping criteria

* fix(decoder): allow updating checkpoint dir when saving

* fix(tgi): improve err if the model is sharded

* fix(tgi): do not return inputs in generation

* Apply suggestions from code review

Co-authored-by: Jingya HUANG <44135271+JingyaHuang@users.noreply.github.com>

* fix(decoder): only use available cores

* feat(decoder): ignore inconsistent num_cores variables

---------

Co-authored-by: Jingya HUANG <44135271+JingyaHuang@users.noreply.github.com>
  • Loading branch information
dacorvo and JingyaHuang authored Jan 31, 2024
1 parent c345de4 commit eb2a93f
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 17 deletions.
14 changes: 12 additions & 2 deletions optimum/neuron/generation/token_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,12 @@ def __init__(

@classmethod
def create(
cls, input_ids: torch.Tensor, generation_config: GenerationConfig, model: GenerationMixin, max_seq_length: int
cls,
input_ids: torch.Tensor,
generation_config: GenerationConfig,
model: GenerationMixin,
max_seq_length: int,
stopping_criteria: Optional[StoppingCriteriaList] = None,
) -> "TokenSelector":
r"""Creates the `TokenSelector` for a specific generation configuration.
Expand All @@ -66,6 +71,9 @@ def create(
The model provides the internal helpers allowing to select the logits processors and stopping criterias.
max_seq_length (`int`):
The maximum number of input + generated tokens for this model. It depends on the model compilation parameters.
stopping_criteria (`Optional[transformers.generation.StoppingCriteriaList], defaults to `None`):
Custom stopping criteria that complement the default stopping criteria built from arguments and a
generation config.
Return:
`torch.LongTensor`: A `torch.LongTensor` containing the selected tokens.
"""
Expand Down Expand Up @@ -110,7 +118,9 @@ def create(
prefix_allowed_tokens_fn=None,
logits_processor=LogitsProcessorList(),
)
stopping_criteria = model._get_stopping_criteria(generation_config, stopping_criteria=StoppingCriteriaList())
if stopping_criteria is None:
stopping_criteria = StoppingCriteriaList()
stopping_criteria = model._get_stopping_criteria(generation_config, stopping_criteria=stopping_criteria)

# The generation requires special tokens
eos_token_id = generation_config.eos_token_id
Expand Down
9 changes: 8 additions & 1 deletion optimum/neuron/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from tempfile import TemporaryDirectory

from transformers import GenerationConfig, PretrainedConfig
from transformers.generation import StoppingCriteriaList


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -722,6 +723,7 @@ def generate(
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
generation_config: Optional["GenerationConfig"] = None,
stopping_criteria: Optional["StoppingCriteriaList"] = None,
**kwargs,
) -> torch.LongTensor:
r"""
Expand All @@ -747,6 +749,9 @@ def generate(
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
configuration. Please note that unspecified parameters will inherit [`~transformers.generation.GenerationConfig`]'s
default values, whose documentation should be checked to parameterize generation.
stopping_criteria (`Optional[transformers.generation.StoppingCriteriaList], defaults to `None`):
Custom stopping criteria that complement the default stopping criteria built from arguments and a
generation config.
Returns:
`torch.Tensor`: A `torch.FloatTensor`.
Expand All @@ -758,7 +763,9 @@ def generate(
self._validate_model_kwargs(model_kwargs)

# Instantiate a TokenSelector for the specified configuration
selector = TokenSelector.create(input_ids, generation_config, self, self.max_length)
selector = TokenSelector.create(
input_ids, generation_config, self, self.max_length, stopping_criteria=stopping_criteria
)

# Verify that the inputs are compatible with the model static input dimensions
batch_size, sequence_length = input_ids.shape
Expand Down
47 changes: 43 additions & 4 deletions optimum/neuron/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,30 @@ def get_exporter(config, task):
return TasksManager.get_exporter_config_constructor(model_type=config.model_type, exporter="neuron", task=task)()


@requires_transformers_neuronx
def get_available_cores() -> int:
"""A helper to get the number of available cores.
This number depends first on the actual number of cores, then on the
content of the NEURON_RT_NUM_CORES and NEURON_RT_VISIBLE_CORES variables.
"""
max_cores = len(os.listdir("/sys/class/neuron_device/")) * 2
num_cores = os.environ.get("NEURON_RT_NUM_CORES", max_cores)
if num_cores != max_cores:
num_cores = int(num_cores)
num_cores = min(num_cores, max_cores)
visible_cores = os.environ.get("NEURON_RT_VISIBLE_CORES", num_cores)
if visible_cores != num_cores:
# Assume NEURON_RT_VISIBLE_CORES is in the form '4' or '7-15'
if "-" in visible_cores:
start, end = visible_cores.split("-")
visible_cores = int(end) - int(start) + 1
else:
visible_cores = 1
visible_cores = min(visible_cores, num_cores)
return visible_cores


class NeuronDecoderModel(OptimizedModel):
"""
Base class to convert and run pre-trained transformers decoder models on Neuron devices.
Expand Down Expand Up @@ -122,12 +146,27 @@ def __init__(
# Specify the path where compiled artifacts are stored before conversion
neuronx_model.load(compiled_dir)

# Compile the Neuron model (if present compiled artifacts will be reloaded instead of compiled)
# When compiling, only create a cache entry if the model comes from the hub
checkpoint_id = neuron_config.get("checkpoint_id", None)
# Only create a cache entry if the model comes from the hub
cache_entry = None if checkpoint_id is None else ModelCacheEntry(checkpoint_id, config)

# Export the model using the Optimum Neuron Cache
with hub_neuronx_cache(entry=cache_entry):
available_cores = get_available_cores()
if num_cores > available_cores:
raise ValueError(
f"The specified number of cores ({num_cores}) exceeds the number of cores available ({available_cores})."
)
neuron_rt_num_cores = os.environ.get("NEURON_RT_NUM_CORES", None)
# Restrict the number of cores used to allow multiple models on the same host
os.environ["NEURON_RT_NUM_CORES"] = str(num_cores)
# Load the model on neuron cores (if found in cache or compiled directory, the NEFF files
# will be reloaded instead of compiled)
neuronx_model.to_neuron()
if neuron_rt_num_cores is None:
os.environ.pop("NEURON_RT_NUM_CORES")
else:
os.environ["NEURON_RT_NUM_CORES"] = neuron_rt_num_cores

super().__init__(neuronx_model, config)

Expand Down Expand Up @@ -201,7 +240,7 @@ def get_export_config(
sequence_length = config.max_position_embeddings
if num_cores is None:
# Use all available cores
num_cores = len(os.listdir("/sys/class/neuron_device/")) * 2
num_cores = get_available_cores()
if auto_cast_type is None:
auto_cast_type = "fp32"
if config.torch_dtype == "float16":
Expand Down Expand Up @@ -342,7 +381,7 @@ def _save_pretrained(self, save_directory: Union[str, Path]):

def copy_dir_to_path(src_dir: Union[str, Path, TemporaryDirectory], dst_path: Union[str, Path]):
if isinstance(src_dir, TemporaryDirectory):
shutil.copytree(src_dir.name, dst_path)
shutil.copytree(src_dir.name, dst_path, dirs_exist_ok=True)
elif not os.path.samefile(src_dir, dst_path):
os.symlink(dst_path, src_dir)

Expand Down
20 changes: 20 additions & 0 deletions tests/generation/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pytest
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import StoppingCriteria

from optimum.neuron import NeuronModelForCausalLM, NeuronModelForSeq2SeqLM
from optimum.neuron.utils.testing_utils import is_inferentia_test, is_trainium_test, requires_neuronx
Expand Down Expand Up @@ -75,6 +76,25 @@ def test_model_generation_input_dimensions(neuron_decoder_path):
_test_model_generation(model, tokenizer, model.batch_size, input_length=model.max_length * 2)


@is_inferentia_test
@requires_neuronx
def test_decoder_generation_custom_stopping_criteria():
model_id = "hf-internal-testing/tiny-random-gpt2"
model = NeuronModelForCausalLM.from_pretrained(model_id, export=True, batch_size=1)

class CustomStoppingCriteria(StoppingCriteria):
def __init__(self):
self.called = False

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
self.called = True
return True

criteria = CustomStoppingCriteria()
model.generate(input_ids=torch.ones([1, 10], dtype=torch.int64), stopping_criteria=[criteria])
assert criteria.called, "Custom StoppingCriteria should have been called"


@is_inferentia_test
@requires_neuronx
def test_seq2seq_generation_beam(neuron_seq2seq_beam_path):
Expand Down
2 changes: 1 addition & 1 deletion text-generation-inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ This adds several restrictions to the following parameters:
- `--max-concurrent-requests` must be set to `batch size`,
- `--max-input-length` must be lower than `max_length`,
- `--max-total-tokens` must be set to `max_length` (it is per-request),
- `--max-batch-prefill-tokens` must be lower than `max_tokens`,
- `--max-batch-prefill-tokens` must be set to `batch_size * max_input_length`,
- `--max-batch-total-tokens` must be set to `max_tokens`.

### Choosing the correct batch size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ def serve(
Use JSON format for log serialization.
"""
if sharded:
raise ValueError("Sharding cannot be modified after the Neuron model has been compiled.")

raise ValueError("Sharding is not supported.")
# Remove default handler
logger.remove()
logger.add(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def clear(self):
self._generated_tokens = 0
self._next_text_token_start = 0
self._next_text_token_end = 0
self._generated_text = ""
self._next_text = ""

@property
Expand All @@ -126,8 +127,8 @@ def request_id(self) -> int:
return self._request_id

@property
def inputs(self) -> str:
return self._inputs
def cached_text(self) -> str:
return self._inputs + self._generated_text

@property
def generation_config(self) -> GenerationConfig:
Expand Down Expand Up @@ -238,8 +239,8 @@ def append(self, next_token: int) -> str:
self._mask = torch.cat([self._mask, torch.LongTensor([1])])
self._generated_tokens += 1
next_text = self._decode_next_tokens()
# Now that a new token has been generated, we can append the previous one to the inputs
self._inputs += self._next_text
# Now that a new token has been generated, we can append the previous one to the generated text
self._generated_text += self._next_text
self._next_text = next_text
return next_text

Expand All @@ -263,7 +264,7 @@ def stopped(self) -> bool:

@property
def generated_text(self) -> str:
return self._inputs + self._next_text
return self._generated_text + self._next_text

@property
def next_token(self) -> int:
Expand Down Expand Up @@ -314,6 +315,12 @@ def warmup(self, batch: Batch) -> int:
Return:
The maximum number of tokens the model supports.
"""
# Just check that the warmup request parameters match the model capacity
batch_size = self.model.batch_size
if len(batch.requests) > batch_size:
raise ValueError(
f"Inconsistent server configuration: please make sure max-prefill-tokens does not exceed {batch_size} x max-input-length."
)
self.prefill(batch)
return self.model.batch_size * self.model.max_length

Expand Down Expand Up @@ -343,8 +350,12 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
slot = empty_slots.pop()
slot.assign(request, self.model.generation_config)
logger.debug(f"Request {slot.request_id} assigned to slot {slot.id}")
# Reconstruct the full inputs (without padding)
inputs = [slot.inputs for slot in self.slots]
# Reconstruct the full inputs (without padding) as seen by the model.
# This comprises:
# - the inputs for new requests,
# - the inputs and the generated text that has already been cached (i.e. excluding the last generated token)
# for unfinished requests.
inputs = [slot.cached_text for slot in self.slots]
# Tokenize with padding
padded_inputs = self.tokenizer(inputs, return_tensors="pt", padding=True)
# If needed truncate sequences to fit into the static dimensions
Expand Down

0 comments on commit eb2a93f

Please sign in to comment.