Skip to content

Commit

Permalink
Merge pull request #13 from Rishit-dagli/unit-tests
Browse files Browse the repository at this point in the history
Add Unit Tests
  • Loading branch information
Rishit-dagli authored Jan 16, 2022
2 parents 26ee73b + e6a7ea0 commit 8c7504e
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 6 deletions.
1 change: 0 additions & 1 deletion .github/workflows/linter.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,3 @@ jobs:
VALIDATE_ALL_CODEBASE: true
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
VALIDATE_PYTHON_BLACK: true
VALIDATE_PYTHON_ISORT: true
7 changes: 5 additions & 2 deletions .github/workflows/python-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.x'
python-version: '3.9'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build
pip install build tensorflow==2.7.0 einops rotary-embedding-tensorflow
- name: Run Tests
run: |
python -m fast_transformer.test_fast_transformer
- name: Build package
run: python -m build
- name: Publish package
Expand Down
25 changes: 25 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
name: Run Tests

on:
push:
pull_request:
branches: [master]

jobs:
build:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.9'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install tensorflow==2.7.0 einops rotary-embedding-tensorflow
- name: Run Tests
run: |
python -m fast_transformer.test_fast_transformer
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ cd fast-transformer
pip install -e .[dev]
```

To run rank and shape tests run the following:

```py
python -m fast_transformer.test_fast_transformer
```

## Usage

```python
Expand Down
2 changes: 1 addition & 1 deletion fast_transformer/fast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self, dim, mult=4):

self.net = tf.keras.Sequential(
[
tf.keras.layers.Dense(dim * mult, input_dim=dim),
tf.keras.layers.Dense(dim * mult),
tf.keras.layers.Activation(tf.nn.gelu),
tf.keras.layers.Dense(dim, input_dim=dim * mult),
]
Expand Down
30 changes: 30 additions & 0 deletions fast_transformer/test_fast_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import numpy as np
import tensorflow as tf

from fast_transformer.fast_transformer import FastTransformer


class FastTransformerTest(tf.test.TestCase):
def setUp(self):
super(FastTransformerTest, self).setUp()

mask = tf.ones([1, 4096], dtype=tf.bool)
self.model = FastTransformer(
num_tokens=20000,
dim=512,
depth=2,
max_seq_len=4096,
absolute_pos_emb=True,
mask=mask,
)

def test_shape_and_rank(self):
inputs = tf.experimental.numpy.random.randint(0, 20000, (1, 4096))
outputs = self.model(inputs)

self.assertEqual(tf.rank(outputs), 3)
self.assertShapeEqual(np.zeros((1, 4096, 20000)), outputs)


if __name__ == "__main__":
tf.test.main()
2 changes: 1 addition & 1 deletion fast_transformer/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.0"
__version__ = "0.2.0"
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def get_version(rel_path: str) -> str:
author="Rishit Dagli",
author_email="rishit.dagli@gmail.com",
install_requires=[
"tensorflow ~= 2.5.0",
"tensorflow >= 2.5.0",
"einops ~= 0.3.0",
"rotary-embedding-tensorflow ~= 0.1.0",
],
Expand Down

0 comments on commit 8c7504e

Please sign in to comment.