-
Notifications
You must be signed in to change notification settings - Fork 85
/
Copy pathdsn.go
149 lines (133 loc) · 4.95 KB
/
dsn.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
package ydb
import (
"errors"
"fmt"
"regexp"
"strings"
"github.com/ydb-platform/ydb-go-sdk/v3/balancers"
"github.com/ydb-platform/ydb-go-sdk/v3/credentials"
"github.com/ydb-platform/ydb-go-sdk/v3/internal/bind"
"github.com/ydb-platform/ydb-go-sdk/v3/internal/dsn"
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql"
)
const tablePathPrefixTransformer = "table_path_prefix"
var dsnParsers = []func(dsn string) (opts []Option, _ error){
func(dsn string) ([]Option, error) {
opts, err := parseConnectionString(dsn)
if err != nil {
return nil, xerrors.WithStackTrace(err)
}
return opts, nil
},
}
// RegisterDsnParser registers DSN parser for ydb.Open and sql.Open driver constructors
//
// Experimental: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#experimental
func RegisterDsnParser(parser func(dsn string) (opts []Option, _ error)) (registrationID int) {
dsnParsers = append(dsnParsers, parser)
return len(dsnParsers) - 1
}
// UnregisterDsnParser unregisters DSN parser by key
//
// Experimental: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#experimental
func UnregisterDsnParser(registrationID int) {
dsnParsers[registrationID] = nil
}
var stringToType = map[string]QueryMode{
"data": DataQueryMode,
"scan": ScanQueryMode,
"scheme": SchemeQueryMode,
"scripting": ScriptingQueryMode,
"query": QueryExecuteQueryMode,
}
func queryModeFromString(s string) QueryMode {
if t, ok := stringToType[s]; ok {
return t
}
return unknownQueryMode
}
//nolint:funlen
func parseConnectionString(dataSourceName string) (opts []Option, _ error) {
info, err := dsn.Parse(dataSourceName)
if err != nil {
return nil, xerrors.WithStackTrace(err)
}
opts = append(opts, With(info.Options...))
if token := info.Params.Get("token"); token != "" {
opts = append(opts, WithCredentials(credentials.NewAccessTokenCredentials(token)))
}
if balancer := info.Params.Get("go_balancer"); balancer != "" {
opts = append(opts, WithBalancer(balancers.FromConfig(balancer)))
} else if balancer := info.Params.Get("balancer"); balancer != "" {
opts = append(opts, WithBalancer(balancers.FromConfig(balancer)))
}
if queryMode := info.Params.Get("go_query_mode"); queryMode != "" {
switch mode := queryModeFromString(queryMode); mode {
case QueryExecuteQueryMode:
opts = append(opts, withConnectorOptions(xsql.WithQueryService(true)))
case unknownQueryMode:
return nil, xerrors.WithStackTrace(fmt.Errorf("unknown query mode: %s", queryMode))
default:
opts = append(opts, withConnectorOptions(xsql.WithDefaultQueryMode(modeToMode(mode))))
}
} else if queryMode := info.Params.Get("query_mode"); queryMode != "" {
switch mode := queryModeFromString(queryMode); mode {
case QueryExecuteQueryMode:
opts = append(opts, withConnectorOptions(xsql.WithQueryService(true)))
case unknownQueryMode:
return nil, xerrors.WithStackTrace(fmt.Errorf("unknown query mode: %s", queryMode))
default:
opts = append(opts, withConnectorOptions(xsql.WithDefaultQueryMode(modeToMode(mode))))
}
}
if fakeTx := info.Params.Get("go_fake_tx"); fakeTx != "" {
for _, queryMode := range strings.Split(fakeTx, ",") {
switch mode := queryModeFromString(queryMode); mode {
case unknownQueryMode:
return nil, xerrors.WithStackTrace(fmt.Errorf("unknown query mode: %s", queryMode))
default:
opts = append(opts, withConnectorOptions(WithFakeTx(mode)))
}
}
}
if info.Params.Has("go_query_bind") {
var binders []xsql.Option
queryTransformers := strings.Split(info.Params.Get("go_query_bind"), ",")
for _, transformer := range queryTransformers {
switch transformer {
case "declare":
binders = append(binders, xsql.WithQueryBind(bind.AutoDeclare{}))
case "positional":
binders = append(binders, xsql.WithQueryBind(bind.PositionalArgs{}))
case "numeric":
binders = append(binders, xsql.WithQueryBind(bind.NumericArgs{}))
default:
if strings.HasPrefix(transformer, tablePathPrefixTransformer) {
prefix, err := extractTablePathPrefixFromBinderName(transformer)
if err != nil {
return nil, xerrors.WithStackTrace(err)
}
binders = append(binders, xsql.WithQueryBind(bind.TablePathPrefix(prefix)))
} else {
return nil, xerrors.WithStackTrace(
fmt.Errorf("unknown query rewriter: %s", transformer),
)
}
}
}
opts = append(opts, withConnectorOptions(binders...))
}
return opts, nil
}
var (
tablePathPrefixRe = regexp.MustCompile(tablePathPrefixTransformer + "\\((.*)\\)")
errWrongTablePathPrefix = errors.New("wrong '" + tablePathPrefixTransformer + "' query transformer")
)
func extractTablePathPrefixFromBinderName(binderName string) (string, error) {
ss := tablePathPrefixRe.FindAllStringSubmatch(binderName, -1)
if len(ss) != 1 || len(ss[0]) != 2 || ss[0][1] == "" {
return "", xerrors.WithStackTrace(fmt.Errorf("%w: %s", errWrongTablePathPrefix, binderName))
}
return ss[0][1], nil
}