-
Notifications
You must be signed in to change notification settings - Fork 0
/
cuda.cuh
43 lines (24 loc) · 889 Bytes
/
cuda.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
typedef struct {
int rows;
int cols;
double* data;
} Matrix;
Matrix* matrix_create(int rows, int cols);
Matrix* matrix_from_shape(Matrix* m);
Matrix* matrix_from_data(int rows, int cols, double* data);
void matrix_free(Matrix* m);
Matrix* matrix_dot(Matrix* a, Matrix* b);
Matrix* matrix_transpose(Matrix* m);
Matrix* matrix_add(Matrix* a, Matrix* b);
Matrix* matrix_sub(Matrix* a, Matrix* b);
Matrix* matrix_mul(Matrix* a, Matrix* b);
Matrix* matrix_divf(Matrix* a, double f);
Matrix* matrix_subf(Matrix* m, double f);
Matrix* matrix_mulf(Matrix* m, double f);
double matrix_sum(Matrix* m);
void matrix_print(char* label, Matrix* m, int y, int x);
Matrix* matrix_rand(int rows, int cols);
Matrix* matrix_relu(Matrix* m);
Matrix* matrix_softmax(Matrix* m);
Matrix* matrix_one_hot(const double* Y, int len);
Matrix* matrix_cols(Matrix* m, int start, int end);