Skip to content

C native interface

Alexander Medvedev edited this page Aug 15, 2021 · 65 revisions

The library has a C interface. Input-output are serialized in JSON.

Operators

Examples of use


Create net

Creating a Network Architecture

 /// create net
 /// @param[in] jnNet - network architecture in JSON
 /// @param[out] out_err - parse error jnNet. "" - ok. The memory is allocated by the user  
 /// @param[in] statusCBack - callback state. Not necessary  
 /// @param[in] udata - user data. Not necessary
 /// @return object net
 SUNNET_API sunnet snCreateNet(const char* jnNet,
                               char* out_err /*sz 256*/,
                               snStatusCBack = nullptr,  
                               snUData = nullptr);

Example:

    stringstream ss;
    ss << "{"

        "\"BeginNet\":"
        "{"
        "\"NextNodes\":\"C1\""
        "},"

        "\"Nodes\":"
        "["

        "{"
        "\"NodeName\":\"C1\","
        "\"NextNodes\":\"P1\","
        "\"OperatorName\":\"Convolution\","
        "\"OperatorParams\":{\"filters\":\"32\"}"
        "},"

        "{"
        "\"NodeName\":\"P1\","
        "\"NextNodes\":\"F1\","
        "\"OperatorName\":\"Pooling\""
        "},"

        "{"
        "\"NodeName\":\"F1\","
        "\"NextNodes\":\"F3\","
        "\"OperatorName\":\"FullyConnected\","
        "\"OperatorParams\":{\"units\":\"128\","
                            "\"batchNorm\":\"beforeActive\"}"
        "},"

        "{"
        "\"NodeName\":\"F3\","
        "\"NextNodes\":\"LS\","
        "\"OperatorName\":\"FullyConnected\","
        "\"OperatorParams\":{\"units\":\"10\","
                            "\"weightInit\":\"uniform\","
                            "\"active\":\"none\"}"
        "},"

        "{"
        "\"NodeName\":\"LS\","
        "\"NextNodes\":\"EndNet\","
        "\"OperatorName\":\"LossFunction\","
        "\"OperatorParams\":{\"loss\":\"softMaxToCrossEntropy\"}"
        "}"

        "],"

        "\"EndNet\":"                      
        "{"
        "\"PrevNode\":\"LS\""              
        "}"
        "}";

    char err[256]; err[0] = '\0';
    auto snet = SN_API::snCreateNet(ss.str().c_str(), err);

Let's examine the JSON example schema.

 "\"BeginNet\":"
        "{"
        "\"NextNodes\":\"C1\""
        "}"

First comes the node beginning of the network. Has a strictly defined name "BeginNet". There must be only one in the scheme.
The 'NextNodes' Indicates the next node to be transferred.

 "\"Nodes\":"
        "["
         .
         .
        "]"  

Next is a list of network hosts.

 "{"
        "\"NodeName\":\"F1\","
        "\"NextNodes\":\"F3\","
        "\"OperatorName\":\"FullyConnected\","
        "\"OperatorParams\":{\"units\":\"128\","
                            "\"batchNorm\":\"beforeActive\"}"
 "}"

Each node has the same set of fields:

  • 'NodeName' - name of the current node. Can be arbitrary
  • 'NextNodes' - the names of the following nodes through a space
  • 'OperatorName' - the name of the operator of the node. Specific value
  • 'OperatorParams' - parameters of the operator of the node. Specific values
  "\"EndNet\":"                      
        "{"
        "\"PrevNode\":\"LS\""              
        "}"        

At the end, after the list of layers, there should be a node 'EndNet'.
It has only one 'PrevNode' field - the name of the previous single node.

Training net

You can train a network in two ways:

  • by calling one function 'snTraining'
  • and by the standard way: forwarding the function 'snForward', calculating your own error, passing back the function 'snBackward'.

Let's see the first option.

        /// training - a cycle forward-back with auto-correction of weights
        /// @param[in] sunnet - object net
        /// @param[in] lr - learning rate
        /// @param[in] isz - input layer size
        /// @param[in] iLayer - input layer NCHW(bsz, ch, h, w)        
        /// @param[in] osz - size of target and result. Sets for verification
        /// @param[out] outData - result, the size must match the markup. The memory is allocated by the user
        /// @param[in] targetData - target, the size must match the markup. The memory is allocated by the user
        /// @param[out] outAccurate - current accuracy
        /// @return true - ok
        SUNNET_API bool snTraining(sunnet, 
                                   snFloat lr,
                                   snLSize isz,
                                   const snFloat* iLayer,  
                                   snLSize osz,
                                   snFloat* outData,                                   
                                   const snFloat* targetData,
                                   snFloat* outAccurate);

The function takes a batch of input data and the target result.
Returns the result and the evaluation by a batch.
Accurate is calculated as:

    snFloat* targetData = targetTens->getData();
    snFloat* outData = outTens->getData();
    
    size_t accCnt = 0, osz = outTens->size().size();
    for (size_t i = 0; i < osz; ++i){

        if (abs(outData[i] - targetData[i]) < 0.1)
            ++accCnt; 
    }

    return (accCnt * 1.F) / osz;

Architecture of net

Getting network structure in Json.

/// get architecture of net
/// @param[in] sunnet - object net
/// @param[out] jnNet - architecture of net in JSON
/// @return true - ok
SUNNET_API bool snGetArchitecNet(sunnet,
    char** jnNet);

Save and load weight of layer

        /// get weight of node
        /// @param[in] sunnet - object net
        /// @param[in] nodeName - name node
        /// @param[out] wsz - output size
        /// @param[out] wData - output data NCHW(bsz, ch, h, w). First pass NULL, then pass it to the same 
        /// @return true - ok
        SUNNET_API bool snGetWeightNode(sunnet,
                                        const char* nodeName,
                                        snLSize* wsz,
                                        snFloat** wData);

        /// set weight of node
        /// @param[in] sunnet - object net
        /// @param[in] nodeName - name node
        /// @param[in] wsz - size
        /// @param[in] wData - weight NCHW(bsz, ch, h, w)        
        /// @return true - ok
        SUNNET_API bool snSetWeightNode(sunnet,
                                        const char* nodeName,
                                        snLSize wsz,
                                        const snFloat* wData);

After training, you want to save the weights, for this use the 'snGetWeightNode' function. It returns an array of weights for the selected node.
Before work, you can load the trained weights, use the 'snSetWeightNode' function. It takes an array of weights for the selected node.

        /// save all weight's (and bnorm if exist) to file
        /// @param[in] sunnet - object net
        /// @param[in] filePath - path to file
        /// @return true - ok
        SUNNET_API bool snSaveAllWeightToFile(sunnet, const char* filePath);

        /// load all weight's (and bnorm if exist) from file
        /// @param[in] sunnet - object net
        /// @param[in] filePath - path to file
        /// @return true - ok
        SUNNET_API bool snLoadAllWeightFromFile(sunnet, const char* filePath);

Also you can save \ load all weights to disk. The functions 'snSaveAllWeightToFile' and 'snLoadAllWeightToFile'

Set and get params of layer

        /// set params of node
        /// @param[in] sunnet - object net
        /// @param[in] nodeName - name node
        /// @param[in] jnParam - params of node in JSON. 
        /// @return true - ok
        SUNNET_API bool snSetParamNode(sunnet,
                                       const char* nodeName,
                                       const char* jnParam);

You can specify some parameters of the nodes after creating the network. To do this, call the 'snSetParamNode' function.

       /// get params of node
        /// @param[in] sunnet - object net
        /// @param[in] nodeName - name node
        /// @param[out] jnParam - params of node in JSON. The memory is allocated by the user 
        /// @return true - ok
        SUNNET_API bool snGetParamNode(sunnet,
                                       const char* nodeName,
                                       char* jnParam /*minsz 256*/);

You can get the current node parameters using the 'snGetParamNode' function.

Monitoring gradients and weights

To find the current values of the node data, you can use the following functions:
'snGetWeightNode' - node weights
'snGetGradientNode' - node gradients
'snGetOutputNode' - output values of the node

You can also specify your own callback function, and insert your 'UserLayer' node after the node of interest.

        /// userCallBack for 'userLayer' node
        /// @param[in] cbname - name user cback 
        /// @param[in] node - name node 
        /// @param[in] fwdBwd - current action forward(true) or backward(false)
        /// @param[in] insz - input layer size - receive from prev node
        /// @param[in] in - input layer - receive from prev node
        /// @param[out] outsz - output layer size - send to next node
        /// @param[out] out - output layer - send to next node
        /// @param[in] ud - aux used data
        typedef void(*snUserCBack)(const char* cbname,
                                   const char* node,
                                   bool fwdBwd,
                                   snLSize insz,
                                   snFloat* in,
                                   snLSize* outsz,
                                   snFloat** out,
                                   snUData ud);

        /// add user callBack for 'userLayer' node
        /// @param[in] sunnet - object net
        /// @param[in] cbName - name callBack
        /// @param[in] snUserCBack - callBack
        /// @param[in] snUData - user data
        /// @return true - ok
        SUNNET_API bool snAddUserCallBack(sunnet, const char* cbName, snUserCBack, snUData = nullptr);

Example:

  
 void userCBack(const char* cbname,
                        const char* node,
                        bool fwdBwd,
                        snLSize insz,
                        snFloat* in,
                        snLSize* outsz,
                        snFloat** out,
                        snUData ud){
          "Your code"
        }

  int main(int argc, _TCHAR* argv[]){
   
    stringstream ss;
    ss << "{"

        "\"BeginNet\":"
        "{"
        "\"NextNodes\":\"C1\""
        "},"

        "\"Nodes\":"
        "["

        "{"
        "\"NodeName\":\"C1\","
        "\"NextNodes\":\"UserNd\","
        "\"OperatorName\":\"Convolution\","
        "\"OperatorParams\":{\"kernel\":\"28\", \"batchNorm\":\"beforeActive\","
        "\"freeze\":\"1\"}"
        "},"

        "{"
        "\"NodeName\":\"UserNd\","
        "\"NextNodes\":\"P1\","
        "\"OperatorName\":\"UserLayer\","
        "\"OperatorParams\":{\"cbackName\":\"ucb\"}"

        "},"

        "{"
        "\"NodeName\":\"P1\","
        "\"NextNodes\":\"F1\","
        "\"OperatorName\":\"Pooling\""
        "},"
         .
         . 
        char err[256]; err[0] = '\0';
        auto net = SN_API::snCreateNet(ss.str().c_str(), err);

        SN_API::snAddUserCallBack(net, "ucb", userCBack);

Input

The input node receives the user data, and transmits further along the chain.

 "BeginNet":
        {
        "NextNodes":"C1"
        }

There is one main input node, this is the beginning of the 'BeginNet' network,
and there can be a number of additional input nodes.

"Nodes":[
 {
  "OperatorName":"Input"
  "NextNodes":".."
 }  
 .
 .       

Additional input nodes are specified in the 'Nodes' section.

    /// set input node (relevant for additional inputs)
    /// @param[in] sunnet - object net
    /// @param[in] nodeName - name node
    /// @param[in] isz - data size
    /// @param[in] inData - data       
    /// @return true - ok
    SUNNET_API bool snSetInputNode(sunnet,
                                    const char* nodeName,
                                    snLSize isz,
                                    const snFloat* inData);

Set data to a node using the 'snSetInputNode' function.
Input nodes can not be specified as recipients from other nodes.
Input nodes do not have additional parameters.

Output

The output node of the network only receives the resulting data.

"EndNet":{
          "PrevNode":"LS"              
         } 

There is one main output node, and there may be additional output nodes.

"Nodes":[
 .
 . 
 {
  "OperatorName":"Output"
  "PrevNode":".."
 }        

Additional output nodes are specified in the 'Nodes' section.

    /// get output node (relevant for additional inputs)
    /// @param[in] sunnet - object net
    /// @param[in] nodeName - name node
    /// @param[out] osz - data size
    /// @param[out] outData - data. First pass NULL, then pass it to the same 
    /// @return true - ok
    SUNNET_API bool snGetOutputNode(sunnet,
                                    const char* nodeName,
                                    snLSize* osz,
                                    snFloat** outData);

Get data from a node using the 'snGetOutputNode' function.
Output nodes can not transfer data to other nodes.
Output nodes do not have additional parameters.

FullyConnected

"OperatorName":"FullyConnected",
"OperatorParams":{
 "units":"10",            Number of out neurons. !Required parameter (> 0)          
 "active":"relu",         Activation function type. Optional parameter
 "weightInit":"he",       Type of initialization of weights: "he", "lecun", "xavier" or "uniform". Optional parameter
 "batchNorm":"none",      Type of batch norm: "beforeActive", "postActive" or "none". Optional parameter
 "batchNormLr":"0.001",   Learning rate for batch norm coef. Optional parameter  [0..]
 "optimizer":"adam",      Optimizer of weights: "adam", "sgd", "sgdMoment", "adagrad" or "RMSprop". Optional parameter
 "decayMomentDW":"0.9",   Optimizer of weights moment change. Optional parameter  [0..1.F]
 "decayMomentWGr":"0.99", Optimizer of weights moment change of prev. Optional parameter  [0..1.F]
 "lmbRegular":"0.001",    Optimizer of weights l2Norm. Optional parameter  [0..1.F]
 "dropOut":"0.0",         Random disconnection of neurons. Optional parameter [0..1.F]
 "freeze":"0",            Do not change weights. Optional parameter [0/1]
 "useBias":"1",           +bias. Optional parameter [0/1]
 "gpuDeviceId":"0"        GPU Id. Optional parameter
 }        

The default parameters are specified.

Convolution

"OperatorName":"Convolution",
"OperatorParams":{
 "filters":"10",          Number of output layers. !Required parameter (> 0)     
 "fWidth":"3",            Width of mask. Optional parameter (> 0)   
 "fHeight":"3",           Height of mask. Optional parameter (> 0)   
 "padding:"0",            Padding around the edges. "-1" - the output layer will be the same size as the input. Optional parameter 
 "stride":"1",            Mask movement step. Optional parameter (> 0)   
 "dilate":"1",            Expansion mask. Optional parameter (> 0)
 "active":"relu",         Activation function type. Optional parameter 
 "weightInit":"he",       Type of initialization of weights: "he", "lecun", "xavier" or "uniform". Optional parameter
 "batchNorm":"none",      Type of batch norm: "beforeActive", "postActive" or "none". Optional parameter
 "batchNormLr":"0.001",   Learning rate for batch norm coef. Optional parameter [0..]
 "optimizer":"adam",      Optimizer of weights: "adam", "sgd", "sgdMoment", "adagrad" or "RMSprop". Optional parameter
 "decayMomentDW":"0.9",   Optimizer of weights moment change. Optional parameter [0..1.F]
 "decayMomentWGr":"0.99", Optimizer of weights moment change of prev. Optional parameter [0..1.F]
 "lmbRegular":"0.001",    Optimizer of weights l2Norm. Optional parameter [0..1.F]
 "dropOut":"0.0",         Random disconnection of neurons. Optional parameter [0..1.F]
 "freeze":"0"             Do not change weights. Optional parameter [0/1]
 "useBias":"1"            +bias. Optional parameter [0/1]
 "gpuDeviceId":"0"        GPU Id. Optional parameter 
 "checkPadding:0"         Check correct padding and size of mask. Optional parameter [0/1]
 }        

The default parameters are specified.

Deconvolution

"OperatorName":"Deconvolution",
"OperatorParams":{
 "filters":"10",          Number of output layers. !Required parameter (> 0)     
 "fWidth":"3",            Width of mask. Optional parameter (> 0) 
 "fHeight":"3",           Height of mask. Optional parameter (> 0) 
 "stride":"1",            Mask movement step. Optional parameter (> 0)
 "active":"relu",         Activation function type. Optional parameter
 "weightInit":"he",       Type of initialization of weights: "he", "lecun", "xavier" or "uniform". Optional parameter
 "batchNorm":"none",      Type of batch norm: "beforeActive", "postActive" or "none". Optional parameter
 "batchNormLr":"0.001",   Learning rate for batch norm coef. Optional parameter [0..]
 "optimizer":"adam",      Optimizer of weights: "adam", "sgd", "sgdMoment", "adagrad" or "RMSprop". Optional parameter
 "decayMomentDW":"0.9",   Optimizer of weights moment change. Optional parameter [0..1.F]
 "decayMomentWGr":"0.99", Optimizer of weights moment change of prev. Optional parameter [0..1.F]
 "lmbRegular":"0.001",    Optimizer of weights l2Norm. Optional parameter [0..1.F]
 "dropOut":"0.0",         Random disconnection of neurons. Optional parameter [0..1.F]
 "freeze":"0"             Do not change weights. Optional parameter [0/1]
 "gpuDeviceId":"0"        GPU Id. Optional parameter
 }        

The default parameters are specified.

Pooling

"OperatorName":"Pooling",
"OperatorParams":{
 "kernel":"2",           Square Mask Size. Optional parameter (> 0) 
 "stride":"2",           Mask movement step. Optional parameter (> 0)
 "pool":"max",           Operator Type: "max" or "avg". Optional parameter 
}        

The default parameters are specified.
If the mask does not completely enter the image, the image automatically extends around the edges.

LossFunction

Operator for automatic error calculation.
Depending on the network task being solved, supports the following types of errors:

  • "softMaxACrossEntropy" - for multiclass classification
  • "binaryCrossEntropy" - for binary classification
  • "regressionMSE" - regression of a function with least-squares estimation
  • "userLoss" - user operator
"OperatorName":"LossFunction",
"OperatorParams":{
 "loss":"softMaxToCrossEntropy",   Error function type: "softMaxToCrossEntropy", "binaryCrossEntropy" or "regressionOLS" !Required parameter
}        

Switch

Operator for transferring data to several nodes at once.
Data can only be received from one node.

"OperatorName":"Switch",
"OperatorParams":{
 "nextWay":"..."     The following operators, through a space. Optional parameter
}     

Example:

stringstream ss;
    ss << "{"

        "\"BeginNet\":"
        "{"
        "\"NextNodes\":\"SW\""
        "},"

        "\"Nodes\":"
        "["

        "{"
        "\"NodeName\":\"SW\","
        "\"NextNodes\":\"C1 L1\","
        "\"OperatorName\":\"Switch\""
        "},"
       
        "{"                                            "{"
        "\"NodeName\":\"C1\","                         "\"NodeName\":\"L1\","
        "\"NextNodes\":\"P1\","                        "\"NextNodes\":\"F2\","
        "\"OperatorName\":\"Convolution\","            "\"OperatorName\":\"Lock\""
        "\"OperatorParams\":{\"kernel\":\"28\"}"       "},"
        "},"     

Lock

Operator to block further calculation at the current location.
It is designed for the ability to dynamically disconnect the parallel branches of the network during operation.

"OperatorName":"Lock",
"OperatorParams":{
 "state":"lock"     Blocking activity: "lock" or "unlock". Optional parameter
}        

Example:

stringstream ss;
    ss << "{"

        "\"BeginNet\":"
        "{"
        "\"NextNodes\":\"C1 L1\""
        "},"

        "\"Nodes\":"
        "["
       
        "{"                                            "{"
        "\"NodeName\":\"C1\","                         "\"NodeName\":\"L1\","
        "\"NextNodes\":\"P1\","                        "\"NextNodes\":\"F2\","
        "\"OperatorName\":\"Convolution\","            "\"OperatorName\":\"Lock\""
        "\"OperatorParams\":{\"kernel\":\"28\"}"       "},"
        "},"                                          
       
        "{"                                            "{"
        "\"NodeName\":\"P1\","                         "\"NodeName\":\"F2\","
        "\"NextNodes\":\"O1\","                        "\"NextNodes\":\"O2\","
        "\"OperatorName\":\"Pooling\""                 "\"OperatorName\":\"FullyConnected\","
        "},"                                           "\"OperatorParams\":{\"kernel\":\"28\"}"
                                                       "}," 
      
        "{"                                            "{"
        "\"NodeName\":\"O1\","                         "\"NodeName\":\"O2\","   
        "\"PrevNode\":\"P1\","                         "\"PrevNode\":\"F2\","
        "\"OperatorName\":\"Output\""                  "\"OperatorName\":\"Output\""
        "},"                                           "},"
             
       "]"   
     "}"

Summator

The operator is designed to combine the values of two layers.
The consolidation can be performed by the following options: "summ", "diff", "mean".
The dimensions of the input layers must be the same.

"OperatorName":"Summator",
"OperatorParams":{
 "type":"summ"     Unification options: "summ", "diff" or "mean". Optional parameter
}        

Example:

stringstream ss;
    ss << "{"

    "\"BeginNet\":"
        "{"
        "\"NextNodes\":\"C1 F1\""
        "},"

        "\"Nodes\":"
        "["
      
        "{"                                            "{"
        "\"NodeName\":\"C1\","                         "\"NodeName\":\"F1\","
        "\"NextNodes\":\"P1\","                        "\"NextNodes\":\"F2\","
        "\"OperatorName\":\"Convolution\","            "\"OperatorName\":\"FullyConnected\","
        "\"OperatorParams\":{\"filters\":\"28\"}"       "\"OperatorParams\":{\"units\":\"28\"}"
        "},"                                           "},"
       
        "{"                                            "{"
        "\"NodeName\":\"P1\","                         "\"NodeName\":\"F2\","
        "\"NextNodes\":\"S1\","                        "\"NextNodes\":\"S1\","
        "\"OperatorName\":\"Pooling\""                 "\"OperatorName\":\"FullyConnected\","
        "},"                                           "\"OperatorParams\":{\"units\":\"28\"}"
                                                       "},"
        "{"
        "\"NodeName\":\"S1\","
        "\"NextNodes\":\"EndNet\","
        "\"OperatorName\":\"Summator\""
        "},"
     
      "]"

  "\"EndNet\":"
        "{"
        "\"PrevNode\":\"S1\""
        "},"
  "}"

Crop

ROI clipping in each image of each channel.

"OperatorName":"Crop",
"OperatorParams":{
 "roi":"x y w h"     region of interest, through a space. Required parameter (> 0)
}     

Concat

The operator connects the channels with multiple layers.

"OperatorName":"Concat",
"OperatorParams":{
  "sequence":"nd1 nd2.. "    gluing sequence
}      

Resize

Change the number of channels.
Works in conjunction with "Concat".

"OperatorName":"Resize",
"OperatorParams":{
 "fwdDiap: "0 0",     The range of layers to skip. Required parameter (> 0)
 "bwdDiap": "0 0"     
}     

Example:

stringstream ss;
    ss << "{"

    "\"BeginNet\":"
        "{"
        "\"NextNodes\":\"C1 C2\""
        "},"

        "\"Nodes\":"
        "["
      
        "{"                                            "{"
        "\"NodeName\":\"C1\","                         "\"NodeName\":\"C2\","
        "\"NextNodes\":\"R1\","                        "\"NextNodes\":\"R2\","
        "\"OperatorName\":\"Convolution\","            "\"OperatorName\":\"Convolution\","
        "\"OperatorParams\":{\"filters\":\"40\"}"       "\"OperatorParams\":{\"filters\":\"40\"}"
        "},"                                           "},"
              
        "{"                                            "{"
        "\"NodeName\":\"R1\","                         "\"NodeName\":\"R2\","
        "\"NextNodes\":\"CT\","                        "\"NextNodes\":\"CT\","
        "\"OperatorName\":\"Resize\""                  "\"OperatorName\":\"Resize\","
        "\"OperatorParams\":{\"fwdDiap\":\"0 40\","    "\"OperatorParams\":{\"fwdDiap\":\"0 40\","
                             \"bwdDiap\":\"0 40\" }"                        \"bwdDiap\":\"40 80\" }"       
        "},"                                           "},"

        "{"
        "\"NodeName\":\"CT\","
        "\"NextNodes\":\"EndNet\","
        "\"OperatorName\":\"Concat\""
        "\"OperatorParams\":{\"sequence\":\"R1 R2\"}
        "},"
     
      "]"

  "\"EndNet\":"
        "{"
        "\"PrevNode\":\"CT\""
        "},"
  "}"

Activation

Activation function operator.

"OperatorName":"Activation",
"OperatorParams":{
   "active":"relu"        Activation function type. Optional parameter: relu, sigmoid, leakyRelu, elu, none
}     

BatchNorm

"OperatorName":"BatchNorm",
"OperatorParams":{}     

UserLayer

Custom layer.
CallBack is set by the user, the "snAddUserCallBack" function

"OperatorName":"UserLayer",
"OperatorParams":{
 "cbackName":" "     name callBack. Required parameter
}        

MNIST

#include <string>
#include <iostream>
#include <sstream>
#include <cstdlib>
#include <map>
#include <filesystem>

#include "../sunnet/sunnet.h"

#include "Lib/OpenCV_3.3.0/opencv2/core/core_c.h"
#include "Lib/OpenCV_3.3.0/opencv2/core/core.hpp"
#include "Lib/OpenCV_3.3.0/opencv2/imgproc/imgproc_c.h"
#include "Lib/OpenCV_3.3.0/opencv2/imgproc/imgproc.hpp"
#include "Lib/OpenCV_3.3.0/opencv2/highgui/highgui_c.h"
#include "Lib/OpenCV_3.3.0/opencv2/highgui/highgui.hpp"

using namespace std;

bool loadImage(string& imgPath, int classCnt, vector<vector<string>>& imgName, vector<int>& imgCntDir, map<string, cv::Mat>& images){

    for (int i = 0; i < classCnt; ++i){

        namespace fs = std::tr2::sys;

        if (!fs::exists(fs::path(imgPath + to_string(i) + "/"))) continue;

        fs::directory_iterator it(imgPath + to_string(i) + "/"); int cnt = 0;
        while (it != fs::directory_iterator()){

            fs::path p = it->path();
            if (fs::is_regular_file(p) && (p.extension() == ".png")){

                imgName[i].push_back(p.filename());
            }
            ++it;
            ++cnt;
        }

        imgCntDir[i] = cnt;
    }

    return true;
}

int main(int argc, char* argv[])
{
    namespace sn = SN_API;

    stringstream ss;

    ss << "{"

        "\"BeginNet\":"
        "{"
        "\"NextNodes\":\"C1\""
        "},"

        "\"Nodes\":"
        "["

        "{"
        "\"NodeName\":\"C1\","
        "\"NextNodes\":\"C2\","
        "\"OperatorName\":\"Convolution\","
        "\"OperatorParams\":{\"filters\":\"15\"}"
        "},"

        "{"
        "\"NodeName\":\"C2\","
        "\"NextNodes\":\"P1\","
        "\"OperatorName\":\"Convolution\","
        "\"OperatorParams\":{\"filters\":\"15\"}"
        "},"

        "{"
        "\"NodeName\":\"P1\","
        "\"NextNodes\":\"FC1\","
        "\"OperatorName\":\"Pooling\"       
        "},"

        "{"
        "\"NodeName\":\"FC1\","
        "\"NextNodes\":\"FC2\","
        "\"OperatorName\":\"FullyConnected\","
        "\"OperatorParams\":{\"units\":\"128\"}"
        "},"

        "{"
        "\"NodeName\":\"FC2\","
        "\"NextNodes\":\"LS\","
        "\"OperatorName\":\"FullyConnected\","
        "\"OperatorParams\":{\"units\":\"10\"}"
        "},"

        "{"
        "\"NodeName\":\"LS\","
        "\"NextNodes\":\"EndNet\","
        "\"OperatorName\":\"LossFunction\","
        "\"OperatorParams\":{\"loss\":\"softMaxToCrossEntropy\"}"
        "}"

        "],"
        "\"EndNet\":"
        "{"
        "\"PrevNode\":\"LS\""
        "}"
        "}";

    char err[256]{'\0'};
    auto snet = sn::snCreateNet(ss.str().c_str(), err);
    if (!snet){
        cout << "Error 'snCreateNet' " << err << endl;
        system("pause");
        return -1;
    }

    string imgPath = "c:\\cpp\\sunnet\\example\\mnist\\images\\";
  
    int batchSz = 100, classCnt = 10, w = 28, h = 28; float lr = 0.001F;
    vector<vector<string>> imgName(classCnt);
    vector<int> imgCntDir(classCnt);
    map<string, cv::Mat> images;       
   
    if (!loadImage(imgPath, classCnt, imgName, imgCntDir, images)){
        cout << "Error 'loadImage' imgPath: " << imgPath << endl;
        system("pause");
        return -1;
    }
             
    sn::snFloat* inLayer = new sn::snFloat[w * h * batchSz];
    sn::snFloat* targetLayer = new sn::snFloat[classCnt * batchSz];
    sn::snFloat* outLayer = new sn::snFloat[classCnt * batchSz];

    size_t sum_metric = 0;
    size_t num_inst = 0;
    float accuratSumm = 0;
    for (int k = 0; k < 1000; ++k){
               
        srand(clock());

        fill_n(targetLayer, classCnt * batchSz, 0.F);
       
        for (int i = 0; i < batchSz; ++i){

            // directory
            int ndir = rand() % classCnt;
            while (imgCntDir[ndir] == 0) ndir = rand() % classCnt;

            // image
            int nimg = rand() % imgCntDir[ndir];

            // read image
            cv::Mat img; string nm = imgName[ndir][nimg];
            if (images.find(nm) != images.end())
                img = images[nm];
            else{
                img = cv::imread(imgPath + to_string(ndir) + "/" + nm, CV_LOAD_IMAGE_UNCHANGED);
                images[nm] = img;
            }

            float* refData = inLayer + i * w * h;
            double mean = cv::mean(img)[0];
            size_t nr = img.rows, nc = img.cols;
            for (size_t r = 0; r < nr; ++r){
                uchar* pt = img.ptr<uchar>(r);
                for (size_t c = 0; c < nc; ++c)
                    refData[r * nc + c] = pt[c] - mean;
            } 

            float* tarData = targetLayer + i * classCnt;

            tarData[ndir] = 1;
        }

        // training
        float accurat = 0;
        sn::snTraining(snet,
                       lr,
                       sn::snLSize(w, h, 1, batchSz),
                       inLayer,
                       sn::snLSize(10, 1, 1, batchSz),
                       outLayer,
                       targetLayer,                       
                       &accurat);
          
        // calc error
        size_t accCnt = 0;
        for (int i = 0; i < batchSz; ++i){

            float* refTarget = targetLayer + i * classCnt;
            float* refOutput = outLayer + i * classCnt;

            int maxOutInx = distance(refOutput, max_element(refOutput, refOutput + classCnt)),
                maxTargInx = distance(refTarget, max_element(refTarget, refTarget + classCnt));

            if (maxTargInx == maxOutInx)
                ++accCnt;                       
        }

        accuratSumm += (accCnt * 1.F) / batchSz;

        cout << k << " accurate " << accuratSumm / k << endl;
    }

    system("pause");
    return 0;
}