Skip to content

Commit

Permalink
Merge pull request #362 from girder/deepssm-ui
Browse files Browse the repository at this point in the history
DeepSSM UI
  • Loading branch information
JakeWags authored Apr 12, 2024
2 parents 94ed16e + 45ab369 commit 9664733
Show file tree
Hide file tree
Showing 23 changed files with 878 additions and 243 deletions.
67 changes: 53 additions & 14 deletions shapeworks_cloud/core/deepssm_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@
from pathlib import Path
from tempfile import TemporaryDirectory

import DataAugmentationUtils
import DeepSSMUtils
from celery import shared_task
from django.conf import settings
from django.contrib.auth.models import User
from rest_framework.authtoken.models import Token
import shapeworks as sw

from shapeworks_cloud.core import models
from swcc.api import swcc_session
Expand All @@ -18,6 +15,9 @@


def run_prep(params, project, project_file, progress):
import DeepSSMUtils
import shapeworks as sw

# //////////////////////////////////////////////
# /// STEP 1: Create Split
# //////////////////////////////////////////////
Expand All @@ -31,6 +31,7 @@ def run_prep(params, project, project_file, progress):
# /// STEP 2: Groom Training Shapes
# /////////////////////////////////////////////////////////////////
project_params = project.get_parameters('groom')
# alignment should always be set to ICP
project_params.set('alignment_method', 'Iterative Closest Point')
project_params.set('alignment_enabled', 'true')
project.set_parameters('groom', project_params)
Expand All @@ -41,13 +42,6 @@ def run_prep(params, project, project_file, progress):
# /////////////////////////////////////////////////////////////////
# /// STEP 3: Optimize Training Particles
# /////////////////////////////////////////////////////////////////

# set num_particles to 16 and iterations_per_split to 1
project_params = project.get_parameters('optimize')
project_params.set('number_of_particles', '16')
project_params.set('iterations_per_split', '1')
project.set_parameters('optimize', project_params)

DeepSSMUtils.optimize_training_particles(project)
project.save(project_file)
progress.update_percentage(12)
Expand Down Expand Up @@ -96,12 +90,14 @@ def run_prep(params, project, project_file, progress):


def run_augmentation(params, project, download_dir, progress):
import DataAugmentationUtils
import DeepSSMUtils

# /////////////////////////////////////////////////////////////////
# /// STEP 7: Augment Data
# /////////////////////////////////////////////////////////////////
num_samples = int(params['aug_num_samples'])
percent_variability = float(params['percent_variability']) / 100.0
# aug_sampler_type to lowecase
percent_variability = float(params['percent_variability'])
aug_sampler_type = params['aug_sampler_type'].lower()

num_dims = 0 # set to 0 to allow for percent variability to be used
Expand All @@ -127,6 +123,8 @@ def run_augmentation(params, project, download_dir, progress):


def run_training(params, project, download_dir, aug_dims, progress):
import DeepSSMUtils

batch_size = int(params['train_batch_size'])

# /////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -176,6 +174,8 @@ def run_training(params, project, download_dir, aug_dims, progress):


def run_testing(params, project, download_dir, progress):
import DeepSSMUtils

test_indices = DeepSSMUtils.get_split_indices(project, 'test')

# /////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -219,6 +219,8 @@ def run_deepssm_command(
post_command_function,
progress_id,
):
import shapeworks as sw

user = User.objects.get(id=user_id)
progress = models.TaskProgress.objects.get(id=progress_id)
token, _created = Token.objects.get_or_create(user=user)
Expand Down Expand Up @@ -256,6 +258,21 @@ def run_deepssm_command(

sw_project.load(sw_project_file)

groom_params = sw_project.get_parameters('groom')

# for each parameter in the form data, set the parameter in the project
for key, value in form_data.items():
groom_params.set(key, value)

sw_project.set_parameters('groom', groom_params)

optimize_params = sw_project.get_parameters('optimize')
# for each parameter in the form data, set the parameter in the project
for key, value in form_data.items():
optimize_params.set(key, value)

sw_project.set_parameters('optimize', optimize_params)

os.chdir(sw_project.get_project_path())
run_prep(form_data, sw_project, sw_project_file, progress)

Expand All @@ -265,6 +282,9 @@ def run_deepssm_command(
result_data['augmentation'] = {
'total_data_csv': download_dir + '/deepssm/augmentation/TotalData.csv',
'violin_plot': download_dir + '/deepssm/augmentation/violin.png',
'generated_meshes': os.listdir(
download_dir + '/deepssm/augmentation/Generated-Meshes/'
),
'generated_images': os.listdir(
download_dir + '/deepssm/augmentation/Generated-Images/'
),
Expand Down Expand Up @@ -293,6 +313,8 @@ def run_deepssm_command(

run_testing(form_data, sw_project, download_dir, progress)

subjects = sw_project.get_subjects()

result_data['testing'] = {
'world_predictions': os.listdir(
download_dir + '/deepssm/model/test_predictions/world_predictions/'
Expand All @@ -301,6 +323,7 @@ def run_deepssm_command(
download_dir + '/deepssm/model/test_predictions/local_predictions/'
),
'test_distances': download_dir + '/deepssm/test_distances.csv',
'test_split_subjects': subjects,
}

os.chdir('../../')
Expand Down Expand Up @@ -352,6 +375,15 @@ def post_command_function(project, download_dir, result_data, project_filename):
),
)
aug_pair.mesh.save(
result_data['augmentation']['generated_meshes'][i],
open(
download_dir
+ '/deepssm/augmentation/Generated-Meshes/'
+ result_data['augmentation']['generated_meshes'][i],
'rb',
),
)
aug_pair.image.save(
result_data['augmentation']['generated_images'][i],
open(
download_dir
Expand Down Expand Up @@ -402,10 +434,15 @@ def post_command_function(project, download_dir, result_data, project_filename):
file1 = predictions.pop()
filename = file1.split('.')[0]

# filename here represents the SUBJECT INDEX OF THE TEST SPLIT
subject_name = result_data['testing']['test_split_subjects'][
int(filename)
].get_display_name()

test_pair = models.DeepSSMTestingData.objects.create(
project=project,
image_type='world' if predictions == world_predictions else 'local',
image_id=filename,
image_id=subject_name,
)

predictions_path = (
Expand Down Expand Up @@ -457,9 +494,11 @@ def post_command_function(project, download_dir, result_data, project_filename):
for images in [train_images, val_and_test_images]:
for image in images:
image_type = 'train' if images == train_images else 'val_and_test'

train_image = models.DeepSSMTrainingImage.objects.create(
project=project,
validation=True if image_type == 'val_and_test' else False,
index=image.split('.')[0],
)
train_image.image.save(
image,
Expand Down Expand Up @@ -514,7 +553,7 @@ def post_command_function(project, download_dir, result_data, project_filename):
),
)

training_pair.vtk.save(
training_pair.mesh.save(
vtk_file,
open(
download_dir + '/deepssm/model/examples/' + vtk_file,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Generated by Django 3.2.25 on 2024-04-08 18:19
# Generated by Django 3.2.25 on 2024-04-11 16:23

from django.db import migrations, models
import django.db.models.deletion
Expand All @@ -23,7 +23,7 @@ class Migration(migrations.Migration):
),
('particles', s3_file_field.fields.S3FileField()),
('scalar', s3_file_field.fields.S3FileField()),
('vtk', s3_file_field.fields.S3FileField()),
('mesh', s3_file_field.fields.S3FileField()),
('index', models.CharField(max_length=255)),
('example_type', models.CharField(max_length=255)),
('validation', models.BooleanField(default=False)),
Expand All @@ -47,6 +47,7 @@ class Migration(migrations.Migration):
),
),
('image', s3_file_field.fields.S3FileField()),
('index', models.CharField(max_length=255)),
('validation', models.BooleanField(default=False)),
(
'project',
Expand All @@ -68,7 +69,7 @@ class Migration(migrations.Migration):
),
),
('image_type', models.CharField(max_length=255)),
('image_id', models.IntegerField()),
('image_id', models.CharField(max_length=255)),
('mesh', s3_file_field.fields.S3FileField()),
('particles', s3_file_field.fields.S3FileField()),
(
Expand Down Expand Up @@ -116,6 +117,7 @@ class Migration(migrations.Migration):
),
),
('sample_num', models.IntegerField()),
('image', s3_file_field.fields.S3FileField()),
('mesh', s3_file_field.fields.S3FileField()),
('particles', s3_file_field.fields.S3FileField()),
(
Expand Down
8 changes: 5 additions & 3 deletions shapeworks_cloud/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ class DeepSSMTestingData(models.Model):
Project, on_delete=models.CASCADE, related_name='deepssm_testing_data'
)
image_type = models.CharField(max_length=255)
image_id = models.IntegerField()
image_id = models.CharField(max_length=255)
mesh = S3FileField()
particles = S3FileField()

Expand All @@ -212,8 +212,8 @@ class DeepSSMTrainingPair(models.Model):
)
particles = S3FileField() # .particles
scalar = S3FileField() # .scalar
vtk = S3FileField() # .vtk
index = models.CharField(max_length=255) # subject
mesh = S3FileField() # .vtk
index = models.CharField(max_length=255) # index
example_type = models.CharField(max_length=255) # best, median, worst
validation = models.BooleanField(default=False)

Expand All @@ -223,12 +223,14 @@ class DeepSSMTrainingImage(models.Model):
Project, on_delete=models.CASCADE, related_name='deepssm_training_images'
)
image = S3FileField()
index = models.CharField(max_length=255) # index
validation = models.BooleanField(default=False)


class DeepSSMAugPair(models.Model):
project = models.ForeignKey(Project, on_delete=models.CASCADE, related_name='deepssm_aug_pair')
sample_num = models.IntegerField()
image = S3FileField()
mesh = S3FileField()
particles = S3FileField()

Expand Down
8 changes: 3 additions & 5 deletions shapeworks_cloud/core/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from rest_framework.viewsets import GenericViewSet

from . import filters, models, serializers
from .deepssm_tasks import deepssm_run
from .tasks import analyze, groom, optimize

DB_WRITE_ACCESS_LOG_FILE = Path(gettempdir(), 'logging', 'db_write_access.log')
Expand Down Expand Up @@ -50,8 +51,8 @@ def save_thumbnail_image(target, encoded_thumbnail):


class Pagination(PageNumberPagination):
page_size = 25
max_page_size = 100
page_size = 100
max_page_size = 200
page_size_query_param = 'page_size'


Expand Down Expand Up @@ -517,9 +518,6 @@ def analyze(self, request, **kwargs):
methods=['POST'],
)
def deepssm_run(self, request, **kwargs):
# lazy import; requires conda shapeworks env activation
from .deepssm_tasks import deepssm_run

project = self.get_object()
form_data = request.data
form_data = {k: str(v) for k, v in form_data.items()}
Expand Down
10 changes: 7 additions & 3 deletions shapeworks_cloud/core/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class DeepSSMTrainingPairSerializer(serializers.ModelSerializer):
validation = serializers.BooleanField()
particles = S3FileSerializerField()
scalar = S3FileSerializerField()
vtk = S3FileSerializerField()
mesh = S3FileSerializerField()
index = serializers.CharField(max_length=255)

class Meta:
Expand All @@ -87,6 +87,7 @@ class Meta:
class DeepSSMTrainingImageSerializer(serializers.ModelSerializer):
project = ProjectSerializer()
image = S3FileSerializerField()
index = serializers.CharField(max_length=255)
validation = serializers.BooleanField()

class Meta:
Expand All @@ -98,6 +99,7 @@ class DeepSSMAugPairSerializer(serializers.ModelSerializer):
project = ProjectSerializer()
sample_num = serializers.IntegerField()
mesh = S3FileSerializerField()
image = S3FileSerializerField()
particles = S3FileSerializerField()

class Meta:
Expand All @@ -122,7 +124,7 @@ class Meta:
class DeepSSMTestingDataReadSerializer(serializers.ModelSerializer):
project = ProjectSerializer()
image_type = serializers.CharField(max_length=255)
image_id = serializers.IntegerField()
image_id = serializers.CharField(max_length=255)
mesh = S3FileSerializerField()
particles = S3FileSerializerField()

Expand All @@ -137,7 +139,7 @@ class DeepSSMTrainingPairReadSerializer(serializers.ModelSerializer):
validation = serializers.BooleanField()
particles = S3FileSerializerField()
scalar = S3FileSerializerField()
vtk = S3FileSerializerField()
mesh = S3FileSerializerField()
index = serializers.CharField(max_length=255)

class Meta:
Expand All @@ -148,6 +150,7 @@ class Meta:
class DeepSSMTrainingImageReadSerializer(serializers.ModelSerializer):
project = ProjectSerializer()
image = S3FileSerializerField()
index = serializers.CharField(max_length=255)
validation = serializers.BooleanField()

class Meta:
Expand All @@ -158,6 +161,7 @@ class Meta:
class DeepSSMAugPairReadSerializer(serializers.ModelSerializer):
project = ProjectSerializer()
mesh = S3FileSerializerField()
image = S3FileSerializerField()
particles = S3FileSerializerField()
sample_num = serializers.IntegerField()

Expand Down
Loading

0 comments on commit 9664733

Please sign in to comment.