Skip to content

Commit

Permalink
docs: update all docs
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhaoyanLU23 committed Jan 12, 2024
1 parent be95424 commit 2108a50
Show file tree
Hide file tree
Showing 10 changed files with 428 additions and 33 deletions.
18 changes: 0 additions & 18 deletions A/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,3 @@
Binary classification task (using PneumoniaMNIST dataset). The objective
is to classify an image onto "Normal" (no pneumonia) or "Pneumonia"
(presence of pneumonia)

## TODOs:

* [x] Design model for task A
* [x] Choose a model: [XGBoost](https://github.com/dmlc/xgboost)
* [x] Understand [boosted trees](https://xgboost.readthedocs.io/en/stable/tutorials/model.html)
* [x] Try a demo of XGBoost
* [x] Report training, validation, and testing errors / accuracies, along with describe any hyper-parameter tunice process.
* [ ] [Cross Validation Reference](https://scikit-learn.org/dev/modules/cross_validation.html)
* [x] Add tools to [estimate models](https://xgboost.readthedocs.io/en/stable/python/sklearn_estimator.html)
* [x] [Get the best model and all metrics](https://xgboost.readthedocs.io/en/stable/python/examples/sklearn_examples.html#sphx-glr-python-examples-sklearn-examples-py) using [cross validation](https://scikit-learn.org/stable/modules/cross_validation.html#cross-validation)
* [x] Try [early stopping](https://xgboost.readthedocs.io/en/stable/python/sklearn_estimator.html#early-stopping)
* [x] Save and load [models](https://xgboost.readthedocs.io/en/stable/tutorials/saving_model.html)
* [ ] Plot all figures needed

* The assessment will predominantly concentrate on how you articulate about the choice of models, how
you develop/train/validate these models, and how you report/discuss/analyse the
results.
2 changes: 0 additions & 2 deletions A/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@
# -*- coding: utf-8 -*-
import os
import sys
from typing import List

# hack here
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))

from utils.logger import logger
from utils.solution import Solution
from constants import TASK_A_DIR

Expand Down
5 changes: 0 additions & 5 deletions B/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,3 @@

Multi-class classification task (using PathMNIST dataset): The objective is
to classify an image onto 9 different types of tissues.

## TODOs:

* [x] Design model for task B
* [x] Report training, validation, and testing errors / accuracies, along with describe any hyper-parameter tunice process.
1 change: 0 additions & 1 deletion B/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# -*- coding: utf-8 -*-
import os
import sys
from typing import List

# hack here
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
Expand Down
23 changes: 16 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,40 @@

## Overview

* [ ] TODO: Add repo description here.
This repo is for ELEC0134: Applied Machine Learning Systems 23/24.

## Repo Structure
It is using [XGBoost](https://github.com/dmlc/xgboost) to accomplish the binary classification task and the multi-class classificaion task.

* [ ] TODO: Explain the role of each file.
## Repo Structure

```bash
# TODO: update repo structure
$ tree
.
├── A
│   ├── README.md
│   ├── config.json
│   ├── solution.py # Solution for task A
│   ├── task_a_early_stopping_rounds_3_evals_result.json
│   ├── task_a_early_stopping_rounds_None_evals_result.json
│   └── task_a_training_model.json # Model saved
├── B
│   ├── README.md
│   ├── config.json
│   ├── solution.py # Solution for task B
│   ├── task_b_early_stopping_rounds_3_evals_result.json
│   ├── task_b_early_stopping_rounds_None_evals_result.json
│   └── task_b_training_model.json # Model saved
├── Datasets # An empty dir for Datasets
├── Makefile
├── README.md
├── constants.py # global constants
├── environment.yml # conda environment config
├── images
│   ├── flowchart.png # draw.io flowchart
│   ├── task_a_learning_curve.png
│   └── task_b_learning_curve.png
├── main.py # Entrypoint of this repo
├── plot.ipynb # Plots
├── requirements.txt # python package requirements
└── utils
├── dataset.py # Dataset Loader
Expand Down Expand Up @@ -75,10 +83,9 @@ Datasets

## Requirements

* [ ] TODO: Add assignment requirements.
* `cuda 11.8`: If you want to use `py-xgboost-gpu`, GPU and `cuda` are required

```bash
# TODO: update requirements.txt
$ cat requirements.txt
# format python files
black
Expand All @@ -91,9 +98,11 @@ scikit-learn
py-xgboost-gpu
# save results
pandas
# plot
matplotlib
```

**NOTE**: `py-xgboost-gpu` will use GPU. If you want to use CPU only, use [`py-xgboost-cpu`](https://xgboost.readthedocs.io/en/stable/install.html#conda) and set the device using `--device cpu`.
**NOTE**: `py-xgboost-gpu` will use GPU. If you want to use CPU instead, use [`py-xgboost-cpu`](https://xgboost.readthedocs.io/en/stable/install.html#conda) and set the device using `--device cpu`.

## Usage

Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ dependencies:
- numpy
- py-xgboost-gpu
- pandas
- matplotlib
prefix: /home/uceezl8/.conda/envs/amls-final-zhaoyanlu
401 changes: 401 additions & 0 deletions plot.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@ scikit-learn
py-xgboost-gpu
# save results
pandas
# plot
matplotlib
5 changes: 5 additions & 0 deletions utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,19 @@ def __init__(self, path: str) -> None:
logger.info(
"We concatenate the training set and validation set together for a better cross validation."
)
# old_X_train is used for training
self.old_X_train = data.get("train_images", np.asarray([]))
self.old_X_train = self._reshape_to_2_dims(self.old_X_train)
X_val = data.get("val_images", np.asarray([]))
self.X_val = self._reshape_to_2_dims(X_val)
# X_train is used for cross validation
self.X_train = np.concatenate((self.old_X_train, self.X_val), axis=0)
logger.debug(f"X_train shape: {self.X_train.shape}")

# old_y_train is used for training
self.old_y_train = data.get("train_labels", np.asarray([]))
self.y_val = data.get("val_labels", np.asarray([]))
# y_train is used for cross validation
self.y_train = np.concatenate((self.old_y_train, self.y_val), axis=0)
logger.debug(f"y_train shape: {self.y_train.shape}")

Expand All @@ -42,6 +46,7 @@ def __init__(self, path: str) -> None:
logger.info(f"Dataset loaded: {path}")

def _reshape_to_2_dims(self, ndarray: np.ndarray) -> np.ndarray:
"""Reshape a ndarry to the shape (n_samples, n_features)."""
from functools import reduce
import operator

Expand Down
3 changes: 3 additions & 0 deletions utils/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def solve(self, stages: List[str]):
logger.info(f"----------{self.task_name} finished!-----------")

def val(self):
"""Cross Validation"""
logger.info(f"[{self.task_name}] [Validation] Running on {self.device}...")
val_config: dict = self.config.get("validation", {})
param_grid: dict = val_config.get("param_grid", {})
Expand Down Expand Up @@ -111,6 +112,7 @@ def val(self):
df.to_csv(cv_result_path)

def train(self):
"""Training"""
logger.info(f"[{self.task_name}] [Training] Running on {self.device}...")
training_config: dict = self.config.get("training", {})
random_state: int = self.config.get("random_state", DEFAULT_RANDOM_STATE)
Expand Down Expand Up @@ -161,6 +163,7 @@ def train(self):
logger.info(f"classification report for training:\n{report}")

def test(self):
"""Testing"""
logger.info(f"[{self.task_name}] [Testing] Running on {self.device}...")

if not self.classifier:
Expand Down

0 comments on commit 2108a50

Please sign in to comment.