-
Notifications
You must be signed in to change notification settings - Fork 2
/
conv_layer.c
89 lines (63 loc) · 3.02 KB
/
conv_layer.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#include "conv_layer.h"
#include <stdio.h>
conv_layer* conv_alloc(int in_c, int in_w, int in_h, int out_c, int f_size, int stride, int padd) {
conv_layer *layer = aalloc(sizeof(*layer));
layer->out_c = out_c;
layer->stride = stride;
layer->f_size = f_size;
layer->in_c = in_c;
layer->padd = padd;
layer->in_w = in_w;
layer->in_h = in_h;
layer->col_w = ((in_w - f_size + (2 * padd)) / stride) + 1;
layer->col_h = ((in_h - f_size + (2 * padd)) / stride) + 1;
layer->in_dim = in_w * in_h * in_c;
layer->out_dim = layer->col_w * layer->col_h * out_c;
layer->col_c = in_c * f_size * f_size;
layer->col_dim = layer->col_c * layer->col_w * layer->col_h;
layer->filters = matrix_alloc(out_c, layer->col_c);
layer->input_col = NULL;
randomize(layer->filters, 0.0f, sqrtf(2.0f / (float)layer->in_dim));
return layer;
}
void conv_free(conv_layer *layer) {
matrix_free(layer->filters);
matrix_free(layer->input_col);
free(layer);
}
matrix* conv_forward(conv_layer *layer, matrix *raw_input) {
matrix_free(layer->input_col);
matrix *out = matrix_alloc(raw_input->rows, layer->out_dim);
layer->input_col = matrix_alloc(raw_input->rows, layer->col_dim);
int out_cols = layer->col_w * layer->col_h;
for (int i = 0; i < raw_input->rows; i++) {
float *in_row = raw_input->data + i * layer->in_dim;
float *col_row = layer->input_col->data + i * layer->col_dim;
float *out_row = out->data + i * layer->out_dim;
iam2cool(in_row, layer->in_c, layer->in_w, layer->in_h, layer->f_size, layer->stride,
layer->padd, layer->col_w, layer->col_h, layer->col_c, col_row);
gemm(false, false, layer->filters->rows, out_cols, layer->filters->columns, layer->filters->data, col_row, 0.0f, out_row);
}
return out;
}
matrix* conv_backward(conv_layer *layer, matrix *dout, float l_rate) {
assert(dout->rows == layer->input_col->rows);
int batch_size = layer->input_col->rows;
int spatial = layer->col_w * layer->col_h;
float *dcol = aalloc(sizeof(float) * layer->col_dim);
matrix *dfilters = matrix_alloc(layer->filters->rows, layer->filters->columns);
matrix *dinput = matrix_alloc(batch_size, layer->in_dim);
for (int i = 0; i < batch_size; i++) {
float *dout_row = dout->data + i * layer->out_dim;
float *col_row = layer->input_col->data + i * layer->col_dim;
float *din_row = dinput->data + i * layer->in_dim;
gemm(false, true, layer->out_c, layer->col_c, spatial, dout_row, col_row, 1.0f, dfilters->data);
gemm(true, false, layer->filters->columns, spatial, layer->filters->rows, layer->filters->data, dout_row, 0.0f, dcol);
cool2ami(dcol, layer->in_c, layer->in_w, layer->in_h, layer->f_size, layer->stride,
layer->padd, layer->col_w, layer->col_h, layer->col_c, din_row);
}
apply_sum(layer->filters, dfilters, -l_rate / (float)batch_size);
free(dcol);
matrix_free(dfilters);
return dinput;
}