Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated C wrapper wrt. Torch v1.10 #61

Merged
merged 12 commits into from
Aug 8, 2024

Conversation

stemann
Copy link
Collaborator

@stemann stemann commented Aug 11, 2023

Updates the C wrapper based on ocaml-torch @ 0.14 - matching Torch v1.10 (current JLL-build)

Contributes to #54 - follow-up for #56

Notable included changes:

  • Updated C wrapper generator based on ocaml-torch @ 0.14 - matching Torch v1.10.
  • Updated C wrapper (generated part)
  • Updated C wrapper (manual part), torch_api.{cpp, h}
    • Re-implemented change of return type of functions to status code etc.
  • Renamed C wrapper from doeye_caml to torch_c_api
  • Restored Julia-specific additions
  • Conditional CUDA build - C wrapper can be built without CUDA.
  • Added buildkite stepsGitHub Actions workflow for building C wrapper
  • Dev. container: Added support for CUDA and CUDNN

The last two changes could be moved to a separate PR (to reduce number of changes in this PR).

To-do:

  • Clean-up the automatic part (remove comments - fix indentation)
  • Adapt the manual part, torch_api.{cpp, h}
  • The commit Changed torch_api.cpp to reduce diff should be removed before merging: It is meant to reduce the diff when reviewing - it's only a bunch of indentation changes etc. to make the diff smaller.

@zsz00
Copy link

zsz00 commented Oct 21, 2023

What's the status now?

@stemann
Copy link
Collaborator Author

stemann commented Oct 22, 2023

It's been a idle for a while, but status is summarised by these comments:

@stemann
Copy link
Collaborator Author

stemann commented Nov 11, 2023

See #54 for status.

@stemann stemann force-pushed the stemann/ocaml_torch_1.10 branch 19 times, most recently from 6f18889 to beb24f2 Compare November 17, 2023 10:38
A. Non-void, non-* methods.

Search/replace:

1. torch_api.h: ^(?!void)(\w+) (at.+)\( -> int $2($1 *,
2. torch_api.h: , \);$ -> );
3. torch_api.cpp: ^(?!void)(\w+) (at.+)\( -> int $2($1 *out__,
4. torch_api.cpp: , \) \{$ -> ) {

B. void-methods.

Search/replace

1. torch_api.{h,cpp}: ^void (at.+)\( -> int $1(

C. *-methods.

Search/replace
1. torch.api.h: ^(\w+ \*)(at.+)\( -> int $2($1*,
2. torch_api.cpp: ^(\w+ \*)(at.+)\( -> int $2($1*out__,

D. Implemented return status code

Replaced
```
^(\s*)  return new (.+)
\s*\)
\s*return nullptr;
```
with
```
$1  out__[0] = new $2
$1  return 0;
$1)
$1return 1;
```

E. Implemented return status code

1. Replaced
```
^(\s*)PROTECT\(return new (.+)\)
\s*return nullptr;
```
with
```
$1PROTECT(
$1  out__[0] = new $2
$1  return 0;
$1)
$1return 1;
```

F. Implemented return status code

1. Replaced
```
^(\s*)PROTECT\(return (.+)\)
```
with
```
$1PROTECT(
$1  out__[0] = $2
$1  return 0;
$1)
```

G. Implemented return status code

1. Replaced
```
^(\s*)  return (.+)
\s*\)
\s*return nullptr;
```
with
```
$1  out__[0] = $2
$1  return 0;
$1)
$1return 1;
```

H. Restored error handling

Handled caml_failwith by search/replace: Replaced:
```
$
^(\s*)  caml_failwith\((.+)
```
with:
```
 {
$1  myerr = strdup($2
$1  return 1;
$1}
```

I. Replaced
```
^(\s+)PROTECT\(
    return (.+)
  \)
```
with
```
$1PROTECT(
$1  out__[0] = $2
$1  return 0;
$1)
```

J. Manual implement return status code

K. Changed return code from -1 to 1 to reduce diff

L. Fixed a couple of warnings
@stemann stemann force-pushed the stemann/ocaml_torch_1.10 branch 2 times, most recently from 090f09e to aa302ad Compare April 29, 2024 10:28
Also, made CUDA build optional.

int at_empty_cache();
int at_no_grad(int flag);
int at_sync();
int at_from_blob(tensor *, void *data, int64_t *dims, int ndims, int64_t *strides, int nstrides, int dev);
Also:
* Dev. container: Updated for Torch 1.10.2
* Added /build to .gitignore
@stemann stemann force-pushed the stemann/ocaml_torch_1.10 branch 12 times, most recently from 7f5dbfa to aa10911 Compare April 29, 2024 12:13
@DhairyaLGandhi
Copy link
Member

Perhaps it would make sense to start merging some of the excellent changes here?

@stemann
Copy link
Collaborator Author

stemann commented Apr 29, 2024

Yes! :-)

Please give torch_api.{cpp,h} a thorough review - I have made some changes in an effort to make things a bit more consistent - e.g. wrt. always returning an int status code.

Edit: I'll try to go over it as well and try to make a re-cap of the changes.

@stemann stemann marked this pull request as ready for review April 29, 2024 12:42
@stemann
Copy link
Collaborator Author

stemann commented Apr 29, 2024

This is the current main diff which covers the hand-written part (torch_api.{cpp,h}): https://github.com/FluxML/Torch.jl/compare/7828132d..ece96546

@stemann
Copy link
Collaborator Author

stemann commented Apr 29, 2024

The overall aim was to to update for Torch v1.10.2 - but also to make it easier to apply a diff of changes for subsequent version updates...

@stemann
Copy link
Collaborator Author

stemann commented Apr 29, 2024

Recap:

General

  • Removed argument names from function definitions in torch_api.h.
  • Definitions in torch_api.h were updated to match the implementation in torch_api.cpp - e.g. wrt. return value in the first argument, and return code as return value.

Modified function definitions

int at_float_vec(double *values, int value_len, int type);
int at_int_vec(int64_t *values, int value_len, int type);
int at_grad_set_enabled(int);
int at_int64_value_at_indexes(double *i, tensor, int *indexes, int indexes_len);
tensor at_load(char *filename);
int ato_adam(optimizer *, double learning_rate,
                   double beta1,
                   double beta2,
                   double weight_decay);
int atm_load(char *, module *);

Added function definitions

int at_is_sparse(int *, tensor)
int at_device(int *, tensor)
int at_stride(tensor, int *)
int at_autocast_clear_cache();
int at_autocast_decrement_nesting(int *);
int at_autocast_increment_nesting(int *);
int at_autocast_is_enabled(int *);
int at_autocast_set_enabled(int *, int b);
int at_to_string(char **, tensor, int line_size)
int at_get_num_threads(int *);
int at_set_num_threads(int n_threads);
int ati_none(ivalue *);
int ati_bool(ivalue *, int);
int ati_string(ivalue *, char *);
int ati_tuple(ivalue *, ivalue *, int);
int ati_generic_list(ivalue *, ivalue *, int);
int ati_generic_dict(ivalue *, ivalue *, int);
int ati_int_list(ivalue *, int64_t *, int);
int ati_double_list(ivalue *, double *, int);
int ati_bool_list(ivalue *, char *, int);
int ati_string_list(ivalue *, char **, int);
int ati_tensor_list(ivalue *, tensor *, int);
int ati_to_string(char **, ivalue);
int ati_to_bool(int *, ivalue);
int ati_length(int *, ivalue);
int ati_to_generic_list(ivalue, ivalue *, int);
int ati_to_generic_dict(ivalue, ivalue *, int);
int ati_to_int_list(ivalue, int64_t *, int);
int ati_to_double_list(ivalue, double *, int);
int ati_to_bool_list(ivalue, char *, int);
int ati_to_tensor_list(ivalue, tensor *, int);

@stemann
Copy link
Collaborator Author

stemann commented Jul 9, 2024

@DhairyaLGandhi Do you know if anyone is available for reviewing these changes?

I'm at JuliaCon, FYI.

@ToucheSir ToucheSir merged commit d1711d7 into FluxML:master Aug 8, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants