Skip to content

Commit

Permalink
Merge pull request #127 from twitter/bradm/modeling_groups
Browse files Browse the repository at this point in the history
Group modeling and other assorted updates
  • Loading branch information
bradmiller authored Jul 28, 2023
2 parents 796e172 + fed8763 commit b8b6f8a
Show file tree
Hide file tree
Showing 8 changed files with 595 additions and 84 deletions.
44 changes: 44 additions & 0 deletions sourcecode/scoring/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

maxTrainError = 0.09

coreFlipPct = 0.15
expansionFlipPct = 0.19
maxReruns = 5

Expand All @@ -41,6 +42,7 @@
summaryKey = "summary"
authorTopNotHelpfulTagValues = "authorTopNotHelpfulTagValues"
modelingPopulationKey = "modelingPopulation"
modelingGroupKey = "modelingGroup"

# TSV Values
notHelpfulValueTsv = "NOT_HELPFUL"
Expand Down Expand Up @@ -127,6 +129,14 @@ def rater_factor_key(i):
coverageRatingStatusKey = "coverageRatingStatus"
coverageNoteInterceptMaxKey = "coverageNoteInterceptMax"
coverageNoteInterceptMinKey = "coverageNoteInterceptMin"
# Group Model
groupNoteInterceptKey = "groupNoteIntercept"
groupNoteFactor1Key = "groupNoteFactor1"
groupRatingStatusKey = "groupRatingStatus"
groupNoteInterceptMaxKey = "groupNoteInterceptMax"
groupNoteInterceptMinKey = "groupNoteInterceptMin"
groupRaterInterceptKey = "groupRaterIntercept"
groupRaterFactor1Key = "groupRaterFactor1"

# Ids and Indexes
noteIdKey = "noteId"
Expand Down Expand Up @@ -364,6 +374,16 @@ def rater_factor_key(i):
userEnrollmentTSVTypes = [dtype for (_, dtype) in userEnrollmentTSVColumnsAndTypes]
userEnrollmentTSVTypeMapping = {col: dtype for (col, dtype) in userEnrollmentTSVColumnsAndTypes}

# TODO: delete expanded user enrollment definition once modeling group is fully rolled out
userEnrollmentExpandedTSVColumnsAndTypes = userEnrollmentTSVColumnsAndTypes + [
(modelingGroupKey, np.float64)
]
userEnrollmentExpandedTSVColumns = [col for (col, _) in userEnrollmentExpandedTSVColumnsAndTypes]
userEnrollmentExpandedTSVTypes = [dtype for (_, dtype) in userEnrollmentExpandedTSVColumnsAndTypes]
userEnrollmentExpandedTSVTypeMapping = {
col: dtype for (col, dtype) in userEnrollmentExpandedTSVColumnsAndTypes
}

noteInterceptMaxKey = "internalNoteIntercept_max"
noteInterceptMinKey = "internalNoteIntercept_min"
noteParameterUncertaintyTSVMainColumnsAndTypes = [
Expand Down Expand Up @@ -421,6 +441,16 @@ def rater_factor_key(i):
+ incorrectFilterColumns
)

deprecatedNoteModelOutputColumns = frozenset(
{
coverageNoteInterceptKey,
coverageNoteFactor1Key,
coverageRatingStatusKey,
coverageNoteInterceptMinKey,
coverageNoteInterceptMaxKey,
}
)

noteModelOutputTSVColumnsAndTypes = [
(noteIdKey, np.int64),
(coreNoteInterceptKey, np.double),
Expand Down Expand Up @@ -451,9 +481,20 @@ def rater_factor_key(i):
(expansionNoteInterceptMaxKey, np.double),
(coverageNoteInterceptMinKey, np.double),
(coverageNoteInterceptMaxKey, np.double),
(groupNoteInterceptKey, np.double),
(groupNoteFactor1Key, np.double),
(groupRatingStatusKey, np.str),
(groupNoteInterceptMaxKey, np.double),
(groupNoteInterceptMinKey, np.double),
(modelingGroupKey, np.float64),
]
noteModelOutputTSVColumns = [col for (col, dtype) in noteModelOutputTSVColumnsAndTypes]
noteModelOutputTSVTypeMapping = {col: dtype for (col, dtype) in noteModelOutputTSVColumnsAndTypes}
deprecatedNoteModelOutputTSVColumnsAndTypes = [
(col, dtype)
for (col, dtype) in noteModelOutputTSVColumnsAndTypes
if col in deprecatedNoteModelOutputColumns
]

raterModelOutputTSVColumnsAndTypes = [
(raterParticipantIdKey, np.int64),
Expand Down Expand Up @@ -481,6 +522,9 @@ def rater_factor_key(i):
(isEmergingWriterKey, np.bool_),
(aggregateRatingReceivedTotal, pd.Int64Dtype()),
(timestampOfLastEarnOut, np.double),
(groupRaterInterceptKey, np.double),
(groupRaterFactor1Key, np.double),
(modelingGroupKey, np.float64),
]
raterModelOutputTSVColumns = [col for (col, dtype) in raterModelOutputTSVColumnsAndTypes]
raterModelOutputTSVTypeMapping = {col: dtype for (col, dtype) in raterModelOutputTSVColumnsAndTypes}
4 changes: 3 additions & 1 deletion sourcecode/scoring/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ class Scorers(Enum):
"""Exhaustive list of all scorers to simplify setting enabled/disabled scorers."""

MFCoreScorer = auto()
MFCoverageScorer = auto()
MFExpansionScorer = auto()
# Note that the MFGroupScorer value controls whether *all* group scorers are instantiated,
# not just a single MFGroupScorer instance.
MFGroupScorer = auto()


def scorers_from_csv(csv: str) -> Set[Scorers]:
Expand Down
5 changes: 5 additions & 0 deletions sourcecode/scoring/mf_base_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ def __init__(
self._weightedTotalVotes = weightedTotalVotes
self._mfRanker = matrix_factorization.MatrixFactorization()

def get_crh_threshold(self) -> float:
"""Return CRH threshold for general scoring logic."""
return self._crhThreshold

def get_scored_notes_cols(self) -> List[str]:
"""Returns a list of columns which should be present in the scoredNotes output."""
return [
Expand Down Expand Up @@ -160,6 +164,7 @@ def _score_notes_and_users(
userScores pd.DataFrame: one row per user containing a column for each helpfulness score.
"""
if self._seed is not None:
print(f"seeding with {self._seed}")
torch.manual_seed(self._seed)

# Removes ratings where either (1) the note did not receive enough ratings, or
Expand Down
Loading

0 comments on commit b8b6f8a

Please sign in to comment.