diff --git a/internal/libtor/enabled.go b/internal/libtor/enabled.go index 89298fcd9c..5b2588b8ef 100644 --- a/internal/libtor/enabled.go +++ b/internal/libtor/enabled.go @@ -84,6 +84,8 @@ import ( "os" "runtime" "sync" + "sync/atomic" + "time" "github.com/cretz/bine/process" "github.com/ooni/probe-cli/v3/internal/netxlite" @@ -163,6 +165,10 @@ func (p *torProcess) Wait() (err error) { return } +// ErrConcurrentCalls indicates there have been concurrent libtor calls, which +// would lead to memory corruption inside of libtor.a. +var ErrConcurrentCalls = errors.New("libtor: another thread is already running tor") + // ErrTooManyArguments indicates that p.args contains too many arguments var ErrTooManyArguments = errors.New("libtor: too many arguments") @@ -172,6 +178,9 @@ var ErrCannotCreateControlSocket = errors.New("libtor: cannot create a control s // ErrNonzeroExitCode indicates that tor returned a nonzero exit code var ErrNonzeroExitCode = errors.New("libtor: command completed with nonzero exit code") +// concurrentCalls prevents concurrent libtor.a calls. +var concurrentCalls = &atomic.Int64{} + // runtor runs tor until completion and ensures that tor exits when // the given ctx is cancelled or its deadline expires. func (p *torProcess) runtor(ctx context.Context, cc net.Conn, args ...string) { @@ -191,6 +200,13 @@ func (p *torProcess) runtor(ctx context.Context, cc net.Conn, args ...string) { return } + // make sure we're not going to have actual concurrent calls. + if !concurrentCalls.CompareAndSwap(0, 1) { + p.startErr <- ErrConcurrentCalls // nonblocking channel + return + } + defer concurrentCalls.Store(0) + // Note: when writing this code I was wondering whether I needed to // use unsafe.Pointer to track pointers that matter to C code. Reading // this message[1] has been useful to understand that the most likely @@ -284,7 +300,10 @@ func (p *torProcess) runtor(ctx context.Context, cc net.Conn, args ...string) { if !p.simulateNonzeroExitCode { code = C.tor_run_main(config) } else { + // when simulating nonzero exit code we also want to sleep for a bit + // of time, to make sure we're able to see overlapped runs. code = 1 + time.Sleep(time.Second) } if code != 0 { p.waitErr <- fmt.Errorf("%w: %d", ErrNonzeroExitCode, code) // nonblocking channel diff --git a/internal/libtor/enabled_test.go b/internal/libtor/enabled_test.go index d749bc0d55..c6c940426c 100644 --- a/internal/libtor/enabled_test.go +++ b/internal/libtor/enabled_test.go @@ -160,7 +160,7 @@ func TestContextCanceledWhileTorIsRunning(t *testing.T) { t.Fatal("expected to see true here") } - process, err := creator.New(ctx) + process, err := creator.New(ctx, "SocksPort", "auto") if err != nil { t.Fatal(err) } @@ -212,7 +212,7 @@ func TestControlConnectionExplicitlyClosed(t *testing.T) { t.Fatal("expected to see true here") } - process, err := creator.New(ctx) + process, err := creator.New(ctx, "SocksPort", "auto") if err != nil { t.Fatal(err) } @@ -253,3 +253,84 @@ func TestControlConnectionExplicitlyClosed(t *testing.T) { t.Fatal(err) } } + +// This test ensures that we cannot make concurrent calls to the library. +func TestConcurrentCalls(t *testing.T) { + // we need to simulate non zero exit code here such that we're not + // actually hitting into the real tor library; by doing this we + // make the test faster and reduce the risk of triggering the + // https://github.com/ooni/probe/issues/2406 bug caused by the + // fact we're invoking tor multiple times. + + run := func(startch chan<- error) { + ctx := context.Background() + + creator, good := MaybeCreator() + if !good { + t.Fatal("expected to see true here") + } + + process, err := creator.New(ctx) + if err != nil { + t.Fatal(err) + } + process.(*torProcess).simulateNonzeroExitCode = true // don't actually run tor + + cconn, err := process.EmbeddedControlConn() + if err != nil { + t.Fatal(err) + } + defer cconn.Close() + + // we expect a process to either start successfully or fail because + // there are concurrent calls ongoing + err = process.Start() + if err != nil && !errors.Is(err, ErrConcurrentCalls) { + t.Fatal("unexpected err", err) + } + t.Log("seen this error coming from process.Start", err) + startch <- err + if err != nil { + return + } + + // the process that starts should complain about a nonzero + // exit code because it's configured in this way + if err := process.Wait(); !errors.Is(err, ErrNonzeroExitCode) { + t.Fatal("unexpected err", err) + } + } + + // attempt to create N=5 parallel instances + // + // what we would expect to see is that just one instance + // is able to start while the other four instances fail instead + // during their startup phase because of concurrency + const concurrentRuns = 5 + start := make(chan error, concurrentRuns) + for idx := 0; idx < concurrentRuns; idx++ { + go run(start) + } + var ( + countGood int + countConcurrentErr int + ) + for idx := 0; idx < concurrentRuns; idx++ { + err := <-start + if err == nil { + countGood++ + continue + } + if errors.Is(err, ErrConcurrentCalls) { + countConcurrentErr++ + continue + } + t.Fatal("unexpected error", err) + } + if countGood != 1 { + t.Fatal("expected countGood == 1, got", countGood) + } + if countConcurrentErr != 4 { + t.Fatal("expected countConcurrentErr == 4, got", countConcurrentErr) + } +}