diff --git a/cmd/api/src/api/tools/PG_MIGRATE.md b/cmd/api/src/api/tools/PG_MIGRATE.md new file mode 100644 index 000000000..ff0d70e0f --- /dev/null +++ b/cmd/api/src/api/tools/PG_MIGRATE.md @@ -0,0 +1,20 @@ +## Migrating Graph Data from Neo4j to Postgres + +### Endpoints +| Endpoint | HTTP Request | Usage | Expected Response | +| --- | --- | --- | --- | +| `/pg-migration/status/` | `GET` | Returns a status indicating whether the migrator is currently running. | **Status:** `200 OK`

{
  "state": "idle" \| "migrating" \| "canceling"
}
| +| `/pg-migration/start/` | `PUT` | Kicks off the migration process from neo4j to postgres. | **Status:** `202 Accepted` | +| `/pg-migration/cancel/` | `PUT` | Cancels the currently running migration. | **Status:** `202 Accepted` | +| `/graph-db/switch/pg/` | `PUT` | Switches the current graph database driver to postgres. | **Status:** `200 OK` | +| `/graph-db/switch/ne04j/` | `PUT` | Switches the current graph database driver to ne04j. | **Status:** `200 OK` | + +### Running a Migration +1. Confirm the migration status is currently "idle" before running a migration with the `/pg-migration/status/` endpoint. The migration will run in the same direction regardless of the currently selected graph driver. +2. Start the migration process using the `/pg-migration/start/` endpoint. Since the migration occurs asynchronously, you will want to monitor the API logs to see information regarding the currently running migration. + - When the migration starts, there should be a log with the message `"Dispatching live migration from Neo4j to PostgreSQL"` + - Upon completion, you should see the message `"Migration to PostgreSQL completed successfully"` + - Any errors that occur during the migration process will also surface here + - You can also poll the `/pg-migration/status/` endpoint and wait for an `"idle"` status to indicate the migration has completed + - An in-progess migration can be cancelled with the `pg-migration/cancel/` endpoint and run again at any time +3. Once you are ready to switch over to the postgres graph driver, you can use the `/graph-db/switch/pg/` endpoint. \ No newline at end of file diff --git a/cmd/api/src/api/tools/pg.go b/cmd/api/src/api/tools/pg.go index fbd597d9a..45dc020f0 100644 --- a/cmd/api/src/api/tools/pg.go +++ b/cmd/api/src/api/tools/pg.go @@ -36,9 +36,9 @@ import ( type MigratorState string const ( - stateIdle MigratorState = "idle" - stateMigrating MigratorState = "migrating" - stateCanceling MigratorState = "canceling" + StateIdle MigratorState = "idle" + StateMigrating MigratorState = "migrating" + StateCanceling MigratorState = "canceling" ) func migrateTypes(ctx context.Context, neoDB, pgDB graph.Database) error { @@ -187,21 +187,21 @@ func migrateEdges(ctx context.Context, neoDB, pgDB graph.Database, nodeIDMapping type PGMigrator struct { graphSchema graph.Schema graphDBSwitch *graph.DatabaseSwitch - serverCtx context.Context + ServerCtx context.Context migrationCancelFunc func() - state MigratorState + State MigratorState lock *sync.Mutex - cfg config.Configuration + Cfg config.Configuration } func NewPGMigrator(serverCtx context.Context, cfg config.Configuration, graphSchema graph.Schema, graphDBSwitch *graph.DatabaseSwitch) *PGMigrator { return &PGMigrator{ graphSchema: graphSchema, graphDBSwitch: graphDBSwitch, - serverCtx: serverCtx, - state: stateIdle, + ServerCtx: serverCtx, + State: StateIdle, lock: &sync.Mutex{}, - cfg: cfg, + Cfg: cfg, } } @@ -212,31 +212,28 @@ func (s *PGMigrator) advanceState(next MigratorState, validTransitions ...Migrat isValid := false for _, validTransition := range validTransitions { - if s.state == validTransition { + if s.State == validTransition { isValid = true break } } if !isValid { - return fmt.Errorf("migrator state is %s but expected one of: %v", s.state, validTransitions) + return fmt.Errorf("migrator state is %s but expected one of: %v", s.State, validTransitions) } - s.state = next + s.State = next return nil } func (s *PGMigrator) SwitchPostgreSQL(response http.ResponseWriter, request *http.Request) { - if pgDB, err := dawgs.Open(s.serverCtx, pg.DriverName, dawgs.Config{ - GraphQueryMemoryLimit: size.Gibibyte, - DriverCfg: s.cfg.Database.PostgreSQLConnectionString(), - }); err != nil { + if pgDB, err := s.OpenPostgresGraphConnection(); err != nil { api.WriteJSONResponse(request.Context(), map[string]any{ "error": fmt.Errorf("failed connecting to PostgreSQL: %w", err), }, http.StatusInternalServerError, response) } else if err := pgDB.AssertSchema(request.Context(), s.graphSchema); err != nil { log.Errorf("Unable to assert graph schema in PostgreSQL: %v", err) - } else if err := SetGraphDriver(request.Context(), s.cfg, pg.DriverName); err != nil { + } else if err := SetGraphDriver(request.Context(), s.Cfg, pg.DriverName); err != nil { api.WriteJSONResponse(request.Context(), map[string]any{ "error": fmt.Errorf("failed updating graph database driver preferences: %w", err), }, http.StatusInternalServerError, response) @@ -249,14 +246,11 @@ func (s *PGMigrator) SwitchPostgreSQL(response http.ResponseWriter, request *htt } func (s *PGMigrator) SwitchNeo4j(response http.ResponseWriter, request *http.Request) { - if neo4jDB, err := dawgs.Open(s.serverCtx, neo4j.DriverName, dawgs.Config{ - GraphQueryMemoryLimit: size.Gibibyte, - DriverCfg: s.cfg.Neo4J.Neo4jConnectionString(), - }); err != nil { + if neo4jDB, err := s.OpenNeo4jGraphConnection(); err != nil { api.WriteJSONResponse(request.Context(), map[string]any{ "error": fmt.Errorf("failed connecting to Neo4j: %w", err), }, http.StatusInternalServerError, response) - } else if err := SetGraphDriver(request.Context(), s.cfg, neo4j.DriverName); err != nil { + } else if err := SetGraphDriver(request.Context(), s.Cfg, neo4j.DriverName); err != nil { api.WriteJSONResponse(request.Context(), map[string]any{ "error": fmt.Errorf("failed updating graph database driver preferences: %w", err), }, http.StatusInternalServerError, response) @@ -268,23 +262,17 @@ func (s *PGMigrator) SwitchNeo4j(response http.ResponseWriter, request *http.Req } } -func (s *PGMigrator) startMigration() error { - if err := s.advanceState(stateMigrating, stateIdle); err != nil { +func (s *PGMigrator) StartMigration() error { + if err := s.advanceState(StateMigrating, StateIdle); err != nil { return fmt.Errorf("database migration state error: %w", err) - } else if neo4jDB, err := dawgs.Open(s.serverCtx, neo4j.DriverName, dawgs.Config{ - GraphQueryMemoryLimit: size.Gibibyte, - DriverCfg: s.cfg.Neo4J.Neo4jConnectionString(), - }); err != nil { + } else if neo4jDB, err := s.OpenNeo4jGraphConnection(); err != nil { return fmt.Errorf("failed connecting to Neo4j: %w", err) - } else if pgDB, err := dawgs.Open(s.serverCtx, pg.DriverName, dawgs.Config{ - GraphQueryMemoryLimit: size.Gibibyte, - DriverCfg: s.cfg.Database.PostgreSQLConnectionString(), - }); err != nil { + } else if pgDB, err := s.OpenPostgresGraphConnection(); err != nil { return fmt.Errorf("failed connecting to PostgreSQL: %w", err) } else { log.Infof("Dispatching live migration from Neo4j to PostgreSQL") - migrationCtx, migrationCancelFunc := context.WithCancel(s.serverCtx) + migrationCtx, migrationCancelFunc := context.WithCancel(s.ServerCtx) s.migrationCancelFunc = migrationCancelFunc go func(ctx context.Context) { @@ -304,7 +292,7 @@ func (s *PGMigrator) startMigration() error { log.Infof("Migration to PostgreSQL completed successfully") } - if err := s.advanceState(stateIdle, stateMigrating, stateCanceling); err != nil { + if err := s.advanceState(StateIdle, StateMigrating, StateCanceling); err != nil { log.Errorf("Database migration state management error: %v", err) } }(migrationCtx) @@ -314,7 +302,7 @@ func (s *PGMigrator) startMigration() error { } func (s *PGMigrator) MigrationStart(response http.ResponseWriter, request *http.Request) { - if err := s.startMigration(); err != nil { + if err := s.StartMigration(); err != nil { api.WriteJSONResponse(request.Context(), map[string]any{ "error": err.Error(), }, http.StatusInternalServerError, response) @@ -323,8 +311,8 @@ func (s *PGMigrator) MigrationStart(response http.ResponseWriter, request *http. } } -func (s *PGMigrator) cancelMigration() error { - if err := s.advanceState(stateCanceling, stateMigrating); err != nil { +func (s *PGMigrator) CancelMigration() error { + if err := s.advanceState(StateCanceling, StateMigrating); err != nil { return err } @@ -334,7 +322,7 @@ func (s *PGMigrator) cancelMigration() error { } func (s *PGMigrator) MigrationCancel(response http.ResponseWriter, request *http.Request) { - if err := s.cancelMigration(); err != nil { + if err := s.CancelMigration(); err != nil { api.WriteJSONResponse(request.Context(), map[string]any{ "error": err.Error(), }, http.StatusInternalServerError, response) @@ -345,6 +333,20 @@ func (s *PGMigrator) MigrationCancel(response http.ResponseWriter, request *http func (s *PGMigrator) MigrationStatus(response http.ResponseWriter, request *http.Request) { api.WriteJSONResponse(request.Context(), map[string]any{ - "state": s.state, + "state": s.State, }, http.StatusOK, response) } + +func (s *PGMigrator) OpenPostgresGraphConnection() (graph.Database, error) { + return dawgs.Open(s.ServerCtx, pg.DriverName, dawgs.Config{ + GraphQueryMemoryLimit: size.Gibibyte, + DriverCfg: s.Cfg.Database.PostgreSQLConnectionString(), + }) +} + +func (s *PGMigrator) OpenNeo4jGraphConnection() (graph.Database, error) { + return dawgs.Open(s.ServerCtx, neo4j.DriverName, dawgs.Config{ + GraphQueryMemoryLimit: size.Gibibyte, + DriverCfg: s.Cfg.Neo4J.Neo4jConnectionString(), + }) +} diff --git a/cmd/api/src/api/tools/pg_test.go b/cmd/api/src/api/tools/pg_test.go new file mode 100644 index 000000000..ea178fb11 --- /dev/null +++ b/cmd/api/src/api/tools/pg_test.go @@ -0,0 +1,229 @@ +// Copyright 2024 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +//go:build integration +// +build integration + +package tools_test + +import ( + "context" + "net/http" + "net/http/httptest" + "slices" + "testing" + "time" + + "github.com/specterops/bloodhound/dawgs/drivers/neo4j" + "github.com/specterops/bloodhound/dawgs/drivers/pg" + pg_query "github.com/specterops/bloodhound/dawgs/drivers/pg/query" + "github.com/specterops/bloodhound/dawgs/graph" + graph_mocks "github.com/specterops/bloodhound/dawgs/graph/mocks" + "github.com/specterops/bloodhound/dawgs/ops" + "github.com/specterops/bloodhound/dawgs/query" + "github.com/specterops/bloodhound/graphschema" + "github.com/specterops/bloodhound/graphschema/common" + "github.com/specterops/bloodhound/src/api/tools" + "github.com/specterops/bloodhound/src/test/integration" + "github.com/specterops/bloodhound/src/test/integration/utils" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" +) + +func TestSwitchPostgreSQL(t *testing.T) { + var ( + mockCtrl = gomock.NewController(t) + graphDB = graph_mocks.NewMockDatabase(mockCtrl) + request = httptest.NewRequest(http.MethodPut, "/graph-db/switch/pg", nil) + recorder = httptest.NewRecorder() + ctx = request.Context() + migrator = setupTestMigrator(t, ctx, graphDB) + ) + + // lookup creates the database_switch table if needed + driver, err := tools.LookupGraphDriver(migrator.ServerCtx, migrator.Cfg) + require.Nil(t, err) + + if driver != neo4j.DriverName { + err = tools.SetGraphDriver(migrator.ServerCtx, migrator.Cfg, neo4j.DriverName) + require.Nil(t, err) + } + + migrator.SwitchPostgreSQL(recorder, request) + + response := recorder.Result() + defer response.Body.Close() + + require.Equal(t, http.StatusOK, response.StatusCode) + + driver, err = tools.LookupGraphDriver(migrator.ServerCtx, migrator.Cfg) + require.Nil(t, err) + require.Equal(t, pg.DriverName, driver) +} + +func TestSwitchNeo4j(t *testing.T) { + var ( + mockCtrl = gomock.NewController(t) + graphDB = graph_mocks.NewMockDatabase(mockCtrl) + request = httptest.NewRequest(http.MethodPut, "/graph-db/switch/neo4j", nil) + recorder = httptest.NewRecorder() + ctx = request.Context() + migrator = setupTestMigrator(t, ctx, graphDB) + ) + + driver, err := tools.LookupGraphDriver(migrator.ServerCtx, migrator.Cfg) + require.Nil(t, err) + + if driver != pg.DriverName { + err = tools.SetGraphDriver(migrator.ServerCtx, migrator.Cfg, pg.DriverName) + require.Nil(t, err) + } + + migrator.SwitchNeo4j(recorder, request) + + response := recorder.Result() + defer response.Body.Close() + + require.Equal(t, http.StatusOK, response.StatusCode) + + driver, err = tools.LookupGraphDriver(migrator.ServerCtx, migrator.Cfg) + require.Nil(t, err) + require.Equal(t, neo4j.DriverName, driver) +} + +func TestPGMigrator(t *testing.T) { + var ( + schema = graphschema.DefaultGraphSchema() + testContext = integration.NewGraphTestContext(t, schema) + ) + + testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) error { + harness.DBMigrateHarness.Setup(testContext) + return nil + }, func(harness integration.HarnessDetails, neo4jDB graph.Database) { + var ( + migrator = setupTestMigrator(t, testContext.Context(), neo4jDB) + testID = harness.DBMigrateHarness.TestID.String() + sourceNodeKinds graph.Kinds + sourceEdgeKinds graph.Kinds + sourceNodes []*graph.Node + sourceEdges []*graph.Relationship + ) + + pgDB, err := migrator.OpenPostgresGraphConnection() + require.Nil(t, err) + + // clear out nodes to avoid conflict when running the test multiple times + err = pgDB.WriteTransaction(testContext.Context(), func(tx graph.Transaction) error { + return tx.Nodes().Delete() + }) + require.Nil(t, err) + + err = migrator.StartMigration() + require.Nil(t, err) + + // wait until migration status returns to "idle" + for { + if migrator.State == tools.StateMigrating { + time.Sleep(time.Second / 10) + } else if migrator.State == tools.StateIdle { + break + } else { + t.Fatalf("Encountered invalid migration status: %s", migrator.State) + } + } + + // query nodes/relationships in neo4j + err = neo4jDB.ReadTransaction(testContext.Context(), func(tx graph.Transaction) error { + sourceNodes, err = ops.FetchNodes(tx.Nodes()) + require.Nil(t, err) + + sourceEdges, err = ops.FetchRelationships(tx.Relationships()) + require.Nil(t, err) + + return nil + }) + require.Nil(t, err) + + // grab source kinds + // NOTE: the call to db.labels() in our migrator returns all possible node kinds in neo4j, while db.relationshipTypes() + // returns just those edge kinds that have an associated edge in the db, so that is the behavior we are testing here + sourceNodeKinds = schema.DefaultGraph.Nodes + + for _, edge := range sourceEdges { + if !slices.Contains(sourceEdgeKinds, edge.Kind) { + sourceEdgeKinds = append(sourceEdgeKinds, edge.Kind) + } + } + + // confirm that all the data from neo4j made it to pg + err = pgDB.ReadTransaction(testContext.Context(), func(tx graph.Transaction) error { + + // check nodes + for _, sourceNode := range sourceNodes { + id, err := sourceNode.Properties.Get(testID).String() + require.Nil(t, err) + + if targetNode, err := tx.Nodes().Filterf(func() graph.Criteria { + return query.Equals(query.NodeProperty(testID), id) + }).First(); err != nil { + t.Fatalf("Could not find migrated node with '%s' == %s", testID, id) + } else { + require.Equal(t, sourceNode.Kinds, targetNode.Kinds) + require.Equal(t, sourceNode.Properties.Get(common.Name.String()), targetNode.Properties.Get(common.Name.String())) + require.Equal(t, sourceNode.Properties.Get(common.ObjectID.String()), targetNode.Properties.Get(common.ObjectID.String())) + } + } + + // check edges + for _, sourceEdge := range sourceEdges { + id, err := sourceEdge.Properties.Get(testID).String() + require.Nil(t, err) + + if targetRel, err := tx.Relationships().Filterf(func() graph.Criteria { + return query.Equals(query.RelationshipProperty(testID), id) + }).First(); err != nil { + t.Fatalf("Could not find migrated relationship with '%s' == %s", testID, id) + } else { + require.Equal(t, sourceEdge.Kind, targetRel.Kind) + } + } + + // check kinds + targetKinds, err := pg_query.On(tx).SelectKinds() + require.Nil(t, err) + + for _, kind := range append(sourceNodeKinds, sourceEdgeKinds...) { + require.NotNil(t, targetKinds[kind]) + } + + return nil + }) + require.Nil(t, err) + }) +} + +func setupTestMigrator(t *testing.T, ctx context.Context, graphDB graph.Database) *tools.PGMigrator { + var ( + schema = graphschema.DefaultGraphSchema() + dbSwitch = graph.NewDatabaseSwitch(ctx, graphDB) + ) + + cfg, err := utils.LoadIntegrationTestConfig() + require.Nil(t, err) + + return tools.NewPGMigrator(ctx, cfg, schema, dbSwitch) +} diff --git a/cmd/api/src/test/integration/harnesses.go b/cmd/api/src/test/integration/harnesses.go index 932732655..106943bad 100644 --- a/cmd/api/src/test/integration/harnesses.go +++ b/cmd/api/src/test/integration/harnesses.go @@ -6425,6 +6425,48 @@ func (s *ESC4ECA) Setup(graphTestContext *GraphTestContext) { graphTestContext.NewRelationship(s.Computer7, s.CertTemplate7, ad.GenericAll) } +// Use this to set our custom test property in the migration harness +type Property string + +func (s Property) String() string { + return string(s) +} + +type DBMigrateHarness struct { + Group1 *graph.Node + Computer1 *graph.Node + User1 *graph.Node + GenericAll1 *graph.Relationship + HasSession1 *graph.Relationship + MemberOf1 *graph.Relationship + TestID Property +} + +func (s *DBMigrateHarness) Setup(graphTestContext *GraphTestContext) { + sid := RandomDomainSID() + s.TestID = "testing_id" + + s.Group1 = graphTestContext.NewActiveDirectoryGroup("Group1", sid) + s.Computer1 = graphTestContext.NewActiveDirectoryComputer("Computer1", sid) + s.User1 = graphTestContext.NewActiveDirectoryUser("User1", sid, false) + s.Group1.Properties.Set(s.TestID.String(), RandomObjectID(graphTestContext.testCtx)) + s.Computer1.Properties.Set(s.TestID.String(), RandomObjectID(graphTestContext.testCtx)) + s.User1.Properties.Set(s.TestID.String(), RandomObjectID(graphTestContext.testCtx)) + graphTestContext.UpdateNode(s.Group1) + graphTestContext.UpdateNode(s.Computer1) + graphTestContext.UpdateNode(s.User1) + + s.GenericAll1 = graphTestContext.NewRelationship(s.Group1, s.Computer1, ad.GenericAll, graph.AsProperties(graph.PropertyMap{ + s.TestID: RandomObjectID(graphTestContext.testCtx), + })) + s.HasSession1 = graphTestContext.NewRelationship(s.Computer1, s.User1, ad.HasSession, graph.AsProperties(graph.PropertyMap{ + s.TestID: RandomObjectID(graphTestContext.testCtx), + })) + s.MemberOf1 = graphTestContext.NewRelationship(s.User1, s.Group1, ad.MemberOf, graph.AsProperties(graph.PropertyMap{ + s.TestID: RandomObjectID(graphTestContext.testCtx), + })) +} + type ESC13Harness1 struct { CertTemplate1 *graph.Node CertTemplate2 *graph.Node @@ -8585,6 +8627,7 @@ type HarnessDetails struct { ESC4Template3 ESC4Template3 ESC4Template4 ESC4Template4 ESC4ECA ESC4ECA + DBMigrateHarness DBMigrateHarness ESC13Harness1 ESC13Harness1 ESC13Harness2 ESC13Harness2 ESC13HarnessECA ESC13HarnessECA