Skip to content

Commit

Permalink
Ray Summit AI library tiny fixes (#368)
Browse files Browse the repository at this point in the history
  • Loading branch information
marwan116 authored Sep 28, 2024
1 parent 7e67d65 commit b82d650
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
9 changes: 5 additions & 4 deletions templates/ray-summit-ai-libraries/2_Intro_Train.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@
"def data_loader_torch(batch_size: int) -> DataLoader:\n",
" transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])\n",
" train_data = MNIST(root=\"./data\", train=True, download=True, transform=transform)\n",
" train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)\n",
" train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)\n",
" return train_loader"
]
},
Expand Down Expand Up @@ -351,7 +351,7 @@
"def data_loader_ray_train(batch_size: int) -> DataLoader:\n",
" transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])\n",
" train_data = MNIST(root=\"./data\", train=True, download=True, transform=transform)\n",
" train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)\n",
" train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)\n",
"\n",
" # Add DistributedSampler to the DataLoader\n",
" train_loader = ray.train.torch.prepare_data_loader(train_loader)\n",
Expand Down Expand Up @@ -638,8 +638,9 @@
"Use the following code snippets to guide you:\n",
"\n",
"```python\n",
"def print_metrics_ray_train(metrics):\n",
" # Hint: Update the print statement to include AUROC\n",
"# Hint: Update the print function to include AUROC\n",
"def print_metrics_ray_train(...):\n",
" ...\n",
"\n",
"def train_loop_ray_train(config):\n",
" # Hint: Update the training loop to compute AUROC\n",
Expand Down
Loading

0 comments on commit b82d650

Please sign in to comment.