Skip to content

Commit

Permalink
Tracing support added for MongoDB database
Browse files Browse the repository at this point in the history
  • Loading branch information
Umang01-hash authored Oct 4, 2024
1 parent 6b307b5 commit 47069d0
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 40 deletions.
117 changes: 90 additions & 27 deletions pkg/gofr/datasource/mongo/mongo.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"time"

"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"

"go.mongodb.org/mongo-driver/bson"
Expand Down Expand Up @@ -49,7 +50,7 @@ i.e. by default observability features gets initialised when used with GoFr.
// client := New(config)
// client.UseLogger(loggerInstance)
// client.UseMetrics(metricsInstance)
// client.Connect()
// client.Connect().
func New(c Config) *Client {
return &Client{config: c}
}
Expand Down Expand Up @@ -101,28 +102,36 @@ func (c *Client) Connect() {

// InsertOne inserts a single document into the specified collection.
func (c *Client) InsertOne(ctx context.Context, collection string, document interface{}) (interface{}, error) {
defer c.sendOperationStats(&QueryLog{Query: "insertOne", Collection: collection, Filter: document}, time.Now())
tracerCtx, span := c.addTrace(ctx, "insertOne", collection)

return c.Database.Collection(collection).InsertOne(ctx, document)
result, err := c.Database.Collection(collection).InsertOne(tracerCtx, document)

defer c.sendOperationStats(&QueryLog{Query: "insertOne", Collection: collection, Filter: document}, time.Now(),
"insert", span)

return result, err
}

// InsertMany inserts multiple documents into the specified collection.
func (c *Client) InsertMany(ctx context.Context, collection string, documents []interface{}) ([]interface{}, error) {
defer c.sendOperationStats(&QueryLog{Query: "insertMany", Collection: collection, Filter: documents}, time.Now())
tracerCtx, span := c.addTrace(ctx, "insertMany", collection)

res, err := c.Database.Collection(collection).InsertMany(ctx, documents)
res, err := c.Database.Collection(collection).InsertMany(tracerCtx, documents)
if err != nil {
return nil, err
}

defer c.sendOperationStats(&QueryLog{Query: "insertMany", Collection: collection, Filter: documents}, time.Now(),
"insertMany", span)

return res.InsertedIDs, nil
}

// Find retrieves documents from the specified collection based on the provided filter and binds response to result.
func (c *Client) Find(ctx context.Context, collection string, filter, results interface{}) error {
defer c.sendOperationStats(&QueryLog{Query: "find", Collection: collection, Filter: filter}, time.Now())
tracerCtx, span := c.addTrace(ctx, "find", collection)

cur, err := c.Database.Collection(collection).Find(ctx, filter)
cur, err := c.Database.Collection(collection).Find(tracerCtx, filter)
if err != nil {
return err
}
Expand All @@ -133,94 +142,129 @@ func (c *Client) Find(ctx context.Context, collection string, filter, results in
return err
}

defer c.sendOperationStats(&QueryLog{Query: "find", Collection: collection, Filter: filter}, time.Now(), "find",
span)

return nil
}

// FindOne retrieves a single document from the specified collection based on the provided filter and binds response to result.
func (c *Client) FindOne(ctx context.Context, collection string, filter, result interface{}) error {
defer c.sendOperationStats(&QueryLog{Query: "findOne", Collection: collection, Filter: filter}, time.Now())
tracerCtx, span := c.addTrace(ctx, "findOne", collection)

b, err := c.Database.Collection(collection).FindOne(ctx, filter).Raw()
b, err := c.Database.Collection(collection).FindOne(tracerCtx, filter).Raw()
if err != nil {
return err
}

defer c.sendOperationStats(&QueryLog{Query: "findOne", Collection: collection, Filter: filter}, time.Now(),
"findOne", span)

return bson.Unmarshal(b, result)
}

// UpdateByID updates a document in the specified collection by its ID.
func (c *Client) UpdateByID(ctx context.Context, collection string, id, update interface{}) (int64, error) {
defer c.sendOperationStats(&QueryLog{Query: "updateByID", Collection: collection, ID: id, Update: update}, time.Now())
tracerCtx, span := c.addTrace(ctx, "updateByID", collection)

res, err := c.Database.Collection(collection).UpdateByID(tracerCtx, id, update)

res, err := c.Database.Collection(collection).UpdateByID(ctx, id, update)
defer c.sendOperationStats(&QueryLog{Query: "updateByID", Collection: collection, ID: id, Update: update}, time.Now(),
"updateByID", span)

return res.ModifiedCount, err
}

// UpdateOne updates a single document in the specified collection based on the provided filter.
func (c *Client) UpdateOne(ctx context.Context, collection string, filter, update interface{}) error {
defer c.sendOperationStats(&QueryLog{Query: "updateOne", Collection: collection, Filter: filter, Update: update}, time.Now())
tracerCtx, span := c.addTrace(ctx, "updateOne", collection)

_, err := c.Database.Collection(collection).UpdateOne(tracerCtx, filter, update)

_, err := c.Database.Collection(collection).UpdateOne(ctx, filter, update)
defer c.sendOperationStats(&QueryLog{Query: "updateOne", Collection: collection, Filter: filter, Update: update},
time.Now(), "updateOne", span)

return err
}

// UpdateMany updates multiple documents in the specified collection based on the provided filter.
func (c *Client) UpdateMany(ctx context.Context, collection string, filter, update interface{}) (int64, error) {
defer c.sendOperationStats(&QueryLog{Query: "updateMany", Collection: collection, Filter: filter, Update: update}, time.Now())
tracerCtx, span := c.addTrace(ctx, "updateMany", collection)

res, err := c.Database.Collection(collection).UpdateMany(ctx, filter, update)
res, err := c.Database.Collection(collection).UpdateMany(tracerCtx, filter, update)

defer c.sendOperationStats(&QueryLog{Query: "updateMany", Collection: collection, Filter: filter, Update: update}, time.Now(),
"updateMany", span)

return res.ModifiedCount, err
}

// CountDocuments counts the number of documents in the specified collection based on the provided filter.
func (c *Client) CountDocuments(ctx context.Context, collection string, filter interface{}) (int64, error) {
defer c.sendOperationStats(&QueryLog{Query: "countDocuments", Collection: collection, Filter: filter}, time.Now())
tracerCtx, span := c.addTrace(ctx, "countDocuments", collection)

result, err := c.Database.Collection(collection).CountDocuments(tracerCtx, filter)

defer c.sendOperationStats(&QueryLog{Query: "countDocuments", Collection: collection, Filter: filter}, time.Now(),
"countDocuments", span)

return c.Database.Collection(collection).CountDocuments(ctx, filter)
return result, err
}

// DeleteOne deletes a single document from the specified collection based on the provided filter.
func (c *Client) DeleteOne(ctx context.Context, collection string, filter interface{}) (int64, error) {
defer c.sendOperationStats(&QueryLog{Query: "deleteOne", Collection: collection, Filter: filter}, time.Now())
tracerCtx, span := c.addTrace(ctx, "deleteOne", collection)

res, err := c.Database.Collection(collection).DeleteOne(ctx, filter)
res, err := c.Database.Collection(collection).DeleteOne(tracerCtx, filter)
if err != nil {
return 0, err
}

defer c.sendOperationStats(&QueryLog{Query: "deleteOne", Collection: collection, Filter: filter}, time.Now(),
"deleteOne", span)

return res.DeletedCount, nil
}

// DeleteMany deletes multiple documents from the specified collection based on the provided filter.
func (c *Client) DeleteMany(ctx context.Context, collection string, filter interface{}) (int64, error) {
defer c.sendOperationStats(&QueryLog{Query: "deleteMany", Collection: collection, Filter: filter}, time.Now())
tracerCtx, span := c.addTrace(ctx, "deleteMany", collection)

res, err := c.Database.Collection(collection).DeleteMany(ctx, filter)
res, err := c.Database.Collection(collection).DeleteMany(tracerCtx, filter)
if err != nil {
return 0, err
}

defer c.sendOperationStats(&QueryLog{Query: "deleteMany", Collection: collection, Filter: filter}, time.Now(),
"deleteMany", span)

return res.DeletedCount, nil
}

// Drop drops the specified collection from the database.
func (c *Client) Drop(ctx context.Context, collection string) error {
defer c.sendOperationStats(&QueryLog{Query: "drop", Collection: collection}, time.Now())
tracerCtx, span := c.addTrace(ctx, "drop", collection)

err := c.Database.Collection(collection).Drop(tracerCtx)

return c.Database.Collection(collection).Drop(ctx)
defer c.sendOperationStats(&QueryLog{Query: "drop", Collection: collection}, time.Now(), "drop", span)

return err
}

// CreateCollection creates the specified collection in the database.
func (c *Client) CreateCollection(ctx context.Context, name string) error {
defer c.sendOperationStats(&QueryLog{Query: "createCollection", Collection: name}, time.Now())
tracerCtx, span := c.addTrace(ctx, "createCollection", name)

err := c.Database.CreateCollection(tracerCtx, name)

defer c.sendOperationStats(&QueryLog{Query: "createCollection", Collection: name}, time.Now(), "createCollection",
span)

return c.Database.CreateCollection(ctx, name)
return err
}

func (c *Client) sendOperationStats(ql *QueryLog, startTime time.Time) {
func (c *Client) sendOperationStats(ql *QueryLog, startTime time.Time, method string, span trace.Span) {
duration := time.Since(startTime).Milliseconds()

ql.Duration = duration
Expand All @@ -229,6 +273,11 @@ func (c *Client) sendOperationStats(ql *QueryLog, startTime time.Time) {

c.metrics.RecordHistogram(context.Background(), "app_mongo_stats", float64(duration), "hostname", c.uri,
"database", c.database, "type", ql.Query)

if span != nil {
defer span.End()
span.SetAttributes(attribute.Int64(fmt.Sprintf("mongo.%v.duration", method), duration))
}
}

type Health struct {
Expand Down Expand Up @@ -258,7 +307,7 @@ func (c *Client) HealthCheck(ctx context.Context) (any, error) {
}

func (c *Client) StartSession() (interface{}, error) {
defer c.sendOperationStats(&QueryLog{Query: "startSession"}, time.Now())
defer c.sendOperationStats(&QueryLog{Query: "startSession"}, time.Now(), "", nil)

s, err := c.Client().StartSession()
ses := &session{s}
Expand All @@ -280,3 +329,17 @@ type Transaction interface {
CommitTransaction(context.Context) error
EndSession(context.Context)
}

func (c *Client) addTrace(ctx context.Context, method, collection string) (context.Context, trace.Span) {
if c.tracer != nil {
contextWithTrace, span := c.tracer.Start(ctx, fmt.Sprintf("mongodb-%v", method))

span.SetAttributes(
attribute.String("mongo.collection", collection),
)

return contextWithTrace, span
}

return ctx, nil
}
29 changes: 16 additions & 13 deletions pkg/gofr/datasource/mongo/mongo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/integration/mtest"
"go.opentelemetry.io/otel"
"go.uber.org/mock/gomock"
)

Expand Down Expand Up @@ -61,7 +62,7 @@ func Test_InsertCommands(t *testing.T) {
metrics := NewMockMetrics(ctrl)
logger := NewMockLogger(ctrl)

cl := Client{metrics: metrics}
cl := Client{metrics: metrics, tracer: otel.GetTracerProvider().Tracer("gofr-mongo")}

metrics.EXPECT().RecordHistogram(context.Background(), "app_mongo_stats", gomock.Any(), "hostname",
gomock.Any(), "database", gomock.Any(), "type", gomock.Any()).Times(4)
Expand All @@ -79,7 +80,7 @@ func Test_InsertCommands(t *testing.T) {
resp, err := cl.InsertOne(context.Background(), mt.Coll.Name(), doc)

assert.NotNil(t, resp)
assert.Nil(t, err)
assert.NoError(t, err)
})

mt.Run("insertOneError", func(mt *mtest.T) {
Expand Down Expand Up @@ -137,7 +138,7 @@ func Test_CreateCollection(t *testing.T) {
metrics := NewMockMetrics(ctrl)
logger := NewMockLogger(ctrl)

cl := Client{metrics: metrics}
cl := Client{metrics: metrics, tracer: otel.GetTracerProvider().Tracer("gofr-mongo")}

metrics.EXPECT().RecordHistogram(context.Background(), "app_mongo_stats", gomock.Any(), "hostname",
gomock.Any(), "database", gomock.Any(), "type", gomock.Any())
Expand Down Expand Up @@ -166,7 +167,7 @@ func Test_FindMultipleCommands(t *testing.T) {
metrics := NewMockMetrics(ctrl)
logger := NewMockLogger(ctrl)

cl := Client{metrics: metrics}
cl := Client{metrics: metrics, tracer: otel.GetTracerProvider().Tracer("gofr-mongo")}

metrics.EXPECT().RecordHistogram(context.Background(), "app_mongo_stats", gomock.Any(), "hostname",
gomock.Any(), "database", gomock.Any(), "type", gomock.Any()).Times(3)
Expand Down Expand Up @@ -195,7 +196,7 @@ func Test_FindMultipleCommands(t *testing.T) {

err := cl.Find(context.Background(), mt.Coll.Name(), bson.D{{}}, &foundDocuments)

assert.Nil(t, err, "Unexpected error during Find operation")
assert.NoError(t, err, "Unexpected error during Find operation")
})

mt.Run("FindCursorError", func(mt *mtest.T) {
Expand Down Expand Up @@ -240,7 +241,7 @@ func Test_FindOneCommands(t *testing.T) {
metrics := NewMockMetrics(ctrl)
logger := NewMockLogger(ctrl)

cl := Client{metrics: metrics}
cl := Client{metrics: metrics, tracer: otel.GetTracerProvider().Tracer("gofr-mongo")}

metrics.EXPECT().RecordHistogram(context.Background(), "app_mongo_stats", gomock.Any(), "hostname",
gomock.Any(), "database", gomock.Any(), "type", gomock.Any()).Times(2)
Expand Down Expand Up @@ -275,7 +276,7 @@ func Test_FindOneCommands(t *testing.T) {
err := cl.FindOne(context.Background(), mt.Coll.Name(), bson.D{{}}, &foundDocuments)

assert.Equal(t, expectedUser.Name, foundDocuments.Name)
assert.Nil(t, err)
assert.NoError(t, err)
})

mt.Run("FindOneError", func(mt *mtest.T) {
Expand Down Expand Up @@ -307,7 +308,7 @@ func Test_UpdateCommands(t *testing.T) {
metrics := NewMockMetrics(ctrl)
logger := NewMockLogger(ctrl)

cl := Client{metrics: metrics}
cl := Client{metrics: metrics, tracer: otel.GetTracerProvider().Tracer("gofr-mongo")}

metrics.EXPECT().RecordHistogram(context.Background(), "app_mongo_stats", gomock.Any(), "hostname",
gomock.Any(), "database", gomock.Any(), "type", gomock.Any()).Times(3)
Expand Down Expand Up @@ -358,7 +359,7 @@ func Test_CountDocuments(t *testing.T) {
metrics := NewMockMetrics(ctrl)
logger := NewMockLogger(ctrl)

cl := Client{metrics: metrics}
cl := Client{metrics: metrics, tracer: otel.GetTracerProvider().Tracer("gofr-mongo")}

metrics.EXPECT().RecordHistogram(context.Background(), "app_mongo_stats", gomock.Any(), "hostname",
gomock.Any(), "database", gomock.Any(), "type", gomock.Any())
Expand All @@ -379,7 +380,8 @@ func Test_CountDocuments(t *testing.T) {
_, err := indexView.CreateOne(context.Background(), mongo.IndexModel{
Keys: bson.D{{Key: "x", Value: 1}},
})
require.Nil(mt, err, "CreateOne error for index: %v", err)

assert.NoError(mt, err, "CreateOne error for index: %v", err)

resp, err := cl.CountDocuments(context.Background(), mt.Coll.Name(), bson.D{{Key: "name", Value: "test"}})

Expand All @@ -398,7 +400,7 @@ func Test_DeleteCommands(t *testing.T) {
metrics := NewMockMetrics(ctrl)
logger := NewMockLogger(ctrl)

cl := Client{metrics: metrics}
cl := Client{metrics: metrics, tracer: otel.GetTracerProvider().Tracer("gofr-mongo")}

metrics.EXPECT().RecordHistogram(context.Background(), "app_mongo_stats", gomock.Any(), "hostname",
gomock.Any(), "database", gomock.Any(), "type", gomock.Any()).Times(4)
Expand Down Expand Up @@ -466,7 +468,7 @@ func Test_Drop(t *testing.T) {
metrics := NewMockMetrics(ctrl)
logger := NewMockLogger(ctrl)

cl := Client{metrics: metrics}
cl := Client{metrics: metrics, tracer: otel.GetTracerProvider().Tracer("gofr-mongo")}

metrics.EXPECT().RecordHistogram(context.Background(), "app_mongo_stats", gomock.Any(), "hostname",
gomock.Any(), "database", gomock.Any(), "type", gomock.Any())
Expand Down Expand Up @@ -495,7 +497,7 @@ func TestClient_StartSession(t *testing.T) {
metrics := NewMockMetrics(ctrl)
logger := NewMockLogger(ctrl)

cl := Client{metrics: metrics}
cl := Client{metrics: metrics, tracer: otel.GetTracerProvider().Tracer("gofr-mongo")}

// Set up the mock expectation for the metrics recording
metrics.EXPECT().RecordHistogram(gomock.Any(), "app_mongo_stats", gomock.Any(), "hostname",
Expand All @@ -513,6 +515,7 @@ func TestClient_StartSession(t *testing.T) {

// Call the StartSession method
sess, err := cl.StartSession()

ses, ok := sess.(Transaction)
if ok {
err = ses.StartTransaction()
Expand Down
Loading

0 comments on commit 47069d0

Please sign in to comment.