-
Notifications
You must be signed in to change notification settings - Fork 278
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update and fix gnn model factory and models #2177
Conversation
For eval we don't need a staged eval test, because inference test does not include a backward pass. Also, curious what is the error message when running with the FLOPCounterMode |
Yes agreed, I was just thinking of a multi batch evaluation just to have some results but indeed it doesn't make that much sense.
It was a runtime error: Full output for gat below, it's very similar for gnn:
|
The |
@xuzhao9 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
This PR deals with a variety of issues around gnn canary models. Currently the models:
are failing directly from the installation step as they are missing the required data file,
sub_reddit.pt
.The PR checks out the data file
Reddit_minimal.tar.gz
from S3 for all these models. It also updates the requirements and installation files, for example importingpyg_lib
, since running the models without it causesNeighborSampler
to throw a deprecation warning.Lastly, this PR focuses on the updating of the gnn model factory to be more in line with both
model.py
and _invoke_staged_train_test() as it is a multi batch model. This means that it needed aforward()
,backward()
,optimizer_step()
andget_input_iter()
function. This would also make it more in line with other model factories such as the vision one.These changes allow the models to be trained with
run.py
:NOTE: Gat and Gcn cannot collect model_flops metrics because there is a bug when running these models with the FlopCounterMode context manager (here).
NOTE 2: eval is not supported yet as there is no
_invoke_staged_eval_test
() function yet, but this would be a good idea to implement for completion.