Skip to content

Commit

Permalink
fix encoder/decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
mkideal committed Aug 12, 2024
1 parent f5d7081 commit 3d29d73
Show file tree
Hide file tree
Showing 7 changed files with 207 additions and 73 deletions.
62 changes: 23 additions & 39 deletions component/component.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
package component

import (
"bytes"
"context"
"encoding/json"
"errors"
Expand All @@ -15,7 +14,9 @@ import (
"strings"
"sync"
"sync/atomic"
"unicode"

"github.com/gopherd/core/encoding"
"github.com/gopherd/core/lifecycle"
"github.com/gopherd/core/types"
)
Expand Down Expand Up @@ -66,7 +67,7 @@ type Container interface {
GetComponent(uuid string) Component

// Decoder returns the decoder for decoding component configurations.
Decoder() types.Decoder
Decoder() encoding.Decoder

// Logger returns the logger instance for the container.
Logger() *slog.Logger
Expand Down Expand Up @@ -166,50 +167,33 @@ func (r Reference[T]) MarshalJSON() ([]byte, error) {

// UnmarshalJSON unmarshals the referenced component UUID from JSON.
func (r *Reference[T]) UnmarshalJSON(data []byte) error {
return json.Unmarshal(data, &r.uuid)
if err := json.Unmarshal(data, &r.uuid); err != nil {
return err
}
return r.validate()
}

// MarshalText marshals the referenced component UUID to text.
func (r Reference[T]) MarshalText() ([]byte, error) {
// Use strconv.Quote to properly escape the string
return []byte(strconv.Quote(r.uuid)), nil
// Marshal marshals the referenced component UUID to quoted bytes.
func (r Reference[T]) Marshal() ([]byte, error) {
var buf = make([]byte, 0, len(r.uuid)+len(`""`))
return strconv.AppendQuote(buf, r.uuid), nil
}

// UnmarshalText unmarshals the referenced component UUID from text.
func (r *Reference[T]) UnmarshalText(data []byte) error {
// Trim leading and trailing whitespace
data = bytes.TrimSpace(data)

if len(data) < 2 {
return errors.New("invalid string: too short")
}

switch data[0] {
case '"':
// Basic string (double-quoted)
if data[len(data)-1] != '"' {
return errors.New("invalid string: mismatched quotes")
}
s, err := strconv.Unquote(string(data))
if err != nil {
return err
}
r.uuid = s
case '\'':
// Literal string (single-quoted)
if data[len(data)-1] != '\'' {
return errors.New("invalid string: mismatched quotes")
}
uuid := string(data[1 : len(data)-1])
// Check for illegal newlines in literal string
if strings.Contains(uuid, "\n") {
return errors.New("invalid string: newlines not allowed in literal string")
}
r.uuid = uuid
// Unmarshal unmarshals the referenced component UUID from quoted bytes.
func (r *Reference[T]) Unmarshal(data any) error {
switch v := data.(type) {
case string:
r.uuid = v
return r.validate()
default:
return errors.New("invalid string: must start with ' or \"")
return fmt.Errorf("unexpected type %T for reference UUID", data)
}
}

func (r Reference[T]) validate() error {
if strings.ContainsFunc(r.uuid, unicode.IsSpace) {
return fmt.Errorf("unexpected whitespace in reference UUID: %q", r.uuid)
}
return nil
}

Expand Down
3 changes: 2 additions & 1 deletion component/component_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"testing"

"github.com/gopherd/core/component"
"github.com/gopherd/core/encoding"
"github.com/gopherd/core/op"
"github.com/gopherd/core/types"
)
Expand Down Expand Up @@ -92,7 +93,7 @@ func (c *mockContainer) GetComponent(uuid string) component.Component {
return c.components[uuid]
}

func (c *mockContainer) Decoder() types.Decoder {
func (c *mockContainer) Decoder() encoding.Decoder {
return json.Unmarshal
}

Expand Down
80 changes: 80 additions & 0 deletions encoding/encoding.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Package encoding provides interfaces and utilities for encoding and decoding data.
package encoding

import (
"bytes"
"errors"
"strconv"
"strings"
"unicode/utf8"
)

// ErrInvalidString is returned when the input string is invalid.
var ErrInvalidString = errors.New("invalid string")

// Encoder is a function type that encodes a value into bytes.
type Encoder func(any) ([]byte, error)

// Decoder is a function type that decodes bytes into a provided value.
type Decoder func([]byte, any) error

// Marshaler is an interface for types that can marshal themselves into bytes.
type Marshaler interface {
Marshal() ([]byte, error)
}

// Unmarshaler is an interface for types that can unmarshal bytes into themselves.
type Unmarshaler interface {
Unmarshal([]byte) error
}

// UnmarshalString decodes a string from byte slice data.
// It supports both quoted strings and literal strings.
//
// Parameters:
// - data: The byte slice containing the string to unmarshal.
// - literalChar: The character used for literal strings. Use 0 for no literal strings.
// - allowNewline: Whether newlines are allowed in literal strings.
//
// Returns:
// - The unmarshaled string and nil error if successful.
// - An empty string and an error if unmarshaling fails.
func UnmarshalString(data []byte, literalChar byte, allowNewline bool) (string, error) {
data = bytes.TrimSpace(data)
if len(data) < 2 {
return "", errors.New("string too short")
}

switch data[0] {
case '"':
return unquoteString(data)
case literalChar:
if literalChar == 0 || literalChar == '"' {
break
}
return extractLiteralString(data, literalChar, allowNewline)
}

return "", ErrInvalidString
}

func unquoteString(data []byte) (string, error) {
if data[len(data)-1] != '"' {
return "", errors.New("mismatched quotes")
}
return strconv.Unquote(string(data))
}

func extractLiteralString(data []byte, literalChar byte, allowNewline bool) (string, error) {
if data[len(data)-1] != literalChar {
return "", errors.New("mismatched quotes")
}
str := string(data[1 : len(data)-1])
if !allowNewline && strings.ContainsRune(str, '\n') {
return "", errors.New("newlines not allowed in literal string")
}
if !utf8.ValidString(str) {
return "", errors.New("invalid UTF-8 in string")
}
return str, nil
}
46 changes: 46 additions & 0 deletions encoding/encoding_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package encoding

import (
"testing"
)

func TestUnmarshalString(t *testing.T) {
tests := []struct {
name string
input []byte
literalChar byte
allowNewline bool
want string
wantErr bool
}{
{"Empty input", []byte{}, 0, false, "", true},
{"Single character", []byte{'"'}, 0, false, "", true},
{"Valid quoted string", []byte(`"hello"`), 0, false, "hello", false},
{"Valid literal string", []byte(`'hello'`), '\'', false, "hello", false},
{"Quoted string with escapes", []byte(`"he\"llo"`), 0, false, `he"llo`, false},
{"Literal string with newline allowed", []byte("'hello\nworld'"), '\'', true, "hello\nworld", false},
{"Literal string with newline disallowed", []byte("'hello\nworld'"), '\'', false, "", true},
{"Invalid UTF-8 in literal string", []byte{'\'', 0xFF, '\''}, '\'', false, "", true},
{"Mismatched quotes in quoted string", []byte(`"hello`), 0, false, "", true},
{"Mismatched quotes in literal string", []byte(`'hello`), '\'', false, "", true},
{"Invalid string start", []byte(`hello`), 0, false, "", true},
{"Quoted string with literal char", []byte(`"'hello'"`), '\'', false, "'hello'", false},
{"Literal string with quote char", []byte(`'"hello"'`), '\'', false, `"hello"`, false},
{"Literal char is quote", []byte(`"hello"`), '"', false, "hello", false},
{"Whitespace before valid string", []byte(" 'hello'"), '\'', false, "hello", false},
{"Whitespace after valid string", []byte("'hello' "), '\'', false, "hello", false},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := UnmarshalString(tt.input, tt.literalChar, tt.allowNewline)
if (err != nil) != tt.wantErr {
t.Errorf("UnmarshalString() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("UnmarshalString() = %v, want %v", got, tt.want)
}
})
}
}
6 changes: 3 additions & 3 deletions service/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ import (
"time"

"github.com/gopherd/core/component"
"github.com/gopherd/core/encoding"
"github.com/gopherd/core/op"
"github.com/gopherd/core/text/templateutil"
"github.com/gopherd/core/types"
)

// Config represents a generic configuration structure for services.
Expand All @@ -27,7 +27,7 @@ type Config[T any] struct {

// load processes the configuration based on the provided source.
// It returns an error if the configuration cannot be loaded or decoded.
func (c *Config[T]) load(decoder types.Decoder, source string, isJSONC bool) error {
func (c *Config[T]) load(decoder encoding.Decoder, source string, isJSONC bool) error {
if source == "" {
return nil
}
Expand Down Expand Up @@ -139,7 +139,7 @@ func (c *Config[T]) processTemplate(enableTemplate bool, source string) error {

// output encodes the configuration with the encoder and writes it to stdout.
// It uses indentation for better readability.
func (c Config[T]) output(encoder types.Encoder) {
func (c Config[T]) output(encoder encoding.Encoder) {
if data, err := encoder(c); err != nil {
fmt.Fprintf(os.Stderr, "Encode config failed: %v\n", err)
} else {
Expand Down
22 changes: 11 additions & 11 deletions service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ import (
"github.com/gopherd/core/builder"
"github.com/gopherd/core/component"
"github.com/gopherd/core/container/pair"
"github.com/gopherd/core/encoding"
"github.com/gopherd/core/errkit"
"github.com/gopherd/core/lifecycle"
"github.com/gopherd/core/types"
)

// Service represents a process with lifecycle management and component handling capabilities.
Expand All @@ -42,8 +42,8 @@ type BaseService[T any] struct {
}
versionFunc func()
stderr io.Writer
encoder types.Encoder
decoder types.Decoder
encoder encoding.Encoder
decoder encoding.Decoder
isJSONC bool

config Config[T]
Expand Down Expand Up @@ -75,22 +75,22 @@ func (s *BaseService[T]) GetComponent(uuid string) component.Component {
}

// Encoder returns the encoder function for the service.
func (s *BaseService[T]) Encoder() types.Encoder {
func (s *BaseService[T]) Encoder() encoding.Encoder {
return s.encoder
}

// Decoder returns the decoder function for the service.
func (s *BaseService[T]) Decoder() types.Decoder {
func (s *BaseService[T]) Decoder() encoding.Decoder {
return s.decoder
}

// SetEncoder sets the encoder functions for the service.
func (s *BaseService[T]) SetEncoder(encoder types.Encoder) {
func (s *BaseService[T]) SetEncoder(encoder encoding.Encoder) {
s.encoder = encoder
}

// SetDecoder sets the decoder functions for the service.
func (s *BaseService[T]) SetDecoder(decoder types.Decoder) {
func (s *BaseService[T]) SetDecoder(decoder encoding.Decoder) {
s.decoder = decoder
s.isJSONC = false
}
Expand Down Expand Up @@ -250,8 +250,8 @@ func (s *BaseService[T]) Shutdown(ctx context.Context) error {
}

type runOptions struct {
encoder types.Encoder
decoder types.Decoder
encoder encoding.Encoder
decoder encoding.Decoder
}

// apply applies the options to the given options.
Expand All @@ -265,14 +265,14 @@ func (o *runOptions) apply(opts []RunOption) {
type RunOption func(*runOptions)

// WithEncoder sets the encoder function for the Run function.
func WithEncoder(encoder types.Encoder) RunOption {
func WithEncoder(encoder encoding.Encoder) RunOption {
return func(o *runOptions) {
o.encoder = encoder
}
}

// WithDecoder sets the decoder function for the Run function.
func WithDecoder(decoder types.Decoder) RunOption {
func WithDecoder(decoder encoding.Decoder) RunOption {
return func(o *runOptions) {
o.decoder = decoder
}
Expand Down
Loading

0 comments on commit 3d29d73

Please sign in to comment.