Skip to content

Commit

Permalink
Merge pull request #18 from ibois-epfl/ttool_classification_new_weights
Browse files Browse the repository at this point in the history
Ttool classification new weights
  • Loading branch information
9and3 authored Jan 11, 2024
2 parents 0489d82 + 4dd9962 commit 313ac2c
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 5 deletions.
Binary file modified ai/torchscripts/efficientnet.pt
100644 → 100755
Binary file not shown.
11 changes: 11 additions & 0 deletions ai/torchscripts/label_map.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
auger_drill_bit_34_235: 0
chain_swordsaw_blade_200: 1
spade_drill_bit_25_150: 2
brad_point_drill_bit_20_150: 3
chain_saw_blade_f_250: 4
self_feeding_bit_40_90: 5
circular_saw_blade_makita_190: 6
self_feeding_bit_50_90: 7
twist_drill_bit_32_165: 8
saber_saw_blade_makita_t_300: 9
auger_drill_bit_20_235: 10
14 changes: 10 additions & 4 deletions assets/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,22 @@ histOffset: 100
histRad: 40
searchRad: 25
classifierModelPath: "ai/torchscripts/efficientnet.pt"

# The order of classifierLabels is determined by the classifier during training
# the classifier once trained output a .pt (weights) and a .txt file for the order
# see the .txt file in ai/torchscripts to see the order to recreate here below
classifierLabels:
- "auger_drill_bit_20_235"
- "auger_drill_bit_34_235"
- "chain_swordsaw_blade_200"
- "spade_drill_bit_25_150"
- "brad_point_drill_bit_20_150"
- "chain_saw_blade_f_250"
- "circular_saw_blade_makita_190"
- "saber_saw_blade_makita_t_300"
- "self_feeding_bit_40_90"
- "circular_saw_blade_makita_190"
- "self_feeding_bit_50_90"
- "spade_drill_bit_25_150"
- "twist_drill_bit_32_165"
- "saber_saw_blade_makita_t_300"
- "auger_drill_bit_20_235"
classifierImageSize: 384
classifierImageChannels: 3
classifierMean:
Expand Down
3 changes: 2 additions & 1 deletion src/classifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ ttool::ML::Classifier::Classifier(std::string modelPath,
std::vector<float> std)
: IMAGE_SIZE(imageSize), IMAGE_CHANNEL(imageChannel), m_Pred2Label(pred2Label), m_Mean(mean), m_Std(std)
{
m_Module = torch::jit::load(modelPath);
torch::Device device(torch::kCPU);
m_Module = torch::jit::load(modelPath, device);

// // Dry run to initialize the model
// cv::Mat image = cv::Mat::zeros(IMAGE_SIZE, IMAGE_SIZE, CV_8UC3);
Expand Down

0 comments on commit 313ac2c

Please sign in to comment.