Skip to content

Commit

Permalink
Finish attention mechanisms (#77)
Browse files Browse the repository at this point in the history
* wip

* wip

* wip

* wip

* wip

* Update avg_pooling.rs

Signed-off-by: Rui Campos <mail@ruicampos.org>

* Update dev-env-aws.yml

Signed-off-by: Rui Campos <mail@ruicampos.org>

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* refactor

* refactor

---------

Signed-off-by: Rui Campos <mail@ruicampos.org>
  • Loading branch information
RuiFilipeCampos authored Mar 11, 2024
1 parent be02679 commit 6a48598
Show file tree
Hide file tree
Showing 12 changed files with 482 additions and 199 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/dev-env-aws.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ jobs:
deploy-vscode:
name: Deploy development environment
timeout-minutes: 1500
needs:
- start-runner
runs-on: ${{ needs.start-runner.outputs.label }}
Expand Down Expand Up @@ -125,8 +126,7 @@ jobs:
curl -Lk 'https://code.visualstudio.com/sha/download?build=stable&os=cli-alpine-x64' --output vscode_cli.tar.gz
tar -xf vscode_cli.tar.gz
- name: Serve VSCode
continue-on-error: true
timeout-minutes: 999999999999
continue-on-error: true
run: ./code tunnel

stop-runner:
Expand Down
61 changes: 0 additions & 61 deletions pipelines/capacity.py

This file was deleted.

File renamed without changes.
46 changes: 46 additions & 0 deletions src/attention/avg_pooling.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@

use tch::nn::Module;
use tch::Tensor;



/// implicit zero paddings on both sides of the input. Can be a single number or a tuple (padH, padW). Default: 0
const DEFAULT_PADDING: i64 = 0;

/// when True, will use ceil instead of floor in the formula to compute the output shape. Default: False
const DEFAULT_CEIL_MODE: bool = false;

/// when True, will include the zero-padding in the averaging calculation. Default: True
const DEFAULT_COUNT_INCLUDE_PAD: bool = true;

/// if specified, it will be used as divisor, otherwise size of the pooling region will be used. Default: None
const DEFAULT_DIVISOR_OVERRIDE: core::option::Option<i64> = None;


/// Average pooling layer
#[derive(Debug)]
pub struct AvgPooling {
/// Size of the pooling region. Can be a single number or a tuple (kH, kW)
kernel_size: i64,
}


impl AvgPooling {
pub fn new(kernel_size: i64) -> Self {
AvgPooling { kernel_size }
}
}

impl Module for AvgPooling {
fn forward(&self, x_bcd: &Tensor) -> Tensor {
x_bcd.avg_pool2d(
self.kernel_size,
// stride of the pooling operation. Can be a single number or a tuple (sH, sW). Default: kernel_size
self.kernel_size,
DEFAULT_PADDING,
DEFAULT_CEIL_MODE,
DEFAULT_COUNT_INCLUDE_PAD,
DEFAULT_DIVISOR_OVERRIDE,
)
}
}
21 changes: 21 additions & 0 deletions src/attention/identity.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@

use tch::nn;
use tch::Tensor;



#[derive(Debug)]
pub struct Identity { }

impl Identity {
pub fn new() -> Self {
Identity { }
}
}

impl nn::Module for Identity {
fn forward(&self, x_bcd: &Tensor) -> Tensor {
x_bcd.g_mul_scalar(1)
}
}

1 change: 1 addition & 0 deletions src/attention/metric_tensor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
//! To be implemented with CUDA kernels
7 changes: 3 additions & 4 deletions src/attention/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
pub mod quadratic_form;




pub mod scaled_dot_product;
pub mod identity;
pub mod avg_pooling;
100 changes: 63 additions & 37 deletions src/attention/quadratic_form.rs
Original file line number Diff line number Diff line change
@@ -1,71 +1,93 @@

use tch::nn;
use tch::Tensor;


pub fn generate_init() -> nn::Init {
nn::Init::Randn { mean: 0., stdev: 1. }
}

#[derive(Debug)]
pub struct QuadraticAttention {
projections_1ndq: Tensor,
metric_tensors_1nqq: Tensor,
adapter_1pd: Tensor,
sqrt_q: f64,
cp: (i64, i64),
}


/// Performs self attention N times using the quadratic form $xW_nx.T$ where $W_n$ is a learnable matrix.
/// This is an early version of the metric self attention, where $W$ is forced to have the properties a metric tensor.
/// https://arxiv.org/abs/2111.11418 - evidence that any of the attention mechanisms might have similar performance
pub fn quadratic_self_attention_module(
vs_path: &nn::Path,
n: i64,
d: i64,
q: i64,
c: i64,
) -> impl nn::Module {

assert!(d % n == 0, "Embeddings dimension must be divisible by the requested number of heads.");
debug_assert_eq!(n*q, d);

let projections_1ndq = vs_path.var("projections_1ndq", &[1, n, d, q], generate_init());
let metric_tensors_1nqq = vs_path.var("metric_tensors_1nqq", &[1, n, q, q], generate_init());
let mixer_1dd = vs_path.var("mixer_1dd", &[1, d, d], generate_init());

debug_assert_eq!(projections_1ndq.size(), vec![1, n, d, q]);
debug_assert_eq!(metric_tensors_1nqq.size(), vec![1, n, q, q]);
debug_assert_eq!(mixer_1dd.size(), vec![1, d, d]);

let sqrt_q = f64::sqrt(q as f64);
impl QuadraticAttention {
pub fn new(
vs_path: &nn::Path,
number_of_heads: i64,
embedding_dimension: i64,
latent_dimension: i64,
sequence_length: i64,
) -> Self {

let n = number_of_heads;
let d = embedding_dimension;
let c = sequence_length;
let q = latent_dimension;
let p = latent_dimension*number_of_heads;

let projections_1ndq = vs_path.var("projections_1ndq", &[1, n, d, q], generate_init());
let metric_tensors_1nqq = vs_path.var("metric_tensors_1nqq", &[1, n, q, q], generate_init());
let adapter_1pd = vs_path.var("adapter_1pd", &[1, p, d], generate_init());

let sqrt_q = f64::sqrt(q as f64);
QuadraticAttention {
projections_1ndq,
metric_tensors_1nqq,
adapter_1pd,
sqrt_q,
cp: (c, p)
}
}
}


// Implement the nn::Module trait for QuadraticAttention.
impl nn::Module for QuadraticAttention {
fn forward(&self, x_bcd: &Tensor) -> Tensor {

nn::func(move |x_bcd: &tch::Tensor| {

let b = x_bcd.size()[0];
assert_eq!(x_bcd.size(), vec![b, c, d]);

// assert_eq!(x_bcd.size(), vec![b, c, d]);

// Apply n projections to the input
let x_b1cd = &x_bcd.unsqueeze(1);
let x_bncq = &x_b1cd.matmul(&projections_1ndq);
debug_assert_eq!(x_bncq.size(), vec![b, n, c, q]);

let x_bncq = &x_b1cd.matmul(&self.projections_1ndq);
// debug_assert_eq!(x_bncq.size(), vec![b, n, c, q]);

// Use n custom dot products to generate n score tables
let x_bnqc = &x_bncq.transpose(-1, -2);
let dotproducts_bncc = &x_bncq.matmul(&metric_tensors_1nqq.matmul(x_bnqc));
debug_assert!(dotproducts_bncc.size() == vec![b, n, c, c]);
let dotproducts_bncc = &x_bncq.matmul(&self.metric_tensors_1nqq.matmul(x_bnqc));
// debug_assert!(dotproducts_bncc.size() == vec![b, n, c, c]);

// From scaled dot product attention introduced in https://arxiv.org/abs/1706.03762
let scaled_dotproducts_bncc = &dotproducts_bncc.divide_scalar(sqrt_q);
let scaled_dotproducts_bncc = &dotproducts_bncc.divide_scalar(self.sqrt_q);

let softmaxed_scaled_dotproducts_bncc = &scaled_dotproducts_bncc.softmax(-1, tch::kind::Kind::Float);
let y_bnqc = &x_bncq.transpose(-1, -2).matmul(softmaxed_scaled_dotproducts_bncc);
debug_assert!(y_bnqc.size() == vec![b, n, q, c]);
// debug_assert!(y_bnqc.size() == vec![b, n, q, c]);

let y_bcd = &y_bnqc.reshape(x_bcd.size());
debug_assert!(y_bcd.size() == vec![b, c, d]);
let y_bcp = &y_bnqc.reshape(&[b, self.cp.0, self.cp.1]);
// debug_assert!(y_bcp.size() == vec![b, c, p]);

y_bcd.matmul(&mixer_1dd)
})
y_bcp.matmul(&self.adapter_1pd)
}
}






/*
#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -94,4 +116,8 @@ mod tests {
}
}
}
*/


Loading

0 comments on commit 6a48598

Please sign in to comment.