Skip to content

Commit

Permalink
Task Progress Updates (#375)
Browse files Browse the repository at this point in the history
* Save form data to TaskProgress and add filtered list endpoint

* Improve updates to TaskProgress and abort method  during task

* Improve progress polling in web client

* Refactor forms to use new TaskInfo component for TaskProgress handling

* Read form defaults from project file if available

* Add task messages for deepssm progress and fixed form data

* Add disabled prop for task-info component

* Update bad double qoutes

* Add more initial progress message updates to deepssm task

* remove whitespace

---------

Co-authored-by: Jake Wagoner <jakewagoneredu@gmail.com>
  • Loading branch information
annehaley and JakeWags authored May 8, 2024
1 parent 949be92 commit cae1cfe
Show file tree
Hide file tree
Showing 16 changed files with 778 additions and 681 deletions.
200 changes: 115 additions & 85 deletions shapeworks_cloud/core/deepssm_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def run_prep(params, project, project_file, progress):
# /////////////////////////////////////////////////////////////////
# /// STEP 2: Groom Training Shapes
# /////////////////////////////////////////////////////////////////
progress.update_message('Grooming Training Shapes...')

project_params = project.get_parameters('groom')
# alignment should always be set to ICP
project_params.set('alignment_method', 'Iterative Closest Point')
Expand All @@ -39,23 +41,27 @@ def run_prep(params, project, project_file, progress):
DeepSSMUtils.groom_training_shapes(project)
project.save(project_file)

progress.update_percentage(11)

# /////////////////////////////////////////////////////////////////
# /// STEP 3: Optimize Training Particles
# /////////////////////////////////////////////////////////////////
progress.update_message('Optimizing Training Particles...')
DeepSSMUtils.optimize_training_particles(project)
project.save(project_file)
progress.update_percentage(12)

# /////////////////////////////////////////////////////////////////
# /// STEP 4: Groom Training Images
# /////////////////////////////////////////////////////////////////
print('Grooming training images')
progress.update_message('Grooming Training Images...')
DeepSSMUtils.groom_training_images(project)
project.save(project_file)

# /////////////////////////////////////////////////////////////////
# /// STEP 5: Groom Validation Images
# /////////////////////////////////////////////////////////////////
progress.update_message('Grooming Validation Images...')
val_indices = DeepSSMUtils.get_split_indices(project, 'val')
test_indices = DeepSSMUtils.get_split_indices(project, 'test')
val_test_indices = val_indices + test_indices
Expand All @@ -77,10 +83,12 @@ def run_prep(params, project, project_file, progress):

project.set_parameters('optimize', project_params)

progress.update_message('Grooming Validation Shapes...')
DeepSSMUtils.groom_validation_shapes(project)
project.save(project_file)
progress.update_percentage(17)

progress.update_message('Optimizing Validation Particles...')
optimize = sw.Optimize()
optimize.SetUpOptimize(project)
optimize.Run()
Expand All @@ -102,6 +110,7 @@ def run_augmentation(params, project, download_dir, progress):

num_dims = 0 # set to 0 to allow for percent variability to be used

progress.update_message('Running Data Augmentation...')
embedded_dims = DeepSSMUtils.run_data_augmentation(
project,
num_samples,
Expand All @@ -111,6 +120,7 @@ def run_augmentation(params, project, download_dir, progress):
mixture_num=0,
processes=1, # Thread count
)
progress.update_message('Generating Augmentation visualizations...')
progress.update_percentage(25)

aug_dir = download_dir + '/deepssm/augmentation/'
Expand All @@ -130,6 +140,7 @@ def run_training(params, project, download_dir, aug_dims, progress):
# /////////////////////////////////////////////////////////////////
# /// STEP 8: Create PyTorch loaders from data
# /////////////////////////////////////////////////////////////////
progress.update_message('Preparing Training Data Loaders...')
DeepSSMUtils.prepare_data_loaders(project, batch_size, 'train')
DeepSSMUtils.prepare_data_loaders(project, batch_size, 'val')
progress.update_percentage(35)
Expand Down Expand Up @@ -169,6 +180,7 @@ def run_training(params, project, download_dir, aug_dims, progress):
)
progress.update_percentage(40)

progress.update_message('Training DeepSSM Model...')
DeepSSMUtils.trainDeepSSM(project, config_file)
progress.update_percentage(50)

Expand All @@ -181,12 +193,14 @@ def run_testing(params, project, download_dir, progress):
# /////////////////////////////////////////////////////////////////
# /// STEP 10: Groom Testing Images
# /////////////////////////////////////////////////////////////////
progress.update_message('Grooming Testing Images...')
DeepSSMUtils.groom_val_test_images(project, test_indices)
progress.update_percentage(55)

# /////////////////////////////////////////////////////////////////
# /// STEP 11: Prepare Test Data PyTorch Loaders
# /////////////////////////////////////////////////////////////////
progress.update_message('Preparing Testing Data Loaders...')
batch_size = int(params['train_batch_size'])
DeepSSMUtils.prepare_data_loaders(project, batch_size, 'test')

Expand All @@ -204,8 +218,11 @@ def run_testing(params, project, download_dir, progress):
# /////////////////////////////////////////////////////////////////
# /// STEP 12: Test DeepSSM Model
# /////////////////////////////////////////////////////////////////
progress.update_message('Testing DeepSSM Model...')
DeepSSMUtils.testDeepSSM(config_file)

progress.update_messsage('Processing Test Predictions...')
progress.update_percentage(75)
DeepSSMUtils.process_test_predictions(project, config_file)
progress.update_percentage(90)

Expand All @@ -226,110 +243,123 @@ def run_deepssm_command(
token, _created = Token.objects.get_or_create(user=user)
base_url = settings.API_URL # type: ignore

with TemporaryDirectory() as download_dir:
with swcc_session(base_url=base_url) as session:
# fetch everything we need
session.set_token(token.key)
project = models.Project.objects.get(id=project_id)
print('setting project filename')
project_filename = project.file.name.split('/')[-1]
print('project filename set', project_filename)
swcc_project = SWCCProject.from_id(project.id)
swcc_project.download(download_dir)
print('Downloaded swcc project')

pre_command_function()
progress.update_percentage(10)

if form_data:
# write the form data to the project file
edit_swproj_section(
Path(download_dir, project_filename),
'deepssm',
form_data,
)

result_data = {}
try:
progress.update_message('Initializing task...')
with TemporaryDirectory() as download_dir:
with swcc_session(base_url=base_url) as session:
# fetch everything we need
session.set_token(token.key)
project = models.Project.objects.get(id=project_id)
project_filename = project.file.name.split('/')[-1]
swcc_project = SWCCProject.from_id(project.id)
progress.update_message('Downloading project...')
swcc_project.download(download_dir)

pre_command_function()
progress.update_percentage(10)

progress.update_message('Writing form data to project file...')
if form_data:
# write the form data to the project file
edit_swproj_section(
Path(download_dir, project_filename),
'deepssm',
form_data,
)

# Use shapeworks python project class
sw_project = sw.Project()
result_data = {}

sw_project_file = str(Path(download_dir, project_filename))
# Use shapeworks python project class
sw_project = sw.Project()

sw_project.load(sw_project_file)
sw_project_file = str(Path(download_dir, project_filename))

groom_params = sw_project.get_parameters('groom')
progress.update_message('Loading project file...')
sw_project.load(sw_project_file)

# for each parameter in the form data, set the parameter in the project
for key, value in form_data.items():
groom_params.set(key, value)
progress.update_message('Copying grooming parameters')
groom_params = sw_project.get_parameters('groom')

sw_project.set_parameters('groom', groom_params)
# for each parameter in the form data, set the parameter in the project
for key, value in form_data.items():
groom_params.set(key, value)

optimize_params = sw_project.get_parameters('optimize')
# for each parameter in the form data, set the parameter in the project
for key, value in form_data.items():
optimize_params.set(key, value)
sw_project.set_parameters('groom', groom_params)

sw_project.set_parameters('optimize', optimize_params)
progress.update_message('Copying optimization parameters')
optimize_params = sw_project.get_parameters('optimize')
# for each parameter in the form data, set the parameter in the project
for key, value in form_data.items():
optimize_params.set(key, value)

os.chdir(sw_project.get_project_path())
run_prep(form_data, sw_project, sw_project_file, progress)
sw_project.set_parameters('optimize', optimize_params)

aug_dims = run_augmentation(form_data, sw_project, download_dir, progress)
progress.update_message('Running DeepSSM on data...')

# result data has paths to files
result_data['augmentation'] = {
'total_data_csv': download_dir + '/deepssm/augmentation/TotalData.csv',
'violin_plot': download_dir + '/deepssm/augmentation/violin.png',
'generated_meshes': os.listdir(
download_dir + '/deepssm/augmentation/Generated-Meshes/'
),
'generated_images': os.listdir(
download_dir + '/deepssm/augmentation/Generated-Images/'
),
'generated_particles': os.listdir(
download_dir + '/deepssm/augmentation/Generated-Particles/'
),
}
os.chdir(sw_project.get_project_path())
run_prep(form_data, sw_project, sw_project_file, progress)

run_training(form_data, sw_project, download_dir, aug_dims, progress)
aug_dims = run_augmentation(form_data, sw_project, download_dir, progress)

training_examples = os.listdir(download_dir + '/deepssm/model/examples/')
# result data has paths to files
result_data['augmentation'] = {
'total_data_csv': download_dir + '/deepssm/augmentation/TotalData.csv',
'violin_plot': download_dir + '/deepssm/augmentation/violin.png',
'generated_meshes': os.listdir(
download_dir + '/deepssm/augmentation/Generated-Meshes/'
),
'generated_images': os.listdir(
download_dir + '/deepssm/augmentation/Generated-Images/'
),
'generated_particles': os.listdir(
download_dir + '/deepssm/augmentation/Generated-Particles/'
),
}

if 'train_' in training_examples:
training_examples.remove('train_')
if 'validation_' in training_examples:
training_examples.remove('validation_')
run_training(form_data, sw_project, download_dir, aug_dims, progress)

result_data['training'] = {
'train_log': download_dir + '/deepssm/model/train_log.csv',
'training_plot': download_dir + '/deepssm/model/training_plot.png',
'training_plot_ft': download_dir + '/deepssm/model/training_plot_ft.png',
'train_examples': training_examples,
'train_images': os.listdir(download_dir + '/deepssm/train_images/'),
'val_and_test_images': os.listdir(download_dir + '/deepssm/val_and_test_images/'),
}
training_examples = os.listdir(download_dir + '/deepssm/model/examples/')

run_testing(form_data, sw_project, download_dir, progress)
if 'train_' in training_examples:
training_examples.remove('train_')
if 'validation_' in training_examples:
training_examples.remove('validation_')

subjects = sw_project.get_subjects()
result_data['training'] = {
'train_log': download_dir + '/deepssm/model/train_log.csv',
'training_plot': download_dir + '/deepssm/model/training_plot.png',
'training_plot_ft': download_dir + '/deepssm/model/training_plot_ft.png',
'train_examples': training_examples,
'train_images': os.listdir(download_dir + '/deepssm/train_images/'),
'val_and_test_images': os.listdir(
download_dir + '/deepssm/val_and_test_images/'
),
}

result_data['testing'] = {
'world_predictions': os.listdir(
download_dir + '/deepssm/model/test_predictions/world_predictions/'
),
'local_predictions': os.listdir(
download_dir + '/deepssm/model/test_predictions/local_predictions/'
),
'test_distances': download_dir + '/deepssm/test_distances.csv',
'test_split_subjects': subjects,
}
run_testing(form_data, sw_project, download_dir, progress)
progress.update_message('Saving Results...')

os.chdir('../../')
subjects = sw_project.get_subjects()

post_command_function(project, download_dir, result_data, project_filename)
progress.update_percentage(100)
result_data['testing'] = {
'world_predictions': os.listdir(
download_dir + '/deepssm/model/test_predictions/world_predictions/'
),
'local_predictions': os.listdir(
download_dir + '/deepssm/model/test_predictions/local_predictions/'
),
'test_distances': download_dir + '/deepssm/test_distances.csv',
'test_split_subjects': subjects,
}

os.chdir('../../')

post_command_function(project, download_dir, result_data, project_filename)
progress.update_percentage(100)
except models.TaskProgress.TaskAbortedError:
print('Task Aborted. Exiting.')
except Exception as e:
progress.update_error(str(e))


@shared_task
Expand Down
8 changes: 8 additions & 0 deletions shapeworks_cloud/core/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,14 @@ class Meta:
fields = ['project']


class TaskProgressFilter(FilterSet):
project = ModelChoiceFilter(queryset=models.Project.objects.all())

class Meta:
models = models.TaskProgress
fields = ['project']


class DeepSSMTestingDataFilter(FilterSet):
project = ModelChoiceFilter(queryset=models.Project.objects.all())

Expand Down
18 changes: 18 additions & 0 deletions shapeworks_cloud/core/migrations/0041_taskprogress_formdata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Generated by Django 3.2.25 on 2024-04-22 16:15

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('core', '0040_deepssm_ui_updates'),
]

operations = [
migrations.AddField(
model_name='taskprogress',
name='form_data',
field=models.JSONField(blank=True, null=True),
),
]
17 changes: 13 additions & 4 deletions shapeworks_cloud/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,21 +356,30 @@ class ReconstructedSample(TimeStampedModel, models.Model):


class TaskProgress(TimeStampedModel, models.Model):
class TaskAbortedError(Exception):
pass

name = models.CharField(max_length=255)
project = models.ForeignKey(Project, on_delete=models.CASCADE, null=True)
error = models.CharField(max_length=255, blank=True)
message = models.CharField(max_length=255, blank=True)
percent_complete = models.IntegerField(default=0)
abort = models.BooleanField(default=False)
form_data = models.JSONField(null=True, blank=True)

def update_percentage(self, percentage):
self.percent_complete = percentage
self.save()

def update_error(self, error):
self.error = error[:255]
self.save()
if self.abort:
raise self.TaskAbortedError()

def update_message(self, message):
self.message = message[:255]
self.save()
if self.abort:
raise self.TaskAbortedError()

def update_error(self, error):
self.error = error[:255]
self.abort = True
self.save()
Loading

0 comments on commit cae1cfe

Please sign in to comment.