-
Notifications
You must be signed in to change notification settings - Fork 47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[GNNFlux] Translate Traffic prediction
Pluto notebook to Literate
#572
base: master
Are you sure you want to change the base?
Conversation
|
||
# ## Dataset: METR-LA | ||
|
||
# We use the `METR-LA` dataset from the paper [Diffusion Convolutional Recurrent Neural Network: Data-driven Traffic Forecasting](https://arxiv.org/pdf/1707.01926.pdf), which contains traffic data from loop detectors in the highway of Los Angeles County. The dataset contains traffic speed data from March 1, 2012 to June 30, 2012. The data is collected every 5 minutes, resulting in 12 observations per hour, from 207 sensors. Each sensor is a node in the graph, and the edges represent the distances between the sensors. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does it mean the edges represent the distances between the sensors
? should be clarified
train_loader = zip(features[1:200], targets[1:200]); | ||
test_loader = zip(features[2001:2288], targets[2001:2288]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
motivate this choice of ranges
for epoch in 1:100 | ||
for (x, y) in train_loader | ||
x, y = (x, y) | ||
grads = Flux.gradient(model) do model | ||
ŷ = model(graph, x) | ||
Flux.mae(ŷ, y) | ||
end | ||
Flux.update!(opt, model, grads[1]) | ||
end | ||
|
||
if epoch % 10 == 0 | ||
loss = mean([Flux.mae(model(graph,x), y) for (x, y) in train_loader]) | ||
@show epoch, loss | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
identation
model = GNNChain(TGCN(2 => 100; add_self_loops = false), Dense(100, 1)) | ||
|
||
# ![](https://www.researchgate.net/profile/Haifeng-Li-3/publication/335353434/figure/fig4/AS:851870352437249@1580113127759/The-architecture-of-the-Gated-Recurrent-Unit-model.jpg) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here it would be useful to show the output of the model and how it is interpreted as a prediction
#545