-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* wip * wip * wip * wip * wip * Update avg_pooling.rs Signed-off-by: Rui Campos <mail@ruicampos.org> * Update dev-env-aws.yml Signed-off-by: Rui Campos <mail@ruicampos.org> * wip * wip * wip * wip * wip * wip * wip * refactor * refactor --------- Signed-off-by: Rui Campos <mail@ruicampos.org>
- Loading branch information
1 parent
be02679
commit 6a48598
Showing
12 changed files
with
482 additions
and
199 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
|
||
use tch::nn::Module; | ||
use tch::Tensor; | ||
|
||
|
||
|
||
/// implicit zero paddings on both sides of the input. Can be a single number or a tuple (padH, padW). Default: 0 | ||
const DEFAULT_PADDING: i64 = 0; | ||
|
||
/// when True, will use ceil instead of floor in the formula to compute the output shape. Default: False | ||
const DEFAULT_CEIL_MODE: bool = false; | ||
|
||
/// when True, will include the zero-padding in the averaging calculation. Default: True | ||
const DEFAULT_COUNT_INCLUDE_PAD: bool = true; | ||
|
||
/// if specified, it will be used as divisor, otherwise size of the pooling region will be used. Default: None | ||
const DEFAULT_DIVISOR_OVERRIDE: core::option::Option<i64> = None; | ||
|
||
|
||
/// Average pooling layer | ||
#[derive(Debug)] | ||
pub struct AvgPooling { | ||
/// Size of the pooling region. Can be a single number or a tuple (kH, kW) | ||
kernel_size: i64, | ||
} | ||
|
||
|
||
impl AvgPooling { | ||
pub fn new(kernel_size: i64) -> Self { | ||
AvgPooling { kernel_size } | ||
} | ||
} | ||
|
||
impl Module for AvgPooling { | ||
fn forward(&self, x_bcd: &Tensor) -> Tensor { | ||
x_bcd.avg_pool2d( | ||
self.kernel_size, | ||
// stride of the pooling operation. Can be a single number or a tuple (sH, sW). Default: kernel_size | ||
self.kernel_size, | ||
DEFAULT_PADDING, | ||
DEFAULT_CEIL_MODE, | ||
DEFAULT_COUNT_INCLUDE_PAD, | ||
DEFAULT_DIVISOR_OVERRIDE, | ||
) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
|
||
use tch::nn; | ||
use tch::Tensor; | ||
|
||
|
||
|
||
#[derive(Debug)] | ||
pub struct Identity { } | ||
|
||
impl Identity { | ||
pub fn new() -> Self { | ||
Identity { } | ||
} | ||
} | ||
|
||
impl nn::Module for Identity { | ||
fn forward(&self, x_bcd: &Tensor) -> Tensor { | ||
x_bcd.g_mul_scalar(1) | ||
} | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
//! To be implemented with CUDA kernels |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,4 @@ | ||
pub mod quadratic_form; | ||
|
||
|
||
|
||
|
||
pub mod scaled_dot_product; | ||
pub mod identity; | ||
pub mod avg_pooling; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.