-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add instance-cost-sensitive classifier
- Loading branch information
Showing
27 changed files
with
953 additions
and
59 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
https://maxsat-evaluations.github.io/2021/index.html | ||
This dataset contains data about 439 out of 539 instances from the MaxSAT 2021 unweighted complete track. | ||
The excluded instances were removed as no algorithm was able to solve them before timing out. | ||
For each instance 32 features listed below were extracted and binarized according the threshold listed in the info file. | ||
|
||
Each row is one instance, the first value is the index of the best performing algorithm, followed by the differences in runtime between between all algorithms and the best and the binary feature values, with a 1 indicating the feature is smaller than the threshold, or 0 otherwise. | ||
|
||
Algorithms in order: MaxHS, CASHWMaxSAT, EvalMaxSAT, UWrMaxSAT, Open-WBO-RES-MergeSAT, Open-WBO-RES-Glucose ,Pacose, Exact | ||
|
||
FEATURES: | ||
* Size features | ||
- 0 -> number of variables v | ||
- 1 -> number of clauses c | ||
- 2-4 -> ratios v/c (v/c)^2 (v/c )^3 | ||
- 5-7 -> reciprocals of above | ||
- 27-31 -> length of clauses, mean, var, min, max, entropy | ||
* Balance features | ||
- 8-12 -> fraction of positive to total per clause, mean. var, min, max, entropy | ||
- 13-15 -> fraction of unary, binary, ternary clauses | ||
- 16-20 -> fraction of positive occurrence for each variable, mean, var, min, max | ||
* Horn features | ||
- 21 -> Fraction of horn clauses | ||
- 22-26 -> occurrences in horn clauses per variables mean, var, min, max, entropy | ||
|
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
feature index: 0 threshold: 13980.0 | ||
feature index: 1 threshold: 83832.0 | ||
feature index: 2 threshold: 0.21964298831998824 | ||
feature index: 3 threshold: 0.04824304231813449 | ||
feature index: 4 threshold: 0.010596245980402713 | ||
feature index: 5 threshold: 4.552842809364549 | ||
feature index: 6 threshold: 20.728377646782477 | ||
feature index: 7 threshold: 94.37304511894645 | ||
feature index: 8 threshold: 0.45109388458225663 | ||
feature index: 9 threshold: 0.0651367501572155 | ||
feature index: 10 threshold: 0.0 | ||
feature index: 11 threshold: 1.0 | ||
feature index: 12 threshold: 10.439579908918455 | ||
feature index: 13 threshold: 0.034280817026139124 | ||
feature index: 14 threshold: 0.4864387270398651 | ||
feature index: 15 threshold: 0.17018532233770772 | ||
feature index: 16 threshold: 0.4603800662334622 | ||
feature index: 17 threshold: 0.013677687801700178 | ||
feature index: 18 threshold: 0.023809523809523808 | ||
feature index: 19 threshold: 0.9333333333333333 | ||
feature index: 20 threshold: 9.470678001840568 | ||
feature index: 21 threshold: 0.29340573213873516 | ||
feature index: 22 threshold: 2.8584124720717283 | ||
feature index: 23 threshold: 8.238888888888699 | ||
feature index: 24 threshold: 0.0 | ||
feature index: 25 threshold: 36.0 | ||
feature index: 26 threshold: 7.918714299188948 | ||
feature index: 27 threshold: 2.525265066681101 | ||
feature index: 28 threshold: 1.5079626017459662 | ||
feature index: 29 threshold: 1.0 | ||
feature index: 30 threshold: 8.0 | ||
feature index: 31 threshold: 11.2336204336583 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from pystreed import STreeDInstanceCostSensitiveClassifier | ||
import pandas as pd | ||
from sklearn.metrics import accuracy_score | ||
from sklearn.model_selection import train_test_split | ||
|
||
# Read the data | ||
df = pd.read_csv("data/instance-cost-sensitive/maxsat21-32f.txt", sep=" ", header=None) | ||
# In this file, | ||
# Column 0 is the optimal label (with lowest cost) | ||
# Columns 1..8 are the costs of assigning the labels 0...7 | ||
# Columns 9... are the binary features | ||
X = df[df.columns[9:]].values | ||
y = df[0].values | ||
costs = df[df.columns[1:9]].values | ||
|
||
X_train, X_test, y_train, y_test, costs_train, costs_test = train_test_split(X, y, costs, test_size=0.20, random_state=42) | ||
|
||
# Fit the model by passing the cost vector | ||
model = STreeDInstanceCostSensitiveClassifier(max_depth = 5, time_limit=600, verbose=True) | ||
model.fit(X_train, costs_train) | ||
|
||
model.print_tree() | ||
|
||
# Obtain the test predictions | ||
yhat = model.predict(X_test) | ||
|
||
# Obtain the test accuracy | ||
print(f"Test Accuracy Score: {accuracy_score(y_test, yhat) * 100}%") | ||
|
||
# Obtain the classification costs on the test set through model.score | ||
print(f"Test Average Outcome: {model.score(X_test, costs_test)}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
#pragma once | ||
#include "tasks/optimization_task.h" | ||
|
||
namespace STreeD { | ||
|
||
template <class OT> | ||
struct PairWorstCount; | ||
|
||
struct InstanceCostSensitiveData { | ||
|
||
InstanceCostSensitiveData() : costs() {} | ||
|
||
InstanceCostSensitiveData(std::vector<double>& costs) : costs(costs) { | ||
worst = *std::max_element(std::begin(costs), std::end(costs)); | ||
} | ||
|
||
static InstanceCostSensitiveData ReadData(std::istringstream& iss, int num_labels); | ||
|
||
// Get the cost for classifying this instance with the given label | ||
inline double GetLabelCost(int label) const { return costs.at(label); } | ||
|
||
// Add the label cost to the cost vector | ||
inline void AddLabelCost(double cost) { costs.push_back(cost); } | ||
|
||
// Return the number of possible labels | ||
inline int NumLabels() const { return int(costs.size()); } | ||
|
||
// Get the worst possible score for this instance (max of costs) | ||
inline double GetWorst() const { return worst; } | ||
|
||
// The costs for classifying this instance with each of the labels | ||
std::vector<double> costs; | ||
// The worst cost (max of costs) | ||
double worst{ 0 }; | ||
}; | ||
|
||
class InstanceCostSensitive : public Classification { | ||
public: | ||
using SolType = double; | ||
using SolD2Type = double; | ||
using TestSolType = double; | ||
|
||
using ET = InstanceCostSensitiveData; | ||
|
||
static const bool total_order = true; | ||
static const bool custom_leaf = false; | ||
static const bool preprocess_train_test_data = true; | ||
static const bool custom_similarity_lb = true; | ||
static constexpr double worst = DBL_MAX; | ||
static constexpr double best = 0; | ||
|
||
explicit InstanceCostSensitive(const ParameterHandler& parameters) | ||
: num_labels(int(parameters.GetIntegerParameter("num-extra-cols"))), | ||
Classification(parameters) {} | ||
|
||
inline void UpdateParameters(const ParameterHandler& parameters) { | ||
num_labels = int(parameters.GetIntegerParameter("num-extra-cols")); | ||
} | ||
|
||
// Compute the leaf costs for the data in the context when assigning label | ||
double GetLeafCosts(const ADataView& data, const BranchContext& context, int label) const; | ||
|
||
// Compute the test leaf costs for the data in the context when assigning label | ||
double GetTestLeafCosts(const ADataView& data, const BranchContext& context, int label) const; | ||
|
||
// Compute the leaf costs for an instance given a assigned label | ||
void GetInstanceLeafD2Costs(const AInstance* instance, int org_label, int label, double& costs, int multiplier) const; | ||
|
||
// Compute the solution value from a terminal solution value | ||
void ComputeD2Costs(const double& d2costs, int count, double& costs) const; | ||
|
||
// Return true if the terminal solution value is zero | ||
inline bool IsD2ZeroCost(const double& d2costs) const { return d2costs <= 1e-6 && d2costs >= -1e-6; } | ||
|
||
// Get a bound on the worst contribution to the objective of a single instance with label | ||
inline double GetWorstPerLabel(int label) const { return worst_per_label.at(label); } | ||
|
||
// Inform the task on what training data is used | ||
void InformTrainData(const ADataView& train_data, const DataSummary& train_summary); | ||
|
||
// Compute the train score from the training solution value | ||
inline double ComputeTrainScore(double train_value) const { return train_value ; } | ||
|
||
// Compute the test score on the training data from the test solution value | ||
inline double ComputeTrainTestScore(double train_value) const { return train_value; } | ||
|
||
// Compute the test score on the test data from the test solution value | ||
inline double ComputeTestTestScore(double test_value) const { return test_value; } | ||
|
||
// Compare two score values. Lower is better | ||
inline static bool CompareScore(double score1, double score2) { return score1 - score2 <= 0; } // return true if score1 is better than score2 | ||
|
||
// Provide a custom lower bound based on the worst values possible per instance | ||
PairWorstCount<InstanceCostSensitive> ComputeSimilarityLowerBound(const ADataView& data_old, const ADataView& data_new) const; | ||
|
||
// Preprocess the training and test data | ||
void PreprocessTrainData(ADataView& train_data); | ||
void PreprocessTestData(ADataView& test_data); | ||
|
||
private: | ||
std::vector<double> worst_per_label; | ||
int num_labels{ 1 }; | ||
}; | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
from .cstreed import SAData, PPGData, FeatureCostSpecifier, CostSpecifier | ||
from .cstreed import SAData, PPGData, FeatureCostSpecifier, CostSpecifier, CostVector |
Oops, something went wrong.