Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/keras-team/keras-core into …
Browse files Browse the repository at this point in the history
…ops/hyperbolic
  • Loading branch information
FayazRahman committed Jul 30, 2023
2 parents 38ae1b2 + c153bac commit 003f842
Show file tree
Hide file tree
Showing 11 changed files with 2,406 additions and 43 deletions.
4 changes: 2 additions & 2 deletions keras_core/applications/xception.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def Xception(
The default input image size for this model is 299x299.
Note: each Keras Application expects a specific kind of input preprocessing.
For Xception, call `tf.keras.applications.xception.preprocess_input` on your
inputs before passing them to the model.
For Xception, call `keras_core.applications.xception.preprocess_input`
on your inputs before passing them to the model.
`xception.preprocess_input` will scale input pixels between -1 and 1.
Args:
Expand Down
312 changes: 303 additions & 9 deletions keras_core/utils/audio_dataset_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import numpy as np

from keras_core.api_export import keras_core_export
from keras_core.utils import dataset_utils
from keras_core.utils.module_utils import tensorflow as tf
from keras_core.utils.module_utils import tensorflow_io as tfio

ALLOWED_FORMATS = (".wav",)


@keras_core_export("keras_core.utils.audio_dataset_from_directory")
Expand Down Expand Up @@ -106,19 +112,307 @@ def audio_dataset_from_directory(
of shape `(batch_size, num_classes)`, representing a one-hot
encoding of the class index.
"""
# TODO: long-term, port implementation.
return tf.keras.utils.audio_dataset_from_directory(
if labels not in ("inferred", None):
if not isinstance(labels, (list, tuple)):
raise ValueError(
"The `labels` argument should be a list/tuple of integer "
"labels, of the same size as the number of audio files in "
"the target directory. If you wish to infer the labels from "
"the subdirectory names in the target directory,"
' pass `labels="inferred"`. '
"If you wish to get a dataset that only contains audio samples "
f"(no labels), pass `labels=None`. Received: labels={labels}"
)
if class_names:
raise ValueError(
"You can only pass `class_names` if "
f'`labels="inferred"`. Received: labels={labels}, and '
f"class_names={class_names}"
)
if label_mode not in {"int", "categorical", "binary", None}:
raise ValueError(
'`label_mode` argument must be one of "int", "categorical", '
'"binary", '
f"or None. Received: label_mode={label_mode}"
)

if ragged and output_sequence_length is not None:
raise ValueError(
"Cannot set both `ragged` and `output_sequence_length`"
)

if sampling_rate is not None:
if not isinstance(sampling_rate, int):
raise ValueError(
"`sampling_rate` should have an integer value. "
f"Received: sampling_rate={sampling_rate}"
)

if sampling_rate <= 0:
raise ValueError(
"`sampling_rate` should be higher than 0. "
f"Received: sampling_rate={sampling_rate}"
)

if not tfio.available:
raise ImportError(
"To use the argument `sampling_rate`, you should install "
"tensorflow_io. You can install it via `pip install "
"tensorflow-io`."
)

if labels is None or label_mode is None:
labels = None
label_mode = None

dataset_utils.check_validation_split_arg(
validation_split, subset, shuffle, seed
)

if seed is None:
seed = np.random.randint(1e6)

file_paths, labels, class_names = dataset_utils.index_directory(
directory,
labels,
formats=ALLOWED_FORMATS,
class_names=class_names,
shuffle=shuffle,
seed=seed,
follow_links=follow_links,
)

if label_mode == "binary" and len(class_names) != 2:
raise ValueError(
'When passing `label_mode="binary"`, there must be exactly 2 '
f"class_names. Received: class_names={class_names}"
)

if subset == "both":
train_dataset, val_dataset = get_training_and_validation_dataset(
file_paths=file_paths,
labels=labels,
validation_split=validation_split,
directory=directory,
label_mode=label_mode,
class_names=class_names,
sampling_rate=sampling_rate,
output_sequence_length=output_sequence_length,
ragged=ragged,
)

train_dataset = prepare_dataset(
dataset=train_dataset,
batch_size=batch_size,
shuffle=shuffle,
seed=seed,
class_names=class_names,
output_sequence_length=output_sequence_length,
ragged=ragged,
)
val_dataset = prepare_dataset(
dataset=val_dataset,
batch_size=batch_size,
shuffle=False,
seed=seed,
class_names=class_names,
output_sequence_length=output_sequence_length,
ragged=ragged,
)
return train_dataset, val_dataset

else:
dataset = get_dataset(
file_paths=file_paths,
labels=labels,
directory=directory,
validation_split=validation_split,
subset=subset,
label_mode=label_mode,
class_names=class_names,
sampling_rate=sampling_rate,
output_sequence_length=output_sequence_length,
ragged=ragged,
)

dataset = prepare_dataset(
dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
seed=seed,
class_names=class_names,
output_sequence_length=output_sequence_length,
ragged=ragged,
)
return dataset


def prepare_dataset(
dataset,
batch_size,
shuffle,
seed,
class_names,
output_sequence_length,
ragged,
):
dataset = dataset.prefetch(tf.data.AUTOTUNE)
if batch_size is not None:
if shuffle:
dataset = dataset.shuffle(buffer_size=batch_size * 8, seed=seed)

if output_sequence_length is None and not ragged:
dataset = dataset.padded_batch(
batch_size, padded_shapes=([None, None], [])
)
else:
dataset = dataset.batch(batch_size)
else:
if shuffle:
dataset = dataset.shuffle(buffer_size=1024, seed=seed)

# Users may need to reference `class_names`.
dataset.class_names = class_names
return dataset


def get_training_and_validation_dataset(
file_paths,
labels,
validation_split,
directory,
label_mode,
class_names,
sampling_rate,
output_sequence_length,
ragged,
):
(
file_paths_train,
labels_train,
) = dataset_utils.get_training_or_validation_split(
file_paths, labels, validation_split, "training"
)
if not file_paths_train:
raise ValueError(
f"No training audio files found in directory {directory}. "
f"Allowed format(s): {ALLOWED_FORMATS}"
)

file_paths_val, labels_val = dataset_utils.get_training_or_validation_split(
file_paths, labels, validation_split, "validation"
)
if not file_paths_val:
raise ValueError(
f"No validation audio files found in directory {directory}. "
f"Allowed format(s): {ALLOWED_FORMATS}"
)

train_dataset = paths_and_labels_to_dataset(
file_paths=file_paths_train,
labels=labels_train,
label_mode=label_mode,
num_classes=len(class_names),
sampling_rate=sampling_rate,
output_sequence_length=output_sequence_length,
ragged=ragged,
)

val_dataset = paths_and_labels_to_dataset(
file_paths=file_paths_val,
labels=labels_val,
label_mode=label_mode,
num_classes=len(class_names),
sampling_rate=sampling_rate,
output_sequence_length=output_sequence_length,
ragged=ragged,
)

return train_dataset, val_dataset


def get_dataset(
file_paths,
labels,
directory,
validation_split,
subset,
label_mode,
class_names,
sampling_rate,
output_sequence_length,
ragged,
):
file_paths, labels = dataset_utils.get_training_or_validation_split(
file_paths, labels, validation_split, subset
)
if not file_paths:
raise ValueError(
f"No audio files found in directory {directory}. "
f"Allowed format(s): {ALLOWED_FORMATS}"
)

dataset = paths_and_labels_to_dataset(
file_paths=file_paths,
labels=labels,
label_mode=label_mode,
class_names=class_names,
batch_size=batch_size,
num_classes=len(class_names),
sampling_rate=sampling_rate,
output_sequence_length=output_sequence_length,
ragged=ragged,
shuffle=shuffle,
seed=seed,
validation_split=validation_split,
subset=subset,
follow_links=follow_links,
)

return dataset


def read_and_decode_audio(
path, sampling_rate=None, output_sequence_length=None
):
"""Reads and decodes audio file."""
audio = tf.io.read_file(path)

if output_sequence_length is None:
output_sequence_length = -1

audio, default_audio_rate = tf.audio.decode_wav(
contents=audio, desired_samples=output_sequence_length
)
if sampling_rate is not None:
# default_audio_rate should have dtype=int64
default_audio_rate = tf.cast(default_audio_rate, tf.int64)
audio = tfio.audio.resample(
input=audio, rate_in=default_audio_rate, rate_out=sampling_rate
)
return audio


def paths_and_labels_to_dataset(
file_paths,
labels,
label_mode,
num_classes,
sampling_rate,
output_sequence_length,
ragged,
):
"""Constructs a fixed-size dataset of audio and labels."""
path_ds = tf.data.Dataset.from_tensor_slices(file_paths)
audio_ds = path_ds.map(
lambda x: read_and_decode_audio(
x, sampling_rate, output_sequence_length
),
num_parallel_calls=tf.data.AUTOTUNE,
)

if ragged:
audio_ds = audio_ds.map(
lambda x: tf.RaggedTensor.from_tensor(x),
num_parallel_calls=tf.data.AUTOTUNE,
)

if label_mode:
label_ds = dataset_utils.labels_to_dataset(
labels, label_mode, num_classes
)
audio_ds = tf.data.Dataset.zip((audio_ds, label_ds))
return audio_ds
Loading

0 comments on commit 003f842

Please sign in to comment.