Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add alternative model serializations for inference #100

Merged
merged 47 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
ca13028
Add some pre-generated outputs for prediction tests
Sep 29, 2023
8d9daa9
Refactor inference so we can add more model types
Sep 29, 2023
dfb20ef
Add onnx serialized model
Sep 29, 2023
5ceb645
Add TFLite serialized model
Sep 29, 2023
5d575a3
Add CoreML serialized model
Sep 29, 2023
dea3f6b
Add comments to tox.ini
Sep 29, 2023
60f973b
Add tests for getting audio
Sep 29, 2023
aaf023a
Fix MANIFEST.in
Sep 29, 2023
38153ec
Unpin coremltools
Sep 29, 2023
11d19f7
Unpin tflite-runtime
Sep 29, 2023
8281657
Fix mypy
Sep 29, 2023
4ee59c7
Ignore unused imports on __init__.py
Sep 29, 2023
700364e
Remove unused imports
Sep 29, 2023
6cc3267
more formatting changes
Sep 29, 2023
f5e862d
use tf by default for py>=3.11
Sep 30, 2023
2227005
Windowing is now iterator based
Sep 30, 2023
21a41e8
Reintroduce model_or_model_path
Oct 1, 2023
767e818
Add commandline flag to specify preferred model serialization
Oct 1, 2023
ce91056
Explain how the multiple model serializations are used in this repo
Oct 1, 2023
3ce2617
Move from model-type to model-serialization
Oct 1, 2023
2ee881c
Remove unused code
Oct 1, 2023
6196cb7
Add build-system requirements for scipy
Oct 1, 2023
35c2994
tf-macos does not support py37
Oct 1, 2023
30c1df8
Move from setup.cfg to pyproject.toml
Oct 1, 2023
2ab4527
Reorganize tox
Oct 1, 2023
4f8292c
More informative error
Oct 1, 2023
4831812
TFLite should be enabled if TF is installed
Oct 2, 2023
5801056
Move to macos-13 for mlprogram support
Oct 2, 2023
863be61
Dont convert numpy array to list
Oct 2, 2023
fde6787
debug
Oct 2, 2023
71c74ca
Bump setup-python
Oct 3, 2023
08673da
reconvert the coreml model
Oct 3, 2023
989baf5
Enable faulthandler
Oct 3, 2023
f87fbc2
pin sklearn
Oct 3, 2023
d75956d
Remove unittest
Oct 3, 2023
a5ab739
Include tensortype input for model
Oct 3, 2023
1609d51
Force rerun of workflows
Nov 17, 2023
b6b8b57
Try an M1 mac
Feb 27, 2024
8ef7ebc
More specific python versions
Feb 27, 2024
c5a5e21
conditionals
Feb 28, 2024
99551a1
custom macos matrix
Feb 29, 2024
3238689
Install llvm libopenmp on mac
Feb 29, 2024
76a759c
unpin scikit learn
Feb 29, 2024
c8cede9
Take care of documentation comments
Mar 12, 2024
ede09c9
Format
Mar 12, 2024
ef94d1a
Pin tf<2.16 and resolve in a future pr
Mar 12, 2024
638da87
unpin black in check-formatting
Mar 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ dependencies = [
"numpy>=1.18",
"onnxruntime; platform_system == 'Windows' and python_version < '3.11'",
"pretty_midi>=0.2.9",
"resampy>=0.2.2",
"resampy>=0.2.2,<0.4.3",
"scikit-learn",
"scipy>=1.4.1",
"tensorflow>=2.4.1; platform_system != 'Darwin' and python_version >= '3.11'",
"tensorflow>=2.4.1,<2.16; platform_system != 'Darwin' and python_version >= '3.11'",
"tensorflow-macos>=2.4.1; platform_system == 'Darwin' and python_version >= '3.11'",
"tflite-runtime; platform_system == 'Linux' and python_version < '3.11'",
"typing_extensions",
Expand Down Expand Up @@ -56,8 +56,8 @@ test = [
"pytest-mock",
]
tf = [
"tensorflow>=2.4.1; platform_system != 'Darwin'",
"tensorflow-macos>=2.4.1; platform_system == 'Darwin' and python_version > '3.7'",
"tensorflow>=2.4.1,<2.16; platform_system != 'Darwin'",
"tensorflow-macos>=2.4.1,<2.16; platform_system == 'Darwin' and python_version > '3.7'",
]
coreml = ["coremltools"]
onnx = ["onnxruntime"]
Expand Down
Binary file modified tests/resources/vocadito_10/model_output.npz
Binary file not shown.
Binary file modified tests/resources/vocadito_10/note_events.npz
Binary file not shown.
5 changes: 2 additions & 3 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,11 @@ def test_predict() -> None:
assert all(note_pitch_max)
assert isinstance(note_events, list)

expected_model_output = np.load("tests/resources/vocadito_10/model_output.npz")
expected_model_output = np.load("tests/resources/vocadito_10/model_output.npz", allow_pickle=True)["arr_0"].item()
bgenchel marked this conversation as resolved.
Show resolved Hide resolved
for k in expected_model_output.keys():
np.testing.assert_allclose(expected_model_output[k], model_output[k], atol=1e-4, rtol=0)

expected_note_events = np.load("tests/resources/vocadito_10/note_events.npz", allow_pickle=True)
expected_note_events = expected_note_events.get("arr_0")
expected_note_events = np.load("tests/resources/vocadito_10/note_events.npz", allow_pickle=True)["arr_0"]

assert len(expected_note_events) == len(note_events)
for expected, calculated in zip(expected_note_events, note_events):
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ requires =
[testenv]
deps = -e .[test]
commands =
pytest tests --ignore tests/test_nn.py -s {posargs}
pytest tests --ignore tests/test_nn.py {posargs}
bgenchel marked this conversation as resolved.
Show resolved Hide resolved
setenv =
SOURCE = {toxinidir}/basic_pitch
TEST_SOURCE = {toxinidir}/tests
Expand Down
Loading