Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Commit

Permalink
Provisioning API fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
smweber committed Nov 1, 2023
1 parent a09cec4 commit 4064dba
Showing 1 changed file with 57 additions and 15 deletions.
72 changes: 57 additions & 15 deletions provisioning.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type ProvisioningAPI struct {
bridge *SignalBridge
log zerolog.Logger
provisioningChannels []<-chan signalmeow.ProvisioningResponse
provisioningUsers map[string]int
}

func (prov *ProvisioningAPI) Init() {
Expand Down Expand Up @@ -105,6 +106,15 @@ type Response struct {
func (prov *ProvisioningAPI) LinkNew(w http.ResponseWriter, r *http.Request) {
user := r.Context().Value("user").(*User)
prov.log.Debug().Msgf("LinkNew from %v", user.MXID)
if _, ok := prov.provisioningUsers[user.MXID.String()]; ok {
prov.log.Warn().Msgf("LinkNew from %v, user already has a pending provisioning request", user.MXID)
jsonResponse(w, http.StatusConflict, Error{
Success: false,
Error: "User already has a pending provisioning request",
ErrCode: "M_CONFLICT",
})
return
}

provChan, err := user.Login()
if err != nil {
Expand All @@ -117,8 +127,9 @@ func (prov *ProvisioningAPI) LinkNew(w http.ResponseWriter, r *http.Request) {
return
}
prov.provisioningChannels = append(prov.provisioningChannels, provChan)
prov.log.Debug().Msgf("LinkNew from %v, waiting for provisioning response", user.MXID)
sessionID := len(prov.provisioningChannels) - 1
prov.provisioningUsers[user.MXID.String()] = sessionID
prov.log.Debug().Msgf("LinkNew from %v, waiting for provisioning response", user.MXID)

select {
case resp := <-provChan:
Expand Down Expand Up @@ -150,11 +161,11 @@ func (prov *ProvisioningAPI) LinkNew(w http.ResponseWriter, r *http.Request) {
})
return
case <-time.After(30 * time.Second):
prov.log.Err(err).Msg("Timeout waiting for provisioning response")
jsonResponse(w, http.StatusInternalServerError, Error{
prov.log.Err(err).Msg("Timeout waiting for provisioning response (new)")
jsonResponse(w, http.StatusGatewayTimeout, Error{
Success: false,
Error: "Timeout waiting for provisioning response",
ErrCode: "M_INTERNAL",
Error: "Timeout waiting for provisioning response (new)",
ErrCode: "M_TIMEOUT",
})
return
}
Expand All @@ -175,6 +186,7 @@ func (prov *ProvisioningAPI) LinkWaitForScan(w http.ResponseWriter, r *http.Requ
})
return
}

sessionID, err := strconv.Atoi(body.SessionID)
if err != nil {
prov.log.Err(err).Msg("Error decoding JSON body")
Expand All @@ -186,6 +198,15 @@ func (prov *ProvisioningAPI) LinkWaitForScan(w http.ResponseWriter, r *http.Requ
return
}
prov.log.Debug().Msgf("LinkWaitForScan from %v, session_id: %v", user.MXID, sessionID)
if userSessionID, ok := prov.provisioningUsers[user.MXID.String()]; ok && userSessionID != sessionID {
prov.log.Warn().Msgf("LinkWaitForAccount from %v, session_id %v does not match user's session_id %v", user.MXID, sessionID, userSessionID)
jsonResponse(w, http.StatusBadRequest, Error{
Success: false,
Error: "session_id does not match user's session_id",
ErrCode: "M_BAD_JSON",
})
return
}
respChan := prov.provisioningChannels[sessionID]

select {
Expand Down Expand Up @@ -221,12 +242,12 @@ func (prov *ProvisioningAPI) LinkWaitForScan(w http.ResponseWriter, r *http.Requ
user.Update()
}
return
case <-time.After(30 * time.Second):
prov.log.Err(err).Msg("Timeout waiting for provisioning response")
jsonResponse(w, http.StatusInternalServerError, Error{
case <-time.After(60 * time.Second):
prov.log.Err(err).Msg("Timeout waiting for provisioning response (scan)")
jsonResponse(w, http.StatusRequestTimeout, Error{
Success: false,
Error: "Timeout waiting for provisioning response",
ErrCode: "M_INTERNAL",
Error: "Timeout waiting for QR code scan",
ErrCode: "M_TIMEOUT",
})
return
}
Expand Down Expand Up @@ -260,6 +281,15 @@ func (prov *ProvisioningAPI) LinkWaitForAccount(w http.ResponseWriter, r *http.R
}
deviceName := body.DeviceName
prov.log.Debug().Msgf("LinkWaitForAccount from %v, session_id: %v, device_name: %v", user.MXID, sessionID, deviceName)
if userSessionID, ok := prov.provisioningUsers[user.MXID.String()]; ok && userSessionID != sessionID {
prov.log.Warn().Msgf("LinkWaitForAccount from %v, session_id %v does not match user's session_id %v", user.MXID, sessionID, userSessionID)
jsonResponse(w, http.StatusBadRequest, Error{
Success: false,
Error: "session_id does not match user's session_id",
ErrCode: "M_BAD_JSON",
})
return
}
respChan := prov.provisioningChannels[sessionID]

select {
Expand Down Expand Up @@ -295,18 +325,30 @@ func (prov *ProvisioningAPI) LinkWaitForAccount(w http.ResponseWriter, r *http.R
user.Connect()
return
case <-time.After(30 * time.Second):
prov.log.Err(err).Msg("Timeout waiting for provisioning response")
jsonResponse(w, http.StatusInternalServerError, Error{
prov.log.Err(err).Msg("Timeout waiting for provisioning response (account)")
jsonResponse(w, http.StatusGatewayTimeout, Error{
Success: false,
Error: "Timeout waiting for provisioning response",
ErrCode: "M_INTERNAL",
Error: "Timeout waiting for provisioning response (account)",
ErrCode: "M_TIMEOUT",
})
return
}
}

func (prov *ProvisioningAPI) CancelLink(user *User) {
if sessionID, ok := prov.provisioningUsers[user.MXID.String()]; ok {
prov.log.Debug().Msgf("CancelLink called for %v, clearing session %v", user.MXID, sessionID)
prov.provisioningChannels[sessionID] = nil
delete(prov.provisioningUsers, user.MXID.String())
} else {
prov.log.Debug().Msgf("CancelLink called for %v, no session found", user.MXID)
}
}

func (prov *ProvisioningAPI) Logout(w http.ResponseWriter, r *http.Request) {
//user := r.Context().Value("user").(*User)
user := r.Context().Value("user").(*User)
prov.log.Debug().Msgf("Logout called from %v (but not logging out)", user.MXID)
prov.CancelLink(user)

// For now do nothing - we need this API to return 200 to be compatible with
// the old Signal bridge, which needed a call to Logout before allowing LinkNew
Expand Down

0 comments on commit 4064dba

Please sign in to comment.