Skip to content

Commit

Permalink
Some refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
olbrichattila committed Jul 14, 2024
1 parent ba34976 commit c6cd4a4
Showing 1 changed file with 83 additions and 68 deletions.
151 changes: 83 additions & 68 deletions godi.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,119 +9,134 @@ import (
)

var (
// ErrCannotBeResolved returned when the container not able to resolve the dependency, not mapped with container.Set()
// ErrCannotBeResolved is returned when the container is not able to resolve the dependency.
ErrCannotBeResolved = errors.New("the DI parameter cannot be resolved")
// ErrCannotBeResolvedPossibleNeedExport returned when the container not able to resolve the dependency, autowire possiblie received a a non exported field
// ErrCannotBeResolvedPossibleNeedExport is returned when the container is not able to resolve the dependency, possibly due to an unexported field.
ErrCannotBeResolvedPossibleNeedExport = errors.New("the DI parameter cannot be resolved, possible unexported field for autowire notation")
// ErrCircularReference returned when the dependencies would end up in a forever loop. instead golang blowing up, it returns an error.
// ErrCircularReference is returned when there is a circular dependency reference.
ErrCircularReference = errors.New("circular reference")
)

// dependencyMap stores the dependency.
type dependencyMap struct {
dependency interface{}
}

// New creates a new dependency injector container
// New creates a new dependency injector container.
func New() *Cont {
c := &Cont{}
c.dependencies = make(map[string]*dependencyMap)
return c
return &Cont{
dependencies: make(map[string]*dependencyMap),
}
}

// Cont is the container returned by New
// Cont is the container returned by New.
type Cont struct {
callStack map[string]bool
dependencies map[string]*dependencyMap
}

// Set new dependency, provide a "packagePath.InterfaceName" as a string, and your dependency, which should always be an interface or struct
// Set registers a new dependency. Provide a "packagePath.InterfaceName" as a string and your dependency, which should be an interface or struct.
func (t *Cont) Set(paramName string, dependency interface{}) {
t.dependencies[paramName] = &dependencyMap{dependency: dependency}
}

// Get resolves dependencies. Use a Construct func with your dependency interface type hints. They will be resolved recursively
// Get resolves dependencies. Use a construct function with your dependency interface type hints. They will be resolved recursively.
func (t *Cont) Get(obj interface{}) (interface{}, error) {
callStack := make(map[string]bool)
return t.getRecursive(obj, callStack)
t.callStack = make(map[string]bool)
return t.getRecursive(obj)
}

// getRecursive resolves dependencies recursively, tracking call stack to detect circular references
func (t *Cont) getRecursive(obj interface{}, callStack map[string]bool) (interface{}, error) {
// getRecursive resolves dependencies recursively, tracking call stack to detect circular references.
func (t *Cont) getRecursive(obj interface{}) (interface{}, error) {
v := reflect.ValueOf(obj)

if v.Kind() == reflect.Ptr && v.Elem().Kind() == reflect.Struct {
rt := reflect.TypeOf(obj)

// Resolve constructor
method, found := rt.MethodByName("Construct")
if found {
passParams := []reflect.Value{v}
methodType := method.Type
numParams := methodType.NumIn()

for i := 1; i < numParams; i++ {
paramType := methodType.In(i)
param, fullTypeName, err := t.resolve(paramType)
if err != nil {
return nil, err
}
if callStack[fullTypeName] {
return nil, errors.Join(ErrCircularReference, fmt.Errorf("circular call: %s", fullTypeName))
}
callStack[fullTypeName] = true

_, err = t.getRecursive(param, callStack)
delete(callStack, fullTypeName)
if err != nil {
return nil, err
}
passParams = append(passParams, reflect.ValueOf(param))
}
rt := v.Type()

method.Func.Call(passParams)
if err := t.resolveConstructor(v, rt); err != nil {
return nil, err
}

// Resole autowire
vTyp := v.Elem().Type()
for i := 0; i < v.Elem().NumField(); i++ {
if err := t.resolveAutoWire(v); err != nil {
return nil, err
}
}

field := vTyp.Field(i)
tag := field.Tag.Get("di")
if tag == "autowire" {
return obj, nil
}

resolvedField := v.Elem().FieldByName(field.Name)
// resolveConstructor resolves dependencies for the constructor method.
func (t *Cont) resolveConstructor(v reflect.Value, rt reflect.Type) error {
if method, found := rt.MethodByName("Construct"); found {
passParams := []reflect.Value{v}
methodType := method.Type
numParams := methodType.NumIn()

for i := 1; i < numParams; i++ {
paramType := methodType.In(i)
param, fullTypeName, err := t.resolveConstructorParam(paramType)
if err != nil {
return err
}
if t.callStack[fullTypeName] {
return fmt.Errorf("%w: circular call: %s", ErrCircularReference, fullTypeName)
}
t.callStack[fullTypeName] = true

if !resolvedField.CanSet() {
return nil, errors.Join(ErrCannotBeResolvedPossibleNeedExport, fmt.Errorf("the field name: %s", field.Name))
}
if _, err := t.getRecursive(param); err != nil {
delete(t.callStack, fullTypeName)
return err
}
delete(t.callStack, fullTypeName)
passParams = append(passParams, reflect.ValueOf(param))
}

value, _, err := t.resolve(field.Type)
if err != nil {
return nil, err
}
method.Func.Call(passParams)
}

_, err = t.getRecursive(value, callStack)
if err != nil {
return nil, err
}
return nil
}

fieldValue := reflect.ValueOf(value)
// resolveAutoWire resolves dependencies for struct fields with the "di" tag.
func (t *Cont) resolveAutoWire(v reflect.Value) error {
vTyp := v.Elem().Type()
for i := 0; i < v.Elem().NumField(); i++ {
field := vTyp.Field(i)
tag := field.Tag.Get("di")
if tag == "autowire" {
resolvedField := v.Elem().FieldByName(field.Name)

if !resolvedField.CanSet() {
return fmt.Errorf("%w: the field name: %s", ErrCannotBeResolvedPossibleNeedExport, field.Name)
}

resolvedField.Set(fieldValue)
value, _, err := t.resolveConstructorParam(field.Type)
if err != nil {
return err
}
}

}
if _, err := t.getRecursive(value); err != nil {
return err
}

return obj, nil
fieldValue := reflect.ValueOf(value)
if field.Type.Kind() == reflect.Interface && !fieldValue.Type().Implements(field.Type) {
return fmt.Errorf("provided value does not implement the field's interface: %s", field.Name)
}

resolvedField.Set(fieldValue)
}
}
return nil
}

// resolve finds and returns the dependency, checking for circular references
func (t *Cont) resolve(paramType reflect.Type) (interface{}, string, error) {
// resolveConstructorParam resolves a constructor parameter by its type.
func (t *Cont) resolveConstructorParam(paramType reflect.Type) (interface{}, string, error) {
pkgPath := paramType.PkgPath() + "/" + paramType.Name()
fullTypeName := strings.Join(strings.Split(pkgPath, "/")[1:], ".")
param, ok := t.dependencies[fullTypeName]
if !ok {
return nil, fullTypeName, errors.Join(ErrCannotBeResolved, fmt.Errorf("dependency name: %s", fullTypeName))
return nil, fullTypeName, fmt.Errorf("%w: dependency name: %s", ErrCannotBeResolved, fullTypeName)
}

return param.dependency, fullTypeName, nil
Expand Down

0 comments on commit c6cd4a4

Please sign in to comment.