diff --git a/backend/experiment/rules/tests/test_categorization.py b/backend/experiment/rules/tests/test_categorization.py index a8557625a..8b0b82a7d 100644 --- a/backend/experiment/rules/tests/test_categorization.py +++ b/backend/experiment/rules/tests/test_categorization.py @@ -131,20 +131,24 @@ def test_plan_experiment_and_phase(self): self.assertEqual(self.session.json_data.get('phase'), 'training-1A') # Test section sequence training phase - if self.session.json_data.get('group') == 'C1': - sections = Section.objects.filter(group="CROSSED", tag__contains="1", song__artist__contains="Training") + if self.session.json_data.get('group') == 'S1': + sections = self.session.playlist.section_set.filter( + group="SAME", tag__contains="1", song__artist__contains="Training") for section in sections: self.assertIn(section.id, self.session.json_data.get('sequence')) - if self.session.json_data.get('group') == 'C2': - sections = Section.objects.filter(group="CROSSED", tag__contains="2", song__artist__contains="Training") + if self.session.json_data.get('group') == 'S2': + sections = self.session.playlist.section_set.filter( + group="SAME", tag__contains="2", song__artist__contains="Training") for section in sections: self.assertIn(section.id, self.session.json_data.get('sequence')) - if self.session.json_data.get('group') == 'S1': - sections = Section.objects.filter(group="SAME", tag__contains="1", song__artist__contains="Training") + if self.session.json_data.get('group') == 'C1': + sections = self.session.playlist.section_set.filter( + group="CROSSED", tag__contains="1", song__artist__contains="Training") for section in sections: self.assertIn(section.id, self.session.json_data.get('sequence')) - if self.session.json_data.get('group') == 'S2': - sections = Section.objects.filter(group="SAME", tag__contains="2", song__artist__contains="Training") + if self.session.json_data.get('group') == 'C2': + sections = self.session.playlist.section_set.filter( + group="CROSSED", tag__contains="2", song__artist__contains="Training") for section in sections: self.assertIn(section.id, self.session.json_data.get('sequence'))