Skip to content

Commit

Permalink
Add hyperbolic ops (#634)
Browse files Browse the repository at this point in the history
* add hyperbolic numpy functions

* update backend modules + tests + black

* docstrings

* fix

* remove ops.nn.tanh
  • Loading branch information
FayazRahman authored Jul 30, 2023
1 parent c153bac commit f1652d1
Show file tree
Hide file tree
Showing 9 changed files with 353 additions and 26 deletions.
20 changes: 14 additions & 6 deletions guides/distributed_training_with_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ def get_model():
# Make a simple convnet with batch normalization and dropout.
inputs = keras.Input(shape=(28, 28, 1))
x = keras.layers.Rescaling(1.0 / 255.0)(inputs)
x = keras.layers.Conv2D(filters=12, kernel_size=3, padding="same", use_bias=False)(
x
)
x = keras.layers.Conv2D(
filters=12, kernel_size=3, padding="same", use_bias=False
)(x)
x = keras.layers.BatchNormalization(scale=False, center=True)(x)
x = keras.layers.ReLU()(x)
x = keras.layers.Conv2D(
Expand Down Expand Up @@ -187,7 +187,11 @@ def compute_loss(trainable_variables, non_trainable_variables, x, y):
# Training step, Keras provides a pure functional optimizer.stateless_apply
@jax.jit
def train_step(train_state, x, y):
trainable_variables, non_trainable_variables, optimizer_variables = train_state
(
trainable_variables,
non_trainable_variables,
optimizer_variables,
) = train_state
(loss_value, non_trainable_variables), grads = compute_gradients(
trainable_variables, non_trainable_variables, x, y
)
Expand All @@ -211,7 +215,9 @@ def get_replicated_train_state(devices):
var_replication = NamedSharding(var_mesh, P())

# Apply the distribution settings to the model variables
trainable_variables = jax.device_put(model.trainable_variables, var_replication)
trainable_variables = jax.device_put(
model.trainable_variables, var_replication
)
non_trainable_variables = jax.device_put(
model.non_trainable_variables, var_replication
)
Expand Down Expand Up @@ -255,7 +261,9 @@ def get_replicated_train_state(devices):
trainable_variables, non_trainable_variables, optimizer_variables = train_state
for variable, value in zip(model.trainable_variables, trainable_variables):
variable.assign(value)
for variable, value in zip(model.non_trainable_variables, non_trainable_variables):
for variable, value in zip(
model.non_trainable_variables, non_trainable_variables
):
variable.assign(value)

"""
Expand Down
10 changes: 6 additions & 4 deletions guides/distributed_training_with_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def get_model():
# Make a simple convnet with batch normalization and dropout.
inputs = keras.Input(shape=(28, 28, 1))
x = keras.layers.Rescaling(1.0 / 255.0)(inputs)
x = keras.layers.Conv2D(filters=12, kernel_size=3, padding="same", use_bias=False)(
x
)
x = keras.layers.Conv2D(
filters=12, kernel_size=3, padding="same", use_bias=False
)(x)
x = keras.layers.BatchNormalization(scale=False, center=True)(x)
x = keras.layers.ReLU()(x)
x = keras.layers.Conv2D(
Expand Down Expand Up @@ -231,7 +231,9 @@ def per_device_launch_fn(current_gpu_index, num_gpu):
model = get_model()

# prepare the dataloader
dataloader = prepare_dataloader(dataset, current_gpu_index, num_gpu, batch_size)
dataloader = prepare_dataloader(
dataset, current_gpu_index, num_gpu, batch_size
)

# Instantiate the torch optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
Expand Down
24 changes: 24 additions & 0 deletions keras_core/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,18 @@ def arccos(x):
return jnp.arccos(x)


def arccosh(x):
return jnp.arccosh(x)


def arcsin(x):
return jnp.arcsin(x)


def arcsinh(x):
return jnp.arcsinh(x)


def arctan(x):
return jnp.arctan(x)

Expand All @@ -119,6 +127,10 @@ def arctan2(x1, x2):
return jnp.arctan2(x1, x2)


def arctanh(x):
return jnp.arctanh(x)


def argmax(x, axis=None):
return jnp.argmax(x, axis=axis)

Expand Down Expand Up @@ -171,6 +183,10 @@ def cos(x):
return jnp.cos(x)


def cosh(x):
return jnp.cosh(x)


def count_nonzero(x, axis=None):
return jnp.count_nonzero(x, axis=axis)

Expand Down Expand Up @@ -441,6 +457,10 @@ def sin(x):
return jnp.sin(x)


def sinh(x):
return jnp.sinh(x)


def size(x):
return jnp.size(x)

Expand Down Expand Up @@ -479,6 +499,10 @@ def tan(x):
return jnp.tan(x)


def tanh(x):
return jnp.tanh(x)


def tensordot(x1, x2, axes=2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
Expand Down
24 changes: 24 additions & 0 deletions keras_core/backend/numpy/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,18 @@ def arccos(x):
return np.arccos(x)


def arccosh(x):
return np.arccosh(x)


def arcsin(x):
return np.arcsin(x)


def arcsinh(x):
return np.arcsinh(x)


def arctan(x):
return np.arctan(x)

Expand All @@ -96,6 +104,10 @@ def arctan2(x1, x2):
return np.arctan2(x1, x2)


def arctanh(x):
return np.arctanh(x)


def argmax(x, axis=None):
axis = tuple(axis) if isinstance(axis, list) else axis
return np.argmax(x, axis=axis)
Expand Down Expand Up @@ -157,6 +169,10 @@ def cos(x):
return np.cos(x)


def cosh(x):
return np.cosh(x)


def count_nonzero(x, axis=None):
axis = tuple(axis) if isinstance(axis, list) else axis
return np.count_nonzero(x, axis=axis)
Expand Down Expand Up @@ -438,6 +454,10 @@ def sin(x):
return np.sin(x)


def sinh(x):
return np.sinh(x)


def size(x):
return np.size(x)

Expand Down Expand Up @@ -480,6 +500,10 @@ def tan(x):
return np.tan(x)


def tanh(x):
return np.tanh(x)


def tensordot(x1, x2, axes=2):
axes = tuple(axes) if isinstance(axes, list) else axes
return np.tensordot(x1, x2, axes=axes)
Expand Down
24 changes: 24 additions & 0 deletions keras_core/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,18 @@ def arccos(x):
return tfnp.arccos(x)


def arccosh(x):
return tfnp.arccosh(x)


def arcsin(x):
return tfnp.arcsin(x)


def arcsinh(x):
return tfnp.arcsinh(x)


def arctan(x):
return tfnp.arctan(x)

Expand All @@ -117,6 +125,10 @@ def arctan2(x1, x2):
return tfnp.arctan2(x1, x2)


def arctanh(x):
return tfnp.arctanh(x)


def argmax(x, axis=None):
return tfnp.argmax(x, axis=axis)

Expand Down Expand Up @@ -174,6 +186,10 @@ def cos(x):
return tfnp.cos(x)


def cosh(x):
return tfnp.cosh(x)


def count_nonzero(x, axis=None):
return tfnp.count_nonzero(x, axis=axis)

Expand Down Expand Up @@ -472,6 +488,10 @@ def sin(x):
return tfnp.sin(x)


def sinh(x):
return tfnp.sinh(x)


def size(x):
return tfnp.size(x)

Expand Down Expand Up @@ -508,6 +528,10 @@ def tan(x):
return tfnp.tan(x)


def tanh(x):
return tfnp.tanh(x)


def tensordot(x1, x2, axes=2):
return tfnp.tensordot(x1, x2, axes=axes)

Expand Down
30 changes: 30 additions & 0 deletions keras_core/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,21 @@ def arccos(x):
return torch.arccos(x)


def arccosh(x):
x = convert_to_tensor(x)
return torch.arccosh(x)


def arcsin(x):
x = convert_to_tensor(x)
return torch.arcsin(x)


def arcsinh(x):
x = convert_to_tensor(x)
return torch.arcsinh(x)


def arctan(x):
x = convert_to_tensor(x)
return torch.arctan(x)
Expand All @@ -188,6 +198,11 @@ def arctan2(x1, x2):
return torch.arctan2(x1, x2)


def arctanh(x):
x = convert_to_tensor(x)
return torch.arctanh(x)


def argmax(x, axis=None):
x = convert_to_tensor(x)
return torch.argmax(x, dim=axis)
Expand Down Expand Up @@ -277,6 +292,11 @@ def cos(x):
return torch.cos(x)


def cosh(x):
x = convert_to_tensor(x)
return torch.cosh(x)


def count_nonzero(x, axis=None):
x = convert_to_tensor(x)
if axis == () or axis == []:
Expand Down Expand Up @@ -729,6 +749,11 @@ def sin(x):
return torch.sin(x)


def sinh(x):
x = convert_to_tensor(x)
return torch.sinh(x)


def size(x):
x_shape = convert_to_tensor(tuple(x.shape))
return torch.prod(x_shape)
Expand Down Expand Up @@ -806,6 +831,11 @@ def tan(x):
return torch.tan(x)


def tanh(x):
x = convert_to_tensor(x)
return torch.tanh(x)


def tensordot(x1, x2, axes=2):
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
# Conversion to long necessary for `torch.tensordot`
Expand Down
15 changes: 0 additions & 15 deletions keras_core/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,21 +57,6 @@ def sigmoid(x):
return backend.nn.sigmoid(x)


class Tanh(Operation):
def call(self, x):
return backend.nn.tanh(x)

def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)


@keras_core_export(["keras_core.ops.tanh", "keras_core.ops.nn.tanh"])
def tanh(x):
if any_symbolic_tensors((x,)):
return Tanh().symbolic_call(x)
return backend.nn.tanh(x)


class Softplus(Operation):
def call(self, x):
return backend.nn.softplus(x)
Expand Down
Loading

0 comments on commit f1652d1

Please sign in to comment.