From 6fdfb0ca5d40266b1376676858a5c6a8e5a9b616 Mon Sep 17 00:00:00 2001 From: Flavio Schneider Date: Tue, 19 Apr 2022 01:54:06 +0200 Subject: [PATCH] feat: added SMARTLoss implementation --- .gitignore | 1 + setup.py | 29 ++++++++++++++++++ smart_pytorch/smart_pytorch.py | 56 ++++++++++++++++++++++++++++++++++ 3 files changed, 86 insertions(+) create mode 100644 .gitignore create mode 100644 setup.py create mode 100644 smart_pytorch/smart_pytorch.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ed8ebf5 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +__pycache__ \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..a4833b3 --- /dev/null +++ b/setup.py @@ -0,0 +1,29 @@ +from setuptools import setup, find_packages + +setup( + name = 'smart-pytorch', + packages = find_packages(exclude=[]), + version = '0.0.1', + license='MIT', + description = 'SMART Fine-Tuning - Pytorch', + author = 'Flavio Schneider', + author_email = 'archinetai@protonmail.com', + url = 'https://github.com/archinetai/smart-pytorch', + keywords = [ + 'artificial intelligence', + 'deep learning', + 'fine-tuning', + 'pre-trained', + ], + install_requires=[ + 'torch>=1.6', + 'data-science-types>=0.2' + ], + classifiers=[ + 'Development Status :: 4 - Beta', + 'Intended Audience :: Developers', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + 'License :: OSI Approved :: MIT License', + 'Programming Language :: Python :: 3.6', + ], +) \ No newline at end of file diff --git a/smart_pytorch/smart_pytorch.py b/smart_pytorch/smart_pytorch.py new file mode 100644 index 0000000..e33ce64 --- /dev/null +++ b/smart_pytorch/smart_pytorch.py @@ -0,0 +1,56 @@ +from typing import List, Union, Callable + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +def inf_norm(x): + return torch.norm(x, p=float('inf'), dim=-1, keepdim=True) + +def to_list(x): + return x if isinstance(x, list) else [x] + +class SMARTLoss(nn.Module): + + def __init__( + self, + eval_fn: Callable, + loss_fn: Union[Callable, List[Callable]], + norm_fn: Callable = inf_norm, + num_steps: int = 1, + step_size: float = 1e-3, + epsilon: float = 1e-6, + noise_var: float = 1e-5 + ) -> None: + super().__init__() + self.eval_fn = eval_fn + self.loss_fn = to_list(loss_fn) + self.norm_fn = norm_fn + self.num_steps = num_steps + self.step_size = step_size + self.epsilon = epsilon + self.noise_var = noise_var + + def forward(self, embed: Tensor, state: Union[Tensor, List[Tensor]]) -> Tensor: + states = to_list(state) + noise = torch.randn_like(embed, requires_grad=True) * self.noise_var + + for i in range(self.num_steps + 2): + # Compute perturbed states + embed_perturbed = embed + noise + states_perturbed = to_list(self.eval_fn(embed_perturbed)) + loss = 0 + # Compute perturbation loss over all states + for j in range(len(states)): + loss += self.loss_fn[j](states_perturbed[j], states[j].detach()) + if i == self.num_steps + 1: + return loss + # Compute noise gradient + noise_gradient = torch.autograd.grad(loss, noise)[0] + # Move noise towards gradient to change state as much as possible + step = noise + self.step_size * noise_gradient + step_norm = self.norm_fn(step) + noise = step / (step_norm + self.epsilon) + # Reset noise gradients for next step + noise = noise.detach().requires_grad_() \ No newline at end of file