Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torchbench is now a library #1933

Closed
wants to merge 12 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ old.json
te.json
logs/
scripts/scribe.py
.userbenchmark/
.userbenchmark/
torchbench.egg-info/
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,30 @@ cd benchmark
python install.py
```

### Install torchbench as a library

if you're interested in running torchbench as a library you can

```bash
python install.py
pip install git+https://www.github.com:pytorch/benchmark.git
```

or

```bash
python install.py
pip install . # add -e for an editable installation
```

The above

```python
import torchbenchmark.models.densenet121
model, example_inputs = torchbenchmark.models.densenet121.Model(test="eval", device="cuda", batch_size=1).get_module()
model(*example_inputs)
```

### Building From Source
Note that when building PyTorch from source, torchvision and torchaudio must also be built from source to make sure the C APIs match.

Expand Down
16 changes: 16 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from setuptools import setup, find_packages

setup(
name='torchbench',
version='0.1',
description='Benchmarking library for PyTorch',
author='PyTorch Team',
url='https://github.com/pytorch/benchmark',
packages=find_packages(include=['torchbenchmark*', 'userbenchmark*']),
classifiers=[
'Intended Audience :: Developers',
'Topic :: Software Development :: Build Tools',
'License :: OSI Approved :: BSD 3 License',
'Programming Language :: Python',
],
)
8 changes: 8 additions & 0 deletions test_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import torchbenchmark.models.densenet121
model, example_inputs = torchbenchmark.models.densenet121.Model(test="eval", device="cuda", batch_size=1).get_module()
model(*example_inputs)

import userbenchmark.optim
import torchbenchmark.canary_models
import torchbenchmark.models
import torchbenchmark.score
4 changes: 4 additions & 0 deletions torchbenchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

from components._impl.tasks import base as base_task
from components._impl.workers import subprocess_worker
from . import models
from . import canary_models
from . import e2e_models
from . import util

class ModelNotFoundError(RuntimeError):
pass
Expand Down
Empty file.
Empty file.
Empty file.
Empty file.