Skip to content

Commit

Permalink
Store BURNTSUSHI_TOML_110 in parser and lexer
Browse files Browse the repository at this point in the history
Setting a global is racy when multiple decodes are run in parallel.

Fixes #395
  • Loading branch information
arp242 committed Jun 8, 2023
1 parent d4c441a commit b324da5
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 29 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"uses": "actions/checkout@v3"
}, {
"name": "Test",
"run": "go test ./..."
"run": "go test -race ./..."
}, {
"name": "Test on 32bit",
"if": "runner.os == 'Linux'",
Expand Down
21 changes: 21 additions & 0 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"reflect"
"strconv"
"strings"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -1201,6 +1202,26 @@ func TestMetaKeys(t *testing.T) {
}
}

func TestDecodeParallel(t *testing.T) {
doc, err := os.ReadFile("testdata/ja-JP.toml")
if err != nil {
t.Fatal(err)
}

var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
err := Unmarshal(doc, new(map[string]interface{}))
if err != nil {
t.Fatal(err)
}
}()
}
wg.Wait()
}

// errorContains checks if the error message in have contains the text in
// want.
//
Expand Down
40 changes: 21 additions & 19 deletions lex.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,13 @@ func (p Position) String() string {
}

type lexer struct {
input string
start int
pos int
line int
state stateFn
items chan item
input string
start int
pos int
line int
state stateFn
items chan item
tomlNext bool

// Allow for backing up up to 4 runes. This is necessary because TOML
// contains 3-rune tokens (""" and ''').
Expand Down Expand Up @@ -87,13 +88,14 @@ func (lx *lexer) nextItem() item {
}
}

func lex(input string) *lexer {
func lex(input string, tomlNext bool) *lexer {
lx := &lexer{
input: input,
state: lexTop,
items: make(chan item, 10),
stack: make([]stateFn, 0, 10),
line: 1,
input: input,
state: lexTop,
items: make(chan item, 10),
stack: make([]stateFn, 0, 10),
line: 1,
tomlNext: tomlNext,
}
return lx
}
Expand Down Expand Up @@ -408,7 +410,7 @@ func lexTableNameEnd(lx *lexer) stateFn {
// Lexes only one part, e.g. only 'a' inside 'a.b'.
func lexBareName(lx *lexer) stateFn {
r := lx.next()
if isBareKeyChar(r) {
if isBareKeyChar(r, lx.tomlNext) {
return lexBareName
}
lx.backup()
Expand Down Expand Up @@ -618,7 +620,7 @@ func lexInlineTableValue(lx *lexer) stateFn {
case isWhitespace(r):
return lexSkip(lx, lexInlineTableValue)
case isNL(r):
if tomlNext {
if lx.tomlNext {
return lexSkip(lx, lexInlineTableValue)
}
return lx.errorPrevLine(errLexInlineTableNL{})
Expand All @@ -643,7 +645,7 @@ func lexInlineTableValueEnd(lx *lexer) stateFn {
case isWhitespace(r):
return lexSkip(lx, lexInlineTableValueEnd)
case isNL(r):
if tomlNext {
if lx.tomlNext {
return lexSkip(lx, lexInlineTableValueEnd)
}
return lx.errorPrevLine(errLexInlineTableNL{})
Expand All @@ -654,7 +656,7 @@ func lexInlineTableValueEnd(lx *lexer) stateFn {
lx.ignore()
lx.skip(isWhitespace)
if lx.peek() == '}' {
if tomlNext {
if lx.tomlNext {
return lexInlineTableValueEnd
}
return lx.errorf("trailing comma not allowed in inline tables")
Expand Down Expand Up @@ -838,7 +840,7 @@ func lexStringEscape(lx *lexer) stateFn {
r := lx.next()
switch r {
case 'e':
if !tomlNext {
if !lx.tomlNext {
return lx.error(errLexEscape{r})
}
fallthrough
Expand All @@ -861,7 +863,7 @@ func lexStringEscape(lx *lexer) stateFn {
case '\\':
return lx.pop()
case 'x':
if !tomlNext {
if !lx.tomlNext {
return lx.error(errLexEscape{r})
}
return lexHexEscape
Expand Down Expand Up @@ -1258,7 +1260,7 @@ func isHexadecimal(r rune) bool {
return (r >= '0' && r <= '9') || (r >= 'a' && r <= 'f') || (r >= 'A' && r <= 'F')
}

func isBareKeyChar(r rune) bool {
func isBareKeyChar(r rune, tomlNext bool) bool {
if tomlNext {
return (r >= 'A' && r <= 'Z') ||
(r >= 'a' && r <= 'z') ||
Expand Down
2 changes: 1 addition & 1 deletion meta.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func (k Key) maybeQuoted(i int) string {
return `""`
}
for _, c := range k[i] {
if !isBareKeyChar(c) {
if !isBareKeyChar(c, false) {
return `"` + dblQuotedReplacer.Replace(k[i]) + `"`
}
}
Expand Down
15 changes: 7 additions & 8 deletions parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@ import (
"github.com/BurntSushi/toml/internal"
)

var tomlNext bool

type parser struct {
lx *lexer
context Key // Full key for the current hash in scope.
currentKey string // Base key name for everything except hashes.
pos Position // Current position in the TOML file.
tomlNext bool

ordered []Key // List of keys in the order that they appear in the TOML data.

Expand All @@ -32,8 +31,7 @@ type keyInfo struct {
}

func parse(data string) (p *parser, err error) {
_, ok := os.LookupEnv("BURNTSUSHI_TOML_110")
tomlNext = ok
_, tomlNext := os.LookupEnv("BURNTSUSHI_TOML_110")

defer func() {
if r := recover(); r != nil {
Expand Down Expand Up @@ -74,9 +72,10 @@ func parse(data string) (p *parser, err error) {
p = &parser{
keyInfo: make(map[string]keyInfo),
mapping: make(map[string]interface{}),
lx: lex(data),
lx: lex(data, tomlNext),
ordered: make([]Key, 0),
implicits: make(map[string]struct{}),
tomlNext: tomlNext,
}
for {
item := p.next()
Expand Down Expand Up @@ -361,7 +360,7 @@ func (p *parser) valueDatetime(it item) (interface{}, tomlType) {
err error
)
for _, dt := range dtTypes {
if dt.next && !tomlNext {
if dt.next && !p.tomlNext {
continue
}
t, err = time.ParseInLocation(dt.fmt, it.val, dt.zone)
Expand Down Expand Up @@ -764,7 +763,7 @@ func (p *parser) replaceEscapes(it item, str string) string {
replaced = append(replaced, rune(0x000D))
r += 1
case 'e':
if tomlNext {
if p.tomlNext {
replaced = append(replaced, rune(0x001B))
r += 1
}
Expand All @@ -775,7 +774,7 @@ func (p *parser) replaceEscapes(it item, str string) string {
replaced = append(replaced, rune(0x005C))
r += 1
case 'x':
if tomlNext {
if p.tomlNext {
escaped := p.asciiEscapeToUnicode(it, s[r+1:r+3])
replaced = append(replaced, escaped)
r += 3
Expand Down

0 comments on commit b324da5

Please sign in to comment.