Skip to content

Commit

Permalink
Add test cases for private project on topicmodel and summary
Browse files Browse the repository at this point in the history
PR changes
Update side_effect function for checking Confidential leads
  • Loading branch information
susilnem committed Jun 14, 2024
1 parent 9866203 commit 39d2fec
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 11 deletions.
3 changes: 3 additions & 0 deletions apps/analysis/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,9 @@ def _get_entries_qs(analysis_pillar, entry_filters):
).qs

def get_entries_qs(self):
additional_filters = copy.deepcopy(self.additional_filters)
# NOTE: NLP is using LLM, to avoid data leakage we only pass UNPROTECTED
additional_filters['lead_confidentialities'] = [Lead.Confidentiality.UNPROTECTED]
return self._get_entries_qs(self.analysis_pillar, self.additional_filters)


Expand Down
7 changes: 2 additions & 5 deletions apps/analysis/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from entry.serializers import SimpleEntrySerializer
from entry.filter_set import EntryGQFilterSet, EntriesFilterDataInputType

from lead.models import Lead
from .models import (
Analysis,
AnalysisPillar,
Expand Down Expand Up @@ -468,8 +467,6 @@ def validate_analysis_pillar(self, analysis_pillar):
return analysis_pillar

def validate_additional_filters(self, additional_filters):
# NOTE: overlapping the lead_confidentialities filter, only public leads are allowed
additional_filters['lead_confidentialities'] = [Lead.Confidentiality.UNPROTECTED]
filter_set = EntryGQFilterSet(data=additional_filters, request=self.context['request'])
if not filter_set.is_valid():
raise serializers.ValidationError(filter_set.errors)
Expand Down Expand Up @@ -526,11 +523,11 @@ class AnalysisAutomaticSummarySerializer(EntriesCollectionNlpTriggerBaseSerializ
trigger_task_func = trigger_automatic_summary
widget_tags = StringListField()

def validate_project(self, project):
def validate(self, data):
project = self.context['request'].active_project
if project.is_private:
raise serializers.ValidationError('Automatic summary is not allowed for private projects')
return project
return data

class Meta:
model = AutomaticSummary
Expand Down
3 changes: 2 additions & 1 deletion apps/analysis/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from celery import shared_task
from django.db import models
from lead.models import Lead

from utils.files import generate_json_file_for_upload
from deepl_integration.handlers import (
Expand All @@ -12,6 +11,7 @@
AnalyticalStatementGeoHandler,
)

from lead.models import Lead
from entry.models import Entry
from .models import (
TopicModel,
Expand Down Expand Up @@ -56,6 +56,7 @@ def trigger_automatic_summary(_id):
Entry.objects.filter(
project=a_summary.project,
id__in=a_summary.entries_id,
# NOTE: NLP is using LLM, to avoid data leakage we only pass UNPROTECTED
lead__confidentiality=Lead.Confidentiality.UNPROTECTED,
).values('excerpt', entry_id=models.F('id'))
)
Expand Down
58 changes: 53 additions & 5 deletions apps/analysis/tests/test_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,13 @@ def setUp(self):
self.af = AnalysisFrameworkFactory.create()
self.project = ProjectFactory.create(analysis_framework=self.af)
self.another_project = ProjectFactory.create()
self.private_project = ProjectFactory.create(analysis_framework=self.af, is_private=True)
# User with role
self.non_member_user = UserFactory.create()
self.readonly_member_user = UserFactory.create()
self.member_user = UserFactory.create()
self.project.add_member(self.member_user, role=self.project_role_member)
self.private_project.add_member(self.member_user, role=self.project_role_member)

def _check_status(self, obj, status):
obj.refresh_from_db()
Expand All @@ -235,6 +237,15 @@ def test_topic_model(self, trigger_results_mock, RequestHelperMock):
analysis=analysis,
assignee=self.member_user,
)
private_analysis = AnalysisFactory.create(
project=self.private_project,
team_lead=self.member_user,
end_date=datetime.date(2022, 4, 1),
)
private_analysis_pillar = AnalysisPillarFactory.create(
analysis=private_analysis,
assignee=self.member_user,
)

# NOTE: This should be ignored by analysis end_date
lead1 = LeadFactory.create(
Expand Down Expand Up @@ -281,6 +292,15 @@ def _mutation_check(minput, **kwargs):
**kwargs
)

def _private_project_mutation_check(minput, **kwargs):
return self.query_check(
self.TRIGGER_TOPIC_MODEL,
minput=minput,
mnested=['project'],
variables={'projectId': self.private_project.id},
**kwargs
)

def _query_check(_id):
return self.query_check(
self.QUERY_TOPIC_MODEL,
Expand Down Expand Up @@ -319,6 +339,13 @@ def _query_check(_id):
self.force_login(self.member_user)
_mutation_check(minput, okay=False)

# using private_analysis_pillar for private project validation
minput['analysisPillar'] = str(private_analysis_pillar.id)

# --- member user (error since the project is private)
self.force_login(self.member_user)
_private_project_mutation_check(minput, okay=False)

# Valid data
minput['analysisPillar'] = str(analysis_pillar.id)

Expand Down Expand Up @@ -413,9 +440,11 @@ def test_automatic_summary(self, trigger_results_mock, RequestHelperMock):
lead1 = LeadFactory.create(project=self.project)
lead2 = LeadFactory.create(project=self.project)
another_lead = LeadFactory.create(project=self.another_project)
lead3 = LeadFactory.create(project=self.private_project)
lead1_entries = EntryFactory.create_batch(3, analysis_framework=self.af, lead=lead1)
lead2_entries = EntryFactory.create_batch(4, analysis_framework=self.af, lead=lead2)
another_lead_entries = EntryFactory.create_batch(4, analysis_framework=self.af, lead=another_lead)
lead3_entries = EntryFactory.create_batch(2, analysis_framework=self.af, lead=lead3)

def nlp_validator_mock(url, data=None, json=None, **kwargs):
if not json:
Expand All @@ -425,10 +454,15 @@ def nlp_validator_mock(url, data=None, json=None, **kwargs):
payload = self.get_json_media_file(
json['entries_url'].split('http://testserver/media/')[1],
)
# TODO: Need to check the Child fields of data and File payload as well
expected_keys = ['data', 'tags']
if set(payload.keys()) != set(expected_keys):
return mock.MagicMock(status_code=400)

if 'data' in payload and isinstance(payload['data'], list):
entry_ids = [entry['entry_id'] for entry in payload['data']]

# NOTE: Confidential leads entries should not be included
for entry in lead3_entries:
if str(entry.id) in entry_ids:
assert False, 'Confidential entries should not be included'

return mock.MagicMock(status_code=202)

def nlp_fail_mock(*args, **kwargs):
Expand All @@ -445,6 +479,15 @@ def _mutation_check(minput, **kwargs):
**kwargs
)

def _private_project_mutation_check(minput, **kwargs):
return self.query_check(
self.TRIGGER_AUTOMATIC_SUMMARY,
minput=minput,
mnested=['project'],
variables={'projectId': self.private_project.id},
**kwargs
)

def _query_check(_id):
return self.query_check(
self.QUERY_AUTOMATIC_SUMMARY,
Expand Down Expand Up @@ -474,7 +517,8 @@ def _query_check(_id):
for entries in [
lead1_entries,
lead2_entries,
another_lead_entries
another_lead_entries,
lead3_entries,
]
for entry in entries
]
Expand All @@ -483,6 +527,10 @@ def _query_check(_id):
'tag2',
]

# --- member user (error since the project is private)
self.force_login(self.member_user)
response = _private_project_mutation_check(minput, okay=False)

# --- member user (All good)
with self.captureOnCommitCallbacks(execute=True):
response = _mutation_check(minput, okay=True)
Expand Down

0 comments on commit 39d2fec

Please sign in to comment.