Skip to content

Commit

Permalink
Fix tests. Don't update ability estimate for the unvalidated Cat. Han…
Browse files Browse the repository at this point in the history
…dle unvalidated remaining items separately
  • Loading branch information
richford committed Sep 30, 2024
1 parent 5c684a8 commit 5d0ff50
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 71 deletions.
183 changes: 142 additions & 41 deletions src/__tests__/clowder.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ describe('Clowder Class', () => {

it('initializes with provided cats and corpora', () => {
expect(Object.keys(clowder.cats)).toContain('cat1');
expect(Object.keys(clowder.cats)).toContain('unvalidated'); // Ensure 'unvalidated' cat is present
expect(clowder.remainingItems).toHaveLength(5);
expect(clowder.corpus).toHaveLength(5);
expect(clowder.seenItems).toHaveLength(0);
Expand Down Expand Up @@ -92,7 +91,7 @@ describe('Clowder Class', () => {

it('throws an error when updating ability estimates for an invalid cat', () => {
expect(() => clowder.updateAbilityEstimates(['invalidCatName'], createStimulus('1'), [0])).toThrowError(
'Invalid Cat name. Expected one of cat1, cat2, unvalidated. Received invalidCatName.',
'Invalid Cat name. Expected one of cat1, cat2. Received invalidCatName.',
);
});

Expand All @@ -109,7 +108,6 @@ describe('Clowder Class', () => {
const expected = {
cat1: clowder.cats['cat1'][property as keyof Cat],
cat2: clowder.cats['cat2'][property as keyof Cat],
unvalidated: clowder.cats['unvalidated'][property as keyof Cat],
};
expect(clowder[property as keyof Clowder]).toEqual(expected);
});
Expand Down Expand Up @@ -138,7 +136,7 @@ describe('Clowder Class', () => {
catToSelect: 'cat1',
catsToUpdate: ['invalidCatName', 'cat2'],
});
}).toThrow('Invalid Cat name. Expected one of cat1, cat2, unvalidated. Received invalidCatName.');
}).toThrow('Invalid Cat name. Expected one of cat1, cat2. Received invalidCatName.');
});

it('updates seen and remaining items', () => {
Expand Down Expand Up @@ -207,6 +205,110 @@ describe('Clowder Class', () => {
expect(['0', '2']).toContain(nextItem?.id);
});

it('should select an unvalidated item if catToSelect is "unvalidated"', () => {
const clowderInput: ClowderInput = {
cats: {
cat1: { method: 'MLE', theta: 0.5 },
},
corpus: [
createMultiZetaStimulus('0', [createZetaCatMap([])]),
createMultiZetaStimulus('1', [createZetaCatMap(['cat1'])]),
createMultiZetaStimulus('2', [createZetaCatMap([])]),
createMultiZetaStimulus('3', [createZetaCatMap(['cat1'])]),
],
};

const clowder = new Clowder(clowderInput);

const nDraws = 50;
// Simulate sDraws unvalidated items being selected
// eslint-disable-next-line @typescript-eslint/no-unused-vars
for (const _ of Array(nDraws).fill(0)) {
const nextItem = clowder.updateCatAndGetNextItem({
catToSelect: 'unvalidated',
});

expect(['0', '2']).toContain(nextItem?.id);
}
});

it('should not update cats with items that do not have parameters for that cat', () => {
const clowderInput: ClowderInput = {
cats: {
cat1: { method: 'MLE', theta: 0.5 },
cat2: { method: 'MLE', theta: 0.5 },
},
corpus: [
createMultiZetaStimulus('0', [createZetaCatMap(['cat1'])]),
createMultiZetaStimulus('1', [createZetaCatMap(['cat1'])]),
createMultiZetaStimulus('2', [createZetaCatMap(['cat2'])]),
createMultiZetaStimulus('3', [createZetaCatMap(['cat2'])]),
],
};

const clowder = new Clowder(clowderInput);

clowder.updateCatAndGetNextItem({
catsToUpdate: ['cat1', 'cat2'],
items: clowder.corpus,
answers: [1, 1, 1, 1],
catToSelect: 'unvalidated',
});

expect(clowder.nItems.cat1).toBe(2);
expect(clowder.nItems.cat2).toBe(2);
});

it('should not update any cats if only unvalidated items have been seen', () => {
const clowderInput: ClowderInput = {
cats: {
cat1: { method: 'MLE', theta: 0.5 },
},
corpus: [
createMultiZetaStimulus('0', [createZetaCatMap([])]),
createMultiZetaStimulus('1', [createZetaCatMap(['cat1'])]),
createMultiZetaStimulus('2', [createZetaCatMap([])]),
createMultiZetaStimulus('3', [createZetaCatMap(['cat1'])]),
],
};

const clowder = new Clowder(clowderInput);

clowder.updateCatAndGetNextItem({
catsToUpdate: ['cat1'],
items: [clowder.corpus[0], clowder.corpus[2]],
answers: [1, 1],
catToSelect: 'unvalidated',
});

expect(clowder.nItems.cat1).toBe(0);
});

it('should return undefined for next item if catToSelect = "unvalidated" and no unvalidated items remain', () => {
const clowderInput: ClowderInput = {
cats: {
cat1: { method: 'MLE', theta: 0.5 },
},
corpus: [
createMultiZetaStimulus('0', [createZetaCatMap([])]),
createMultiZetaStimulus('1', [createZetaCatMap(['cat1'])]),
createMultiZetaStimulus('2', [createZetaCatMap([])]),
createMultiZetaStimulus('3', [createZetaCatMap(['cat1'])]),
],
};

const clowder = new Clowder(clowderInput);

const nextItem = clowder.updateCatAndGetNextItem({
catsToUpdate: ['cat1'],
items: [clowder.corpus[0], clowder.corpus[2]],
answers: [1, 1],
catToSelect: 'unvalidated',
});

expect(nextItem).toBeUndefined();
});

it('should correctly update ability estimates during the updateCatAndGetNextItem method', () => {
const originalTheta = clowder.cats.cat1.theta;
clowder.updateCatAndGetNextItem({
Expand Down Expand Up @@ -249,15 +351,12 @@ describe('Clowder Class', () => {
});

it('should return undefined if no more items remain', () => {
clowder.updateCatAndGetNextItem({
const nextItem = clowder.updateCatAndGetNextItem({
catToSelect: 'cat1',
items: clowder.remainingItems,
answers: [1, 0, 1, 1, 0], // Exhaust all items
});

const nextItem = clowder.updateCatAndGetNextItem({
catToSelect: 'cat1',
});
expect(nextItem).toBeUndefined();
});

Expand Down Expand Up @@ -341,8 +440,8 @@ describe('Clowder Early Stopping', () => {
cats: { cat1: { method: 'MLE', theta: 0.5 } },
corpus: [
createMultiZetaStimulus('0', [createZetaCatMap(['cat1'])]),
createMultiZetaStimulus('1', [createZetaCatMap(['cat1'])]),
createMultiZetaStimulus('2', [createZetaCatMap(['cat1'])]), // This item should trigger early stopping
createMultiZetaStimulus('1', [createZetaCatMap(['cat1'])]), // This item should trigger early stopping
createMultiZetaStimulus('2', [createZetaCatMap(['cat1'])]),
],
earlyStopping,
});
Expand All @@ -353,17 +452,13 @@ describe('Clowder Early Stopping', () => {
items: [clowder.corpus[0]],
answers: [1],
});
clowder.updateCatAndGetNextItem({
catToSelect: 'cat1',
catsToUpdate: ['cat1'],
items: [clowder.corpus[1]],
answers: [1],
});

expect(clowder.earlyStopping?.earlyStop).toBe(false);

const nextItem = clowder.updateCatAndGetNextItem({
catToSelect: 'cat1',
catsToUpdate: ['cat1'],
items: [clowder.corpus[2]],
items: [clowder.corpus[1]],
answers: [1],
});

Expand All @@ -378,35 +473,41 @@ describe('Clowder Early Stopping', () => {
tolerance: { cat1: 0.01 },
});

clowder = new Clowder({
cats: { cat1: { method: 'MLE', theta: 0.5 } },
corpus: [
createMultiZetaStimulus('0', [createZetaCatMap(['cat1'])]),
createMultiZetaStimulus('1', [createZetaCatMap(['cat1'])]),
],
earlyStopping,
const zetaMap = createZetaCatMap(['cat1'], {
a: 6,
b: 6,
c: 0,
d: 1,
});

// First update
clowder.updateCatAndGetNextItem({
catToSelect: 'cat1',
catsToUpdate: ['cat1'],
items: [clowder.corpus[0]],
answers: [1],
});
// pringing results
console.log('SE Measurements:', clowder.earlyStopping?.seMeasurementThreshold, clowder.cats.cat1);
const corpus = [
createMultiZetaStimulus('0', [zetaMap]),
createMultiZetaStimulus('1', [zetaMap]),
createMultiZetaStimulus('2', [zetaMap]), // Here the SE measurement drops below threshold
createMultiZetaStimulus('3', [zetaMap]), // And here, early stopping should be triggered because it has been below threshold for 2 items
];

const nextItem = clowder.updateCatAndGetNextItem({
catToSelect: 'cat1',
catsToUpdate: ['cat1'],
items: [clowder.corpus[1]],
answers: [1],
clowder = new Clowder({
cats: { cat1: { method: 'MLE', theta: 0.5 } },
corpus,
earlyStopping,
});

console.log('Early Stop Triggered:', clowder.earlyStopping?.earlyStop);
for (const item of corpus) {
const nextItem = clowder.updateCatAndGetNextItem({
catToSelect: 'cat1',
catsToUpdate: ['cat1'],
items: [item],
answers: [1],
});

expect(clowder.earlyStopping?.earlyStop).toBe(true); // Should stop after SE drops below threshold
expect(nextItem).toBe(undefined); // No further items should be selected
if (item.id === '3') {
expect(clowder.earlyStopping?.earlyStop).toBe(true); // Should stop after SE drops below threshold
expect(nextItem).toBe(undefined); // No further items should be selected
} else {
expect(clowder.earlyStopping?.earlyStop).toBe(false);
expect(nextItem).toBeDefined();
}
}
});
});
78 changes: 48 additions & 30 deletions src/clowder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,12 @@ export class Clowder {
* @throws {Error} - Throws an error if any item in the corpus has duplicated IRT parameters for any Cat name.
*/
constructor({ cats, corpus, randomSeed = null, earlyStopping }: ClowderInput) {
// TODO: Add some imput validation to both the cats and the corpus to make sure that "unvalidated" is not used as a cat name.
// If so, throw an error saying that "unvalidated" is a reserved name and may not be used.
// TODO: Also add a test of this behavior.
this._cats = {
..._mapValues(cats, (catInput) => new Cat(catInput)),
unvalidated: new Cat(), // Add 'unvalidated' cat
unvalidated: new Cat({ itemSelect: 'random', randomSeed }), // Add 'unvalidated' cat
};
this._seenItems = [];
checkNoDuplicateCatNames(corpus);
Expand All @@ -74,20 +77,22 @@ export class Clowder {
* Throw an error if the Cat name is not found.
*
* @param {string} catName - The name of the Cat instance to validate.
* @param {boolean} allowUnvalidated - Whether to allow the reserved 'unvalidated' name.
*
* @throws {Error} - Throws an error if the provided Cat name is not found among the existing Cat instances.
*/
private _validateCatName(catName: string): void {
if (!Object.prototype.hasOwnProperty.call(this._cats, catName)) {
throw new Error(`Invalid Cat name. Expected one of ${Object.keys(this._cats).join(', ')}. Received ${catName}.`);
private _validateCatName(catName: string, allowUnvalidated = false): void {
const allowedCats = allowUnvalidated ? this._cats : this.cats;
if (!Object.prototype.hasOwnProperty.call(allowedCats, catName)) {
throw new Error(`Invalid Cat name. Expected one of ${Object.keys(allowedCats).join(', ')}. Received ${catName}.`);
}
}

/**
* The named Cat instances that this Clowder manages.
*/
public get cats() {
return this._cats;
return _omit(this._cats, ['unvalidated']);
}

/**
Expand Down Expand Up @@ -162,7 +167,7 @@ export class Clowder {
*/
public updateAbilityEstimates(catNames: string[], zeta: Zeta | Zeta[], answer: (0 | 1) | (0 | 1)[], method?: string) {
catNames.forEach((catName) => {
this._validateCatName(catName);
this._validateCatName(catName, false);
});
for (const catName of catNames) {
this.cats[catName].updateAbilityEstimate(zeta, answer, method);
Expand Down Expand Up @@ -216,10 +221,14 @@ export class Clowder {
itemSelect?: string;
randomlySelectUnvalidated?: boolean;
}): Stimulus | undefined {
this._validateCatName(catToSelect);
// +----------+
// ----------| Update |----------|
// +----------+

this._validateCatName(catToSelect, true);
catsToUpdate = Array.isArray(catsToUpdate) ? catsToUpdate : [catsToUpdate];
catsToUpdate.forEach((cat) => {
this._validateCatName(cat);
this._validateCatName(cat, false);
});

// Convert items and answers to arrays
Expand Down Expand Up @@ -257,28 +266,24 @@ export class Clowder {
// retrieve only the item parameters that apply to this cat.
stim.zetas.some((zeta: ZetaCatMap) => zeta.cats.includes(catName)),
);
const zetasAndAnswersForCat = itemsAndAnswersForCat
.map(([stim, _answer]) => {
const zetaForCat: ZetaCatMap | undefined = stim.zetas.find((zeta: ZetaCatMap) => zeta.cats.includes(catName));
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return [zetaForCat!.zeta, _answer]; // Optional chaining in case zetaForCat is undefined
})
.filter(([zeta]) => zeta !== undefined); // Filter out undefined zeta values

// Unzip the zetas and answers, making sure the zetas array contains only Zeta types
const [zetas, answers] = _unzip(zetasAndAnswersForCat) as [Zeta[], (0 | 1)[]];

// Now, pass the filtered zetas and answers to the cat's updateAbilityEstimate method
this.cats[catName].updateAbilityEstimate(zetas, answers, method);
}

// Assign items with no valid parameters to the 'unvalidated' cat
const unvalidatedItemsAndAnswers = itemsAndAnswers.filter(
([stim]) => !stim.zetas.some((zeta: ZetaCatMap) => zeta.cats.length > 0),
);
if (unvalidatedItemsAndAnswers.length > 0) {
const [zetas, answers] = _unzip(unvalidatedItemsAndAnswers) as [Zeta[], (0 | 1)[]];
this.cats['unvalidated'].updateAbilityEstimate(zetas, answers, method);
if (itemsAndAnswersForCat.length > 0) {
const zetasAndAnswersForCat = itemsAndAnswersForCat
.map(([stim, _answer]) => {
const zetaForCat: ZetaCatMap | undefined = stim.zetas.find((zeta: ZetaCatMap) =>
zeta.cats.includes(catName),
);
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return [zetaForCat!.zeta, _answer]; // Optional chaining in case zetaForCat is undefined
})
.filter(([zeta]) => zeta !== undefined); // Filter out undefined zeta values

// Unzip the zetas and answers, making sure the zetas array contains only Zeta types
const [zetas, answers] = _unzip(zetasAndAnswersForCat) as [Zeta[], (0 | 1)[]];

// Now, pass the filtered zetas and answers to the cat's updateAbilityEstimate method
this.cats[catName].updateAbilityEstimate(zetas, answers, method);
}
}

if (this._earlyStopping) {
Expand All @@ -292,9 +297,22 @@ export class Clowder {
// ----------| Select |----------|
// +----------+

if (catToSelect === 'unvalidated') {
// Assign items with no valid parameters to the 'unvalidated' cat
const unvalidatedRemainingItems = this._remainingItems.filter(
(stim) => !stim.zetas.some((zeta: ZetaCatMap) => zeta.cats.length > 0),
);

if (unvalidatedRemainingItems.length === 0) {
return undefined;
} else {
const randInt = Math.floor(this._rng() * unvalidatedRemainingItems.length);
return unvalidatedRemainingItems[randInt];
}
}

// Now, we need to dynamically calculate the stimuli available for selection by `catToSelect`.
// We inspect the remaining items and find ones that have zeta parameters for `catToSelect`

const { available, missing } = filterItemsByCatParameterAvailability(this._remainingItems, catToSelect);

// The cat expects an array of Stimulus objects, with the zeta parameters
Expand Down

0 comments on commit 5d0ff50

Please sign in to comment.