Skip to content

Commit

Permalink
Merge pull request #10 from eliorc/dev
Browse files Browse the repository at this point in the history
Major changes, better interfaces and corrections to masking mechanisms
  • Loading branch information
eliorc authored Dec 13, 2019
2 parents a499d13 + 4ebb9a6 commit faa3fc9
Show file tree
Hide file tree
Showing 11 changed files with 238 additions and 245 deletions.
4 changes: 0 additions & 4 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,6 @@ jobs:
source venv/bin/activate
pytest -s --cov=tavolo tests/
codecov
test-3.8:
<<: *test-template
docker:
- image: circleci/python:3.8
test-3.6:
<<: *test-template
docker:
Expand Down
9 changes: 4 additions & 5 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

------------

.. image:: https://img.shields.io/badge/python-3.5%20%7C%203.6%20%7C%203.7%20%7C%203.8-blue.svg
.. image:: https://img.shields.io/badge/python-3.5%20%7C%203.6%20%7C%203.7-blue.svg
:alt: Supported Python versions

.. image:: https://img.shields.io/badge/tensorflow-2.0-orange.svg
Expand Down Expand Up @@ -42,7 +42,7 @@ Showcase
| tavolo's API is straightforward and adopting its modules is as easy as it gets.
| In tavolo, you'll find implementations for basic layers like `PositionalEncoding`_ to complex modules like the Transformer's
`MultiHeadedAttention`_. You'll also find non-layer implementations that can ease development, like the `LearningRateFinder`_.
| For example, if we wanted to add head a multi-headed attention mechanism into our model and look for the optimal learning rate, it would look something like:
| For example, if we wanted to add head a Yang-style attention mechanism into our model and look for the optimal learning rate, it would look something like:
.. code-block:: python3
Expand All @@ -51,12 +51,11 @@ Showcase
model = tf.keras.Sequential([
tf.keras.layers.Embedding(input_dim=vocab_size, output_dim=embedding_size, input_length=max_len),
tvl.seq2seq.MultiHeadedAttention(n_heads=8), # <--- Add self attention
tf.keras.layers.LSTM(n_lstm_units, return_sequences=True),
tvl.seq2vec.YangAttention(n_units=64), # <--- Add Yang style attention
tf.keras.layers.Dense(n_hidden_units, activation='relu'),
tf.keras.layers.Dense(1, activation='sigmoid')])
model.compile(optimizer=tf.keras.optimizers.SGD(), loss=tf.keras.losses.CategoricalCrossentropy())
model.compile(optimizer=tf.keras.optimizers.SGD(), loss=tf.keras.losses.BinaryCrossentropy())
# Run learning rate range test
lr_finder = tvl.learning.LearningRateFinder(model=model)
Expand Down
7 changes: 3 additions & 4 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Showcase
| tavolo's API is straightforward and adopting its modules is as easy as it gets.
| In tavolo, you'll find implementations for basic layers like :ref:`positional_encoding` to complex modules like the Transformer's
:ref:`multi_headed_attention`. You'll also find non-layer implementations that can ease development, like the :ref:`learning_rate_finder`.
| For example, if we wanted to add head a multi-headed attention mechanism into our model and look for the optimal learning rate, it would look something like:
| For example, if we wanted to add head a Yang-style attention mechanism into our model and look for the optimal learning rate, it would look something like:
.. code-block:: python3
Expand All @@ -25,12 +25,11 @@ Showcase
model = tf.keras.Sequential([
tf.keras.layers.Embedding(input_dim=vocab_size, output_dim=embedding_size, input_length=max_len),
tvl.seq2seq.MultiHeadedAttention(n_heads=8), # <--- Add self attention
tf.keras.layers.LSTM(n_lstm_units, return_sequences=True),
tvl.seq2vec.YangAttention(n_units=64), # <--- Add Yang style attention
tf.keras.layers.Dense(n_hidden_units, activation='relu'),
tf.keras.layers.Dense(1, activation='sigmoid')])
model.compile(optimizer=tf.keras.optimizers.SGD(), loss=tf.keras.losses.CategoricalCrossentropy())
model.compile(optimizer=tf.keras.optimizers.SGD(), loss=tf.keras.losses.BinaryCrossentropy())
# Run learning rate range test
lr_finder = tvl.learning.LearningRateFinder(model=model)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from setuptools import setup

VERSION = '0.5.1'
VERSION = '0.6.0'

setup(name='tavolo',
version=VERSION,
Expand Down
2 changes: 1 addition & 1 deletion tavolo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
__name__ = 'tavolo'
__version__ = '0.5.1'
__version__ = '0.6.0'

from . import embeddings
from . import seq2vec
Expand Down
148 changes: 91 additions & 57 deletions tavolo/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import tensorflow as tf
from tensorflow.python.ops import math_ops


class PositionalEncoding(tf.keras.layers.Layer):
Expand Down Expand Up @@ -67,6 +68,7 @@ def __init__(self,
name: str = 'positional_encoding',
**kwargs):
"""
:param max_sequence_length: Maximum sequence length of input
:param embedding_dim: Dimensionality of the of the input's last dimension
:param normalize_factor: Normalize factor
Expand Down Expand Up @@ -104,12 +106,13 @@ def compute_mask(self, inputs, mask=None):
def call(self, inputs,
mask: Optional[tf.Tensor] = None,
**kwargs) -> tf.Tensor:
output = inputs + self.positional_encoding
output = inputs + self.positional_encoding # shape=(batch_size, time_steps, channels)

if mask is not None:
output = tf.where(tf.tile(tf.expand_dims(mask, axis=-1), multiples=[1, 1, inputs.shape[-1]]), output,
inputs)
inputs) # shape=(batch_size, time_steps, channels)

return output # shape=(batch_size, time_steps, channels)
return output

def get_config(self):
base_config = super().get_config()
Expand All @@ -132,8 +135,10 @@ class DynamicMetaEmbedding(tf.keras.layers.Layer):
Arguments
---------
- `embedding_matrices` (``List[tf.keras.layers.Embedding]``): List of embedding layers
- `embedding_matrices` (``List[np.ndarray]``): List of embedding matrices
- `output_dim` (``int``): Dimension of the output embedding
- `mask_zero` (``bool``): Whether or not the input value 0 is a special "padding" value that should be masked out
- `input_length` (``Optional[int]``): Parameter to be passed into internal ``tf.keras.layers.Embedding`` matrices
- `name` (``str``): Layer name
Expand All @@ -160,27 +165,22 @@ class DynamicMetaEmbedding(tf.keras.layers.Layer):
import tensorflow as tf
import tavolo as tvl
w2v_embedding = tf.keras.layers.Embedding(num_words,
EMBEDDING_DIM,
embeddings_initializer=tf.keras.initializers.Constant(w2v_matrix),
input_length=MAX_SEQUENCE_LENGTH,
trainable=False)
w2v_embedding = np.array(...) # Pre-trained embedding matrix
glove_embedding = tf.keras.layers.Embedding(num_words,
EMBEDDING_DIM,
embeddings_initializer=tf.keras.initializers.Constant(glove_matrix),
input_length=MAX_SEQUENCE_LENGTH,
trainable=False)
glove_embedding = np.array(...) # Pre-trained embedding matrix
model = tf.keras.Sequential([tf.keras.layers.Input(shape=(MAX_SEQUENCE_LENGTH,), dtype='int32'),
tvl.embeddings.DynamicMetaEmbedding([w2v_embedding, glove_embedding])]) # Use DME embeddings
tvl.embeddings.DynamicMetaEmbedding([w2v_embedding, glove_embedding],
input_length=MAX_SEQUENCE_LENGTH)]) # Use DME embeddings
Using the same example as above, it is possible to define the output's channel size
.. code-block:: python3
model = tf.keras.Sequential([tf.keras.layers.Input(shape=(MAX_SEQUENCE_LENGTH,), dtype='int32'),
tvl.embeddings.DynamicMetaEmbedding([w2v_embedding, glove_embedding], output_dim=200)])
tvl.embeddings.DynamicMetaEmbedding([w2v_embedding, glove_embedding],
input_length=MAX_SEQUENCE_LENGTH,
output_dim=200)])
References
Expand All @@ -193,25 +193,41 @@ class DynamicMetaEmbedding(tf.keras.layers.Layer):
"""

def __init__(self,
embedding_matrices: List[tf.keras.layers.Embedding],
embedding_matrices: List[np.ndarray],
output_dim: Optional[int] = None,
mask_zero: bool = False,
input_length: Optional[int] = None,
name: str = 'dynamic_meta_embedding',
**kwargs):
"""
:param embedding_matrices: List of embedding layers
:param embedding_matrices: List of embedding matrices
:param output_dim: Dimension of the output embedding
:param mask_zero: Whether or not the input value 0 is a special "padding" value that should be masked out
:param input_length: Parameter to be passed into internal ``tf.keras.layers.Embedding`` matrices
:param name: Layer name
"""
super().__init__(name=name, **kwargs)

self.mask_zero = mask_zero
self.input_length = input_length
self.base_matrices_shapes = [e.shape for e in embedding_matrices]

# Validate all the embedding matrices have the same vocabulary size
if not len(set((e.input_dim for e in embedding_matrices))) == 1:
if not len(set((e.shape[0] for e in embedding_matrices))) == 1:
raise ValueError('Vocabulary sizes (first dimension) of all embedding matrices must match')
if not set((e.ndim for e in embedding_matrices)) == {2}:
raise ValueError('All embedding matrices should have only 2 dimensions')

# If no output_dim is supplied, use the maximum dimension from the given matrices
self.output_dim = output_dim or min([e.output_dim for e in embedding_matrices])

self.embedding_matrices = embedding_matrices
self.output_dim = output_dim or min([e.shape[1] for e in embedding_matrices])

self.embedding_matrices = [tf.keras.layers.Embedding(input_dim=e.shape[0],
output_dim=e.shape[1],
embeddings_initializer=tf.keras.initializers.Constant(e),
input_length=input_length,
name='embedding_matrix_{}'.format(i))
for i, e in enumerate(embedding_matrices)]
self.n_embeddings = len(self.embedding_matrices)

self.projections = [tf.keras.layers.Dense(units=self.output_dim,
Expand All @@ -225,11 +241,14 @@ def __init__(self,
dtype=self.dtype)

def compute_mask(self, inputs, mask=None):
return self.projections[0].compute_mask(
inputs, mask=self.embedding_matrices[0].compute_mask(inputs, mask=mask))
if not self.mask_zero:
return None

return math_ops.not_equal(inputs, 0)

def call(self, inputs,
**kwargs) -> tf.Tensor:

batch_size, time_steps = inputs.shape[:2]

# Embedding lookup
Expand All @@ -254,16 +273,17 @@ def call(self, inputs,

def get_config(self):
base_config = super().get_config()
base_config['embedding_matrices'] = [e.get_config() for e in self.embedding_matrices]
base_config['base_matrices_shapes'] = self.base_matrices_shapes
base_config['output_dim'] = self.output_dim
base_config['mask_zero'] = self.mask_zero
base_config['input_length'] = self.input_length

return base_config

@classmethod
def from_config(cls, config: dict):
embedding_matrices = [tf.keras.layers.Embedding.from_config(e_conf) for e_conf in
config.pop('embedding_matrices')]
return cls(embedding_matrices=embedding_matrices, **config)
initial_matrices = [np.zeros(shape=s) for s in config.pop('base_matrices_shapes')]
return cls(embedding_matrices=initial_matrices, **config)


class ContextualDynamicMetaEmbedding(tf.keras.layers.Layer):
Expand All @@ -277,8 +297,10 @@ class ContextualDynamicMetaEmbedding(tf.keras.layers.Layer):
Arguments
---------
- `embedding_matrices` (``List[tf.keras.layers.Embedding]``): List of embedding layers
- `embedding_matrices` (``List[np.ndarray]``): List of embedding matrices
- `output_dim` (``int``): Dimension of the output embedding
- `mask_zero` (``bool``): Whether or not the input value 0 is a special "padding" value that should be masked out
- `input_length` (``Optional[int]``): Parameter to be passed into internal ``tf.keras.layers.Embedding`` matrices
- `n_lstm_units` (``int``): Number of units in each LSTM, (notated as `m` in the original article)
- `name` (``str``): Layer name
Expand Down Expand Up @@ -306,27 +328,22 @@ class ContextualDynamicMetaEmbedding(tf.keras.layers.Layer):
import tensorflow as tf
import tavolo as tvl
w2v_embedding = tf.keras.layers.Embedding(num_words,
EMBEDDING_DIM,
embeddings_initializer=tf.keras.initializers.Constant(w2v_matrix),
input_length=MAX_SEQUENCE_LENGTH,
trainable=False)
w2v_embedding = np.array(...) # Pre-trained embedding matrix
glove_embedding = tf.keras.layers.Embedding(num_words,
EMBEDDING_DIM,
embeddings_initializer=tf.keras.initializers.Constant(glove_matrix),
input_length=MAX_SEQUENCE_LENGTH,
trainable=False)
glove_embedding = np.array(...) # Pre-trained embedding matrix
model = tf.keras.Sequential([tf.keras.layers.Input(shape=(MAX_SEQUENCE_LENGTH,), dtype='int32'),
tvl.embeddings.DynamicMetaEmbedding([w2v_embedding, glove_embedding])]) # Use CDME embeddings
tvl.embeddings.DynamicMetaEmbedding([w2v_embedding, glove_embedding],
input_length=MAX_SEQUENCE_LENGTH)]) # Use CDME embeddings
Using the same example as above, it is possible to define the output's channel size and number of units in each LSTM
.. code-block:: python3
model = tf.keras.Sequential([tf.keras.layers.Input(shape=(MAX_SEQUENCE_LENGTH,), dtype='int32'),
tvl.embeddings.DynamicMetaEmbedding([w2v_embedding, glove_embedding], n_lstm_units=128, output_dim=200)])
tvl.embeddings.DynamicMetaEmbedding([w2v_embedding, glove_embedding],
input_length=MAX_SEQUENCE_LENGTH,
n_lstm_units=128, output_dim=200)])
References
----------
Expand All @@ -335,33 +352,47 @@ class ContextualDynamicMetaEmbedding(tf.keras.layers.Layer):
.. _`Dynamic Meta-Embeddings for Improved Sentence Representations`:
https://arxiv.org/abs/1804.07983
add """
"""

def __init__(self,
embedding_matrices: List[tf.keras.layers.Embedding],
embedding_matrices: List[np.ndarray],
output_dim: Optional[int] = None,
mask_zero: bool = False,
input_length: Optional[int] = None,
n_lstm_units: int = 2,
name: str = 'contextual_dynamic_meta_embedding',
**kwargs):
"""
:param embedding_matrices: List of embedding layers
:param n_lstm_units: Number of units in each LSTM, (notated as `m` in the original article)
:param embedding_matrices: List of embedding matrices
:param output_dim: Dimension of the output embedding
:param mask_zero: Whether or not the input value 0 is a special "padding" value that should be masked out
:param input_length: Parameter to be passed into internal ``tf.keras.layers.Embedding`` matrices
:param n_lstm_units: Number of units in each LSTM, (notated as `m` in the original article)
:param name: Layer name
"""

super().__init__(name=name, **kwargs)

self.mask_zero = mask_zero
self.input_length = input_length
self.n_lstm_units = n_lstm_units
self.base_matrices_shapes = [e.shape for e in embedding_matrices]

# Validate all the embedding matrices have the same vocabulary size
if not len(set((e.input_dim for e in embedding_matrices))) == 1:
if not len(set((e.shape[0] for e in embedding_matrices))) == 1:
raise ValueError('Vocabulary sizes (first dimension) of all embedding matrices must match')
if not set((e.ndim for e in embedding_matrices)) == {2}:
raise ValueError('All embedding matrices should have only 2 dimensions')

# If no output_dim is supplied, use the maximum dimension from the given matrices
self.output_dim = output_dim or min([e.output_dim for e in embedding_matrices])

self.n_lstm_units = n_lstm_units

self.embedding_matrices = embedding_matrices
self.output_dim = output_dim or min([e.shape[1] for e in embedding_matrices])

self.embedding_matrices = [tf.keras.layers.Embedding(input_dim=e.shape[0],
output_dim=e.shape[1],
embeddings_initializer=tf.keras.initializers.Constant(e),
input_length=input_length,
name='embedding_matrix_{}'.format(i))
for i, e in enumerate(embedding_matrices)]
self.n_embeddings = len(self.embedding_matrices)

self.projections = [tf.keras.layers.Dense(units=self.output_dim,
Expand All @@ -380,8 +411,10 @@ def __init__(self,
dtype=self.dtype)

def compute_mask(self, inputs, mask=None):
return self.projections[0].compute_mask(
inputs, mask=self.embedding_matrices[0].compute_mask(inputs, mask=mask))
if not self.mask_zero:
return None

return math_ops.not_equal(inputs, 0)

def call(self, inputs,
**kwargs) -> tf.Tensor:
Expand Down Expand Up @@ -416,14 +449,15 @@ def call(self, inputs,

def get_config(self):
base_config = super().get_config()
base_config['embedding_matrices'] = [e.get_config() for e in self.embedding_matrices]
base_config['base_matrices_shapes'] = self.base_matrices_shapes
base_config['output_dim'] = self.output_dim
base_config['mask_zero'] = self.mask_zero
base_config['input_length'] = self.input_length
base_config['n_lstm_units'] = self.n_lstm_units

return base_config

@classmethod
def from_config(cls, config: dict):
embedding_matrices = [tf.keras.layers.Embedding.from_config(e_conf) for e_conf in
config.pop('embedding_matrices')]
return cls(embedding_matrices=embedding_matrices, **config)
initial_matrices = [np.zeros(shape=s) for s in config.pop('base_matrices_shapes')]
return cls(embedding_matrices=initial_matrices, **config)
Loading

0 comments on commit faa3fc9

Please sign in to comment.