diff --git a/cmd/tracegen/model.go b/cmd/tracegen/model.go index 8cf4fae..85e02ac 100644 --- a/cmd/tracegen/model.go +++ b/cmd/tracegen/model.go @@ -166,22 +166,57 @@ func (b *tracingMethodBuilder) Build() ast.Decl { }, }) - // If the first parameter is context, add tracing call. + // If the parameters have a context.Context or http.Request + // add tracing, based on the first-found element. + // // ctx, span := trace.StartSpan(ctx, "github.com/pkg.Component.Method") // defer span.End() - if len(b.methodConfig.MethodParams) > 0 { - p1 := b.methodConfig.MethodParams[0] - if sel, ok := p1.Type.(*ast.SelectorExpr); ok { + // + // or + // + // // presuming that is of type *http.Request + // ctx := .Context() + // ctx, span := trace.StartSpan(ctx, "github.com/pkg.Component.Method") + // defer span.End() + ctxFound := false + for _, param := range b.methodConfig.MethodParams { + if ctxFound { + break + } + switch sel := param.Type.(type) { + // context.Context + case *ast.SelectorExpr: if sel.Sel.String() == "Context" { if id, ok := sel.X.(*ast.Ident); ok && id.String() == b.contextPackageAlias { b.method.AddStatement( newTraceMethodInvocation(b.tracePackageAlias, - p1.Names[0].Name, b.fullMethodName)) + param.Names[0].Name, b.fullMethodName)) b.method.AddStatement(newEndSpanStmt()) - + ctxFound = true + } + } + // *http.Request + case *ast.StarExpr: + if concreteType, ok := sel.X.(*ast.SelectorExpr); ok { + // TODO: Make sure that the request is from 'net/http'. + if concreteType.Sel.String() == "Request" { + ctxName, ctxStmt := contextFromRequest(param.Names[0].Name) + b.method.AddStatement(ctxStmt) + b.method.AddStatement( + newTraceMethodInvocation( + b.tracePackageAlias, + ctxName, + b.fullMethodName, + ), + ) + b.method.AddStatement(newEndSpanStmt()) + b.method.AddStatement(requestWithContext(param.Names[0].Name, ctxName)) + ctxFound = true } } + default: + continue } } @@ -197,6 +232,70 @@ func (b *tracingMethodBuilder) Build() ast.Decl { return b.method.Build() } +// contextFromRequest returns a new statement that defines +// a new "ctx" variable in the scope, with its value set to +// .Context(). +// +// The function also returns the name of the variable. +// +// i.e. +// +// // presuming that argName is of type *http.Request. +// ctx := .Context() +// +func contextFromRequest(argName string) (string, ast.Stmt) { + callExpr := &ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent(argName), + Sel: ast.NewIdent("Context"), + }, + } + + const varName = "ctx" + vars := []ast.Expr{ + ast.NewIdent(varName), + } + + return varName, &ast.AssignStmt{ + Lhs: vars, + Tok: token.DEFINE, + Rhs: []ast.Expr{ + callExpr, + }, + } +} + +// requestWithContext creates a re-assignment statement for +// reqName with the context set to ctxName. +// +// i.e: +// +// = .WithContext() +func requestWithContext(reqName, ctxName string) ast.Stmt { + callExpr := &ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent(reqName), + Sel: ast.NewIdent("WithContext"), + }, + Args: []ast.Expr{ + ast.NewIdent(ctxName), + }, + } + + const varName = "ctx" + vars := []ast.Expr{ + ast.NewIdent(reqName), + } + + return &ast.AssignStmt{ + Lhs: vars, + Tok: token.ASSIGN, + Rhs: []ast.Expr{ + callExpr, + }, + } +} + func newEndSpanStmt() ast.Stmt { callExpr := &ast.CallExpr{ Fun: &ast.SelectorExpr{