diff --git a/mixing/mixpool/mixpool.go b/mixing/mixpool/mixpool.go index 335dbb938..5c1637693 100644 --- a/mixing/mixpool/mixpool.go +++ b/mixing/mixpool/mixpool.go @@ -731,10 +731,13 @@ type Received struct { // such session has any messages currently accepted in the mixpool, the method // immediately errors. // -// If any secrets messages are received for the described session, and r.RSs -// is nil, Receive immediately returns ErrSecretsRevealed. An additional call -// to Receive with a non-nil RSs can be used to receive all of the secrets -// after each peer publishes their own revealed secrets. +// If r.RSs is nil and any secrets messages are received for the described +// session, Receive will count these messages towards the total number of +// expected messages based on the slice capacities (without overcounting if +// the identity has also sent an expected non-RS message) and will return +// ErrSecretsRevealed. An additional call to Receive with a non-nil RSs can +// be used to receive all of the secrets after each peer publishes their own +// revealed secrets. func (p *Pool) Receive(ctx context.Context, r *Received) error { sid := r.Sid var bc *broadcast @@ -784,42 +787,32 @@ Loop: for { // Pool is locked for reads. Count if the total number of // expected messages have been received. - received := 0 + received := make(map[idPubKey]struct{}) + countMsg := func(msg mixing.Message) { + received[*(*idPubKey)(msg.Pub())] = struct{}{} + } for hash := range ses.hashes { - msgtype := p.pool[hash].msgtype + e := p.pool[hash] + msg := e.msg + msgtype := e.msgtype switch { case msgtype == msgtypeKE && r.KEs != nil: - received++ + countMsg(msg) case msgtype == msgtypeCT && r.CTs != nil: - received++ + countMsg(msg) case msgtype == msgtypeSR && r.SRs != nil: - received++ + countMsg(msg) case msgtype == msgtypeDC && r.DCs != nil: - received++ + countMsg(msg) case msgtype == msgtypeCM && r.CMs != nil: - received++ + countMsg(msg) case msgtype == msgtypeFP && r.FPs != nil: - received++ + countMsg(msg) case msgtype == msgtypeRS: - if r.RSs == nil { - // Since initial reporters of secrets - // need to take the blame for - // erroneous blame assignment if no - // issue was detected, we only trigger - // this for RS messages that do not - // reference any other previous RS. - rs := p.pool[hash].msg.(*wire.MsgMixSecrets) - prev := rs.PrevMsgs() - if len(prev) == 0 { - p.mtx.RUnlock() - return ErrSecretsRevealed - } - } else { - received++ - } + countMsg(msg) } } - if received >= expectedMessages { + if len(received) >= expectedMessages { break } @@ -836,6 +829,8 @@ Loop: p.mtx.RLock() } + var err error + // Pool is locked for reads. Collect all of the messages. for hash := range ses.hashes { msg := p.pool[hash].msg @@ -865,14 +860,16 @@ Loop: r.FPs = append(r.FPs, msg) } case *wire.MsgMixSecrets: - if r.RSs != nil { + if r.RSs == nil { + err = ErrSecretsRevealed + } else { r.RSs = append(r.RSs, msg) } } } p.mtx.RUnlock() - return nil + return err } var zeroHash chainhash.Hash