Solution for the Challenge
There are several possible approaches to implementing a solution for reasoning simultaneously. Two of the methods that have been considered are splitting the embedding dimensions for the hidden states of both problems and using attention masks to control the attention across the problems.
However, these two methods have been excluded for the following reasons:
- Dividing the embedding dimensions contradicts the nature of the base language model,
GPT2
, and prevents it from fully leveraging the knowledge it has acquired, resulting in suboptimal performance. - Utilizing attention masks involves using two
<eos>
tokens for generation and attention masks to separate the two problems. However, this approach is essentially equivalent to treating the two multiplications as separate examples, effectively doubling the batch size. Moreover, it consumes additional computational resources and memory.
Finally, I chose to retain the paradigm of the original Implicit Chain of Thought (CoT) model but modifies the distribution of the extracted teacher minds within the teacher CoT hidden states, so that the model can do simultaneous reasoning "implicitly".
The solution involves processing a concatenated sequence of the two problems and generating the answers sequentially. However, it extracts the teacher's minds alternately from the CoT hidden states of two diagonals. The extracted minds are illustrated in the provided figure. The first <eos>
token in the first layer and the second <eos>
token in the last layer are selected as the hidden states. Additionally, the hidden states from the two diagonals of the CoT for Multiplication 1 (M1) and Multiplication 2 (M2) are chosen alternately. Although candidate states are not extracted, they are shown in the figure for readers to better see the diagonal.
This variant is designed to receive the diagonal states of both multiplication CoTs. In every layer except the first and last layers, the first half of the dimensions of M1 and the second half of the dimensions of M2 are concatenated and extracted as teacher's minds.
A simple but effective multi-modal fusion method is to sum the features of two modalities. This method sums up the two diagonals of the two CoTs, so that the model remember the reasoning process of both problems.
The basic task is accomplished through the following three steps:
- Creation of the double 2x2 multiplication dataset. Because the train set size in the paper is too big for two $ 2 \times 2 $ multiplications, I also generated a smaller dataset with 10k examples for train set, and 2k for validation and test. The code for this step can be found at
src/scripts/generate_new_data_small.py
. - Modification of the distribution of extracted teacher minds.
- Fine-tuning of
GPT2
on the new dataset with the revised distribution. Besides the above methods, an experiment in which no modification is made to the distribution of extracted teacher minds is also conducted for comparison. Although it does not reason simultaneously, it is more compatible with the transformer architecture and expected to perform better than the above methods.
Add --subset diagonal_double
to the command line at emulator training and use the same commands for other steps as the original code, and set DIAGONAL_DOUBLE_VARIANT_CONCAT
or DIAGONAL_DOUBLE_VARIANT_SUM
in environment variables to choose the variant.
The table below presents the accuracy of the above methods trained with 800k train samples. The training epochs "3, 4, 6, 1" indicate the number of epochs for training the teacher, emulator, student, and coupled model, respectively.
Method\Epochs | 3, 4, 6, 1 | 3, 2, 3, 1 |
---|---|---|
Vanilla | 100% | 100% |
Sum | 100% | 100% |
Concatenate | 99.7% | 98.3% |
Alternating | 97.9 | 99.8% |
Because their results are too close, I do not think we can conclude that one method is better than the other.
Method\Epochs | 20 | 50 |
---|---|---|
Vanilla | 70.4% | 91.0% |
Sum | 63.1% | 80.9% |
Concatenate | 56.1% | 57.4% |
Alternating | 53.6% | 51.8% |
When the train set size scales to 10k and validation set size scales to 2k, more training epochs are required to fine-tune a GPT2
model. Therefore, I trained for 20 epochs in all training stages with 4 times and report the average accuracy and 50 epochs in all training. The results are shown in the table above. As expected, Vanilla
performs the best when trained 20 and 50 epochs, and "Sum" follows because the former is most compatible and natural for the transformer architecture. The Sum
requires the model to remember the reasoning process of both problems, which is more difficult than the Vanilla
method. The "Concatenate" method disrupts the natural dimension of the model, which is not expected to perform well. The "Alternating" method is the worst because the alternating pattern is too hard for the model to learn."
Regarding the three steps in the inference process, it seems impossible to perform beam search in mind-reading. As for the generation process, it is originally autoregressive, which means it is hard to explore beam searches other than the original beam search. Therefore, I think "using hidden states to imitate beam search" refers to the mixture of components during thought emulation.
In thought emulation, the model uses the mixture model, which is included in the emulator technically, to process the "Mind token" between the transformer layers of the language model. In this mixing process, one of the mixing components is chosen and fused with the intermediate hidden states. Because there exists the decision of choosing the best component, beam search can be applied to it.
Because the mixture of components is designed for "Multi Reasoning Pathways" tasks, I only applied beam search to the GSM8K
dataset and didn't apply it to the multiplication datasets. The emulation of beam search between the transformer layers is implemented in the commit f31bae
. I will provide the results and analysis in the next section. Specifically, I select the beams with the highest accumulated log probabilities of the component paths and pass them to the next transformer layer. In the mixture process, I embed the component one-hot vectors and fuse them with the hidden states separately using an MLP, as stated in the paper. Furthermore, I kept the batching during inference to maintain parallelism and inference speed.
To enable beam search, set "EMULATOR_BEAM_SIZE" in environment variables. The command line for inference is as follows:
export EMULATOR_BEAM_SIZE=5
#export BEAM_SEARCH_SOFTMAX_TEMPERATURE=1
#export PRINT_ICOT=true
python src/generate.py --batch_size 1 --test_path data/gsm8k/test.txt --student_path models/gsm8k/gpt2/student --emulator_path models/gsm8k/gpt2/emulator
Despite implementing beam search, the performance of the model did not show clear improvement. I grid-searched on beam width and softmax temperature, to find the performance stable.
Here are the detailed results:
Beam Width \ Softmax Temperature | 0.05 | 1.0 | 3.0 | 10.0 | 30.0 |
---|---|---|---|---|---|
None | 0.20 | 0.20 | 0.20 | 0.20 | 0.20 |
1 | 0.20 | 0.20 | 0.20 | 0.20 | 0.20 |
5 | 0.20 | 0.20 | 0.20 | 0.20 | 0.20 |
15 | 0.20 | 0.20 | 0.20 | 0.20 | 0.20 |
The results show that the performance of the model is almost not affected by the softmax temperature and the beam size, degrading to greed search. I think the reason is that the decision of choosing the best component is "sure" and not heavily influenced by the softmax temperature. To verify this idea, I printed the tokens of the best components in three different experiment settings. The results are shown below:
beam size: 1, softmax temperature: 0.05 (default)
<|endoftext|> / 560 Ġ250 ĠFlood 750 - = FactoryReloaded Ġballpark Vers Ġgrun
<|endoftext|> / 560 ĠPentagon ĠFlood 750 * Ġingestion Ott ĠVI Ġediting .;
<|endoftext|> / 560 Ġ250 ĠFlood 750 + Ġingestion ĠRaiders Ġbadges Ġdelinquent ĠSail
<|endoftext|> / 560 Ġ250 ĠFlood 750 + = oling arian <|endoftext|> .;
<|endoftext|> / 560 Ġ250 ĠFlood 750 + = ĠRaiders Ġbadges Ġdelinquent paper
beam size: 5, softmax temperature: 0.05
<|endoftext|> / 560 Ġ250 ĠFlood 750 - = FactoryReloaded Ġballpark Vers Ġgrun
<|endoftext|> / 560 ĠPentagon ĠFlood 750 * Ġingestion Ott ĠVI Ġediting .;
<|endoftext|> / 560 Ġ250 ĠFlood 750 + Ġingestion ĠRaiders Ġbadges Ġdelinquent ĠSail
<|endoftext|> / 560 Ġ250 ĠFlood 750 + = oling arian <|endoftext|> .;
<|endoftext|> / 560 Ġ250 ĠFlood 750 + = ĠRaiders Ġbadges Ġdelinquent paper
beam size: 5, softmax temperature: 10
<|endoftext|> / 560 Ġ250 ĠFlood 750 - = FactoryReloaded Ġballpark Vers Ġgrun
<|endoftext|> / 560 ĠPentagon ĠFlood 750 * Ġingestion Ott ĠVI Ġediting .;
<|endoftext|> / 560 Ġ250 ĠFlood 750 + Ġingestion ĠRaiders Ġbadges Ġdelinquent ĠSail
<|endoftext|> / 560 Ġ250 ĠFlood 750 + = oling arian <|endoftext|> .;
<|endoftext|> / 560 Ġ250 ĠFlood 750 + = ĠRaiders Ġbadges Ġdelinquent paper
And their corresponding input, cot, and output are
Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?||<<16-3-4=9>> <<9*2=18>> #### 18
A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?||<<2/2=1>> <<2+1=3>> #### 3
Josh decides to try flipping a house. He buys a house for $80,000 and then puts in $50,000 in repairs. This increased the value of the house by 150%. How much profit did he make?||<<80000+50000=130000>> <<80000*1.5=120000>> <<120000+80000=200000>> <<200000-130000=70000>> #### 70000
James decides to run 3 sprints 3 times a week. He runs 60 meters each sprint. How many total meters does he run a week?||<<3*3=9>> <<9*60=540>> #### 540
Every day, Wendi feeds each of her chickens three cups of mixed chicken feed, containing seeds, mealworms and vegetables to help keep them healthy. She gives the chickens their feed in three separate meals. In the morning, she gives her flock of chickens 15 cups of feed. In the afternoon, she gives her chickens another 25 cups of feed. How many cups of feed does she need to give her chickens in the final meal of the day if the size of Wendi's flock is 20 chickens?||<<3*20=60>> <<60-15-25=20>> #### 20
It is obvious in the visualization that the decision of components under these three settings is completely the same. This further proves my assumption.