diff --git a/src/__tests__/index.test.ts b/src/__tests__/index.test.ts index 37bcaec..bcbbd7c 100644 --- a/src/__tests__/index.test.ts +++ b/src/__tests__/index.test.ts @@ -1,229 +1,206 @@ +/* eslint-disable @typescript-eslint/no-non-null-assertion */ import { Cat } from '../index'; -import { zetaKeyMap, Stimulus, ZetaImplicit, ZetaExplicit } from '../type'; +import { Stimulus } from '../type'; import seedrandom from 'seedrandom'; -import _mapKeys from 'lodash/mapKeys'; - -// Convert ZetaImplicit to ZetaExplicit -const convertZetaImplicitToExplicit = (zeta: ZetaImplicit): ZetaExplicit => { - const explicitZeta = _mapKeys(zeta, (value, key) => { - return zetaKeyMap[key as keyof typeof zetaKeyMap]; - }) as ZetaExplicit; - - return { - discrimination: explicitZeta.discrimination, - difficulty: explicitZeta.difficulty, - guessing: explicitZeta.guessing, - slipping: explicitZeta.slipping, - }; -}; - -describe('Cat', () => { - let cat1: Cat, cat2: Cat, cat3: Cat, cat4: Cat, cat5: Cat, cat6: Cat, cat7: Cat, cat8: Cat; - let rng = seedrandom(); - - beforeEach(() => { - cat1 = new Cat(); - cat1.updateAbilityEstimate( - [ - convertZetaImplicitToExplicit({ a: 2.225, b: -1.885, c: 0.21, d: 1 }), - convertZetaImplicitToExplicit({ a: 1.174, b: -2.411, c: 0.212, d: 1 }), - convertZetaImplicitToExplicit({ a: 2.104, b: -2.439, c: 0.192, d: 1 }), - ], - [1, 0, 1], - ); - - cat2 = new Cat(); - cat2.updateAbilityEstimate( - [ - convertZetaImplicitToExplicit({ a: 1, b: -0.447, c: 0.5, d: 1 }), - convertZetaImplicitToExplicit({ a: 1, b: 2.869, c: 0.5, d: 1 }), - convertZetaImplicitToExplicit({ a: 1, b: -0.469, c: 0.5, d: 1 }), - convertZetaImplicitToExplicit({ a: 1, b: -0.576, c: 0.5, d: 1 }), - convertZetaImplicitToExplicit({ a: 1, b: -1.43, c: 0.5, d: 1 }), - convertZetaImplicitToExplicit({ a: 1, b: -1.607, c: 0.5, d: 1 }), - convertZetaImplicitToExplicit({ a: 1, b: 0.529, c: 0.5, d: 1 }), - ], - [0, 1, 0, 1, 1, 1, 1], - ); - cat3 = new Cat({ nStartItems: 0 }); - const randomSeed = 'test'; - rng = seedrandom(randomSeed); - cat4 = new Cat({ nStartItems: 0, itemSelect: 'RANDOM', randomSeed }); - cat5 = new Cat({ nStartItems: 1, startSelect: 'miDdle' }); // ask - - cat6 = new Cat(); - cat6.updateAbilityEstimate( - [ - convertZetaImplicitToExplicit({ a: 1, b: -4.0, c: 0.5, d: 1 }), - convertZetaImplicitToExplicit({ a: 1, b: -3.0, c: 0.5, d: 1 }), - ], - [0, 0], - ); - - cat7 = new Cat({ method: 'eap' }); - cat7.updateAbilityEstimate( - [ - convertZetaImplicitToExplicit({ a: 1, b: -4.0, c: 0.5, d: 1 }), - convertZetaImplicitToExplicit({ a: 1, b: -3.0, c: 0.5, d: 1 }), - ], - [0, 0], - ); - - cat8 = new Cat({ nStartItems: 0, itemSelect: 'FIXED' }); - }); - - const s1: Stimulus = { difficulty: 0.5, guessing: 0.5, discrimination: 1, slipping: 1, word: 'looking' }; - const s2: Stimulus = { difficulty: 3.5, guessing: 0.5, discrimination: 1, slipping: 1, word: 'opaque' }; - const s3: Stimulus = { difficulty: 2, guessing: 0.5, discrimination: 1, slipping: 1, word: 'right' }; - const s4: Stimulus = { difficulty: -2.5, guessing: 0.5, discrimination: 1, slipping: 1, word: 'yes' }; - const s5: Stimulus = { difficulty: -1.8, guessing: 0.5, discrimination: 1, slipping: 1, word: 'mom' }; - const stimuli = [s1, s2, s3, s4, s5]; - - it('constructs an adaptive test', () => { - expect(cat1.method).toBe('mle'); - expect(cat1.itemSelect).toBe('mfi'); - }); - - it('correctly updates ability estimate', () => { - expect(cat1.theta).toBeCloseTo(-1.642307, 1); - }); - - it('correctly updates ability estimate', () => { - expect(cat2.theta).toBeCloseTo(-1.272, 1); - }); - - it('correctly updates standard error of mean of ability estimate', () => { - expect(cat2.seMeasurement).toBeCloseTo(1.71, 1); - }); - - it('correctly counts number of items', () => { - expect(cat2.nItems).toEqual(7); - }); - - it('correctly updates answers', () => { - expect(cat2.resps).toEqual([0, 1, 0, 1, 1, 1, 1]); - }); - - it('correctly updates zetas', () => { - expect(cat2.zetas).toEqual([ - convertZetaImplicitToExplicit({ a: 1, b: -0.447, c: 0.5, d: 1 }), - convertZetaImplicitToExplicit({ a: 1, b: 2.869, c: 0.5, d: 1 }), - convertZetaImplicitToExplicit({ a: 1, b: -0.469, c: 0.5, d: 1 }), - convertZetaImplicitToExplicit({ a: 1, b: -0.576, c: 0.5, d: 1 }), - convertZetaImplicitToExplicit({ a: 1, b: -1.43, c: 0.5, d: 1 }), - convertZetaImplicitToExplicit({ a: 1, b: -1.607, c: 0.5, d: 1 }), - convertZetaImplicitToExplicit({ a: 1, b: 0.529, c: 0.5, d: 1 }), - ]); - }); - - it('correctly suggests the next item (closest method)', () => { - const expected = { nextStimulus: s5, remainingStimuli: [s4, s1, s3, s2] }; - const received = cat1.findNextItem(stimuli, 'closest'); - expect(received).toEqual(expected); - }); - - it('correctly suggests the next item (mfi method)', () => { - const expected = { nextStimulus: s1, remainingStimuli: [s4, s5, s3, s2] }; - const received = cat3.findNextItem(stimuli, 'MFI'); - expect(received).toEqual(expected); - }); - - it('correctly suggests the next item (middle method)', () => { - const expected = { nextStimulus: s1, remainingStimuli: [s4, s5, s3, s2] }; - const received = cat5.findNextItem(stimuli); - expect(received).toEqual(expected); - }); - - it('correctly suggests the next item (fixed method)', () => { - expect(cat8.itemSelect).toBe('fixed'); - const expected = { nextStimulus: s1, remainingStimuli: [s2, s3, s4, s5] }; - const received = cat8.findNextItem(stimuli); - expect(received).toEqual(expected); - }); +import { convertZeta } from '../utils'; - it('correctly suggests the next item (random method)', () => { - let received; - const stimuliSorted = stimuli.sort((a: Stimulus, b: Stimulus) => a.difficulty - b.difficulty); // ask - let index = Math.floor(rng() * stimuliSorted.length); - received = cat4.findNextItem(stimuliSorted); - expect(received.nextStimulus).toEqual(stimuliSorted[index]); - - for (let i = 0; i < 3; i++) { - const remainingStimuli = received.remainingStimuli; - index = Math.floor(rng() * remainingStimuli.length); - received = cat4.findNextItem(remainingStimuli); - expect(received.nextStimulus).toEqual(remainingStimuli[index]); - } - }); - - it('correctly updates ability estimate through MLE', () => { - expect(cat6.theta).toBeCloseTo(-6.0, 1); - }); +for (const format of ['symbolic', 'semantic'] as Array<'symbolic' | 'semantic'>) { + describe('Cat with explicit zeta', () => { + let cat1: Cat, cat2: Cat, cat3: Cat, cat4: Cat, cat5: Cat, cat6: Cat, cat7: Cat, cat8: Cat; + let rng = seedrandom(); - it('correctly updates ability estimate through EAP', () => { - expect(cat7.theta).toBeCloseTo(0.25, 1); - }); - - it('should throw an error if zeta and answers do not have matching length', () => { - try { - cat7.updateAbilityEstimate( + beforeEach(() => { + cat1 = new Cat(); + cat1.updateAbilityEstimate( [ - convertZetaImplicitToExplicit({ a: 1, b: -4.0, c: 0.5, d: 1 }), - convertZetaImplicitToExplicit({ a: 1, b: -3.0, c: 0.5, d: 1 }), + convertZeta({ a: 2.225, b: -1.885, c: 0.21, d: 1 }, format), + convertZeta({ a: 1.174, b: -2.411, c: 0.212, d: 1 }, format), + convertZeta({ a: 2.104, b: -2.439, c: 0.192, d: 1 }, format), ], - [0, 0, 0], + [1, 0, 1], ); - } catch (error) { - expect(error).toBeInstanceOf(Error); - } - }); - - it('should throw an error if method is invalid', () => { - try { - new Cat({ method: 'coolMethod' }); - } catch (error) { - expect(error).toBeInstanceOf(Error); - } - try { - cat7.updateAbilityEstimate( + cat2 = new Cat(); + cat2.updateAbilityEstimate( [ - convertZetaImplicitToExplicit({ a: 1, b: -4.0, c: 0.5, d: 1 }), - convertZetaImplicitToExplicit({ a: 1, b: -3.0, c: 0.5, d: 1 }), + convertZeta({ a: 1, b: -0.447, c: 0.5, d: 1 }, format), + convertZeta({ a: 1, b: 2.869, c: 0.5, d: 1 }, format), + convertZeta({ a: 1, b: -0.469, c: 0.5, d: 1 }, format), + convertZeta({ a: 1, b: -0.576, c: 0.5, d: 1 }, format), + convertZeta({ a: 1, b: -1.43, c: 0.5, d: 1 }, format), + convertZeta({ a: 1, b: -1.607, c: 0.5, d: 1 }, format), + convertZeta({ a: 1, b: 0.529, c: 0.5, d: 1 }, format), ], + [0, 1, 0, 1, 1, 1, 1], + ); + cat3 = new Cat({ nStartItems: 0 }); + const randomSeed = 'test'; + rng = seedrandom(randomSeed); + cat4 = new Cat({ nStartItems: 0, itemSelect: 'RANDOM', randomSeed }); + cat5 = new Cat({ nStartItems: 1, startSelect: 'miDdle' }); // ask + + cat6 = new Cat(); + cat6.updateAbilityEstimate( + [convertZeta({ a: 1, b: -4.0, c: 0.5, d: 1 }, format), convertZeta({ a: 1, b: -3.0, c: 0.5, d: 1 }, format)], [0, 0], - 'coolMethod', ); - } catch (error) { - expect(error).toBeInstanceOf(Error); - } - }); - - it('should throw an error if itemSelect is invalid', () => { - try { - new Cat({ itemSelect: 'coolMethod' }); - } catch (error) { - expect(error).toBeInstanceOf(Error); - } - - try { - cat7.findNextItem(stimuli, 'coolMethod'); - } catch (error) { - expect(error).toBeInstanceOf(Error); - } - }); - it('should throw an error if startSelect is invalid', () => { - try { - new Cat({ startSelect: 'coolMethod' }); - } catch (error) { - expect(error).toBeInstanceOf(Error); - } - }); + cat7 = new Cat({ method: 'eap' }); + cat7.updateAbilityEstimate( + [convertZeta({ a: 1, b: -4.0, c: 0.5, d: 1 }, format), convertZeta({ a: 1, b: -3.0, c: 0.5, d: 1 }, format)], + [0, 0], + ); - it('should return undefined if there are no input items', () => { - const cat10 = new Cat(); - const { nextStimulus } = cat10.findNextItem([]); - expect(nextStimulus).toBeUndefined(); - }); -}); + cat8 = new Cat({ nStartItems: 0, itemSelect: 'FIXED' }); + }); + + const s1: Stimulus = { difficulty: 0.5, guessing: 0.5, discrimination: 1, slipping: 1, word: 'looking' }; + const s2: Stimulus = { difficulty: 3.5, guessing: 0.5, discrimination: 1, slipping: 1, word: 'opaque' }; + const s3: Stimulus = { difficulty: 2, guessing: 0.5, discrimination: 1, slipping: 1, word: 'right' }; + const s4: Stimulus = { difficulty: -2.5, guessing: 0.5, discrimination: 1, slipping: 1, word: 'yes' }; + const s5: Stimulus = { difficulty: -1.8, guessing: 0.5, discrimination: 1, slipping: 1, word: 'mom' }; + const stimuli = [s1, s2, s3, s4, s5]; + + it('constructs an adaptive test', () => { + expect(cat1.method).toBe('mle'); + expect(cat1.itemSelect).toBe('mfi'); + }); + + it('correctly updates ability estimate', () => { + expect(cat1.theta).toBeCloseTo(-1.642307, 1); + }); + + it('correctly updates ability estimate', () => { + expect(cat2.theta).toBeCloseTo(-1.272, 1); + }); + + it('correctly updates standard error of mean of ability estimate', () => { + expect(cat2.seMeasurement).toBeCloseTo(1.71, 1); + }); + + it('correctly counts number of items', () => { + expect(cat2.nItems).toEqual(7); + }); + + it('correctly updates answers', () => { + expect(cat2.resps).toEqual([0, 1, 0, 1, 1, 1, 1]); + }); + + it('correctly updates zetas', () => { + expect(cat2.zetas).toEqual([ + convertZeta({ a: 1, b: -0.447, c: 0.5, d: 1 }, format), + convertZeta({ a: 1, b: 2.869, c: 0.5, d: 1 }, format), + convertZeta({ a: 1, b: -0.469, c: 0.5, d: 1 }, format), + convertZeta({ a: 1, b: -0.576, c: 0.5, d: 1 }, format), + convertZeta({ a: 1, b: -1.43, c: 0.5, d: 1 }, format), + convertZeta({ a: 1, b: -1.607, c: 0.5, d: 1 }, format), + convertZeta({ a: 1, b: 0.529, c: 0.5, d: 1 }, format), + ]); + }); + + it('correctly suggests the next item (closest method)', () => { + const expected = { nextStimulus: s5, remainingStimuli: [s4, s1, s3, s2] }; + const received = cat1.findNextItem(stimuli, 'closest'); + expect(received).toEqual(expected); + }); + + it('correctly suggests the next item (mfi method)', () => { + const expected = { nextStimulus: s1, remainingStimuli: [s4, s5, s3, s2] }; + const received = cat3.findNextItem(stimuli, 'MFI'); + expect(received).toEqual(expected); + }); + + it('correctly suggests the next item (middle method)', () => { + const expected = { nextStimulus: s1, remainingStimuli: [s4, s5, s3, s2] }; + const received = cat5.findNextItem(stimuli); + expect(received).toEqual(expected); + }); + + it('correctly suggests the next item (fixed method)', () => { + expect(cat8.itemSelect).toBe('fixed'); + const expected = { nextStimulus: s1, remainingStimuli: [s2, s3, s4, s5] }; + const received = cat8.findNextItem(stimuli); + expect(received).toEqual(expected); + }); + + it('correctly suggests the next item (random method)', () => { + let received; + const stimuliSorted = stimuli.sort((a: Stimulus, b: Stimulus) => a.difficulty! - b.difficulty!); // ask + let index = Math.floor(rng() * stimuliSorted.length); + received = cat4.findNextItem(stimuliSorted); + expect(received.nextStimulus).toEqual(stimuliSorted[index]); + + for (let i = 0; i < 3; i++) { + const remainingStimuli = received.remainingStimuli; + index = Math.floor(rng() * remainingStimuli.length); + received = cat4.findNextItem(remainingStimuli); + expect(received.nextStimulus).toEqual(remainingStimuli[index]); + } + }); + + it('correctly updates ability estimate through MLE', () => { + expect(cat6.theta).toBeCloseTo(-6.0, 1); + }); + + it('correctly updates ability estimate through EAP', () => { + expect(cat7.theta).toBeCloseTo(0.25, 1); + }); + + it('should throw an error if zeta and answers do not have matching length', () => { + try { + cat7.updateAbilityEstimate( + [convertZeta({ a: 1, b: -4.0, c: 0.5, d: 1 }, format), convertZeta({ a: 1, b: -3.0, c: 0.5, d: 1 }, format)], + [0, 0, 0], + ); + } catch (error) { + expect(error).toBeInstanceOf(Error); + } + }); + + it('should throw an error if method is invalid', () => { + try { + new Cat({ method: 'coolMethod' }); + } catch (error) { + expect(error).toBeInstanceOf(Error); + } + + try { + cat7.updateAbilityEstimate( + [convertZeta({ a: 1, b: -4.0, c: 0.5, d: 1 }, format), convertZeta({ a: 1, b: -3.0, c: 0.5, d: 1 }, format)], + [0, 0], + 'coolMethod', + ); + } catch (error) { + expect(error).toBeInstanceOf(Error); + } + }); + + it('should throw an error if itemSelect is invalid', () => { + try { + new Cat({ itemSelect: 'coolMethod' }); + } catch (error) { + expect(error).toBeInstanceOf(Error); + } + + try { + cat7.findNextItem(stimuli, 'coolMethod'); + } catch (error) { + expect(error).toBeInstanceOf(Error); + } + }); + + it('should throw an error if startSelect is invalid', () => { + try { + new Cat({ startSelect: 'coolMethod' }); + } catch (error) { + expect(error).toBeInstanceOf(Error); + } + }); + + it('should return undefined if there are no input items', () => { + const cat10 = new Cat(); + const { nextStimulus } = cat10.findNextItem([]); + expect(nextStimulus).toBeUndefined(); + }); + }); +} diff --git a/src/clowder.ts b/src/clowder.ts index 97ee537..60367c2 100644 --- a/src/clowder.ts +++ b/src/clowder.ts @@ -1,13 +1,10 @@ import { Cat, CatInput } from './index'; -import { Stimulus, Zeta } from './type'; +import { MultiZetaStimulus, Stimulus, Zeta, ZetaCatMap } from './type'; import _cloneDeep from 'lodash/cloneDeep'; import _mapValues from 'lodash/mapValues'; +import _unzip from 'lodash/unzip'; import _zip from 'lodash/zip'; - -interface Corpora { - validated: Stimulus[]; - unvalidated: Stimulus[]; -} +import { validateCorpora } from './utils'; export interface ClowderInput { // An object containing Cat configurations for each Cat instance. @@ -15,13 +12,13 @@ export interface ClowderInput { [name: string]: CatInput; }; // An object containing arrays of stimuli for each corpus. - corpora: Corpora; + corpora: MultiZetaStimulus[]; } export class Clowder { private cats: { [name: string]: Cat }; - private corpora: Corpora; - public remainingItems: Corpora; + private corpora: MultiZetaStimulus[]; + public remainingItems: MultiZetaStimulus[]; public seenItems: Stimulus[]; /** @@ -29,8 +26,10 @@ export class Clowder { * @param {ClowderInput} input - An object containing arrays of Cat configurations and corpora. */ constructor({ cats, corpora }: ClowderInput) { + // TODO: Need to pass in numItemsRequired so that we know when to stop providing new items. this.cats = _mapValues(cats, (catInput) => new Cat(catInput)); this.seenItems = []; + validateCorpora(corpora); this.corpora = corpora; this.remainingItems = _cloneDeep(corpora); } @@ -73,23 +72,23 @@ export class Clowder { /** * Updates the ability estimates for the specified `catsToUpdate` and selects the next stimulus for the `catToSelect`. * This function processes previous items and answers, updates internal state, and selects the next stimulus - * based on the current state of validated and unvalidated stimuli. + * based on the remaining stimuli and `catToSelect`. * - * @param {Object} params - The parameters for updating the Cat instance and selecting the next stimulus. - * @param {string} params.catToSelect - The Cat instance to use for selecting the next stimulus. - * @param {string | string[]} [params.catsToUpdate=[]] - A single Cat or array of Cats for which to update ability estimates. - * @param {Stimulus[]} [params.previousItems=[]] - An array of previously presented stimuli. - * @param {(0 | 1) | (0 | 1)[]} [params.previousAnswers=[]] - An array of answers (0 or 1) corresponding to `previousItems`. - * @param {string} [params.method] - Optional method for updating ability estimates (if applicable). + * @param {Object} input - The parameters for updating the Cat instance and selecting the next stimulus. + * @param {string} input.catToSelect - The Cat instance to use for selecting the next stimulus. + * @param {string | string[]} [input.catsToUpdate=[]] - A single Cat or array of Cats for which to update ability estimates. + * @param {Stimulus[]} [input.items=[]] - An array of previously presented stimuli. + * @param {(0 | 1) | (0 | 1)[]} [input.answers=[]] - An array of answers (0 or 1) corresponding to `items`. + * @param {string} [input.method] - Optional method for updating ability estimates (if applicable). * * @returns {Stimulus | undefined} - The next stimulus to present, or `undefined` if no further validated stimuli are available. * - * @throws {Error} If `previousItems` and `previousAnswers` lengths do not match. - * @throws {Error} If any `previousItems` are not found in the Clowder's corpora (validated or unvalidated). + * @throws {Error} If `items` and `answers` lengths do not match. + * @throws {Error} If any `items` are not found in the Clowder's corpora (validated or unvalidated). * * The function operates in several steps: * 1. Validates the `catToSelect` and `catsToUpdate`. - * 2. Ensures `previousItems` and `previousAnswers` arrays are properly formatted. + * 2. Ensures `items` and `answers` arrays are properly formatted. * 3. Updates the internal list of seen items. * 4. Updates the ability estimates for the `catsToUpdate`. * 5. Selects the next stimulus for `catToSelect`, considering validated and unvalidated stimuli. @@ -97,69 +96,60 @@ export class Clowder { public updateCatAndGetNextItem({ catToSelect, catsToUpdate = [], - previousItems = [], - previousAnswers = [], + items = [], + answers = [], method, }: { catToSelect: string; catsToUpdate?: string | string[]; - previousItems: Stimulus[]; - previousAnswers: (0 | 1) | (0 | 1)[]; + items: MultiZetaStimulus[]; + answers: (0 | 1) | (0 | 1)[]; method?: string; }): Stimulus | undefined { + // Validate all cat names this._validateCatName(catToSelect); - catsToUpdate = Array.isArray(catsToUpdate) ? catsToUpdate : [catsToUpdate]; catsToUpdate.forEach((cat) => { this._validateCatName(cat); }); - previousItems = Array.isArray(previousItems) ? previousItems : [previousItems]; - previousAnswers = Array.isArray(previousAnswers) ? previousAnswers : [previousAnswers]; + // Convert items and answers to arrays + items = Array.isArray(items) ? items : [items]; + answers = Array.isArray(answers) ? answers : [answers]; - if (previousItems.length !== previousAnswers.length) { + // Ensure that the lengths of items and answers match + if (items.length !== answers.length) { throw new Error('Previous items and answers must have the same length.'); } // Update the seenItems with the provided previous items - this.seenItems.push(...previousItems); - - const itemsAndAnswers = _zip(previousItems, previousAnswers) as [Stimulus, 0 | 1][]; - // eslint-disable-next-line @typescript-eslint/no-unused-vars - const validatedItemsAndAnswers = itemsAndAnswers.filter(([item, _answer]) => this.corpora.validated.includes(item)); - // eslint-disable-next-line @typescript-eslint/no-unused-vars - const unvalidatedItemsAndAnswers = itemsAndAnswers.filter(([item, _answer]) => - this.corpora.unvalidated.includes(item), - ); - // eslint-disable-next-line @typescript-eslint/no-unused-vars - const invalidItems = itemsAndAnswers.filter(([item, _answer]) => { - return !this.corpora.validated.includes(item) && !this.corpora.unvalidated.includes(item); - }); - - if (!invalidItems) { - throw new Error( - `The following previous items provided are not in this Clowder's corpora:\n${JSON.stringify( - invalidItems, - null, - 2, - )} ${invalidItems}`, - ); + this.seenItems.push(...items); + + // Remove the seenItems from the remainingItems + this.remainingItems = this.remainingItems.filter((stim) => !items.includes(stim)); + + const itemsAndAnswers = _zip(items, answers) as [Stimulus, 0 | 1][]; + + // Update the ability estimate for all cats + for (const catName of catsToUpdate) { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const itemsAndAnswersForCat = itemsAndAnswers.filter(([stim, _answer]) => { + const allCats = stim.zetas.reduce((acc: string[], { cats }: { cats: string }) => { + return [...acc, ...cats]; + }, []); + return allCats.includes(catName); + }); + + const zetasAndAnswersForCat = itemsAndAnswersForCat.map(([stim, _answer]) => { + const { zetas } = stim; + const zetaForCat = zetas.find((zeta: ZetaCatMap) => zeta.cats.includes(catName)); + return [zetaForCat.zeta, _answer]; + }); + + // Extract the cat to update ability estimate + const [zetas, answers] = _unzip(zetasAndAnswersForCat); + this.cats[catName].updateAbilityEstimate(zetas, answers, method); } - // eslint-disable-next-line @typescript-eslint/no-unused-vars - const validatedStimuli = validatedItemsAndAnswers.map(([stim, _]) => stim); - // eslint-disable-next-line @typescript-eslint/no-unused-vars - const unvalidatedStimuli = unvalidatedItemsAndAnswers.map(([stim, _]) => stim); - // eslint-disable-next-line @typescript-eslint/no-unused-vars - const validatedAnswers = validatedItemsAndAnswers.map(([_, answer]) => answer); - - // Remove previous items from the remainingItems - this.remainingItems.validated = this.remainingItems.validated.filter((item) => !validatedStimuli.includes(item)); - this.remainingItems.unvalidated = this.remainingItems.unvalidated.filter( - (item) => !unvalidatedStimuli.includes(item), - ); - - // Update the ability estimates for the requested Cats - this.updateAbilityEstimates(catsToUpdate, validatedStimuli, validatedAnswers, method); // Use the catForSelect to determine the next stimulus const cat = this.cats[catToSelect]; @@ -189,26 +179,4 @@ export class Clowder { } } } - - /** - * Add a new Cat instance to the Clowder. - * @param {string} catName - Name of the new Cat. - * @param {CatInput} catInput - Configuration for the new Cat instance. - * @param {Stimulus[]} stimuli - The corpus for the new Cat. - */ - public addCat(catName: string, catInput: CatInput) { - if (Object.prototype.hasOwnProperty.call(this.cats, catName)) { - throw new Error(`Cat with the name "${catName}" already exists.`); - } - this.cats[catName] = new Cat(catInput); - } - - /** - * Remove a Cat instance from the Clowder. - * @param {string} catName - The name of the Cat instance to remove. - */ - public removeCat(catName: string) { - this._validateCatName(catName); - delete this.cats[catName]; - } } diff --git a/src/index.ts b/src/index.ts index 5c79286..b0cf123 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,7 +1,15 @@ +/* eslint-disable @typescript-eslint/no-non-null-assertion */ import { minimize_Powell } from 'optimization-js'; import { cloneDeep } from 'lodash'; import { Stimulus, Zeta } from './type'; -import { itemResponseFunction, fisherInformation, normal, findClosest } from './utils'; +import { + itemResponseFunction, + fisherInformation, + normal, + findClosest, + validateZetaParams, + fillZetaDefaults, +} from './utils'; import seedrandom from 'seedrandom'; export const abilityPrior = normal(); @@ -26,7 +34,6 @@ export class Cat { public prior: number[][]; private readonly _zetas: Zeta[]; private readonly _resps: (0 | 1)[]; - private _nItems: number; private _theta: number; private _seMeasurement: number; public nStartItems: number; @@ -70,7 +77,6 @@ export class Cat { this._zetas = []; this._resps = []; this._theta = theta; - this._nItems = 0; this._seMeasurement = Number.MAX_VALUE; this.nStartItems = nStartItems; this._rng = randomSeed === null ? seedrandom() : seedrandom(randomSeed); @@ -84,6 +90,9 @@ export class Cat { return this._seMeasurement; } + /** + * Return the number of items that have been observed so far. + */ public get nItems() { return this._resps.length; } @@ -135,6 +144,8 @@ export class Cat { zeta = Array.isArray(zeta) ? zeta : [zeta]; answer = Array.isArray(answer) ? answer : [answer]; + zeta.forEach((z) => validateZetaParams(z, true)); + if (zeta.length !== answer.length) { throw new Error('Unmatched length between answers and item params'); } @@ -209,6 +220,9 @@ export class Cat { } else { arr = stimuli; } + + arr = arr.map((stim) => fillZetaDefaults(stim, 'semantic')); + if (this.nItems < this.nStartItems) { selector = this.startSelect; } @@ -216,7 +230,7 @@ export class Cat { // for mfi, we sort the arr by fisher information in the private function to select the best item, // and then sort by difficulty to return the remainingStimuli // for fixed, we want to keep the corpus order as input - arr.sort((a: Stimulus, b: Stimulus) => a.difficulty - b.difficulty); + arr.sort((a: Stimulus, b: Stimulus) => a.difficulty! - b.difficulty!); } if (selector === 'middle') { @@ -233,14 +247,10 @@ export class Cat { } } - private selectorMFI(arr: Stimulus[]) { - const stimuliAddFisher = arr.map((element: Stimulus) => ({ - fisherInformation: fisherInformation(this._theta, { - a: element.a || 1, - b: element.difficulty || 0, - c: element.c || 0, - d: element.d || 1, - }), + private selectorMFI(inputStimuli: Stimulus[]) { + const stimuli = inputStimuli.map((stim) => fillZetaDefaults(stim, 'semantic')); + const stimuliAddFisher = stimuli.map((element: Stimulus) => ({ + fisherInformation: fisherInformation(this._theta, fillZetaDefaults(element, 'symbolic')), ...element, })); @@ -250,7 +260,7 @@ export class Cat { }); return { nextStimulus: stimuliAddFisher[0], - remainingStimuli: stimuliAddFisher.slice(1).sort((a: Stimulus, b: Stimulus) => a.difficulty - b.difficulty), + remainingStimuli: stimuliAddFisher.slice(1).sort((a: Stimulus, b: Stimulus) => a.difficulty! - b.difficulty!), }; } diff --git a/src/type.ts b/src/type.ts index 46e4813..522ccfd 100644 --- a/src/type.ts +++ b/src/type.ts @@ -1,27 +1,36 @@ -export const zetaKeyMap = { - a: 'discrimination', - b: 'difficulty', - c: 'guessing', - d: 'slipping', -}; - -export type ZetaImplicit = { +export type ZetaSymbolic = { + // Symbolic parameter names a: number; // Discrimination (slope of the curve) b: number; // Difficulty (location of the curve) c: number; // Guessing (lower asymptote) d: number; // Slipping (upper asymptote) }; -export type ZetaExplicit = { - discrimination: number; - difficulty: number; - guessing: number; - slipping: number; -}; +export interface Zeta { + // Symbolic parameter names + a?: number; // Discrimination (slope of the curve) + b?: number; // Difficulty (location of the curve) + c?: number; // Guessing (lower asymptote) + d?: number; // Slipping (upper asymptote) + // Semantic parameter names + discrimination?: number; + difficulty?: number; + guessing?: number; + slipping?: number; +} -export type Zeta = ZetaImplicit | ZetaExplicit; +export interface Stimulus extends Zeta { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + [key: string]: any; +} + +export type ZetaCatMap = { + cats: string[]; + zeta: Zeta; +}; -export interface Stimulus extends ZetaExplicit { +export interface MultiZetaStimulus { + zetas: ZetaCatMap[]; // eslint-disable-next-line @typescript-eslint/no-explicit-any [key: string]: any; } diff --git a/src/utils.ts b/src/utils.ts index 6982fab..f5eba90 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -1,5 +1,93 @@ +/* eslint-disable @typescript-eslint/no-non-null-assertion */ import bs from 'binary-search'; -import { Stimulus, Zeta, ZetaExplicit, ZetaImplicit } from './type'; +import { MultiZetaStimulus, Stimulus, Zeta, ZetaSymbolic } from './type'; +import _intersection from 'lodash/intersection'; +import _invert from 'lodash/invert'; +import _mapKeys from 'lodash/mapKeys'; + +export const zetaKeyMap = { + a: 'discrimination', + b: 'difficulty', + c: 'guessing', + d: 'slipping', +}; + +export const defaultZeta = (desiredFormat: 'symbolic' | 'semantic' = 'symbolic'): Zeta => { + const defaultZeta: Zeta = { + a: 1, + b: 0, + c: 0, + d: 1, + }; + + return convertZeta(defaultZeta, desiredFormat); +}; + +export const validateZetaParams = (zeta: Zeta, requireAll = false): void => { + if (zeta.a !== undefined && zeta.discrimination !== undefined) { + throw new Error('This item has both an `a` key and `discrimination` key. Please provide only one.'); + } + + if (zeta.b !== undefined && zeta.difficulty !== undefined) { + throw new Error('This item has both a `b` key and `difficulty` key. Please provide only one.'); + } + + if (zeta.c !== undefined && zeta.guessing !== undefined) { + throw new Error('This item has both a `c` key and `guessing` key. Please provide only one.'); + } + + if (zeta.d !== undefined && zeta.slipping !== undefined) { + throw new Error('This item has both a `d` key and `slipping` key. Please provide only one.'); + } + + if (requireAll) { + if (zeta.a === undefined && zeta.discrimination === undefined) { + throw new Error('This item is missing an `a` or `discrimination` key.'); + } + + if (zeta.b === undefined && zeta.difficulty === undefined) { + throw new Error('This item is missing a `b` or `difficulty` key.'); + } + + if (zeta.c === undefined && zeta.guessing === undefined) { + throw new Error('This item is missing a `c` or `guessing` key.'); + } + + if (zeta.d === undefined && zeta.slipping === undefined) { + throw new Error('This item is missing a `d` or `slipping` key.'); + } + } +}; + +export const fillZetaDefaults = (zeta: Zeta, desiredFormat: 'symbolic' | 'semantic' = 'symbolic'): Zeta => { + return { + ...defaultZeta(desiredFormat), + ...convertZeta(zeta, desiredFormat), + }; +}; + +export const convertZeta = (zeta: Zeta, desiredFormat: 'symbolic' | 'semantic'): Zeta => { + if (!['symbolic', 'semantic'].includes(desiredFormat)) { + throw new Error(`Invalid desired format. Expected 'symbolic' or'semantic'. Received ${desiredFormat} instead.`); + } + + return _mapKeys(zeta, (value, key) => { + if (desiredFormat === 'symbolic') { + const inverseMap = _invert(zetaKeyMap); + if (key in inverseMap) { + return inverseMap[key]; + } else { + return key; + } + } else { + if (key in zetaKeyMap) { + return zetaKeyMap[key as keyof typeof zetaKeyMap]; + } else { + return key; + } + } + }); +}; /** * calculates the probability that someone with a given ability level theta will answer correctly an item. Uses the 4 parameters logistic model @@ -8,16 +96,8 @@ import { Stimulus, Zeta, ZetaExplicit, ZetaImplicit } from './type'; * @returns {number} the probability */ export const itemResponseFunction = (theta: number, zeta: Zeta) => { - if ((zeta as ZetaImplicit).a) { - const _zeta = zeta as ZetaImplicit; - return _zeta.c + (_zeta.d - _zeta.c) / (1 + Math.exp(-_zeta.a * (theta - _zeta.b))); - } else { - const _zeta = zeta as ZetaExplicit; - return ( - _zeta.guessing + - (_zeta.slipping - _zeta.guessing) / (1 + Math.exp(-_zeta.discrimination * (theta - _zeta.difficulty))) - ); - } + const _zeta = fillZetaDefaults(zeta, 'symbolic') as ZetaSymbolic; + return _zeta.c + (_zeta.d - _zeta.c) / (1 + Math.exp(-_zeta.a * (theta - _zeta.b))); }; /** @@ -27,17 +107,10 @@ export const itemResponseFunction = (theta: number, zeta: Zeta) => { * @returns {number} - the expected value of the observed information */ export const fisherInformation = (theta: number, zeta: Zeta) => { - const p = itemResponseFunction(theta, zeta); + const _zeta = fillZetaDefaults(zeta, 'symbolic') as ZetaSymbolic; + const p = itemResponseFunction(theta, _zeta); const q = 1 - p; - if ((zeta as ZetaImplicit).a) { - const _zeta = zeta as ZetaImplicit; - return Math.pow(_zeta.a, 2) * (q / p) * (Math.pow(p - _zeta.c, 2) / Math.pow(1 - _zeta.c, 2)); - } else { - const _zeta = zeta as ZetaExplicit; - return ( - Math.pow(_zeta.discrimination, 2) * (q / p) * (Math.pow(p - _zeta.guessing, 2) / Math.pow(1 - _zeta.guessing, 2)) - ); - } + return Math.pow(_zeta.a, 2) * (q / p) * (Math.pow(p - _zeta.c, 2) / Math.pow(1 - _zeta.c, 2)); }; /** @@ -67,22 +140,23 @@ export const normal = (mean = 0, stdDev = 1, min = -4, max = 4, stepSize = 0.1) * @remarks * The input array of stimuli must be sorted by difficulty. * - * @param arr Array - an array of stimuli sorted by difficulty + * @param stimuli Array - an array of stimuli sorted by difficulty * @param target number - ability estimate - * @returns {number} the index of arr + * @returns {number} the index of stimuli */ -export const findClosest = (arr: Array, target: number) => { +export const findClosest = (inputStimuli: Array, target: number) => { + const stimuli = inputStimuli.map((stim) => fillZetaDefaults(stim, 'semantic')); // Let's consider the edge cases first - if (target <= arr[0].difficulty) { + if (target <= stimuli[0].difficulty!) { return 0; - } else if (target >= arr[arr.length - 1].difficulty) { - return arr.length - 1; + } else if (target >= stimuli[stimuli.length - 1].difficulty!) { + return stimuli.length - 1; } const comparitor = (element: Stimulus, needle: number) => { - return element.difficulty - needle; + return element.difficulty! - needle; }; - const indexOfTarget = bs(arr, target, comparitor); + const indexOfTarget = bs(stimuli, target, comparitor); if (indexOfTarget >= 0) { // `bs` returns a positive integer index if it found an exact match. @@ -96,8 +170,8 @@ export const findClosest = (arr: Array, target: number) => { // So we simply compare the differences between the target and the high and // low values, respectively - const lowDiff = Math.abs(arr[lowIndex].difficulty - target); - const highDiff = Math.abs(arr[highIndex].difficulty - target); + const lowDiff = Math.abs(stimuli[lowIndex].difficulty! - target); + const highDiff = Math.abs(stimuli[highIndex].difficulty! - target); if (lowDiff < highDiff) { return lowIndex; @@ -106,3 +180,13 @@ export const findClosest = (arr: Array, target: number) => { } } }; + +export const validateCorpora = (corpus: MultiZetaStimulus[]): void => { + const zetaCatMapsArray = corpus.map((item) => item.zetas); + for (const zetaCatMaps of zetaCatMapsArray) { + const intersection = _intersection(zetaCatMaps); + if (intersection.length > 0) { + throw new Error(`The cat names ${intersection.join(', ')} are present in multiple corpora.`); + } + } +};