Skip to content

alevas/deployable-toy-model

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Generic deployable XGBoost predictor

Description

This little toy project is intended to expose a simple yet effective model through a FastAPI endpoint, as well as deploy the model via Docker.

The models implemented are listed in an Enum class under app/utils/utils_classifier.py, but can be extended with the corresponding changes in app/model.py.

Implementation details

Data

The data can be found in Kaggle and can be retrieved through the Kaggle API:

cd data/raw && kaggle datasets download -d purumalgi/music-genre-classification

Model(s)

The main model used is an XGBoost classifier with the DART booster. Other implementations such as a Multiclass SVM and an MLP classifier are implemented and briefly tested.

Logic

flowchart TD
    A[Parse CSV Data] -->|data cleaning| B[Transformed Dataset]
    B -->|encoding & transformation| O[Numerical features & labels]
    O --> Z{Stage}
    Z -->|Training| D[Entire dataset]
    Z -->|Inference| E[Train & Test splits]
    D --> K[Predictions]
    E --> K
    K --> I[Save to CSV] 
Loading

Service Endpoints

  • /: Get the Service Status.
  • /train_predict: Preprocess the data, train the model and retrieve the predictions.
  • /predict: Load a pretrained model and preprocessor, retrieve predictions.

How to use

The model can be trained and used for inference as a service accessible on localhost:8000 or via Docker.

It must be noted that the produced CSV as per the requirements is by default saved under data/raw/train.csv.

As a service

python -m virtualenv venv
source venv/bin/activate
pip install -r requirements.txt

python main.py

The service will be accessible via http://127.0.0.1:8000.

Docker

The model can also be deployed on Docker:

docker build -t toy_class_predictor .
docker run -d --name toy_class_predictor_container -p 8000:8000 -v ./data:/dockerized_model/data toy_class_predictor # mount the data folder to have direct access to the output file

Again, the service listens to http://127.0.0.1:8000.

Testing

To train and retrieve the predictions for the first time, use a payload such:

curl --location 'http://127.0.0.1:8000/train_predict' --header 'Content-Type: application/json' --data '{
    "csv_train": "data/raw/train.csv", "label_col_name": "Class"
}'

After the model has been trained, the endpoint predict can be used to perform the same operation on other datasets:

curl --location 'http://127.0.0.1:8000/predict' \
--header 'Content-Type: application/json' \
--data '{
    "csv_path": "data/raw/data.csv",
    "pretrained_model_path": "model_XGBOOST.joblib"
}'

NOTE: If the endpoint predict is called with an untrained model, an exception will be raised.

NOTES

If you want to read the CSV from e.g. MS Excel, use comma and ^ as delimiters. I added the escapechar="^" parameter to the stage saving the final CSV, as it could not be parsed by Excel otherwise.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published