Skip to content

Commit

Permalink
add instance-cost-sensitive classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
kjgm committed Apr 10, 2024
1 parent 48b9f90 commit 743d821
Show file tree
Hide file tree
Showing 27 changed files with 953 additions and 59 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ set(STREED_HEADER_FILES
${PROJECT_INCLUDE_DIR}/tasks/tasks.h
${PROJECT_INCLUDE_DIR}/tasks/optimization_task.h
${PROJECT_INCLUDE_DIR}/tasks/cost_sensitive.h
${PROJECT_INCLUDE_DIR}/tasks/instance_cost_sensitive.h
${PROJECT_INCLUDE_DIR}/tasks/f1score.h
${PROJECT_INCLUDE_DIR}/tasks/group_fairness.h
${PROJECT_INCLUDE_DIR}/tasks/eq_opp.h
Expand Down Expand Up @@ -97,6 +98,7 @@ set(STREED_SRC_FILES

${PROJECT_SOURCE_DIR}/tasks/optimization_task.cpp
${PROJECT_SOURCE_DIR}/tasks/cost_sensitive.cpp
${PROJECT_SOURCE_DIR}/tasks/instance_cost_sensitive.cpp
${PROJECT_SOURCE_DIR}/tasks/f1score.cpp
${PROJECT_SOURCE_DIR}/tasks/group_fairness.cpp
${PROJECT_SOURCE_DIR}/tasks/eq_opp.cpp
Expand Down
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,13 @@ Note that currently `STreeDCostSensitiveClassifier` does not support automatic b

See [examples/cost_sensitive_example.py](examples/cost_sensitive_example.py) for an example.

### Instance-Cost-Sensitive Classification
`STreeDInstanceCostSensitiveClassifier` implements an instance-cost-sensitive classifier. Each instance can have a different misclassification cost per label.

The costs can be specified with a `CostVector` object. For each instance, initialize a `CostVector` object with a list of the costs for each possible label.

See [examples/instance_cost_sensitive_example.py](examples/instance_cost_sensitive_example.py) for an example.

### Classification under a Group Fairness constraint
`STreeDGroupFairnessClassifier` implements a classifier that satisfies a group fairness constraint.
The maximum amount of discrimination on the training data can be specified by the `discrimination-limit` parameter, e.g., 0.01 for maximum of 1% discrimination.
Expand Down
24 changes: 24 additions & 0 deletions data/instance-cost-sensitive/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
https://maxsat-evaluations.github.io/2021/index.html
This dataset contains data about 439 out of 539 instances from the MaxSAT 2021 unweighted complete track.
The excluded instances were removed as no algorithm was able to solve them before timing out.
For each instance 32 features listed below were extracted and binarized according the threshold listed in the info file.

Each row is one instance, the first value is the index of the best performing algorithm, followed by the differences in runtime between between all algorithms and the best and the binary feature values, with a 1 indicating the feature is smaller than the threshold, or 0 otherwise.

Algorithms in order: MaxHS, CASHWMaxSAT, EvalMaxSAT, UWrMaxSAT, Open-WBO-RES-MergeSAT, Open-WBO-RES-Glucose ,Pacose, Exact

FEATURES:
* Size features
- 0 -> number of variables v
- 1 -> number of clauses c
- 2-4 -> ratios v/c (v/c)^2 (v/c )^3
- 5-7 -> reciprocals of above
- 27-31 -> length of clauses, mean, var, min, max, entropy
* Balance features
- 8-12 -> fraction of positive to total per clause, mean. var, min, max, entropy
- 13-15 -> fraction of unary, binary, ternary clauses
- 16-20 -> fraction of positive occurrence for each variable, mean, var, min, max
* Horn features
- 21 -> Fraction of horn clauses
- 22-26 -> occurrences in horn clauses per variables mean, var, min, max, entropy

439 changes: 439 additions & 0 deletions data/instance-cost-sensitive/maxsat21-32f.txt

Large diffs are not rendered by default.

32 changes: 32 additions & 0 deletions data/instance-cost-sensitive/maxsat21-32f_info.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
feature index: 0 threshold: 13980.0
feature index: 1 threshold: 83832.0
feature index: 2 threshold: 0.21964298831998824
feature index: 3 threshold: 0.04824304231813449
feature index: 4 threshold: 0.010596245980402713
feature index: 5 threshold: 4.552842809364549
feature index: 6 threshold: 20.728377646782477
feature index: 7 threshold: 94.37304511894645
feature index: 8 threshold: 0.45109388458225663
feature index: 9 threshold: 0.0651367501572155
feature index: 10 threshold: 0.0
feature index: 11 threshold: 1.0
feature index: 12 threshold: 10.439579908918455
feature index: 13 threshold: 0.034280817026139124
feature index: 14 threshold: 0.4864387270398651
feature index: 15 threshold: 0.17018532233770772
feature index: 16 threshold: 0.4603800662334622
feature index: 17 threshold: 0.013677687801700178
feature index: 18 threshold: 0.023809523809523808
feature index: 19 threshold: 0.9333333333333333
feature index: 20 threshold: 9.470678001840568
feature index: 21 threshold: 0.29340573213873516
feature index: 22 threshold: 2.8584124720717283
feature index: 23 threshold: 8.238888888888699
feature index: 24 threshold: 0.0
feature index: 25 threshold: 36.0
feature index: 26 threshold: 7.918714299188948
feature index: 27 threshold: 2.525265066681101
feature index: 28 threshold: 1.5079626017459662
feature index: 29 threshold: 1.0
feature index: 30 threshold: 8.0
feature index: 31 threshold: 11.2336204336583
31 changes: 31 additions & 0 deletions examples/instance_cost_sensitive_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from pystreed import STreeDInstanceCostSensitiveClassifier
import pandas as pd
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

# Read the data
df = pd.read_csv("data/instance-cost-sensitive/maxsat21-32f.txt", sep=" ", header=None)
# In this file,
# Column 0 is the optimal label (with lowest cost)
# Columns 1..8 are the costs of assigning the labels 0...7
# Columns 9... are the binary features
X = df[df.columns[9:]].values
y = df[0].values
costs = df[df.columns[1:9]].values

X_train, X_test, y_train, y_test, costs_train, costs_test = train_test_split(X, y, costs, test_size=0.20, random_state=42)

# Fit the model by passing the cost vector
model = STreeDInstanceCostSensitiveClassifier(max_depth = 5, time_limit=600, verbose=True)
model.fit(X_train, costs_train)

model.print_tree()

# Obtain the test predictions
yhat = model.predict(X_test)

# Obtain the test accuracy
print(f"Test Accuracy Score: {accuracy_score(y_test, yhat) * 100}%")

# Obtain the classification costs on the test set through model.score
print(f"Test Average Outcome: {model.score(X_test, costs_test)}")
105 changes: 105 additions & 0 deletions include/tasks/instance_cost_sensitive.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#pragma once
#include "tasks/optimization_task.h"

namespace STreeD {

template <class OT>
struct PairWorstCount;

struct InstanceCostSensitiveData {

InstanceCostSensitiveData() : costs() {}

InstanceCostSensitiveData(std::vector<double>& costs) : costs(costs) {
worst = *std::max_element(std::begin(costs), std::end(costs));
}

static InstanceCostSensitiveData ReadData(std::istringstream& iss, int num_labels);

// Get the cost for classifying this instance with the given label
inline double GetLabelCost(int label) const { return costs.at(label); }

// Add the label cost to the cost vector
inline void AddLabelCost(double cost) { costs.push_back(cost); }

// Return the number of possible labels
inline int NumLabels() const { return int(costs.size()); }

// Get the worst possible score for this instance (max of costs)
inline double GetWorst() const { return worst; }

// The costs for classifying this instance with each of the labels
std::vector<double> costs;
// The worst cost (max of costs)
double worst{ 0 };
};

class InstanceCostSensitive : public Classification {
public:
using SolType = double;
using SolD2Type = double;
using TestSolType = double;

using ET = InstanceCostSensitiveData;

static const bool total_order = true;
static const bool custom_leaf = false;
static const bool preprocess_train_test_data = true;
static const bool custom_similarity_lb = true;
static constexpr double worst = DBL_MAX;
static constexpr double best = 0;

explicit InstanceCostSensitive(const ParameterHandler& parameters)
: num_labels(int(parameters.GetIntegerParameter("num-extra-cols"))),
Classification(parameters) {}

inline void UpdateParameters(const ParameterHandler& parameters) {
num_labels = int(parameters.GetIntegerParameter("num-extra-cols"));
}

// Compute the leaf costs for the data in the context when assigning label
double GetLeafCosts(const ADataView& data, const BranchContext& context, int label) const;

// Compute the test leaf costs for the data in the context when assigning label
double GetTestLeafCosts(const ADataView& data, const BranchContext& context, int label) const;

// Compute the leaf costs for an instance given a assigned label
void GetInstanceLeafD2Costs(const AInstance* instance, int org_label, int label, double& costs, int multiplier) const;

// Compute the solution value from a terminal solution value
void ComputeD2Costs(const double& d2costs, int count, double& costs) const;

// Return true if the terminal solution value is zero
inline bool IsD2ZeroCost(const double& d2costs) const { return d2costs <= 1e-6 && d2costs >= -1e-6; }

// Get a bound on the worst contribution to the objective of a single instance with label
inline double GetWorstPerLabel(int label) const { return worst_per_label.at(label); }

// Inform the task on what training data is used
void InformTrainData(const ADataView& train_data, const DataSummary& train_summary);

// Compute the train score from the training solution value
inline double ComputeTrainScore(double train_value) const { return train_value ; }

// Compute the test score on the training data from the test solution value
inline double ComputeTrainTestScore(double train_value) const { return train_value; }

// Compute the test score on the test data from the test solution value
inline double ComputeTestTestScore(double test_value) const { return test_value; }

// Compare two score values. Lower is better
inline static bool CompareScore(double score1, double score2) { return score1 - score2 <= 0; } // return true if score1 is better than score2

// Provide a custom lower bound based on the worst values possible per instance
PairWorstCount<InstanceCostSensitive> ComputeSimilarityLowerBound(const ADataView& data_old, const ADataView& data_new) const;

// Preprocess the training and test data
void PreprocessTrainData(ADataView& train_data);
void PreprocessTestData(ADataView& test_data);

private:
std::vector<double> worst_per_label;
int num_labels{ 1 };
};

}
1 change: 1 addition & 0 deletions include/tasks/tasks.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "tasks/accuracy/cost_complex_accuracy.h"

#include "tasks/cost_sensitive.h"
#include "tasks/instance_cost_sensitive.h"
#include "tasks/f1score.h"
#include "tasks/group_fairness.h"
#include "tasks/eq_opp.h"
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "pystreed"
version = "1.2.2"
version = "1.2.3"
requires-python = ">=3.8"
description = "Python Implementation of STreeD: Dynamic Programming Approach for Optimal Decision Trees with Separable objectives and Constraints"
license= {file = "LICENSE"}
Expand Down
1 change: 1 addition & 0 deletions pystreed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
from pystreed.survival_analysis import STreeDSurvivalAnalysis
from pystreed.prescriptive_policy_generation import STreeDPrescriptivePolicyGenerator
from pystreed.group_fair import STreeDGroupFairnessClassifier
from pystreed.instance_cost_sensitive_classification import STreeDInstanceCostSensitiveClassifier
from pystreed.data import *
2 changes: 1 addition & 1 deletion pystreed/data.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .cstreed import SAData, PPGData, FeatureCostSpecifier, CostSpecifier
from .cstreed import SAData, PPGData, FeatureCostSpecifier, CostSpecifier, CostVector
Loading

0 comments on commit 743d821

Please sign in to comment.