Skip to content

Commit

Permalink
feat: Add basic unit test file for new experiment
Browse files Browse the repository at this point in the history
  • Loading branch information
drikusroor committed Jan 25, 2024
1 parent 51fb737 commit b0b2d45
Showing 1 changed file with 44 additions and 3 deletions.
47 changes: 44 additions & 3 deletions backend/experiment/management/commands/createexperiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,16 @@ def handle(self, *args, **options):
# Ask for the experiment name
experiment_name = input("What is the name of your experiment? (ex. Musical Preferences): ")

# Create the experiment rule class
self.create_experiment_rule_class(experiment_name)

# Add the new experiment to ./experiment/rules/__init__.py
self.register_experiment_rule(experiment_name, './experiment/rules/__init__.py')

# Create a basic test file for the experiment
self.create_test_file(experiment_name)

def create_experiment_rule_class(self, experiment_name):
# Get the experiment name in different cases
experiment_name_snake_case, experiment_name_snake_case_upper, experiment_name_pascal_case = self.get_experiment_name_cases(experiment_name)

Expand Down Expand Up @@ -40,9 +50,6 @@ def handle(self, *args, **options):

self.stdout.write(self.style.SUCCESS(f"Created {filename} for experiment {experiment_name}"))

# Add the new experiment to ./experiment/rules/__init__.py
self.register_experiment_rule(experiment_name, './experiment/rules/__init__.py')

def register_experiment_rule(self, experiment_name, file_path):

# Get the experiment name in different cases
Expand All @@ -69,6 +76,40 @@ def register_experiment_rule(self, experiment_name, file_path):
with open(file_path, 'w') as file:
file.writelines(lines)

self.stdout.write(self.style.SUCCESS(f"Registered {experiment_name} in {file_path}"))

def create_test_file(self, experiment_name):
# Get the experiment name in different cases
experiment_name_snake_case, experiment_name_snake_case_upper, experiment_name_pascal_case = self.get_experiment_name_cases(experiment_name)

# Create a new file for the experiment class
filename = f"./experiment/rules/tests/test_{experiment_name_snake_case}.py"

# Check if the file already exists
if os.path.isfile(filename):
# Warn the user that the file already exists and ask if they want to overwrite it
self.stdout.write(self.style.WARNING(f"File {filename} already exists"))
overwrite = input(f"Do you want to overwrite it? (y/n): ")

# If the user does not want to overwrite the file, exit the command
if overwrite.lower() != 'y':
self.stdout.write(self.style.WARNING(f"File {filename} was not created"))
return
else:
self.stdout.write(self.style.WARNING(f"File {filename} will be overwritten"))

# Create the file by copying ./experiment/management/commands/templates/experiment.py
with open(filename, 'w') as f:
with open('./experiment/management/commands/templates/test_experiment.py', 'r') as template:
f.write(template.read()
.replace('NewExperiment', experiment_name_pascal_case)
.replace('new_experiment', experiment_name_snake_case)
.replace('NEW_EXPERIMENT', experiment_name_snake_case_upper)
.replace('New Experiment', experiment_name.title())
)

self.stdout.write(self.style.SUCCESS(f"Created {filename} for experiment {experiment_name}"))

def get_experiment_name_cases(self, experiment_name):
# Convert experiment name to snake_case and lowercase every word and replace spaces with underscores
experiment_name_snake_case = experiment_name.lower().replace(' ', '_')
Expand Down

0 comments on commit b0b2d45

Please sign in to comment.