diff --git a/go/ql/test/query-tests/Security/CWE-918/go.mod b/go/ql/test/query-tests/Security/CWE-918/go.mod index ce6c493a190d..5d81e787137a 100644 --- a/go/ql/test/query-tests/Security/CWE-918/go.mod +++ b/go/ql/test/query-tests/Security/CWE-918/go.mod @@ -5,7 +5,8 @@ go 1.14 require ( github.com/gobwas/ws v1.0.3 github.com/gorilla/websocket v1.4.2 + github.com/sacOO7/go-logger v0.0.0-20180719173527-9ac9add5a50d // indirect github.com/sacOO7/gowebsocket v0.0.0-20180719182212-1436bb906a4e - golang.org/x/net v0.0.0-20200421231249-e086a090c8fd + golang.org/x/net v0.7.0 nhooyr.io/websocket v1.8.5 ) diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/httphead/LICENSE b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/httphead/LICENSE new file mode 100644 index 000000000000..274431766fa1 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/httphead/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2017 Sergey Kamardin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/httphead/README.md b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/httphead/README.md new file mode 100644 index 000000000000..67a97fdbe926 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/httphead/README.md @@ -0,0 +1,63 @@ +# httphead.[go](https://golang.org) + +[![GoDoc][godoc-image]][godoc-url] + +> Tiny HTTP header value parsing library in go. + +## Overview + +This library contains low-level functions for scanning HTTP RFC2616 compatible header value grammars. + +## Install + +```shell + go get github.com/gobwas/httphead +``` + +## Example + +The example below shows how multiple-choise HTTP header value could be parsed with this library: + +```go + options, ok := httphead.ParseOptions([]byte(`foo;bar=1,baz`), nil) + fmt.Println(options, ok) + // Output: [{foo map[bar:1]} {baz map[]}] true +``` + +The low-level example below shows how to optimize keys skipping and selection +of some key: + +```go + // The right part of full header line like: + // X-My-Header: key;foo=bar;baz,key;baz + header := []byte(`foo;a=0,foo;a=1,foo;a=2,foo;a=3`) + + // We want to search key "foo" with an "a" parameter that equal to "2". + var ( + foo = []byte(`foo`) + a = []byte(`a`) + v = []byte(`2`) + ) + var found bool + httphead.ScanOptions(header, func(i int, key, param, value []byte) Control { + if !bytes.Equal(key, foo) { + return ControlSkip + } + if !bytes.Equal(param, a) { + if bytes.Equal(value, v) { + // Found it! + found = true + return ControlBreak + } + return ControlSkip + } + return ControlContinue + }) +``` + +For more usage examples please see [docs][godoc-url] or package tests. + +[godoc-image]: https://godoc.org/github.com/gobwas/httphead?status.svg +[godoc-url]: https://godoc.org/github.com/gobwas/httphead +[travis-image]: https://travis-ci.org/gobwas/httphead.svg?branch=master +[travis-url]: https://travis-ci.org/gobwas/httphead diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/httphead/cookie.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/httphead/cookie.go new file mode 100644 index 000000000000..05c9a1fb6a16 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/httphead/cookie.go @@ -0,0 +1,200 @@ +package httphead + +import ( + "bytes" +) + +// ScanCookie scans cookie pairs from data using DefaultCookieScanner.Scan() +// method. +func ScanCookie(data []byte, it func(key, value []byte) bool) bool { + return DefaultCookieScanner.Scan(data, it) +} + +// DefaultCookieScanner is a CookieScanner which is used by ScanCookie(). +// Note that it is intended to have the same behavior as http.Request.Cookies() +// has. +var DefaultCookieScanner = CookieScanner{} + +// CookieScanner contains options for scanning cookie pairs. +// See https://tools.ietf.org/html/rfc6265#section-4.1.1 +type CookieScanner struct { + // DisableNameValidation disables name validation of a cookie. If false, + // only RFC2616 "tokens" are accepted. + DisableNameValidation bool + + // DisableValueValidation disables value validation of a cookie. If false, + // only RFC6265 "cookie-octet" characters are accepted. + // + // Note that Strict option also affects validation of a value. + // + // If Strict is false, then scanner begins to allow space and comma + // characters inside the value for better compatibility with non standard + // cookies implementations. + DisableValueValidation bool + + // BreakOnPairError sets scanner to immediately return after first pair syntax + // validation error. + // If false, scanner will try to skip invalid pair bytes and go ahead. + BreakOnPairError bool + + // Strict enables strict RFC6265 mode scanning. It affects name and value + // validation, as also some other rules. + // If false, it is intended to bring the same behavior as + // http.Request.Cookies(). + Strict bool +} + +// Scan maps data to name and value pairs. Usually data represents value of the +// Cookie header. +func (c CookieScanner) Scan(data []byte, it func(name, value []byte) bool) bool { + lexer := &Scanner{data: data} + + const ( + statePair = iota + stateBefore + ) + + state := statePair + + for lexer.Buffered() > 0 { + switch state { + case stateBefore: + // Pairs separated by ";" and space, according to the RFC6265: + // cookie-pair *( ";" SP cookie-pair ) + // + // Cookie pairs MUST be separated by (";" SP). So our only option + // here is to fail as syntax error. + a, b := lexer.Peek2() + if a != ';' { + return false + } + + state = statePair + + advance := 1 + if b == ' ' { + advance++ + } else if c.Strict { + return false + } + + lexer.Advance(advance) + + case statePair: + if !lexer.FetchUntil(';') { + return false + } + + var value []byte + name := lexer.Bytes() + if i := bytes.IndexByte(name, '='); i != -1 { + value = name[i+1:] + name = name[:i] + } else if c.Strict { + if !c.BreakOnPairError { + goto nextPair + } + return false + } + + if !c.Strict { + trimLeft(name) + } + if !c.DisableNameValidation && !ValidCookieName(name) { + if !c.BreakOnPairError { + goto nextPair + } + return false + } + + if !c.Strict { + value = trimRight(value) + } + value = stripQuotes(value) + if !c.DisableValueValidation && !ValidCookieValue(value, c.Strict) { + if !c.BreakOnPairError { + goto nextPair + } + return false + } + + if !it(name, value) { + return true + } + + nextPair: + state = stateBefore + } + } + + return true +} + +// ValidCookieValue reports whether given value is a valid RFC6265 +// "cookie-octet" bytes. +// +// cookie-octet = %x21 / %x23-2B / %x2D-3A / %x3C-5B / %x5D-7E +// ; US-ASCII characters excluding CTLs, +// ; whitespace DQUOTE, comma, semicolon, +// ; and backslash +// +// Note that the false strict parameter disables errors on space 0x20 and comma +// 0x2c. This could be useful to bring some compatibility with non-compliant +// clients/servers in the real world. +// It acts the same as standard library cookie parser if strict is false. +func ValidCookieValue(value []byte, strict bool) bool { + if len(value) == 0 { + return true + } + for _, c := range value { + switch c { + case '"', ';', '\\': + return false + case ',', ' ': + if strict { + return false + } + default: + if c <= 0x20 { + return false + } + if c >= 0x7f { + return false + } + } + } + return true +} + +// ValidCookieName reports wheter given bytes is a valid RFC2616 "token" bytes. +func ValidCookieName(name []byte) bool { + for _, c := range name { + if !OctetTypes[c].IsToken() { + return false + } + } + return true +} + +func stripQuotes(bts []byte) []byte { + if last := len(bts) - 1; last > 0 && bts[0] == '"' && bts[last] == '"' { + return bts[1:last] + } + return bts +} + +func trimLeft(p []byte) []byte { + var i int + for i < len(p) && OctetTypes[p[i]].IsSpace() { + i++ + } + return p[i:] +} + +func trimRight(p []byte) []byte { + j := len(p) + for j > 0 && OctetTypes[p[j-1]].IsSpace() { + j-- + } + return p[:j] +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/httphead/head.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/httphead/head.go new file mode 100644 index 000000000000..a50e907dd18e --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/httphead/head.go @@ -0,0 +1,275 @@ +package httphead + +import ( + "bufio" + "bytes" +) + +// Version contains protocol major and minor version. +type Version struct { + Major int + Minor int +} + +// RequestLine contains parameters parsed from the first request line. +type RequestLine struct { + Method []byte + URI []byte + Version Version +} + +// ResponseLine contains parameters parsed from the first response line. +type ResponseLine struct { + Version Version + Status int + Reason []byte +} + +// SplitRequestLine splits given slice of bytes into three chunks without +// parsing. +func SplitRequestLine(line []byte) (method, uri, version []byte) { + return split3(line, ' ') +} + +// ParseRequestLine parses http request line like "GET / HTTP/1.0". +func ParseRequestLine(line []byte) (r RequestLine, ok bool) { + var i int + for i = 0; i < len(line); i++ { + c := line[i] + if !OctetTypes[c].IsToken() { + if i > 0 && c == ' ' { + break + } + return + } + } + if i == len(line) { + return + } + + var proto []byte + r.Method = line[:i] + r.URI, proto = split2(line[i+1:], ' ') + if len(r.URI) == 0 { + return + } + if major, minor, ok := ParseVersion(proto); ok { + r.Version.Major = major + r.Version.Minor = minor + return r, true + } + + return r, false +} + +// SplitResponseLine splits given slice of bytes into three chunks without +// parsing. +func SplitResponseLine(line []byte) (version, status, reason []byte) { + return split3(line, ' ') +} + +// ParseResponseLine parses first response line into ResponseLine struct. +func ParseResponseLine(line []byte) (r ResponseLine, ok bool) { + var ( + proto []byte + status []byte + ) + proto, status, r.Reason = split3(line, ' ') + if major, minor, ok := ParseVersion(proto); ok { + r.Version.Major = major + r.Version.Minor = minor + } else { + return r, false + } + if n, ok := IntFromASCII(status); ok { + r.Status = n + } else { + return r, false + } + // TODO(gobwas): parse here r.Reason fot TEXT rule: + // TEXT = + return r, true +} + +var ( + httpVersion10 = []byte("HTTP/1.0") + httpVersion11 = []byte("HTTP/1.1") + httpVersionPrefix = []byte("HTTP/") +) + +// ParseVersion parses major and minor version of HTTP protocol. +// It returns parsed values and true if parse is ok. +func ParseVersion(bts []byte) (major, minor int, ok bool) { + switch { + case bytes.Equal(bts, httpVersion11): + return 1, 1, true + case bytes.Equal(bts, httpVersion10): + return 1, 0, true + case len(bts) < 8: + return + case !bytes.Equal(bts[:5], httpVersionPrefix): + return + } + + bts = bts[5:] + + dot := bytes.IndexByte(bts, '.') + if dot == -1 { + return + } + major, ok = IntFromASCII(bts[:dot]) + if !ok { + return + } + minor, ok = IntFromASCII(bts[dot+1:]) + if !ok { + return + } + + return major, minor, true +} + +// ReadLine reads line from br. It reads until '\n' and returns bytes without +// '\n' or '\r\n' at the end. +// It returns err if and only if line does not end in '\n'. Note that read +// bytes returned in any case of error. +// +// It is much like the textproto/Reader.ReadLine() except the thing that it +// returns raw bytes, instead of string. That is, it avoids copying bytes read +// from br. +// +// textproto/Reader.ReadLineBytes() is also makes copy of resulting bytes to be +// safe with future I/O operations on br. +// +// We could control I/O operations on br and do not need to make additional +// copy for safety. +func ReadLine(br *bufio.Reader) ([]byte, error) { + var line []byte + for { + bts, err := br.ReadSlice('\n') + if err == bufio.ErrBufferFull { + // Copy bytes because next read will discard them. + line = append(line, bts...) + continue + } + // Avoid copy of single read. + if line == nil { + line = bts + } else { + line = append(line, bts...) + } + if err != nil { + return line, err + } + // Size of line is at least 1. + // In other case bufio.ReadSlice() returns error. + n := len(line) + // Cut '\n' or '\r\n'. + if n > 1 && line[n-2] == '\r' { + line = line[:n-2] + } else { + line = line[:n-1] + } + return line, nil + } +} + +// ParseHeaderLine parses HTTP header as key-value pair. It returns parsed +// values and true if parse is ok. +func ParseHeaderLine(line []byte) (k, v []byte, ok bool) { + colon := bytes.IndexByte(line, ':') + if colon == -1 { + return + } + k = trim(line[:colon]) + for _, c := range k { + if !OctetTypes[c].IsToken() { + return nil, nil, false + } + } + v = trim(line[colon+1:]) + return k, v, true +} + +// IntFromASCII converts ascii encoded decimal numeric value from HTTP entities +// to an integer. +func IntFromASCII(bts []byte) (ret int, ok bool) { + // ASCII numbers all start with the high-order bits 0011. + // If you see that, and the next bits are 0-9 (0000 - 1001) you can grab those + // bits and interpret them directly as an integer. + var n int + if n = len(bts); n < 1 { + return 0, false + } + for i := 0; i < n; i++ { + if bts[i]&0xf0 != 0x30 { + return 0, false + } + ret += int(bts[i]&0xf) * pow(10, n-i-1) + } + return ret, true +} + +const ( + toLower = 'a' - 'A' // for use with OR. + toUpper = ^byte(toLower) // for use with AND. +) + +// CanonicalizeHeaderKey is like standard textproto/CanonicalMIMEHeaderKey, +// except that it operates with slice of bytes and modifies it inplace without +// copying. +func CanonicalizeHeaderKey(k []byte) { + upper := true + for i, c := range k { + if upper && 'a' <= c && c <= 'z' { + k[i] &= toUpper + } else if !upper && 'A' <= c && c <= 'Z' { + k[i] |= toLower + } + upper = c == '-' + } +} + +// pow for integers implementation. +// See Donald Knuth, The Art of Computer Programming, Volume 2, Section 4.6.3 +func pow(a, b int) int { + p := 1 + for b > 0 { + if b&1 != 0 { + p *= a + } + b >>= 1 + a *= a + } + return p +} + +func split3(p []byte, sep byte) (p1, p2, p3 []byte) { + a := bytes.IndexByte(p, sep) + b := bytes.IndexByte(p[a+1:], sep) + if a == -1 || b == -1 { + return p, nil, nil + } + b += a + 1 + return p[:a], p[a+1 : b], p[b+1:] +} + +func split2(p []byte, sep byte) (p1, p2 []byte) { + i := bytes.IndexByte(p, sep) + if i == -1 { + return p, nil + } + return p[:i], p[i+1:] +} + +func trim(p []byte) []byte { + var i, j int + for i = 0; i < len(p) && (p[i] == ' ' || p[i] == '\t'); { + i++ + } + for j = len(p); j > i && (p[j-1] == ' ' || p[j-1] == '\t'); { + j-- + } + return p[i:j] +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/httphead/httphead.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/httphead/httphead.go new file mode 100644 index 000000000000..2387e8033c94 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/httphead/httphead.go @@ -0,0 +1,331 @@ +// Package httphead contains utils for parsing HTTP and HTTP-grammar compatible +// text protocols headers. +// +// That is, this package first aim is to bring ability to easily parse +// constructions, described here https://tools.ietf.org/html/rfc2616#section-2 +package httphead + +import ( + "bytes" + "strings" +) + +// ScanTokens parses data in this form: +// +// list = 1#token +// +// It returns false if data is malformed. +func ScanTokens(data []byte, it func([]byte) bool) bool { + lexer := &Scanner{data: data} + + var ok bool + for lexer.Next() { + switch lexer.Type() { + case ItemToken: + ok = true + if !it(lexer.Bytes()) { + return true + } + case ItemSeparator: + if !isComma(lexer.Bytes()) { + return false + } + default: + return false + } + } + + return ok && !lexer.err +} + +// ParseOptions parses all header options and appends it to given slice of +// Option. It returns flag of successful (wellformed input) parsing. +// +// Note that appended options are all consist of subslices of data. That is, +// mutation of data will mutate appended options. +func ParseOptions(data []byte, options []Option) ([]Option, bool) { + var i int + index := -1 + return options, ScanOptions(data, func(idx int, name, attr, val []byte) Control { + if idx != index { + index = idx + i = len(options) + options = append(options, Option{Name: name}) + } + if attr != nil { + options[i].Parameters.Set(attr, val) + } + return ControlContinue + }) +} + +// SelectFlag encodes way of options selection. +type SelectFlag byte + +// String represetns flag as string. +func (f SelectFlag) String() string { + var flags [2]string + var n int + if f&SelectCopy != 0 { + flags[n] = "copy" + n++ + } + if f&SelectUnique != 0 { + flags[n] = "unique" + n++ + } + return "[" + strings.Join(flags[:n], "|") + "]" +} + +const ( + // SelectCopy causes selector to copy selected option before appending it + // to resulting slice. + // If SelectCopy flag is not passed to selector, then appended options will + // contain sub-slices of the initial data. + SelectCopy SelectFlag = 1 << iota + + // SelectUnique causes selector to append only not yet existing option to + // resulting slice. Unique is checked by comparing option names. + SelectUnique +) + +// OptionSelector contains configuration for selecting Options from header value. +type OptionSelector struct { + // Check is a filter function that applied to every Option that possibly + // could be selected. + // If Check is nil all options will be selected. + Check func(Option) bool + + // Flags contains flags for options selection. + Flags SelectFlag + + // Alloc used to allocate slice of bytes when selector is configured with + // SelectCopy flag. It will be called with number of bytes needed for copy + // of single Option. + // If Alloc is nil make is used. + Alloc func(n int) []byte +} + +// Select parses header data and appends it to given slice of Option. +// It also returns flag of successful (wellformed input) parsing. +func (s OptionSelector) Select(data []byte, options []Option) ([]Option, bool) { + var current Option + var has bool + index := -1 + + alloc := s.Alloc + if alloc == nil { + alloc = defaultAlloc + } + check := s.Check + if check == nil { + check = defaultCheck + } + + ok := ScanOptions(data, func(idx int, name, attr, val []byte) Control { + if idx != index { + if has && check(current) { + if s.Flags&SelectCopy != 0 { + current = current.Copy(alloc(current.Size())) + } + options = append(options, current) + has = false + } + if s.Flags&SelectUnique != 0 { + for i := len(options) - 1; i >= 0; i-- { + if bytes.Equal(options[i].Name, name) { + return ControlSkip + } + } + } + index = idx + current = Option{Name: name} + has = true + } + if attr != nil { + current.Parameters.Set(attr, val) + } + + return ControlContinue + }) + if has && check(current) { + if s.Flags&SelectCopy != 0 { + current = current.Copy(alloc(current.Size())) + } + options = append(options, current) + } + + return options, ok +} + +func defaultAlloc(n int) []byte { return make([]byte, n) } +func defaultCheck(Option) bool { return true } + +// Control represents operation that scanner should perform. +type Control byte + +const ( + // ControlContinue causes scanner to continue scan tokens. + ControlContinue Control = iota + // ControlBreak causes scanner to stop scan tokens. + ControlBreak + // ControlSkip causes scanner to skip current entity. + ControlSkip +) + +// ScanOptions parses data in this form: +// +// values = 1#value +// value = token *( ";" param ) +// param = token [ "=" (token | quoted-string) ] +// +// It calls given callback with the index of the option, option itself and its +// parameter (attribute and its value, both could be nil). Index is useful when +// header contains multiple choises for the same named option. +// +// Given callback should return one of the defined Control* values. +// ControlSkip means that passed key is not in caller's interest. That is, all +// parameters of that key will be skipped. +// ControlBreak means that no more keys and parameters should be parsed. That +// is, it must break parsing immediately. +// ControlContinue means that caller want to receive next parameter and its +// value or the next key. +// +// It returns false if data is malformed. +func ScanOptions(data []byte, it func(index int, option, attribute, value []byte) Control) bool { + lexer := &Scanner{data: data} + + var ok bool + var state int + const ( + stateKey = iota + stateParamBeforeName + stateParamName + stateParamBeforeValue + stateParamValue + ) + + var ( + index int + key, param, value []byte + mustCall bool + ) + for lexer.Next() { + var ( + call bool + growIndex int + ) + + t := lexer.Type() + v := lexer.Bytes() + + switch t { + case ItemToken: + switch state { + case stateKey, stateParamBeforeName: + key = v + state = stateParamBeforeName + mustCall = true + case stateParamName: + param = v + state = stateParamBeforeValue + mustCall = true + case stateParamValue: + value = v + state = stateParamBeforeName + call = true + default: + return false + } + + case ItemString: + if state != stateParamValue { + return false + } + value = v + state = stateParamBeforeName + call = true + + case ItemSeparator: + switch { + case isComma(v) && state == stateKey: + // Nothing to do. + + case isComma(v) && state == stateParamBeforeName: + state = stateKey + // Make call only if we have not called this key yet. + call = mustCall + if !call { + // If we have already called callback with the key + // that just ended. + index++ + } else { + // Else grow the index after calling callback. + growIndex = 1 + } + + case isComma(v) && state == stateParamBeforeValue: + state = stateKey + growIndex = 1 + call = true + + case isSemicolon(v) && state == stateParamBeforeName: + state = stateParamName + + case isSemicolon(v) && state == stateParamBeforeValue: + state = stateParamName + call = true + + case isEquality(v) && state == stateParamBeforeValue: + state = stateParamValue + + default: + return false + } + + default: + return false + } + + if call { + switch it(index, key, param, value) { + case ControlBreak: + // User want to stop to parsing parameters. + return true + + case ControlSkip: + // User want to skip current param. + state = stateKey + lexer.SkipEscaped(',') + + case ControlContinue: + // User is interested in rest of parameters. + // Nothing to do. + + default: + panic("unexpected control value") + } + ok = true + param = nil + value = nil + mustCall = false + index += growIndex + } + } + if mustCall { + ok = true + it(index, key, param, value) + } + + return ok && !lexer.err +} + +func isComma(b []byte) bool { + return len(b) == 1 && b[0] == ',' +} +func isSemicolon(b []byte) bool { + return len(b) == 1 && b[0] == ';' +} +func isEquality(b []byte) bool { + return len(b) == 1 && b[0] == '=' +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/httphead/lexer.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/httphead/lexer.go new file mode 100644 index 000000000000..729855ed0d31 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/httphead/lexer.go @@ -0,0 +1,360 @@ +package httphead + +import ( + "bytes" +) + +// ItemType encodes type of the lexing token. +type ItemType int + +const ( + // ItemUndef reports that token is undefined. + ItemUndef ItemType = iota + // ItemToken reports that token is RFC2616 token. + ItemToken + // ItemSeparator reports that token is RFC2616 separator. + ItemSeparator + // ItemString reports that token is RFC2616 quouted string. + ItemString + // ItemComment reports that token is RFC2616 comment. + ItemComment + // ItemOctet reports that token is octet slice. + ItemOctet +) + +// Scanner represents header tokens scanner. +// See https://tools.ietf.org/html/rfc2616#section-2 +type Scanner struct { + data []byte + pos int + + itemType ItemType + itemBytes []byte + + err bool +} + +// NewScanner creates new RFC2616 data scanner. +func NewScanner(data []byte) *Scanner { + return &Scanner{data: data} +} + +// Next scans for next token. It returns true on successful scanning, and false +// on error or EOF. +func (l *Scanner) Next() bool { + c, ok := l.nextChar() + if !ok { + return false + } + switch c { + case '"': // quoted-string; + return l.fetchQuotedString() + + case '(': // comment; + return l.fetchComment() + + case '\\', ')': // unexpected chars; + l.err = true + return false + + default: + return l.fetchToken() + } +} + +// FetchUntil fetches ItemOctet from current scanner position to first +// occurence of the c or to the end of the underlying data. +func (l *Scanner) FetchUntil(c byte) bool { + l.resetItem() + if l.pos == len(l.data) { + return false + } + return l.fetchOctet(c) +} + +// Peek reads byte at current position without advancing it. On end of data it +// returns 0. +func (l *Scanner) Peek() byte { + if l.pos == len(l.data) { + return 0 + } + return l.data[l.pos] +} + +// Peek2 reads two first bytes at current position without advancing it. +// If there not enough data it returs 0. +func (l *Scanner) Peek2() (a, b byte) { + if l.pos == len(l.data) { + return 0, 0 + } + if l.pos+1 == len(l.data) { + return l.data[l.pos], 0 + } + return l.data[l.pos], l.data[l.pos+1] +} + +// Buffered reporst how many bytes there are left to scan. +func (l *Scanner) Buffered() int { + return len(l.data) - l.pos +} + +// Advance moves current position index at n bytes. It returns true on +// successful move. +func (l *Scanner) Advance(n int) bool { + l.pos += n + if l.pos > len(l.data) { + l.pos = len(l.data) + return false + } + return true +} + +// Skip skips all bytes until first occurence of c. +func (l *Scanner) Skip(c byte) { + if l.err { + return + } + // Reset scanner state. + l.resetItem() + + if i := bytes.IndexByte(l.data[l.pos:], c); i == -1 { + // Reached the end of data. + l.pos = len(l.data) + } else { + l.pos += i + 1 + } +} + +// SkipEscaped skips all bytes until first occurence of non-escaped c. +func (l *Scanner) SkipEscaped(c byte) { + if l.err { + return + } + // Reset scanner state. + l.resetItem() + + if i := ScanUntil(l.data[l.pos:], c); i == -1 { + // Reached the end of data. + l.pos = len(l.data) + } else { + l.pos += i + 1 + } +} + +// Type reports current token type. +func (l *Scanner) Type() ItemType { + return l.itemType +} + +// Bytes returns current token bytes. +func (l *Scanner) Bytes() []byte { + return l.itemBytes +} + +func (l *Scanner) nextChar() (byte, bool) { + // Reset scanner state. + l.resetItem() + + if l.err { + return 0, false + } + l.pos += SkipSpace(l.data[l.pos:]) + if l.pos == len(l.data) { + return 0, false + } + return l.data[l.pos], true +} + +func (l *Scanner) resetItem() { + l.itemType = ItemUndef + l.itemBytes = nil +} + +func (l *Scanner) fetchOctet(c byte) bool { + i := l.pos + if j := bytes.IndexByte(l.data[l.pos:], c); j == -1 { + // Reached the end of data. + l.pos = len(l.data) + } else { + l.pos += j + } + + l.itemType = ItemOctet + l.itemBytes = l.data[i:l.pos] + + return true +} + +func (l *Scanner) fetchToken() bool { + n, t := ScanToken(l.data[l.pos:]) + if n == -1 { + l.err = true + return false + } + + l.itemType = t + l.itemBytes = l.data[l.pos : l.pos+n] + l.pos += n + + return true +} + +func (l *Scanner) fetchQuotedString() (ok bool) { + l.pos++ + + n := ScanUntil(l.data[l.pos:], '"') + if n == -1 { + l.err = true + return false + } + + l.itemType = ItemString + l.itemBytes = RemoveByte(l.data[l.pos:l.pos+n], '\\') + l.pos += n + 1 + + return true +} + +func (l *Scanner) fetchComment() (ok bool) { + l.pos++ + + n := ScanPairGreedy(l.data[l.pos:], '(', ')') + if n == -1 { + l.err = true + return false + } + + l.itemType = ItemComment + l.itemBytes = RemoveByte(l.data[l.pos:l.pos+n], '\\') + l.pos += n + 1 + + return true +} + +// ScanUntil scans for first non-escaped character c in given data. +// It returns index of matched c and -1 if c is not found. +func ScanUntil(data []byte, c byte) (n int) { + for { + i := bytes.IndexByte(data[n:], c) + if i == -1 { + return -1 + } + n += i + if n == 0 || data[n-1] != '\\' { + break + } + n++ + } + return +} + +// ScanPairGreedy scans for complete pair of opening and closing chars in greedy manner. +// Note that first opening byte must not be present in data. +func ScanPairGreedy(data []byte, open, close byte) (n int) { + var m int + opened := 1 + for { + i := bytes.IndexByte(data[n:], close) + if i == -1 { + return -1 + } + n += i + // If found index is not escaped then it is the end. + if n == 0 || data[n-1] != '\\' { + opened-- + } + + for m < i { + j := bytes.IndexByte(data[m:i], open) + if j == -1 { + break + } + m += j + 1 + opened++ + } + + if opened == 0 { + break + } + + n++ + m = n + } + return +} + +// RemoveByte returns data without c. If c is not present in data it returns +// the same slice. If not, it copies data without c. +func RemoveByte(data []byte, c byte) []byte { + j := bytes.IndexByte(data, c) + if j == -1 { + return data + } + + n := len(data) - 1 + + // If character is present, than allocate slice with n-1 capacity. That is, + // resulting bytes could be at most n-1 length. + result := make([]byte, n) + k := copy(result, data[:j]) + + for i := j + 1; i < n; { + j = bytes.IndexByte(data[i:], c) + if j != -1 { + k += copy(result[k:], data[i:i+j]) + i = i + j + 1 + } else { + k += copy(result[k:], data[i:]) + break + } + } + + return result[:k] +} + +// SkipSpace skips spaces and lws-sequences from p. +// It returns number ob bytes skipped. +func SkipSpace(p []byte) (n int) { + for len(p) > 0 { + switch { + case len(p) >= 3 && + p[0] == '\r' && + p[1] == '\n' && + OctetTypes[p[2]].IsSpace(): + p = p[3:] + n += 3 + case OctetTypes[p[0]].IsSpace(): + p = p[1:] + n++ + default: + return + } + } + return +} + +// ScanToken scan for next token in p. It returns length of the token and its +// type. It do not trim p. +func ScanToken(p []byte) (n int, t ItemType) { + if len(p) == 0 { + return 0, ItemUndef + } + + c := p[0] + switch { + case OctetTypes[c].IsSeparator(): + return 1, ItemSeparator + + case OctetTypes[c].IsToken(): + for n = 1; n < len(p); n++ { + c := p[n] + if !OctetTypes[c].IsToken() { + break + } + } + return n, ItemToken + + default: + return -1, ItemUndef + } +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/httphead/octet.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/httphead/octet.go new file mode 100644 index 000000000000..2a04cdd0909f --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/httphead/octet.go @@ -0,0 +1,83 @@ +package httphead + +// OctetType desribes character type. +// +// From the "Basic Rules" chapter of RFC2616 +// See https://tools.ietf.org/html/rfc2616#section-2.2 +// +// OCTET = +// CHAR = +// UPALPHA = +// LOALPHA = +// ALPHA = UPALPHA | LOALPHA +// DIGIT = +// CTL = +// CR = +// LF = +// SP = +// HT = +// <"> = +// CRLF = CR LF +// LWS = [CRLF] 1*( SP | HT ) +// +// Many HTTP/1.1 header field values consist of words separated by LWS +// or special characters. These special characters MUST be in a quoted +// string to be used within a parameter value (as defined in section +// 3.6). +// +// token = 1* +// separators = "(" | ")" | "<" | ">" | "@" +// | "," | ";" | ":" | "\" | <"> +// | "/" | "[" | "]" | "?" | "=" +// | "{" | "}" | SP | HT +type OctetType byte + +// IsChar reports whether octet is CHAR. +func (t OctetType) IsChar() bool { return t&octetChar != 0 } + +// IsControl reports whether octet is CTL. +func (t OctetType) IsControl() bool { return t&octetControl != 0 } + +// IsSeparator reports whether octet is separator. +func (t OctetType) IsSeparator() bool { return t&octetSeparator != 0 } + +// IsSpace reports whether octet is space (SP or HT). +func (t OctetType) IsSpace() bool { return t&octetSpace != 0 } + +// IsToken reports whether octet is token. +func (t OctetType) IsToken() bool { return t&octetToken != 0 } + +const ( + octetChar OctetType = 1 << iota + octetControl + octetSpace + octetSeparator + octetToken +) + +// OctetTypes is a table of octets. +var OctetTypes [256]OctetType + +func init() { + for c := 32; c < 256; c++ { + var t OctetType + if c <= 127 { + t |= octetChar + } + if 0 <= c && c <= 31 || c == 127 { + t |= octetControl + } + switch c { + case '(', ')', '<', '>', '@', ',', ';', ':', '"', '/', '[', ']', '?', '=', '{', '}', '\\': + t |= octetSeparator + case ' ', '\t': + t |= octetSpace | octetSeparator + } + + if t.IsChar() && !t.IsControl() && !t.IsSeparator() && !t.IsSpace() { + t |= octetToken + } + + OctetTypes[c] = t + } +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/httphead/option.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/httphead/option.go new file mode 100644 index 000000000000..243be08c9a03 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/httphead/option.go @@ -0,0 +1,187 @@ +package httphead + +import ( + "bytes" + "sort" +) + +// Option represents a header option. +type Option struct { + Name []byte + Parameters Parameters +} + +// Size returns number of bytes need to be allocated for use in opt.Copy. +func (opt Option) Size() int { + return len(opt.Name) + opt.Parameters.bytes +} + +// Copy copies all underlying []byte slices into p and returns new Option. +// Note that p must be at least of opt.Size() length. +func (opt Option) Copy(p []byte) Option { + n := copy(p, opt.Name) + opt.Name = p[:n] + opt.Parameters, p = opt.Parameters.Copy(p[n:]) + return opt +} + +// String represents option as a string. +func (opt Option) String() string { + return "{" + string(opt.Name) + " " + opt.Parameters.String() + "}" +} + +// NewOption creates named option with given parameters. +func NewOption(name string, params map[string]string) Option { + p := Parameters{} + for k, v := range params { + p.Set([]byte(k), []byte(v)) + } + return Option{ + Name: []byte(name), + Parameters: p, + } +} + +// Equal reports whether option is equal to b. +func (opt Option) Equal(b Option) bool { + if bytes.Equal(opt.Name, b.Name) { + return opt.Parameters.Equal(b.Parameters) + } + return false +} + +// Parameters represents option's parameters. +type Parameters struct { + pos int + bytes int + arr [8]pair + dyn []pair +} + +// Equal reports whether a equal to b. +func (p Parameters) Equal(b Parameters) bool { + switch { + case p.dyn == nil && b.dyn == nil: + case p.dyn != nil && b.dyn != nil: + default: + return false + } + + ad, bd := p.data(), b.data() + if len(ad) != len(bd) { + return false + } + + sort.Sort(pairs(ad)) + sort.Sort(pairs(bd)) + + for i := 0; i < len(ad); i++ { + av, bv := ad[i], bd[i] + if !bytes.Equal(av.key, bv.key) || !bytes.Equal(av.value, bv.value) { + return false + } + } + return true +} + +// Size returns number of bytes that needed to copy p. +func (p *Parameters) Size() int { + return p.bytes +} + +// Copy copies all underlying []byte slices into dst and returns new +// Parameters. +// Note that dst must be at least of p.Size() length. +func (p *Parameters) Copy(dst []byte) (Parameters, []byte) { + ret := Parameters{ + pos: p.pos, + bytes: p.bytes, + } + if p.dyn != nil { + ret.dyn = make([]pair, len(p.dyn)) + for i, v := range p.dyn { + ret.dyn[i], dst = v.copy(dst) + } + } else { + for i, p := range p.arr { + ret.arr[i], dst = p.copy(dst) + } + } + return ret, dst +} + +// Get returns value by key and flag about existence such value. +func (p *Parameters) Get(key string) (value []byte, ok bool) { + for _, v := range p.data() { + if string(v.key) == key { + return v.value, true + } + } + return nil, false +} + +// Set sets value by key. +func (p *Parameters) Set(key, value []byte) { + p.bytes += len(key) + len(value) + + if p.pos < len(p.arr) { + p.arr[p.pos] = pair{key, value} + p.pos++ + return + } + + if p.dyn == nil { + p.dyn = make([]pair, len(p.arr), len(p.arr)+1) + copy(p.dyn, p.arr[:]) + } + p.dyn = append(p.dyn, pair{key, value}) +} + +// ForEach iterates over parameters key-value pairs and calls cb for each one. +func (p *Parameters) ForEach(cb func(k, v []byte) bool) { + for _, v := range p.data() { + if !cb(v.key, v.value) { + break + } + } +} + +// String represents parameters as a string. +func (p *Parameters) String() (ret string) { + ret = "[" + for i, v := range p.data() { + if i > 0 { + ret += " " + } + ret += string(v.key) + ":" + string(v.value) + } + return ret + "]" +} + +func (p *Parameters) data() []pair { + if p.dyn != nil { + return p.dyn + } + return p.arr[:p.pos] +} + +type pair struct { + key, value []byte +} + +func (p pair) copy(dst []byte) (pair, []byte) { + n := copy(dst, p.key) + p.key = dst[:n] + m := n + copy(dst[n:], p.value) + p.value = dst[n:m] + + dst = dst[m:] + + return p, dst +} + +type pairs []pair + +func (p pairs) Len() int { return len(p) } +func (p pairs) Less(a, b int) bool { return bytes.Compare(p[a].key, p[b].key) == -1 } +func (p pairs) Swap(a, b int) { p[a], p[b] = p[b], p[a] } diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/httphead/writer.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/httphead/writer.go new file mode 100644 index 000000000000..e5df3ddf4046 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/httphead/writer.go @@ -0,0 +1,101 @@ +package httphead + +import "io" + +var ( + comma = []byte{','} + equality = []byte{'='} + semicolon = []byte{';'} + quote = []byte{'"'} + escape = []byte{'\\'} +) + +// WriteOptions write options list to the dest. +// It uses the same form as {Scan,Parse}Options functions: +// values = 1#value +// value = token *( ";" param ) +// param = token [ "=" (token | quoted-string) ] +// +// It wraps valuse into the quoted-string sequence if it contains any +// non-token characters. +func WriteOptions(dest io.Writer, options []Option) (n int, err error) { + w := writer{w: dest} + for i, opt := range options { + if i > 0 { + w.write(comma) + } + + writeTokenSanitized(&w, opt.Name) + + for _, p := range opt.Parameters.data() { + w.write(semicolon) + writeTokenSanitized(&w, p.key) + if len(p.value) != 0 { + w.write(equality) + writeTokenSanitized(&w, p.value) + } + } + } + return w.result() +} + +// writeTokenSanitized writes token as is or as quouted string if it contains +// non-token characters. +// +// Note that is is not expects LWS sequnces be in s, cause LWS is used only as +// header field continuation: +// "A CRLF is allowed in the definition of TEXT only as part of a header field +// continuation. It is expected that the folding LWS will be replaced with a +// single SP before interpretation of the TEXT value." +// See https://tools.ietf.org/html/rfc2616#section-2 +// +// That is we sanitizing s for writing, so there could not be any header field +// continuation. +// That is any CRLF will be escaped as any other control characters not allowd in TEXT. +func writeTokenSanitized(bw *writer, bts []byte) { + var qt bool + var pos int + for i := 0; i < len(bts); i++ { + c := bts[i] + if !OctetTypes[c].IsToken() && !qt { + qt = true + bw.write(quote) + } + if OctetTypes[c].IsControl() || c == '"' { + if !qt { + qt = true + bw.write(quote) + } + bw.write(bts[pos:i]) + bw.write(escape) + bw.write(bts[i : i+1]) + pos = i + 1 + } + } + if !qt { + bw.write(bts) + } else { + bw.write(bts[pos:]) + bw.write(quote) + } +} + +type writer struct { + w io.Writer + n int + err error +} + +func (w *writer) write(p []byte) { + if w.err != nil { + return + } + var n int + n, w.err = w.w.Write(p) + w.n += n + return +} + +func (w *writer) result() (int, error) { + return w.n, w.err +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/pool/README.md b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/pool/README.md new file mode 100644 index 000000000000..45685581daee --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/pool/README.md @@ -0,0 +1,107 @@ +# pool + +[![GoDoc][godoc-image]][godoc-url] + +> Tiny memory reuse helpers for Go. + +## generic + +Without use of subpackages, `pool` allows to reuse any struct distinguishable +by size in generic way: + +```go +package main + +import "github.com/gobwas/pool" + +func main() { + x, n := pool.Get(100) // Returns object with size 128 or nil. + if x == nil { + // Create x somehow with knowledge that n is 128. + } + defer pool.Put(x, n) + + // Work with x. +} +``` + +Pool allows you to pass specific options for constructing custom pool: + +```go +package main + +import "github.com/gobwas/pool" + +func main() { + p := pool.Custom( + pool.WithLogSizeMapping(), // Will ceil size n passed to Get(n) to nearest power of two. + pool.WithLogSizeRange(64, 512), // Will reuse objects in logarithmic range [64, 512]. + pool.WithSize(65536), // Will reuse object with size 65536. + ) + x, n := p.Get(1000) // Returns nil and 1000 because mapped size 1000 => 1024 is not reusing by the pool. + defer pool.Put(x, n) // Will not reuse x. + + // Work with x. +} +``` + +Note that there are few non-generic pooling implementations inside subpackages. + +## pbytes + +Subpackage `pbytes` is intended for `[]byte` reuse. + +```go +package main + +import "github.com/gobwas/pool/pbytes" + +func main() { + bts := pbytes.GetCap(100) // Returns make([]byte, 0, 128). + defer pbytes.Put(bts) + + // Work with bts. +} +``` + +You can also create your own range for pooling: + +```go +package main + +import "github.com/gobwas/pool/pbytes" + +func main() { + // Reuse only slices whose capacity is 128, 256, 512 or 1024. + pool := pbytes.New(128, 1024) + + bts := pool.GetCap(100) // Returns make([]byte, 0, 128). + defer pool.Put(bts) + + // Work with bts. +} +``` + +## pbufio + +Subpackage `pbufio` is intended for `*bufio.{Reader, Writer}` reuse. + +```go +package main + +import "github.com/gobwas/pool/pbufio" + +func main() { + bw := pbufio.GetWriter(os.Stdout, 100) // Returns bufio.NewWriterSize(128). + defer pbufio.PutWriter(bw) + + // Work with bw. +} +``` + +Like with `pbytes`, you can also create pool with custom reuse bounds. + + + +[godoc-image]: https://godoc.org/github.com/gobwas/pool?status.svg +[godoc-url]: https://godoc.org/github.com/gobwas/pool diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/pool/generic.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/pool/generic.go new file mode 100644 index 000000000000..d40b362458bb --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/pool/generic.go @@ -0,0 +1,87 @@ +package pool + +import ( + "sync" + + "github.com/gobwas/pool/internal/pmath" +) + +var DefaultPool = New(128, 65536) + +// Get pulls object whose generic size is at least of given size. It also +// returns a real size of x for further pass to Put(). It returns -1 as real +// size for nil x. Size >-1 does not mean that x is non-nil, so checks must be +// done. +// +// Note that size could be ceiled to the next power of two. +// +// Get is a wrapper around DefaultPool.Get(). +func Get(size int) (interface{}, int) { return DefaultPool.Get(size) } + +// Put takes x and its size for future reuse. +// Put is a wrapper around DefaultPool.Put(). +func Put(x interface{}, size int) { DefaultPool.Put(x, size) } + +// Pool contains logic of reusing objects distinguishable by size in generic +// way. +type Pool struct { + pool map[int]*sync.Pool + size func(int) int +} + +// New creates new Pool that reuses objects which size is in logarithmic range +// [min, max]. +// +// Note that it is a shortcut for Custom() constructor with Options provided by +// WithLogSizeMapping() and WithLogSizeRange(min, max) calls. +func New(min, max int) *Pool { + return Custom( + WithLogSizeMapping(), + WithLogSizeRange(min, max), + ) +} + +// Custom creates new Pool with given options. +func Custom(opts ...Option) *Pool { + p := &Pool{ + pool: make(map[int]*sync.Pool), + size: pmath.Identity, + } + + c := (*poolConfig)(p) + for _, opt := range opts { + opt(c) + } + + return p +} + +// Get pulls object whose generic size is at least of given size. +// It also returns a real size of x for further pass to Put() even if x is nil. +// Note that size could be ceiled to the next power of two. +func (p *Pool) Get(size int) (interface{}, int) { + n := p.size(size) + if pool := p.pool[n]; pool != nil { + return pool.Get(), n + } + return nil, size +} + +// Put takes x and its size for future reuse. +func (p *Pool) Put(x interface{}, size int) { + if pool := p.pool[size]; pool != nil { + pool.Put(x) + } +} + +type poolConfig Pool + +// AddSize adds size n to the map. +func (p *poolConfig) AddSize(n int) { + p.pool[n] = new(sync.Pool) +} + +// SetSizeMapping sets up incoming size mapping function. +func (p *poolConfig) SetSizeMapping(size func(int) int) { + p.size = size +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/pool/internal/pmath/pmath.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/pool/internal/pmath/pmath.go new file mode 100644 index 000000000000..df152ed12a54 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/pool/internal/pmath/pmath.go @@ -0,0 +1,65 @@ +package pmath + +const ( + bitsize = 32 << (^uint(0) >> 63) + maxint = int(1<<(bitsize-1) - 1) + maxintHeadBit = 1 << (bitsize - 2) +) + +// LogarithmicRange iterates from ceiled to power of two min to max, +// calling cb on each iteration. +func LogarithmicRange(min, max int, cb func(int)) { + if min == 0 { + min = 1 + } + for n := CeilToPowerOfTwo(min); n <= max; n <<= 1 { + cb(n) + } +} + +// IsPowerOfTwo reports whether given integer is a power of two. +func IsPowerOfTwo(n int) bool { + return n&(n-1) == 0 +} + +// Identity is identity. +func Identity(n int) int { + return n +} + +// CeilToPowerOfTwo returns the least power of two integer value greater than +// or equal to n. +func CeilToPowerOfTwo(n int) int { + if n&maxintHeadBit != 0 && n > maxintHeadBit { + panic("argument is too large") + } + if n <= 2 { + return n + } + n-- + n = fillBits(n) + n++ + return n +} + +// FloorToPowerOfTwo returns the greatest power of two integer value less than +// or equal to n. +func FloorToPowerOfTwo(n int) int { + if n <= 2 { + return n + } + n = fillBits(n) + n >>= 1 + n++ + return n +} + +func fillBits(n int) int { + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n |= n >> 32 + return n +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/pool/option.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/pool/option.go new file mode 100644 index 000000000000..d6e42b700551 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/pool/option.go @@ -0,0 +1,43 @@ +package pool + +import "github.com/gobwas/pool/internal/pmath" + +// Option configures pool. +type Option func(Config) + +// Config describes generic pool configuration. +type Config interface { + AddSize(n int) + SetSizeMapping(func(int) int) +} + +// WithSizeLogRange returns an Option that will add logarithmic range of +// pooling sizes containing [min, max] values. +func WithLogSizeRange(min, max int) Option { + return func(c Config) { + pmath.LogarithmicRange(min, max, func(n int) { + c.AddSize(n) + }) + } +} + +// WithSize returns an Option that will add given pooling size to the pool. +func WithSize(n int) Option { + return func(c Config) { + c.AddSize(n) + } +} + +func WithSizeMapping(sz func(int) int) Option { + return func(c Config) { + c.SetSizeMapping(sz) + } +} + +func WithLogSizeMapping() Option { + return WithSizeMapping(pmath.CeilToPowerOfTwo) +} + +func WithIdentitySizeMapping() Option { + return WithSizeMapping(pmath.Identity) +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/pool/pbufio/pbufio.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/pool/pbufio/pbufio.go new file mode 100644 index 000000000000..d526bd80da85 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/pool/pbufio/pbufio.go @@ -0,0 +1,106 @@ +// Package pbufio contains tools for pooling bufio.Reader and bufio.Writers. +package pbufio + +import ( + "bufio" + "io" + + "github.com/gobwas/pool" +) + +var ( + DefaultWriterPool = NewWriterPool(256, 65536) + DefaultReaderPool = NewReaderPool(256, 65536) +) + +// GetWriter returns bufio.Writer whose buffer has at least size bytes. +// Note that size could be ceiled to the next power of two. +// GetWriter is a wrapper around DefaultWriterPool.Get(). +func GetWriter(w io.Writer, size int) *bufio.Writer { return DefaultWriterPool.Get(w, size) } + +// PutWriter takes bufio.Writer for future reuse. +// It does not reuse bufio.Writer which underlying buffer size is not power of +// PutWriter is a wrapper around DefaultWriterPool.Put(). +func PutWriter(bw *bufio.Writer) { DefaultWriterPool.Put(bw) } + +// GetReader returns bufio.Reader whose buffer has at least size bytes. It returns +// its capacity for further pass to Put(). +// Note that size could be ceiled to the next power of two. +// GetReader is a wrapper around DefaultReaderPool.Get(). +func GetReader(w io.Reader, size int) *bufio.Reader { return DefaultReaderPool.Get(w, size) } + +// PutReader takes bufio.Reader and its size for future reuse. +// It does not reuse bufio.Reader if size is not power of two or is out of pool +// min/max range. +// PutReader is a wrapper around DefaultReaderPool.Put(). +func PutReader(bw *bufio.Reader) { DefaultReaderPool.Put(bw) } + +// WriterPool contains logic of *bufio.Writer reuse with various size. +type WriterPool struct { + pool *pool.Pool +} + +// NewWriterPool creates new WriterPool that reuses writers which size is in +// logarithmic range [min, max]. +func NewWriterPool(min, max int) *WriterPool { + return &WriterPool{pool.New(min, max)} +} + +// CustomWriterPool creates new WriterPool with given options. +func CustomWriterPool(opts ...pool.Option) *WriterPool { + return &WriterPool{pool.Custom(opts...)} +} + +// Get returns bufio.Writer whose buffer has at least size bytes. +func (wp *WriterPool) Get(w io.Writer, size int) *bufio.Writer { + v, n := wp.pool.Get(size) + if v != nil { + bw := v.(*bufio.Writer) + bw.Reset(w) + return bw + } + return bufio.NewWriterSize(w, n) +} + +// Put takes ownership of bufio.Writer for further reuse. +func (wp *WriterPool) Put(bw *bufio.Writer) { + // Should reset even if we do Reset() inside Get(). + // This is done to prevent locking underlying io.Writer from GC. + bw.Reset(nil) + wp.pool.Put(bw, writerSize(bw)) +} + +// ReaderPool contains logic of *bufio.Reader reuse with various size. +type ReaderPool struct { + pool *pool.Pool +} + +// NewReaderPool creates new ReaderPool that reuses writers which size is in +// logarithmic range [min, max]. +func NewReaderPool(min, max int) *ReaderPool { + return &ReaderPool{pool.New(min, max)} +} + +// CustomReaderPool creates new ReaderPool with given options. +func CustomReaderPool(opts ...pool.Option) *ReaderPool { + return &ReaderPool{pool.Custom(opts...)} +} + +// Get returns bufio.Reader whose buffer has at least size bytes. +func (rp *ReaderPool) Get(r io.Reader, size int) *bufio.Reader { + v, n := rp.pool.Get(size) + if v != nil { + br := v.(*bufio.Reader) + br.Reset(r) + return br + } + return bufio.NewReaderSize(r, n) +} + +// Put takes ownership of bufio.Reader for further reuse. +func (rp *ReaderPool) Put(br *bufio.Reader) { + // Should reset even if we do Reset() inside Get(). + // This is done to prevent locking underlying io.Reader from GC. + br.Reset(nil) + rp.pool.Put(br, readerSize(br)) +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/pool/pbufio/pbufio_go110.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/pool/pbufio/pbufio_go110.go new file mode 100644 index 000000000000..c736ae56e110 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/pool/pbufio/pbufio_go110.go @@ -0,0 +1,13 @@ +// +build go1.10 + +package pbufio + +import "bufio" + +func writerSize(bw *bufio.Writer) int { + return bw.Size() +} + +func readerSize(br *bufio.Reader) int { + return br.Size() +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/pool/pbufio/pbufio_go19.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/pool/pbufio/pbufio_go19.go new file mode 100644 index 000000000000..e71dd447d2ab --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/pool/pbufio/pbufio_go19.go @@ -0,0 +1,27 @@ +// +build !go1.10 + +package pbufio + +import "bufio" + +func writerSize(bw *bufio.Writer) int { + return bw.Available() + bw.Buffered() +} + +// readerSize returns buffer size of the given buffered reader. +// NOTE: current workaround implementation resets underlying io.Reader. +func readerSize(br *bufio.Reader) int { + br.Reset(sizeReader) + br.ReadByte() + n := br.Buffered() + 1 + br.Reset(nil) + return n +} + +var sizeReader optimisticReader + +type optimisticReader struct{} + +func (optimisticReader) Read(p []byte) (int, error) { + return len(p), nil +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/pool/pool.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/pool/pool.go new file mode 100644 index 000000000000..1fe9e602fc5d --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/pool/pool.go @@ -0,0 +1,25 @@ +// Package pool contains helpers for pooling structures distinguishable by +// size. +// +// Quick example: +// +// import "github.com/gobwas/pool" +// +// func main() { +// // Reuse objects in logarithmic range from 0 to 64 (0,1,2,4,6,8,16,32,64). +// p := pool.New(0, 64) +// +// buf, n := p.Get(10) // Returns buffer with 16 capacity. +// if buf == nil { +// buf = bytes.NewBuffer(make([]byte, n)) +// } +// defer p.Put(buf, n) +// +// // Work with buf. +// } +// +// There are non-generic implementations for pooling: +// - pool/pbytes for []byte reuse; +// - pool/pbufio for *bufio.Reader and *bufio.Writer reuse; +// +package pool diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/.gitignore b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/.gitignore new file mode 100644 index 000000000000..e3e2b1080d07 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/.gitignore @@ -0,0 +1,5 @@ +bin/ +reports/ +cpu.out +mem.out +ws.test diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/.travis.yml b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/.travis.yml new file mode 100644 index 000000000000..cf74f1bee3c5 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/.travis.yml @@ -0,0 +1,25 @@ +sudo: required + +language: go + +services: + - docker + +os: + - linux + - windows + +go: + - 1.8.x + - 1.9.x + - 1.10.x + - 1.11.x + - 1.x + +install: + - go get github.com/gobwas/pool + - go get github.com/gobwas/httphead + +script: + - if [ "$TRAVIS_OS_NAME" = "windows" ]; then go test ./...; fi + - if [ "$TRAVIS_OS_NAME" = "linux" ]; then make test autobahn; fi diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/Makefile b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/Makefile new file mode 100644 index 000000000000..075e83c74bc0 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/Makefile @@ -0,0 +1,47 @@ +BENCH ?=. +BENCH_BASE?=master + +clean: + rm -f bin/reporter + rm -fr autobahn/report/* + +bin/reporter: + go build -o bin/reporter ./autobahn + +bin/gocovmerge: + go build -o bin/gocovmerge github.com/wadey/gocovmerge + +.PHONY: autobahn +autobahn: clean bin/reporter + ./autobahn/script/test.sh --build + bin/reporter $(PWD)/autobahn/report/index.json + +test: + go test -coverprofile=ws.coverage . + go test -coverprofile=wsutil.coverage ./wsutil + +cover: bin/gocovmerge test autobahn + bin/gocovmerge ws.coverage wsutil.coverage autobahn/report/server.coverage > total.coverage + +benchcmp: BENCH_BRANCH=$(shell git rev-parse --abbrev-ref HEAD) +benchcmp: BENCH_OLD:=$(shell mktemp -t old.XXXX) +benchcmp: BENCH_NEW:=$(shell mktemp -t new.XXXX) +benchcmp: + if [ ! -z "$(shell git status -s)" ]; then\ + echo "could not compare with $(BENCH_BASE) – found unstaged changes";\ + exit 1;\ + fi;\ + if [ "$(BENCH_BRANCH)" == "$(BENCH_BASE)" ]; then\ + echo "comparing the same branches";\ + exit 1;\ + fi;\ + echo "benchmarking $(BENCH_BRANCH)...";\ + go test -run=none -bench=$(BENCH) -benchmem > $(BENCH_NEW);\ + echo "benchmarking $(BENCH_BASE)...";\ + git checkout -q $(BENCH_BASE);\ + go test -run=none -bench=$(BENCH) -benchmem > $(BENCH_OLD);\ + git checkout -q $(BENCH_BRANCH);\ + echo "\nresults:";\ + echo "========\n";\ + benchcmp $(BENCH_OLD) $(BENCH_NEW);\ + diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/README.md b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/README.md new file mode 100644 index 000000000000..74acd78bd08f --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/README.md @@ -0,0 +1,360 @@ +# ws + +[![GoDoc][godoc-image]][godoc-url] +[![Travis][travis-image]][travis-url] + +> [RFC6455][rfc-url] WebSocket implementation in Go. + +# Features + +- Zero-copy upgrade +- No intermediate allocations during I/O +- Low-level API which allows to build your own logic of packet handling and + buffers reuse +- High-level wrappers and helpers around API in `wsutil` package, which allow + to start fast without digging the protocol internals + +# Documentation + +[GoDoc][godoc-url]. + +# Why + +Existing WebSocket implementations do not allow users to reuse I/O buffers +between connections in clear way. This library aims to export efficient +low-level interface for working with the protocol without forcing only one way +it could be used. + +By the way, if you want get the higher-level tools, you can use `wsutil` +package. + +# Status + +Library is tagged as `v1*` so its API must not be broken during some +improvements or refactoring. + +This implementation of RFC6455 passes [Autobahn Test +Suite](https://github.com/crossbario/autobahn-testsuite) and currently has +about 78% coverage. + +# Examples + +Example applications using `ws` are developed in separate repository +[ws-examples](https://github.com/gobwas/ws-examples). + +# Usage + +The higher-level example of WebSocket echo server: + +```go +package main + +import ( + "net/http" + + "github.com/gobwas/ws" + "github.com/gobwas/ws/wsutil" +) + +func main() { + http.ListenAndServe(":8080", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, _, _, err := ws.UpgradeHTTP(r, w) + if err != nil { + // handle error + } + go func() { + defer conn.Close() + + for { + msg, op, err := wsutil.ReadClientData(conn) + if err != nil { + // handle error + } + err = wsutil.WriteServerMessage(conn, op, msg) + if err != nil { + // handle error + } + } + }() + })) +} +``` + +Lower-level, but still high-level example: + + +```go +import ( + "net/http" + "io" + + "github.com/gobwas/ws" + "github.com/gobwas/ws/wsutil" +) + +func main() { + http.ListenAndServe(":8080", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, _, _, err := ws.UpgradeHTTP(r, w) + if err != nil { + // handle error + } + go func() { + defer conn.Close() + + var ( + state = ws.StateServerSide + reader = wsutil.NewReader(conn, state) + writer = wsutil.NewWriter(conn, state, ws.OpText) + ) + for { + header, err := reader.NextFrame() + if err != nil { + // handle error + } + + // Reset writer to write frame with right operation code. + writer.Reset(conn, state, header.OpCode) + + if _, err = io.Copy(writer, reader); err != nil { + // handle error + } + if err = writer.Flush(); err != nil { + // handle error + } + } + }() + })) +} +``` + +We can apply the same pattern to read and write structured responses through a JSON encoder and decoder.: + +```go + ... + var ( + r = wsutil.NewReader(conn, ws.StateServerSide) + w = wsutil.NewWriter(conn, ws.StateServerSide, ws.OpText) + decoder = json.NewDecoder(r) + encoder = json.NewEncoder(w) + ) + for { + hdr, err = r.NextFrame() + if err != nil { + return err + } + if hdr.OpCode == ws.OpClose { + return io.EOF + } + var req Request + if err := decoder.Decode(&req); err != nil { + return err + } + var resp Response + if err := encoder.Encode(&resp); err != nil { + return err + } + if err = w.Flush(); err != nil { + return err + } + } + ... +``` + +The lower-level example without `wsutil`: + +```go +package main + +import ( + "net" + "io" + + "github.com/gobwas/ws" +) + +func main() { + ln, err := net.Listen("tcp", "localhost:8080") + if err != nil { + log.Fatal(err) + } + + for { + conn, err := ln.Accept() + if err != nil { + // handle error + } + _, err = ws.Upgrade(conn) + if err != nil { + // handle error + } + + go func() { + defer conn.Close() + + for { + header, err := ws.ReadHeader(conn) + if err != nil { + // handle error + } + + payload := make([]byte, header.Length) + _, err = io.ReadFull(conn, payload) + if err != nil { + // handle error + } + if header.Masked { + ws.Cipher(payload, header.Mask, 0) + } + + // Reset the Masked flag, server frames must not be masked as + // RFC6455 says. + header.Masked = false + + if err := ws.WriteHeader(conn, header); err != nil { + // handle error + } + if _, err := conn.Write(payload); err != nil { + // handle error + } + + if header.OpCode == ws.OpClose { + return + } + } + }() + } +} +``` + +# Zero-copy upgrade + +Zero-copy upgrade helps to avoid unnecessary allocations and copying while +handling HTTP Upgrade request. + +Processing of all non-websocket headers is made in place with use of registered +user callbacks whose arguments are only valid until callback returns. + +The simple example looks like this: + +```go +package main + +import ( + "net" + "log" + + "github.com/gobwas/ws" +) + +func main() { + ln, err := net.Listen("tcp", "localhost:8080") + if err != nil { + log.Fatal(err) + } + u := ws.Upgrader{ + OnHeader: func(key, value []byte) (err error) { + log.Printf("non-websocket header: %q=%q", key, value) + return + }, + } + for { + conn, err := ln.Accept() + if err != nil { + // handle error + } + + _, err = u.Upgrade(conn) + if err != nil { + // handle error + } + } +} +``` + +Usage of `ws.Upgrader` here brings ability to control incoming connections on +tcp level and simply not to accept them by some logic. + +Zero-copy upgrade is for high-load services which have to control many +resources such as connections buffers. + +The real life example could be like this: + +```go +package main + +import ( + "fmt" + "io" + "log" + "net" + "net/http" + "runtime" + + "github.com/gobwas/httphead" + "github.com/gobwas/ws" +) + +func main() { + ln, err := net.Listen("tcp", "localhost:8080") + if err != nil { + // handle error + } + + // Prepare handshake header writer from http.Header mapping. + header := ws.HandshakeHeaderHTTP(http.Header{ + "X-Go-Version": []string{runtime.Version()}, + }) + + u := ws.Upgrader{ + OnHost: func(host []byte) error { + if string(host) == "github.com" { + return nil + } + return ws.RejectConnectionError( + ws.RejectionStatus(403), + ws.RejectionHeader(ws.HandshakeHeaderString( + "X-Want-Host: github.com\r\n", + )), + ) + }, + OnHeader: func(key, value []byte) error { + if string(key) != "Cookie" { + return nil + } + ok := httphead.ScanCookie(value, func(key, value []byte) bool { + // Check session here or do some other stuff with cookies. + // Maybe copy some values for future use. + return true + }) + if ok { + return nil + } + return ws.RejectConnectionError( + ws.RejectionReason("bad cookie"), + ws.RejectionStatus(400), + ) + }, + OnBeforeUpgrade: func() (ws.HandshakeHeader, error) { + return header, nil + }, + } + for { + conn, err := ln.Accept() + if err != nil { + log.Fatal(err) + } + _, err = u.Upgrade(conn) + if err != nil { + log.Printf("upgrade error: %s", err) + } + } +} +``` + + + +[rfc-url]: https://tools.ietf.org/html/rfc6455 +[godoc-image]: https://godoc.org/github.com/gobwas/ws?status.svg +[godoc-url]: https://godoc.org/github.com/gobwas/ws +[travis-image]: https://travis-ci.org/gobwas/ws.svg?branch=master +[travis-url]: https://travis-ci.org/gobwas/ws diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/check.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/check.go new file mode 100644 index 000000000000..8aa0df8cc28f --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/check.go @@ -0,0 +1,145 @@ +package ws + +import "unicode/utf8" + +// State represents state of websocket endpoint. +// It used by some functions to be more strict when checking compatibility with RFC6455. +type State uint8 + +const ( + // StateServerSide means that endpoint (caller) is a server. + StateServerSide State = 0x1 << iota + // StateClientSide means that endpoint (caller) is a client. + StateClientSide + // StateExtended means that extension was negotiated during handshake. + StateExtended + // StateFragmented means that endpoint (caller) has received fragmented + // frame and waits for continuation parts. + StateFragmented +) + +// Is checks whether the s has v enabled. +func (s State) Is(v State) bool { + return uint8(s)&uint8(v) != 0 +} + +// Set enables v state on s. +func (s State) Set(v State) State { + return s | v +} + +// Clear disables v state on s. +func (s State) Clear(v State) State { + return s & (^v) +} + +// ServerSide reports whether states represents server side. +func (s State) ServerSide() bool { return s.Is(StateServerSide) } + +// ClientSide reports whether state represents client side. +func (s State) ClientSide() bool { return s.Is(StateClientSide) } + +// Extended reports whether state is extended. +func (s State) Extended() bool { return s.Is(StateExtended) } + +// Fragmented reports whether state is fragmented. +func (s State) Fragmented() bool { return s.Is(StateFragmented) } + +// ProtocolError describes error during checking/parsing websocket frames or +// headers. +type ProtocolError string + +// Error implements error interface. +func (p ProtocolError) Error() string { return string(p) } + +// Errors used by the protocol checkers. +var ( + ErrProtocolOpCodeReserved = ProtocolError("use of reserved op code") + ErrProtocolControlPayloadOverflow = ProtocolError("control frame payload limit exceeded") + ErrProtocolControlNotFinal = ProtocolError("control frame is not final") + ErrProtocolNonZeroRsv = ProtocolError("non-zero rsv bits with no extension negotiated") + ErrProtocolMaskRequired = ProtocolError("frames from client to server must be masked") + ErrProtocolMaskUnexpected = ProtocolError("frames from server to client must be not masked") + ErrProtocolContinuationExpected = ProtocolError("unexpected non-continuation data frame") + ErrProtocolContinuationUnexpected = ProtocolError("unexpected continuation data frame") + ErrProtocolStatusCodeNotInUse = ProtocolError("status code is not in use") + ErrProtocolStatusCodeApplicationLevel = ProtocolError("status code is only application level") + ErrProtocolStatusCodeNoMeaning = ProtocolError("status code has no meaning yet") + ErrProtocolStatusCodeUnknown = ProtocolError("status code is not defined in spec") + ErrProtocolInvalidUTF8 = ProtocolError("invalid utf8 sequence in close reason") +) + +// CheckHeader checks h to contain valid header data for given state s. +// +// Note that zero state (0) means that state is clean, +// neither server or client side, nor fragmented, nor extended. +func CheckHeader(h Header, s State) error { + if h.OpCode.IsReserved() { + return ErrProtocolOpCodeReserved + } + if h.OpCode.IsControl() { + if h.Length > MaxControlFramePayloadSize { + return ErrProtocolControlPayloadOverflow + } + if !h.Fin { + return ErrProtocolControlNotFinal + } + } + + switch { + // [RFC6455]: MUST be 0 unless an extension is negotiated that defines meanings for + // non-zero values. If a nonzero value is received and none of the + // negotiated extensions defines the meaning of such a nonzero value, the + // receiving endpoint MUST _Fail the WebSocket Connection_. + case h.Rsv != 0 && !s.Extended(): + return ErrProtocolNonZeroRsv + + // [RFC6455]: The server MUST close the connection upon receiving a frame that is not masked. + // In this case, a server MAY send a Close frame with a status code of 1002 (protocol error) + // as defined in Section 7.4.1. A server MUST NOT mask any frames that it sends to the client. + // A client MUST close a connection if it detects a masked frame. In this case, it MAY use the + // status code 1002 (protocol error) as defined in Section 7.4.1. + case s.ServerSide() && !h.Masked: + return ErrProtocolMaskRequired + case s.ClientSide() && h.Masked: + return ErrProtocolMaskUnexpected + + // [RFC6455]: See detailed explanation in 5.4 section. + case s.Fragmented() && !h.OpCode.IsControl() && h.OpCode != OpContinuation: + return ErrProtocolContinuationExpected + case !s.Fragmented() && h.OpCode == OpContinuation: + return ErrProtocolContinuationUnexpected + + default: + return nil + } +} + +// CheckCloseFrameData checks received close information +// to be valid RFC6455 compatible close info. +// +// Note that code.Empty() or code.IsAppLevel() will raise error. +// +// If endpoint sends close frame without status code (with frame.Length = 0), +// application should not check its payload. +func CheckCloseFrameData(code StatusCode, reason string) error { + switch { + case code.IsNotUsed(): + return ErrProtocolStatusCodeNotInUse + + case code.IsProtocolReserved(): + return ErrProtocolStatusCodeApplicationLevel + + case code == StatusNoMeaningYet: + return ErrProtocolStatusCodeNoMeaning + + case code.IsProtocolSpec() && !code.IsProtocolDefined(): + return ErrProtocolStatusCodeUnknown + + case !utf8.ValidString(reason): + return ErrProtocolInvalidUTF8 + + default: + return nil + } +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/cipher.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/cipher.go new file mode 100644 index 000000000000..11a2af99bfc4 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/cipher.go @@ -0,0 +1,59 @@ +package ws + +import ( + "encoding/binary" + "unsafe" +) + +// Cipher applies XOR cipher to the payload using mask. +// Offset is used to cipher chunked data (e.g. in io.Reader implementations). +// +// To convert masked data into unmasked data, or vice versa, the following +// algorithm is applied. The same algorithm applies regardless of the +// direction of the translation, e.g., the same steps are applied to +// mask the data as to unmask the data. +func Cipher(payload []byte, mask [4]byte, offset int) { + n := len(payload) + if n < 8 { + for i := 0; i < n; i++ { + payload[i] ^= mask[(offset+i)%4] + } + return + } + + // Calculate position in mask due to previously processed bytes number. + mpos := offset % 4 + // Count number of bytes will processed one by one from the beginning of payload. + ln := remain[mpos] + // Count number of bytes will processed one by one from the end of payload. + // This is done to process payload by 8 bytes in each iteration of main loop. + rn := (n - ln) % 8 + + for i := 0; i < ln; i++ { + payload[i] ^= mask[(mpos+i)%4] + } + for i := n - rn; i < n; i++ { + payload[i] ^= mask[(mpos+i)%4] + } + + // We should cast mask to uint32 with unsafe instead of encoding.BigEndian + // to avoid care of os dependent byte order. That is, on any endianess mask + // and payload will be presented with the same order. In other words, we + // could not use encoding.BigEndian on xoring payload as uint64. + m := *(*uint32)(unsafe.Pointer(&mask)) + m2 := uint64(m)<<32 | uint64(m) + + // Skip already processed right part. + // Get number of uint64 parts remaining to process. + n = (n - ln - rn) >> 3 + for i := 0; i < n; i++ { + idx := ln + (i << 3) + p := binary.LittleEndian.Uint64(payload[idx : idx+8]) + p = p ^ m2 + binary.LittleEndian.PutUint64(payload[idx:idx+8], p) + } +} + +// remain maps position in masking key [0,4) to number +// of bytes that need to be processed manually inside Cipher(). +var remain = [4]int{0, 3, 2, 1} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/dialer.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/dialer.go new file mode 100644 index 000000000000..4357be2142b9 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/dialer.go @@ -0,0 +1,556 @@ +package ws + +import ( + "bufio" + "bytes" + "context" + "crypto/tls" + "fmt" + "io" + "net" + "net/url" + "strconv" + "strings" + "time" + + "github.com/gobwas/httphead" + "github.com/gobwas/pool/pbufio" +) + +// Constants used by Dialer. +const ( + DefaultClientReadBufferSize = 4096 + DefaultClientWriteBufferSize = 4096 +) + +// Handshake represents handshake result. +type Handshake struct { + // Protocol is the subprotocol selected during handshake. + Protocol string + + // Extensions is the list of negotiated extensions. + Extensions []httphead.Option +} + +// Errors used by the websocket client. +var ( + ErrHandshakeBadStatus = fmt.Errorf("unexpected http status") + ErrHandshakeBadSubProtocol = fmt.Errorf("unexpected protocol in %q header", headerSecProtocol) + ErrHandshakeBadExtensions = fmt.Errorf("unexpected extensions in %q header", headerSecProtocol) +) + +// DefaultDialer is dialer that holds no options and is used by Dial function. +var DefaultDialer Dialer + +// Dial is like Dialer{}.Dial(). +func Dial(ctx context.Context, urlstr string) (net.Conn, *bufio.Reader, Handshake, error) { + return DefaultDialer.Dial(ctx, urlstr) +} + +// Dialer contains options for establishing websocket connection to an url. +type Dialer struct { + // ReadBufferSize and WriteBufferSize is an I/O buffer sizes. + // They used to read and write http data while upgrading to WebSocket. + // Allocated buffers are pooled with sync.Pool to avoid extra allocations. + // + // If a size is zero then default value is used. + ReadBufferSize, WriteBufferSize int + + // Timeout is the maximum amount of time a Dial() will wait for a connect + // and an handshake to complete. + // + // The default is no timeout. + Timeout time.Duration + + // Protocols is the list of subprotocols that the client wants to speak, + // ordered by preference. + // + // See https://tools.ietf.org/html/rfc6455#section-4.1 + Protocols []string + + // Extensions is the list of extensions that client wants to speak. + // + // Note that if server decides to use some of this extensions, Dial() will + // return Handshake struct containing a slice of items, which are the + // shallow copies of the items from this list. That is, internals of + // Extensions items are shared during Dial(). + // + // See https://tools.ietf.org/html/rfc6455#section-4.1 + // See https://tools.ietf.org/html/rfc6455#section-9.1 + Extensions []httphead.Option + + // Header is an optional HandshakeHeader instance that could be used to + // write additional headers to the handshake request. + // + // It used instead of any key-value mappings to avoid allocations in user + // land. + Header HandshakeHeader + + // OnStatusError is the callback that will be called after receiving non + // "101 Continue" HTTP response status. It receives an io.Reader object + // representing server response bytes. That is, it gives ability to parse + // HTTP response somehow (probably with http.ReadResponse call) and make a + // decision of further logic. + // + // The arguments are only valid until the callback returns. + OnStatusError func(status int, reason []byte, resp io.Reader) + + // OnHeader is the callback that will be called after successful parsing of + // header, that is not used during WebSocket handshake procedure. That is, + // it will be called with non-websocket headers, which could be relevant + // for application-level logic. + // + // The arguments are only valid until the callback returns. + // + // Returned value could be used to prevent processing response. + OnHeader func(key, value []byte) (err error) + + // NetDial is the function that is used to get plain tcp connection. + // If it is not nil, then it is used instead of net.Dialer. + NetDial func(ctx context.Context, network, addr string) (net.Conn, error) + + // TLSClient is the callback that will be called after successful dial with + // received connection and its remote host name. If it is nil, then the + // default tls.Client() will be used. + // If it is not nil, then TLSConfig field is ignored. + TLSClient func(conn net.Conn, hostname string) net.Conn + + // TLSConfig is passed to tls.Client() to start TLS over established + // connection. If TLSClient is not nil, then it is ignored. If TLSConfig is + // non-nil and its ServerName is empty, then for every Dial() it will be + // cloned and appropriate ServerName will be set. + TLSConfig *tls.Config + + // WrapConn is the optional callback that will be called when connection is + // ready for an i/o. That is, it will be called after successful dial and + // TLS initialization (for "wss" schemes). It may be helpful for different + // user land purposes such as end to end encryption. + // + // Note that for debugging purposes of an http handshake (e.g. sent request + // and received response), there is an wsutil.DebugDialer struct. + WrapConn func(conn net.Conn) net.Conn +} + +// Dial connects to the url host and upgrades connection to WebSocket. +// +// If server has sent frames right after successful handshake then returned +// buffer will be non-nil. In other cases buffer is always nil. For better +// memory efficiency received non-nil bufio.Reader should be returned to the +// inner pool with PutReader() function after use. +// +// Note that Dialer does not implement IDNA (RFC5895) logic as net/http does. +// If you want to dial non-ascii host name, take care of its name serialization +// avoiding bad request issues. For more info see net/http Request.Write() +// implementation, especially cleanHost() function. +func (d Dialer) Dial(ctx context.Context, urlstr string) (conn net.Conn, br *bufio.Reader, hs Handshake, err error) { + u, err := url.ParseRequestURI(urlstr) + if err != nil { + return + } + + // Prepare context to dial with. Initially it is the same as original, but + // if d.Timeout is non-zero and points to time that is before ctx.Deadline, + // we use more shorter context for dial. + dialctx := ctx + + var deadline time.Time + if t := d.Timeout; t != 0 { + deadline = time.Now().Add(t) + if d, ok := ctx.Deadline(); !ok || deadline.Before(d) { + var cancel context.CancelFunc + dialctx, cancel = context.WithDeadline(ctx, deadline) + defer cancel() + } + } + if conn, err = d.dial(dialctx, u); err != nil { + return + } + defer func() { + if err != nil { + conn.Close() + } + }() + if ctx == context.Background() { + // No need to start I/O interrupter goroutine which is not zero-cost. + conn.SetDeadline(deadline) + defer conn.SetDeadline(noDeadline) + } else { + // Context could be canceled or its deadline could be exceeded. + // Start the interrupter goroutine to handle context cancelation. + done := setupContextDeadliner(ctx, conn) + defer func() { + // Map Upgrade() error to a possible context expiration error. That + // is, even if Upgrade() err is nil, context could be already + // expired and connection be "poisoned" by SetDeadline() call. + // In that case we must not return ctx.Err() error. + done(&err) + }() + } + + br, hs, err = d.Upgrade(conn, u) + + return +} + +var ( + // netEmptyDialer is a net.Dialer without options, used in Dialer.dial() if + // Dialer.NetDial is not provided. + netEmptyDialer net.Dialer + // tlsEmptyConfig is an empty tls.Config used as default one. + tlsEmptyConfig tls.Config +) + +func tlsDefaultConfig() *tls.Config { + return &tlsEmptyConfig +} + +func hostport(host string, defaultPort string) (hostname, addr string) { + var ( + colon = strings.LastIndexByte(host, ':') + bracket = strings.IndexByte(host, ']') + ) + if colon > bracket { + return host[:colon], host + } + return host, host + defaultPort +} + +func (d Dialer) dial(ctx context.Context, u *url.URL) (conn net.Conn, err error) { + dial := d.NetDial + if dial == nil { + dial = netEmptyDialer.DialContext + } + switch u.Scheme { + case "ws": + _, addr := hostport(u.Host, ":80") + conn, err = dial(ctx, "tcp", addr) + case "wss": + hostname, addr := hostport(u.Host, ":443") + conn, err = dial(ctx, "tcp", addr) + if err != nil { + return + } + tlsClient := d.TLSClient + if tlsClient == nil { + tlsClient = d.tlsClient + } + conn = tlsClient(conn, hostname) + default: + return nil, fmt.Errorf("unexpected websocket scheme: %q", u.Scheme) + } + if wrap := d.WrapConn; wrap != nil { + conn = wrap(conn) + } + return +} + +func (d Dialer) tlsClient(conn net.Conn, hostname string) net.Conn { + config := d.TLSConfig + if config == nil { + config = tlsDefaultConfig() + } + if config.ServerName == "" { + config = tlsCloneConfig(config) + config.ServerName = hostname + } + // Do not make conn.Handshake() here because downstairs we will prepare + // i/o on this conn with proper context's timeout handling. + return tls.Client(conn, config) +} + +var ( + // This variables are set like in net/net.go. + // noDeadline is just zero value for readability. + noDeadline = time.Time{} + // aLongTimeAgo is a non-zero time, far in the past, used for immediate + // cancelation of dials. + aLongTimeAgo = time.Unix(42, 0) +) + +// Upgrade writes an upgrade request to the given io.ReadWriter conn at given +// url u and reads a response from it. +// +// It is a caller responsibility to manage I/O deadlines on conn. +// +// It returns handshake info and some bytes which could be written by the peer +// right after response and be caught by us during buffered read. +func (d Dialer) Upgrade(conn io.ReadWriter, u *url.URL) (br *bufio.Reader, hs Handshake, err error) { + // headerSeen constants helps to report whether or not some header was seen + // during reading request bytes. + const ( + headerSeenUpgrade = 1 << iota + headerSeenConnection + headerSeenSecAccept + + // headerSeenAll is the value that we expect to receive at the end of + // headers read/parse loop. + headerSeenAll = 0 | + headerSeenUpgrade | + headerSeenConnection | + headerSeenSecAccept + ) + + br = pbufio.GetReader(conn, + nonZero(d.ReadBufferSize, DefaultClientReadBufferSize), + ) + bw := pbufio.GetWriter(conn, + nonZero(d.WriteBufferSize, DefaultClientWriteBufferSize), + ) + defer func() { + pbufio.PutWriter(bw) + if br.Buffered() == 0 || err != nil { + // Server does not wrote additional bytes to the connection or + // error occurred. That is, no reason to return buffer. + pbufio.PutReader(br) + br = nil + } + }() + + nonce := make([]byte, nonceSize) + initNonce(nonce) + + httpWriteUpgradeRequest(bw, u, nonce, d.Protocols, d.Extensions, d.Header) + if err = bw.Flush(); err != nil { + return + } + + // Read HTTP status line like "HTTP/1.1 101 Switching Protocols". + sl, err := readLine(br) + if err != nil { + return + } + // Begin validation of the response. + // See https://tools.ietf.org/html/rfc6455#section-4.2.2 + // Parse request line data like HTTP version, uri and method. + resp, err := httpParseResponseLine(sl) + if err != nil { + return + } + // Even if RFC says "1.1 or higher" without mentioning the part of the + // version, we apply it only to minor part. + if resp.major != 1 || resp.minor < 1 { + err = ErrHandshakeBadProtocol + return + } + if resp.status != 101 { + err = StatusError(resp.status) + if onStatusError := d.OnStatusError; onStatusError != nil { + // Invoke callback with multireader of status-line bytes br. + onStatusError(resp.status, resp.reason, + io.MultiReader( + bytes.NewReader(sl), + strings.NewReader(crlf), + br, + ), + ) + } + return + } + // If response status is 101 then we expect all technical headers to be + // valid. If not, then we stop processing response without giving user + // ability to read non-technical headers. That is, we do not distinguish + // technical errors (such as parsing error) and protocol errors. + var headerSeen byte + for { + line, e := readLine(br) + if e != nil { + err = e + return + } + if len(line) == 0 { + // Blank line, no more lines to read. + break + } + + k, v, ok := httpParseHeaderLine(line) + if !ok { + err = ErrMalformedResponse + return + } + + switch btsToString(k) { + case headerUpgradeCanonical: + headerSeen |= headerSeenUpgrade + if !bytes.Equal(v, specHeaderValueUpgrade) && !bytes.EqualFold(v, specHeaderValueUpgrade) { + err = ErrHandshakeBadUpgrade + return + } + + case headerConnectionCanonical: + headerSeen |= headerSeenConnection + // Note that as RFC6455 says: + // > A |Connection| header field with value "Upgrade". + // That is, in server side, "Connection" header could contain + // multiple token. But in response it must contains exactly one. + if !bytes.Equal(v, specHeaderValueConnection) && !bytes.EqualFold(v, specHeaderValueConnection) { + err = ErrHandshakeBadConnection + return + } + + case headerSecAcceptCanonical: + headerSeen |= headerSeenSecAccept + if !checkAcceptFromNonce(v, nonce) { + err = ErrHandshakeBadSecAccept + return + } + + case headerSecProtocolCanonical: + // RFC6455 1.3: + // "The server selects one or none of the acceptable protocols + // and echoes that value in its handshake to indicate that it has + // selected that protocol." + for _, want := range d.Protocols { + if string(v) == want { + hs.Protocol = want + break + } + } + if hs.Protocol == "" { + // Server echoed subprotocol that is not present in client + // requested protocols. + err = ErrHandshakeBadSubProtocol + return + } + + case headerSecExtensionsCanonical: + hs.Extensions, err = matchSelectedExtensions(v, d.Extensions, hs.Extensions) + if err != nil { + return + } + + default: + if onHeader := d.OnHeader; onHeader != nil { + if e := onHeader(k, v); e != nil { + err = e + return + } + } + } + } + if err == nil && headerSeen != headerSeenAll { + switch { + case headerSeen&headerSeenUpgrade == 0: + err = ErrHandshakeBadUpgrade + case headerSeen&headerSeenConnection == 0: + err = ErrHandshakeBadConnection + case headerSeen&headerSeenSecAccept == 0: + err = ErrHandshakeBadSecAccept + default: + panic("unknown headers state") + } + } + return +} + +// PutReader returns bufio.Reader instance to the inner reuse pool. +// It is useful in rare cases, when Dialer.Dial() returns non-nil buffer which +// contains unprocessed buffered data, that was sent by the server quickly +// right after handshake. +func PutReader(br *bufio.Reader) { + pbufio.PutReader(br) +} + +// StatusError contains an unexpected status-line code from the server. +type StatusError int + +func (s StatusError) Error() string { + return "unexpected HTTP response status: " + strconv.Itoa(int(s)) +} + +func isTimeoutError(err error) bool { + t, ok := err.(net.Error) + return ok && t.Timeout() +} + +func matchSelectedExtensions(selected []byte, wanted, received []httphead.Option) ([]httphead.Option, error) { + if len(selected) == 0 { + return received, nil + } + var ( + index int + option httphead.Option + err error + ) + index = -1 + match := func() (ok bool) { + for _, want := range wanted { + if option.Equal(want) { + // Check parsed extension to be present in client + // requested extensions. We move matched extension + // from client list to avoid allocation. + received = append(received, want) + return true + } + } + return false + } + ok := httphead.ScanOptions(selected, func(i int, name, attr, val []byte) httphead.Control { + if i != index { + // Met next option. + index = i + if i != 0 && !match() { + // Server returned non-requested extension. + err = ErrHandshakeBadExtensions + return httphead.ControlBreak + } + option = httphead.Option{Name: name} + } + if attr != nil { + option.Parameters.Set(attr, val) + } + return httphead.ControlContinue + }) + if !ok { + err = ErrMalformedResponse + return received, err + } + if !match() { + return received, ErrHandshakeBadExtensions + } + return received, err +} + +// setupContextDeadliner is a helper function that starts connection I/O +// interrupter goroutine. +// +// Started goroutine calls SetDeadline() with long time ago value when context +// become expired to make any I/O operations failed. It returns done function +// that stops started goroutine and maps error received from conn I/O methods +// to possible context expiration error. +// +// In concern with possible SetDeadline() call inside interrupter goroutine, +// caller passes pointer to its I/O error (even if it is nil) to done(&err). +// That is, even if I/O error is nil, context could be already expired and +// connection "poisoned" by SetDeadline() call. In that case done(&err) will +// store at *err ctx.Err() result. If err is caused not by timeout, it will +// leaved untouched. +func setupContextDeadliner(ctx context.Context, conn net.Conn) (done func(*error)) { + var ( + quit = make(chan struct{}) + interrupt = make(chan error, 1) + ) + go func() { + select { + case <-quit: + interrupt <- nil + case <-ctx.Done(): + // Cancel i/o immediately. + conn.SetDeadline(aLongTimeAgo) + interrupt <- ctx.Err() + } + }() + return func(err *error) { + close(quit) + // If ctx.Err() is non-nil and the original err is net.Error with + // Timeout() == true, then it means that I/O was canceled by us by + // SetDeadline(aLongTimeAgo) call, or by somebody else previously + // by conn.SetDeadline(x). + // + // Even on race condition when both deadlines are expired + // (SetDeadline() made not by us and context's), we prefer ctx.Err() to + // be returned. + if ctxErr := <-interrupt; ctxErr != nil && (*err == nil || isTimeoutError(*err)) { + *err = ctxErr + } + } +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/dialer_tls_go17.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/dialer_tls_go17.go new file mode 100644 index 000000000000..b606e0ad909b --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/dialer_tls_go17.go @@ -0,0 +1,35 @@ +// +build !go1.8 + +package ws + +import "crypto/tls" + +func tlsCloneConfig(c *tls.Config) *tls.Config { + // NOTE: we copying SessionTicketsDisabled and SessionTicketKey here + // without calling inner c.initOnceServer somehow because we only could get + // here from the ws.Dialer code, which is obviously a client and makes + // tls.Client() when it gets new net.Conn. + return &tls.Config{ + Rand: c.Rand, + Time: c.Time, + Certificates: c.Certificates, + NameToCertificate: c.NameToCertificate, + GetCertificate: c.GetCertificate, + RootCAs: c.RootCAs, + NextProtos: c.NextProtos, + ServerName: c.ServerName, + ClientAuth: c.ClientAuth, + ClientCAs: c.ClientCAs, + InsecureSkipVerify: c.InsecureSkipVerify, + CipherSuites: c.CipherSuites, + PreferServerCipherSuites: c.PreferServerCipherSuites, + SessionTicketsDisabled: c.SessionTicketsDisabled, + SessionTicketKey: c.SessionTicketKey, + ClientSessionCache: c.ClientSessionCache, + MinVersion: c.MinVersion, + MaxVersion: c.MaxVersion, + CurvePreferences: c.CurvePreferences, + DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, + Renegotiation: c.Renegotiation, + } +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/dialer_tls_go18.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/dialer_tls_go18.go new file mode 100644 index 000000000000..a6704d5173a3 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/dialer_tls_go18.go @@ -0,0 +1,9 @@ +// +build go1.8 + +package ws + +import "crypto/tls" + +func tlsCloneConfig(c *tls.Config) *tls.Config { + return c.Clone() +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/doc.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/doc.go new file mode 100644 index 000000000000..c9d5791570c1 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/doc.go @@ -0,0 +1,81 @@ +/* +Package ws implements a client and server for the WebSocket protocol as +specified in RFC 6455. + +The main purpose of this package is to provide simple low-level API for +efficient work with protocol. + +Overview. + +Upgrade to WebSocket (or WebSocket handshake) can be done in two ways. + +The first way is to use `net/http` server: + + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + conn, _, _, err := ws.UpgradeHTTP(r, w) + }) + +The second and much more efficient way is so-called "zero-copy upgrade". It +avoids redundant allocations and copying of not used headers or other request +data. User decides by himself which data should be copied. + + ln, err := net.Listen("tcp", ":8080") + if err != nil { + // handle error + } + + conn, err := ln.Accept() + if err != nil { + // handle error + } + + handshake, err := ws.Upgrade(conn) + if err != nil { + // handle error + } + +For customization details see `ws.Upgrader` documentation. + +After WebSocket handshake you can work with connection in multiple ways. +That is, `ws` does not force the only one way of how to work with WebSocket: + + header, err := ws.ReadHeader(conn) + if err != nil { + // handle err + } + + buf := make([]byte, header.Length) + _, err := io.ReadFull(conn, buf) + if err != nil { + // handle err + } + + resp := ws.NewBinaryFrame([]byte("hello, world!")) + if err := ws.WriteFrame(conn, frame); err != nil { + // handle err + } + +As you can see, it stream friendly: + + const N = 42 + + ws.WriteHeader(ws.Header{ + Fin: true, + Length: N, + OpCode: ws.OpBinary, + }) + + io.CopyN(conn, rand.Reader, N) + +Or: + + header, err := ws.ReadHeader(conn) + if err != nil { + // handle err + } + + io.CopyN(ioutil.Discard, conn, header.Length) + +For more info see the documentation. +*/ +package ws diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/errors.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/errors.go new file mode 100644 index 000000000000..48fce3b72c12 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/errors.go @@ -0,0 +1,54 @@ +package ws + +// RejectOption represents an option used to control the way connection is +// rejected. +type RejectOption func(*rejectConnectionError) + +// RejectionReason returns an option that makes connection to be rejected with +// given reason. +func RejectionReason(reason string) RejectOption { + return func(err *rejectConnectionError) { + err.reason = reason + } +} + +// RejectionStatus returns an option that makes connection to be rejected with +// given HTTP status code. +func RejectionStatus(code int) RejectOption { + return func(err *rejectConnectionError) { + err.code = code + } +} + +// RejectionHeader returns an option that makes connection to be rejected with +// given HTTP headers. +func RejectionHeader(h HandshakeHeader) RejectOption { + return func(err *rejectConnectionError) { + err.header = h + } +} + +// RejectConnectionError constructs an error that could be used to control the way +// handshake is rejected by Upgrader. +func RejectConnectionError(options ...RejectOption) error { + err := new(rejectConnectionError) + for _, opt := range options { + opt(err) + } + return err +} + +// rejectConnectionError represents a rejection of upgrade error. +// +// It can be returned by Upgrader's On* hooks to control the way WebSocket +// handshake is rejected. +type rejectConnectionError struct { + reason string + code int + header HandshakeHeader +} + +// Error implements error interface. +func (r *rejectConnectionError) Error() string { + return r.reason +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/frame.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/frame.go new file mode 100644 index 000000000000..f157ee3e9ff6 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/frame.go @@ -0,0 +1,389 @@ +package ws + +import ( + "bytes" + "encoding/binary" + "math/rand" +) + +// Constants defined by specification. +const ( + // All control frames MUST have a payload length of 125 bytes or less and MUST NOT be fragmented. + MaxControlFramePayloadSize = 125 +) + +// OpCode represents operation code. +type OpCode byte + +// Operation codes defined by specification. +// See https://tools.ietf.org/html/rfc6455#section-5.2 +const ( + OpContinuation OpCode = 0x0 + OpText OpCode = 0x1 + OpBinary OpCode = 0x2 + OpClose OpCode = 0x8 + OpPing OpCode = 0x9 + OpPong OpCode = 0xa +) + +// IsControl checks whether the c is control operation code. +// See https://tools.ietf.org/html/rfc6455#section-5.5 +func (c OpCode) IsControl() bool { + // RFC6455: Control frames are identified by opcodes where + // the most significant bit of the opcode is 1. + // + // Note that OpCode is only 4 bit length. + return c&0x8 != 0 +} + +// IsData checks whether the c is data operation code. +// See https://tools.ietf.org/html/rfc6455#section-5.6 +func (c OpCode) IsData() bool { + // RFC6455: Data frames (e.g., non-control frames) are identified by opcodes + // where the most significant bit of the opcode is 0. + // + // Note that OpCode is only 4 bit length. + return c&0x8 == 0 +} + +// IsReserved checks whether the c is reserved operation code. +// See https://tools.ietf.org/html/rfc6455#section-5.2 +func (c OpCode) IsReserved() bool { + // RFC6455: + // %x3-7 are reserved for further non-control frames + // %xB-F are reserved for further control frames + return (0x3 <= c && c <= 0x7) || (0xb <= c && c <= 0xf) +} + +// StatusCode represents the encoded reason for closure of websocket connection. +// +// There are few helper methods on StatusCode that helps to define a range in +// which given code is lay in. accordingly to ranges defined in specification. +// +// See https://tools.ietf.org/html/rfc6455#section-7.4 +type StatusCode uint16 + +// StatusCodeRange describes range of StatusCode values. +type StatusCodeRange struct { + Min, Max StatusCode +} + +// Status code ranges defined by specification. +// See https://tools.ietf.org/html/rfc6455#section-7.4.2 +var ( + StatusRangeNotInUse = StatusCodeRange{0, 999} + StatusRangeProtocol = StatusCodeRange{1000, 2999} + StatusRangeApplication = StatusCodeRange{3000, 3999} + StatusRangePrivate = StatusCodeRange{4000, 4999} +) + +// Status codes defined by specification. +// See https://tools.ietf.org/html/rfc6455#section-7.4.1 +const ( + StatusNormalClosure StatusCode = 1000 + StatusGoingAway StatusCode = 1001 + StatusProtocolError StatusCode = 1002 + StatusUnsupportedData StatusCode = 1003 + StatusNoMeaningYet StatusCode = 1004 + StatusInvalidFramePayloadData StatusCode = 1007 + StatusPolicyViolation StatusCode = 1008 + StatusMessageTooBig StatusCode = 1009 + StatusMandatoryExt StatusCode = 1010 + StatusInternalServerError StatusCode = 1011 + StatusTLSHandshake StatusCode = 1015 + + // StatusAbnormalClosure is a special code designated for use in + // applications. + StatusAbnormalClosure StatusCode = 1006 + + // StatusNoStatusRcvd is a special code designated for use in applications. + StatusNoStatusRcvd StatusCode = 1005 +) + +// In reports whether the code is defined in given range. +func (s StatusCode) In(r StatusCodeRange) bool { + return r.Min <= s && s <= r.Max +} + +// Empty reports whether the code is empty. +// Empty code has no any meaning neither app level codes nor other. +// This method is useful just to check that code is golang default value 0. +func (s StatusCode) Empty() bool { + return s == 0 +} + +// IsNotUsed reports whether the code is predefined in not used range. +func (s StatusCode) IsNotUsed() bool { + return s.In(StatusRangeNotInUse) +} + +// IsApplicationSpec reports whether the code should be defined by +// application, framework or libraries specification. +func (s StatusCode) IsApplicationSpec() bool { + return s.In(StatusRangeApplication) +} + +// IsPrivateSpec reports whether the code should be defined privately. +func (s StatusCode) IsPrivateSpec() bool { + return s.In(StatusRangePrivate) +} + +// IsProtocolSpec reports whether the code should be defined by protocol specification. +func (s StatusCode) IsProtocolSpec() bool { + return s.In(StatusRangeProtocol) +} + +// IsProtocolDefined reports whether the code is already defined by protocol specification. +func (s StatusCode) IsProtocolDefined() bool { + switch s { + case StatusNormalClosure, + StatusGoingAway, + StatusProtocolError, + StatusUnsupportedData, + StatusInvalidFramePayloadData, + StatusPolicyViolation, + StatusMessageTooBig, + StatusMandatoryExt, + StatusInternalServerError, + StatusNoStatusRcvd, + StatusAbnormalClosure, + StatusTLSHandshake: + return true + } + return false +} + +// IsProtocolReserved reports whether the code is defined by protocol specification +// to be reserved only for application usage purpose. +func (s StatusCode) IsProtocolReserved() bool { + switch s { + // [RFC6455]: {1005,1006,1015} is a reserved value and MUST NOT be set as a status code in a + // Close control frame by an endpoint. + case StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake: + return true + default: + return false + } +} + +// Compiled control frames for common use cases. +// For construct-serialize optimizations. +var ( + CompiledPing = MustCompileFrame(NewPingFrame(nil)) + CompiledPong = MustCompileFrame(NewPongFrame(nil)) + CompiledClose = MustCompileFrame(NewCloseFrame(nil)) + + CompiledCloseNormalClosure = MustCompileFrame(closeFrameNormalClosure) + CompiledCloseGoingAway = MustCompileFrame(closeFrameGoingAway) + CompiledCloseProtocolError = MustCompileFrame(closeFrameProtocolError) + CompiledCloseUnsupportedData = MustCompileFrame(closeFrameUnsupportedData) + CompiledCloseNoMeaningYet = MustCompileFrame(closeFrameNoMeaningYet) + CompiledCloseInvalidFramePayloadData = MustCompileFrame(closeFrameInvalidFramePayloadData) + CompiledClosePolicyViolation = MustCompileFrame(closeFramePolicyViolation) + CompiledCloseMessageTooBig = MustCompileFrame(closeFrameMessageTooBig) + CompiledCloseMandatoryExt = MustCompileFrame(closeFrameMandatoryExt) + CompiledCloseInternalServerError = MustCompileFrame(closeFrameInternalServerError) + CompiledCloseTLSHandshake = MustCompileFrame(closeFrameTLSHandshake) +) + +// Header represents websocket frame header. +// See https://tools.ietf.org/html/rfc6455#section-5.2 +type Header struct { + Fin bool + Rsv byte + OpCode OpCode + Masked bool + Mask [4]byte + Length int64 +} + +// Rsv1 reports whether the header has first rsv bit set. +func (h Header) Rsv1() bool { return h.Rsv&bit5 != 0 } + +// Rsv2 reports whether the header has second rsv bit set. +func (h Header) Rsv2() bool { return h.Rsv&bit6 != 0 } + +// Rsv3 reports whether the header has third rsv bit set. +func (h Header) Rsv3() bool { return h.Rsv&bit7 != 0 } + +// Frame represents websocket frame. +// See https://tools.ietf.org/html/rfc6455#section-5.2 +type Frame struct { + Header Header + Payload []byte +} + +// NewFrame creates frame with given operation code, +// flag of completeness and payload bytes. +func NewFrame(op OpCode, fin bool, p []byte) Frame { + return Frame{ + Header: Header{ + Fin: fin, + OpCode: op, + Length: int64(len(p)), + }, + Payload: p, + } +} + +// NewTextFrame creates text frame with p as payload. +// Note that p is not copied. +func NewTextFrame(p []byte) Frame { + return NewFrame(OpText, true, p) +} + +// NewBinaryFrame creates binary frame with p as payload. +// Note that p is not copied. +func NewBinaryFrame(p []byte) Frame { + return NewFrame(OpBinary, true, p) +} + +// NewPingFrame creates ping frame with p as payload. +// Note that p is not copied. +// Note that p must have length of MaxControlFramePayloadSize bytes or less due +// to RFC. +func NewPingFrame(p []byte) Frame { + return NewFrame(OpPing, true, p) +} + +// NewPongFrame creates pong frame with p as payload. +// Note that p is not copied. +// Note that p must have length of MaxControlFramePayloadSize bytes or less due +// to RFC. +func NewPongFrame(p []byte) Frame { + return NewFrame(OpPong, true, p) +} + +// NewCloseFrame creates close frame with given close body. +// Note that p is not copied. +// Note that p must have length of MaxControlFramePayloadSize bytes or less due +// to RFC. +func NewCloseFrame(p []byte) Frame { + return NewFrame(OpClose, true, p) +} + +// NewCloseFrameBody encodes a closure code and a reason into a binary +// representation. +// +// It returns slice which is at most MaxControlFramePayloadSize bytes length. +// If the reason is too big it will be cropped to fit the limit defined by the +// spec. +// +// See https://tools.ietf.org/html/rfc6455#section-5.5 +func NewCloseFrameBody(code StatusCode, reason string) []byte { + n := min(2+len(reason), MaxControlFramePayloadSize) + p := make([]byte, n) + + crop := min(MaxControlFramePayloadSize-2, len(reason)) + PutCloseFrameBody(p, code, reason[:crop]) + + return p +} + +// PutCloseFrameBody encodes code and reason into buf. +// +// It will panic if the buffer is too small to accommodate a code or a reason. +// +// PutCloseFrameBody does not check buffer to be RFC compliant, but note that +// by RFC it must be at most MaxControlFramePayloadSize. +func PutCloseFrameBody(p []byte, code StatusCode, reason string) { + _ = p[1+len(reason)] + binary.BigEndian.PutUint16(p, uint16(code)) + copy(p[2:], reason) +} + +// MaskFrame masks frame and returns frame with masked payload and Mask header's field set. +// Note that it copies f payload to prevent collisions. +// For less allocations you could use MaskFrameInPlace or construct frame manually. +func MaskFrame(f Frame) Frame { + return MaskFrameWith(f, NewMask()) +} + +// MaskFrameWith masks frame with given mask and returns frame +// with masked payload and Mask header's field set. +// Note that it copies f payload to prevent collisions. +// For less allocations you could use MaskFrameInPlaceWith or construct frame manually. +func MaskFrameWith(f Frame, mask [4]byte) Frame { + // TODO(gobwas): check CopyCipher ws copy() Cipher(). + p := make([]byte, len(f.Payload)) + copy(p, f.Payload) + f.Payload = p + return MaskFrameInPlaceWith(f, mask) +} + +// MaskFrameInPlace masks frame and returns frame with masked payload and Mask +// header's field set. +// Note that it applies xor cipher to f.Payload without copying, that is, it +// modifies f.Payload inplace. +func MaskFrameInPlace(f Frame) Frame { + return MaskFrameInPlaceWith(f, NewMask()) +} + +// MaskFrameInPlaceWith masks frame with given mask and returns frame +// with masked payload and Mask header's field set. +// Note that it applies xor cipher to f.Payload without copying, that is, it +// modifies f.Payload inplace. +func MaskFrameInPlaceWith(f Frame, m [4]byte) Frame { + f.Header.Masked = true + f.Header.Mask = m + Cipher(f.Payload, m, 0) + return f +} + +// NewMask creates new random mask. +func NewMask() (ret [4]byte) { + binary.BigEndian.PutUint32(ret[:], rand.Uint32()) + return +} + +// CompileFrame returns byte representation of given frame. +// In terms of memory consumption it is useful to precompile static frames +// which are often used. +func CompileFrame(f Frame) (bts []byte, err error) { + buf := bytes.NewBuffer(make([]byte, 0, 16)) + err = WriteFrame(buf, f) + bts = buf.Bytes() + return +} + +// MustCompileFrame is like CompileFrame but panics if frame can not be +// encoded. +func MustCompileFrame(f Frame) []byte { + bts, err := CompileFrame(f) + if err != nil { + panic(err) + } + return bts +} + +// Rsv creates rsv byte representation. +func Rsv(r1, r2, r3 bool) (rsv byte) { + if r1 { + rsv |= bit5 + } + if r2 { + rsv |= bit6 + } + if r3 { + rsv |= bit7 + } + return rsv +} + +func makeCloseFrame(code StatusCode) Frame { + return NewCloseFrame(NewCloseFrameBody(code, "")) +} + +var ( + closeFrameNormalClosure = makeCloseFrame(StatusNormalClosure) + closeFrameGoingAway = makeCloseFrame(StatusGoingAway) + closeFrameProtocolError = makeCloseFrame(StatusProtocolError) + closeFrameUnsupportedData = makeCloseFrame(StatusUnsupportedData) + closeFrameNoMeaningYet = makeCloseFrame(StatusNoMeaningYet) + closeFrameInvalidFramePayloadData = makeCloseFrame(StatusInvalidFramePayloadData) + closeFramePolicyViolation = makeCloseFrame(StatusPolicyViolation) + closeFrameMessageTooBig = makeCloseFrame(StatusMessageTooBig) + closeFrameMandatoryExt = makeCloseFrame(StatusMandatoryExt) + closeFrameInternalServerError = makeCloseFrame(StatusInternalServerError) + closeFrameTLSHandshake = makeCloseFrame(StatusTLSHandshake) +) diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/http.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/http.go new file mode 100644 index 000000000000..e18df441b47e --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/http.go @@ -0,0 +1,468 @@ +package ws + +import ( + "bufio" + "bytes" + "io" + "net/http" + "net/textproto" + "net/url" + "strconv" + + "github.com/gobwas/httphead" +) + +const ( + crlf = "\r\n" + colonAndSpace = ": " + commaAndSpace = ", " +) + +const ( + textHeadUpgrade = "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n" +) + +var ( + textHeadBadRequest = statusText(http.StatusBadRequest) + textHeadInternalServerError = statusText(http.StatusInternalServerError) + textHeadUpgradeRequired = statusText(http.StatusUpgradeRequired) + + textTailErrHandshakeBadProtocol = errorText(ErrHandshakeBadProtocol) + textTailErrHandshakeBadMethod = errorText(ErrHandshakeBadMethod) + textTailErrHandshakeBadHost = errorText(ErrHandshakeBadHost) + textTailErrHandshakeBadUpgrade = errorText(ErrHandshakeBadUpgrade) + textTailErrHandshakeBadConnection = errorText(ErrHandshakeBadConnection) + textTailErrHandshakeBadSecAccept = errorText(ErrHandshakeBadSecAccept) + textTailErrHandshakeBadSecKey = errorText(ErrHandshakeBadSecKey) + textTailErrHandshakeBadSecVersion = errorText(ErrHandshakeBadSecVersion) + textTailErrUpgradeRequired = errorText(ErrHandshakeUpgradeRequired) +) + +var ( + headerHost = "Host" + headerUpgrade = "Upgrade" + headerConnection = "Connection" + headerSecVersion = "Sec-WebSocket-Version" + headerSecProtocol = "Sec-WebSocket-Protocol" + headerSecExtensions = "Sec-WebSocket-Extensions" + headerSecKey = "Sec-WebSocket-Key" + headerSecAccept = "Sec-WebSocket-Accept" + + headerHostCanonical = textproto.CanonicalMIMEHeaderKey(headerHost) + headerUpgradeCanonical = textproto.CanonicalMIMEHeaderKey(headerUpgrade) + headerConnectionCanonical = textproto.CanonicalMIMEHeaderKey(headerConnection) + headerSecVersionCanonical = textproto.CanonicalMIMEHeaderKey(headerSecVersion) + headerSecProtocolCanonical = textproto.CanonicalMIMEHeaderKey(headerSecProtocol) + headerSecExtensionsCanonical = textproto.CanonicalMIMEHeaderKey(headerSecExtensions) + headerSecKeyCanonical = textproto.CanonicalMIMEHeaderKey(headerSecKey) + headerSecAcceptCanonical = textproto.CanonicalMIMEHeaderKey(headerSecAccept) +) + +var ( + specHeaderValueUpgrade = []byte("websocket") + specHeaderValueConnection = []byte("Upgrade") + specHeaderValueConnectionLower = []byte("upgrade") + specHeaderValueSecVersion = []byte("13") +) + +var ( + httpVersion1_0 = []byte("HTTP/1.0") + httpVersion1_1 = []byte("HTTP/1.1") + httpVersionPrefix = []byte("HTTP/") +) + +type httpRequestLine struct { + method, uri []byte + major, minor int +} + +type httpResponseLine struct { + major, minor int + status int + reason []byte +} + +// httpParseRequestLine parses http request line like "GET / HTTP/1.0". +func httpParseRequestLine(line []byte) (req httpRequestLine, err error) { + var proto []byte + req.method, req.uri, proto = bsplit3(line, ' ') + + var ok bool + req.major, req.minor, ok = httpParseVersion(proto) + if !ok { + err = ErrMalformedRequest + return + } + + return +} + +func httpParseResponseLine(line []byte) (resp httpResponseLine, err error) { + var ( + proto []byte + status []byte + ) + proto, status, resp.reason = bsplit3(line, ' ') + + var ok bool + resp.major, resp.minor, ok = httpParseVersion(proto) + if !ok { + return resp, ErrMalformedResponse + } + + var convErr error + resp.status, convErr = asciiToInt(status) + if convErr != nil { + return resp, ErrMalformedResponse + } + + return resp, nil +} + +// httpParseVersion parses major and minor version of HTTP protocol. It returns +// parsed values and true if parse is ok. +func httpParseVersion(bts []byte) (major, minor int, ok bool) { + switch { + case bytes.Equal(bts, httpVersion1_0): + return 1, 0, true + case bytes.Equal(bts, httpVersion1_1): + return 1, 1, true + case len(bts) < 8: + return + case !bytes.Equal(bts[:5], httpVersionPrefix): + return + } + + bts = bts[5:] + + dot := bytes.IndexByte(bts, '.') + if dot == -1 { + return + } + var err error + major, err = asciiToInt(bts[:dot]) + if err != nil { + return + } + minor, err = asciiToInt(bts[dot+1:]) + if err != nil { + return + } + + return major, minor, true +} + +// httpParseHeaderLine parses HTTP header as key-value pair. It returns parsed +// values and true if parse is ok. +func httpParseHeaderLine(line []byte) (k, v []byte, ok bool) { + colon := bytes.IndexByte(line, ':') + if colon == -1 { + return + } + + k = btrim(line[:colon]) + // TODO(gobwas): maybe use just lower here? + canonicalizeHeaderKey(k) + + v = btrim(line[colon+1:]) + + return k, v, true +} + +// httpGetHeader is the same as textproto.MIMEHeader.Get, except the thing, +// that key is already canonical. This helps to increase performance. +func httpGetHeader(h http.Header, key string) string { + if h == nil { + return "" + } + v := h[key] + if len(v) == 0 { + return "" + } + return v[0] +} + +// The request MAY include a header field with the name +// |Sec-WebSocket-Protocol|. If present, this value indicates one or more +// comma-separated subprotocol the client wishes to speak, ordered by +// preference. The elements that comprise this value MUST be non-empty strings +// with characters in the range U+0021 to U+007E not including separator +// characters as defined in [RFC2616] and MUST all be unique strings. The ABNF +// for the value of this header field is 1#token, where the definitions of +// constructs and rules are as given in [RFC2616]. +func strSelectProtocol(h string, check func(string) bool) (ret string, ok bool) { + ok = httphead.ScanTokens(strToBytes(h), func(v []byte) bool { + if check(btsToString(v)) { + ret = string(v) + return false + } + return true + }) + return +} +func btsSelectProtocol(h []byte, check func([]byte) bool) (ret string, ok bool) { + var selected []byte + ok = httphead.ScanTokens(h, func(v []byte) bool { + if check(v) { + selected = v + return false + } + return true + }) + if ok && selected != nil { + return string(selected), true + } + return +} + +func strSelectExtensions(h string, selected []httphead.Option, check func(httphead.Option) bool) ([]httphead.Option, bool) { + return btsSelectExtensions(strToBytes(h), selected, check) +} + +func btsSelectExtensions(h []byte, selected []httphead.Option, check func(httphead.Option) bool) ([]httphead.Option, bool) { + s := httphead.OptionSelector{ + Flags: httphead.SelectUnique | httphead.SelectCopy, + Check: check, + } + return s.Select(h, selected) +} + +func httpWriteHeader(bw *bufio.Writer, key, value string) { + httpWriteHeaderKey(bw, key) + bw.WriteString(value) + bw.WriteString(crlf) +} + +func httpWriteHeaderBts(bw *bufio.Writer, key string, value []byte) { + httpWriteHeaderKey(bw, key) + bw.Write(value) + bw.WriteString(crlf) +} + +func httpWriteHeaderKey(bw *bufio.Writer, key string) { + bw.WriteString(key) + bw.WriteString(colonAndSpace) +} + +func httpWriteUpgradeRequest( + bw *bufio.Writer, + u *url.URL, + nonce []byte, + protocols []string, + extensions []httphead.Option, + header HandshakeHeader, +) { + bw.WriteString("GET ") + bw.WriteString(u.RequestURI()) + bw.WriteString(" HTTP/1.1\r\n") + + httpWriteHeader(bw, headerHost, u.Host) + + httpWriteHeaderBts(bw, headerUpgrade, specHeaderValueUpgrade) + httpWriteHeaderBts(bw, headerConnection, specHeaderValueConnection) + httpWriteHeaderBts(bw, headerSecVersion, specHeaderValueSecVersion) + + // NOTE: write nonce bytes as a string to prevent heap allocation – + // WriteString() copy given string into its inner buffer, unlike Write() + // which may write p directly to the underlying io.Writer – which in turn + // will lead to p escape. + httpWriteHeader(bw, headerSecKey, btsToString(nonce)) + + if len(protocols) > 0 { + httpWriteHeaderKey(bw, headerSecProtocol) + for i, p := range protocols { + if i > 0 { + bw.WriteString(commaAndSpace) + } + bw.WriteString(p) + } + bw.WriteString(crlf) + } + + if len(extensions) > 0 { + httpWriteHeaderKey(bw, headerSecExtensions) + httphead.WriteOptions(bw, extensions) + bw.WriteString(crlf) + } + + if header != nil { + header.WriteTo(bw) + } + + bw.WriteString(crlf) +} + +func httpWriteResponseUpgrade(bw *bufio.Writer, nonce []byte, hs Handshake, header HandshakeHeaderFunc) { + bw.WriteString(textHeadUpgrade) + + httpWriteHeaderKey(bw, headerSecAccept) + writeAccept(bw, nonce) + bw.WriteString(crlf) + + if hs.Protocol != "" { + httpWriteHeader(bw, headerSecProtocol, hs.Protocol) + } + if len(hs.Extensions) > 0 { + httpWriteHeaderKey(bw, headerSecExtensions) + httphead.WriteOptions(bw, hs.Extensions) + bw.WriteString(crlf) + } + if header != nil { + header(bw) + } + + bw.WriteString(crlf) +} + +func httpWriteResponseError(bw *bufio.Writer, err error, code int, header HandshakeHeaderFunc) { + switch code { + case http.StatusBadRequest: + bw.WriteString(textHeadBadRequest) + case http.StatusInternalServerError: + bw.WriteString(textHeadInternalServerError) + case http.StatusUpgradeRequired: + bw.WriteString(textHeadUpgradeRequired) + default: + writeStatusText(bw, code) + } + + // Write custom headers. + if header != nil { + header(bw) + } + + switch err { + case ErrHandshakeBadProtocol: + bw.WriteString(textTailErrHandshakeBadProtocol) + case ErrHandshakeBadMethod: + bw.WriteString(textTailErrHandshakeBadMethod) + case ErrHandshakeBadHost: + bw.WriteString(textTailErrHandshakeBadHost) + case ErrHandshakeBadUpgrade: + bw.WriteString(textTailErrHandshakeBadUpgrade) + case ErrHandshakeBadConnection: + bw.WriteString(textTailErrHandshakeBadConnection) + case ErrHandshakeBadSecAccept: + bw.WriteString(textTailErrHandshakeBadSecAccept) + case ErrHandshakeBadSecKey: + bw.WriteString(textTailErrHandshakeBadSecKey) + case ErrHandshakeBadSecVersion: + bw.WriteString(textTailErrHandshakeBadSecVersion) + case ErrHandshakeUpgradeRequired: + bw.WriteString(textTailErrUpgradeRequired) + case nil: + bw.WriteString(crlf) + default: + writeErrorText(bw, err) + } +} + +func writeStatusText(bw *bufio.Writer, code int) { + bw.WriteString("HTTP/1.1 ") + bw.WriteString(strconv.Itoa(code)) + bw.WriteByte(' ') + bw.WriteString(http.StatusText(code)) + bw.WriteString(crlf) + bw.WriteString("Content-Type: text/plain; charset=utf-8") + bw.WriteString(crlf) +} + +func writeErrorText(bw *bufio.Writer, err error) { + body := err.Error() + bw.WriteString("Content-Length: ") + bw.WriteString(strconv.Itoa(len(body))) + bw.WriteString(crlf) + bw.WriteString(crlf) + bw.WriteString(body) +} + +// httpError is like the http.Error with WebSocket context exception. +func httpError(w http.ResponseWriter, body string, code int) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("Content-Length", strconv.Itoa(len(body))) + w.WriteHeader(code) + w.Write([]byte(body)) +} + +// statusText is a non-performant status text generator. +// NOTE: Used only to generate constants. +func statusText(code int) string { + var buf bytes.Buffer + bw := bufio.NewWriter(&buf) + writeStatusText(bw, code) + bw.Flush() + return buf.String() +} + +// errorText is a non-performant error text generator. +// NOTE: Used only to generate constants. +func errorText(err error) string { + var buf bytes.Buffer + bw := bufio.NewWriter(&buf) + writeErrorText(bw, err) + bw.Flush() + return buf.String() +} + +// HandshakeHeader is the interface that writes both upgrade request or +// response headers into a given io.Writer. +type HandshakeHeader interface { + io.WriterTo +} + +// HandshakeHeaderString is an adapter to allow the use of headers represented +// by ordinary string as HandshakeHeader. +type HandshakeHeaderString string + +// WriteTo implements HandshakeHeader (and io.WriterTo) interface. +func (s HandshakeHeaderString) WriteTo(w io.Writer) (int64, error) { + n, err := io.WriteString(w, string(s)) + return int64(n), err +} + +// HandshakeHeaderBytes is an adapter to allow the use of headers represented +// by ordinary slice of bytes as HandshakeHeader. +type HandshakeHeaderBytes []byte + +// WriteTo implements HandshakeHeader (and io.WriterTo) interface. +func (b HandshakeHeaderBytes) WriteTo(w io.Writer) (int64, error) { + n, err := w.Write(b) + return int64(n), err +} + +// HandshakeHeaderFunc is an adapter to allow the use of headers represented by +// ordinary function as HandshakeHeader. +type HandshakeHeaderFunc func(io.Writer) (int64, error) + +// WriteTo implements HandshakeHeader (and io.WriterTo) interface. +func (f HandshakeHeaderFunc) WriteTo(w io.Writer) (int64, error) { + return f(w) +} + +// HandshakeHeaderHTTP is an adapter to allow the use of http.Header as +// HandshakeHeader. +type HandshakeHeaderHTTP http.Header + +// WriteTo implements HandshakeHeader (and io.WriterTo) interface. +func (h HandshakeHeaderHTTP) WriteTo(w io.Writer) (int64, error) { + wr := writer{w: w} + err := http.Header(h).Write(&wr) + return wr.n, err +} + +type writer struct { + n int64 + w io.Writer +} + +func (w *writer) WriteString(s string) (int, error) { + n, err := io.WriteString(w.w, s) + w.n += int64(n) + return n, err +} + +func (w *writer) Write(p []byte) (int, error) { + n, err := w.w.Write(p) + w.n += int64(n) + return n, err +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/nonce.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/nonce.go new file mode 100644 index 000000000000..e694da7c3084 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/nonce.go @@ -0,0 +1,80 @@ +package ws + +import ( + "bufio" + "bytes" + "crypto/sha1" + "encoding/base64" + "fmt" + "math/rand" +) + +const ( + // RFC6455: The value of this header field MUST be a nonce consisting of a + // randomly selected 16-byte value that has been base64-encoded (see + // Section 4 of [RFC4648]). The nonce MUST be selected randomly for each + // connection. + nonceKeySize = 16 + nonceSize = 24 // base64.StdEncoding.EncodedLen(nonceKeySize) + + // RFC6455: The value of this header field is constructed by concatenating + // /key/, defined above in step 4 in Section 4.2.2, with the string + // "258EAFA5- E914-47DA-95CA-C5AB0DC85B11", taking the SHA-1 hash of this + // concatenated value to obtain a 20-byte value and base64- encoding (see + // Section 4 of [RFC4648]) this 20-byte hash. + acceptSize = 28 // base64.StdEncoding.EncodedLen(sha1.Size) +) + +// initNonce fills given slice with random base64-encoded nonce bytes. +func initNonce(dst []byte) { + // NOTE: bts does not escape. + bts := make([]byte, nonceKeySize) + if _, err := rand.Read(bts); err != nil { + panic(fmt.Sprintf("rand read error: %s", err)) + } + base64.StdEncoding.Encode(dst, bts) +} + +// checkAcceptFromNonce reports whether given accept bytes are valid for given +// nonce bytes. +func checkAcceptFromNonce(accept, nonce []byte) bool { + if len(accept) != acceptSize { + return false + } + // NOTE: expect does not escape. + expect := make([]byte, acceptSize) + initAcceptFromNonce(expect, nonce) + return bytes.Equal(expect, accept) +} + +// initAcceptFromNonce fills given slice with accept bytes generated from given +// nonce bytes. Given buffer should be exactly acceptSize bytes. +func initAcceptFromNonce(accept, nonce []byte) { + const magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + + if len(accept) != acceptSize { + panic("accept buffer is invalid") + } + if len(nonce) != nonceSize { + panic("nonce is invalid") + } + + p := make([]byte, nonceSize+len(magic)) + copy(p[:nonceSize], nonce) + copy(p[nonceSize:], magic) + + sum := sha1.Sum(p) + base64.StdEncoding.Encode(accept, sum[:]) + + return +} + +func writeAccept(bw *bufio.Writer, nonce []byte) (int, error) { + accept := make([]byte, acceptSize) + initAcceptFromNonce(accept, nonce) + // NOTE: write accept bytes as a string to prevent heap allocation – + // WriteString() copy given string into its inner buffer, unlike Write() + // which may write p directly to the underlying io.Writer – which in turn + // will lead to p escape. + return bw.WriteString(btsToString(accept)) +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/read.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/read.go new file mode 100644 index 000000000000..bc653e4690f4 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/read.go @@ -0,0 +1,147 @@ +package ws + +import ( + "encoding/binary" + "fmt" + "io" +) + +// Errors used by frame reader. +var ( + ErrHeaderLengthMSB = fmt.Errorf("header error: the most significant bit must be 0") + ErrHeaderLengthUnexpected = fmt.Errorf("header error: unexpected payload length bits") +) + +// ReadHeader reads a frame header from r. +func ReadHeader(r io.Reader) (h Header, err error) { + // Make slice of bytes with capacity 12 that could hold any header. + // + // The maximum header size is 14, but due to the 2 hop reads, + // after first hop that reads first 2 constant bytes, we could reuse 2 bytes. + // So 14 - 2 = 12. + bts := make([]byte, 2, MaxHeaderSize-2) + + // Prepare to hold first 2 bytes to choose size of next read. + _, err = io.ReadFull(r, bts) + if err != nil { + return + } + + h.Fin = bts[0]&bit0 != 0 + h.Rsv = (bts[0] & 0x70) >> 4 + h.OpCode = OpCode(bts[0] & 0x0f) + + var extra int + + if bts[1]&bit0 != 0 { + h.Masked = true + extra += 4 + } + + length := bts[1] & 0x7f + switch { + case length < 126: + h.Length = int64(length) + + case length == 126: + extra += 2 + + case length == 127: + extra += 8 + + default: + err = ErrHeaderLengthUnexpected + return + } + + if extra == 0 { + return + } + + // Increase len of bts to extra bytes need to read. + // Overwrite first 2 bytes that was read before. + bts = bts[:extra] + _, err = io.ReadFull(r, bts) + if err != nil { + return + } + + switch { + case length == 126: + h.Length = int64(binary.BigEndian.Uint16(bts[:2])) + bts = bts[2:] + + case length == 127: + if bts[0]&0x80 != 0 { + err = ErrHeaderLengthMSB + return + } + h.Length = int64(binary.BigEndian.Uint64(bts[:8])) + bts = bts[8:] + } + + if h.Masked { + copy(h.Mask[:], bts) + } + + return +} + +// ReadFrame reads a frame from r. +// It is not designed for high optimized use case cause it makes allocation +// for frame.Header.Length size inside to read frame payload into. +// +// Note that ReadFrame does not unmask payload. +func ReadFrame(r io.Reader) (f Frame, err error) { + f.Header, err = ReadHeader(r) + if err != nil { + return + } + + if f.Header.Length > 0 { + // int(f.Header.Length) is safe here cause we have + // checked it for overflow above in ReadHeader. + f.Payload = make([]byte, int(f.Header.Length)) + _, err = io.ReadFull(r, f.Payload) + } + + return +} + +// MustReadFrame is like ReadFrame but panics if frame can not be read. +func MustReadFrame(r io.Reader) Frame { + f, err := ReadFrame(r) + if err != nil { + panic(err) + } + return f +} + +// ParseCloseFrameData parses close frame status code and closure reason if any provided. +// If there is no status code in the payload +// the empty status code is returned (code.Empty()) with empty string as a reason. +func ParseCloseFrameData(payload []byte) (code StatusCode, reason string) { + if len(payload) < 2 { + // We returning empty StatusCode here, preventing the situation + // when endpoint really sent code 1005 and we should return ProtocolError on that. + // + // In other words, we ignoring this rule [RFC6455:7.1.5]: + // If this Close control frame contains no status code, _The WebSocket + // Connection Close Code_ is considered to be 1005. + return + } + code = StatusCode(binary.BigEndian.Uint16(payload)) + reason = string(payload[2:]) + return +} + +// ParseCloseFrameDataUnsafe is like ParseCloseFrameData except the thing +// that it does not copies payload bytes into reason, but prepares unsafe cast. +func ParseCloseFrameDataUnsafe(payload []byte) (code StatusCode, reason string) { + if len(payload) < 2 { + return + } + code = StatusCode(binary.BigEndian.Uint16(payload)) + reason = btsToString(payload[2:]) + return +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/server.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/server.go new file mode 100644 index 000000000000..48059aded492 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/server.go @@ -0,0 +1,607 @@ +package ws + +import ( + "bufio" + "bytes" + "fmt" + "io" + "net" + "net/http" + "strings" + "time" + + "github.com/gobwas/httphead" + "github.com/gobwas/pool/pbufio" +) + +// Constants used by ConnUpgrader. +const ( + DefaultServerReadBufferSize = 4096 + DefaultServerWriteBufferSize = 512 +) + +// Errors used by both client and server when preparing WebSocket handshake. +var ( + ErrHandshakeBadProtocol = RejectConnectionError( + RejectionStatus(http.StatusHTTPVersionNotSupported), + RejectionReason(fmt.Sprintf("handshake error: bad HTTP protocol version")), + ) + ErrHandshakeBadMethod = RejectConnectionError( + RejectionStatus(http.StatusMethodNotAllowed), + RejectionReason(fmt.Sprintf("handshake error: bad HTTP request method")), + ) + ErrHandshakeBadHost = RejectConnectionError( + RejectionStatus(http.StatusBadRequest), + RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerHost)), + ) + ErrHandshakeBadUpgrade = RejectConnectionError( + RejectionStatus(http.StatusBadRequest), + RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerUpgrade)), + ) + ErrHandshakeBadConnection = RejectConnectionError( + RejectionStatus(http.StatusBadRequest), + RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerConnection)), + ) + ErrHandshakeBadSecAccept = RejectConnectionError( + RejectionStatus(http.StatusBadRequest), + RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecAccept)), + ) + ErrHandshakeBadSecKey = RejectConnectionError( + RejectionStatus(http.StatusBadRequest), + RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecKey)), + ) + ErrHandshakeBadSecVersion = RejectConnectionError( + RejectionStatus(http.StatusBadRequest), + RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecVersion)), + ) +) + +// ErrMalformedResponse is returned by Dialer to indicate that server response +// can not be parsed. +var ErrMalformedResponse = fmt.Errorf("malformed HTTP response") + +// ErrMalformedRequest is returned when HTTP request can not be parsed. +var ErrMalformedRequest = RejectConnectionError( + RejectionStatus(http.StatusBadRequest), + RejectionReason("malformed HTTP request"), +) + +// ErrHandshakeUpgradeRequired is returned by Upgrader to indicate that +// connection is rejected because given WebSocket version is malformed. +// +// According to RFC6455: +// If this version does not match a version understood by the server, the +// server MUST abort the WebSocket handshake described in this section and +// instead send an appropriate HTTP error code (such as 426 Upgrade Required) +// and a |Sec-WebSocket-Version| header field indicating the version(s) the +// server is capable of understanding. +var ErrHandshakeUpgradeRequired = RejectConnectionError( + RejectionStatus(http.StatusUpgradeRequired), + RejectionHeader(HandshakeHeaderString(headerSecVersion+": 13\r\n")), + RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecVersion)), +) + +// ErrNotHijacker is an error returned when http.ResponseWriter does not +// implement http.Hijacker interface. +var ErrNotHijacker = RejectConnectionError( + RejectionStatus(http.StatusInternalServerError), + RejectionReason("given http.ResponseWriter is not a http.Hijacker"), +) + +// DefaultHTTPUpgrader is an HTTPUpgrader that holds no options and is used by +// UpgradeHTTP function. +var DefaultHTTPUpgrader HTTPUpgrader + +// UpgradeHTTP is like HTTPUpgrader{}.Upgrade(). +func UpgradeHTTP(r *http.Request, w http.ResponseWriter) (net.Conn, *bufio.ReadWriter, Handshake, error) { + return DefaultHTTPUpgrader.Upgrade(r, w) +} + +// DefaultUpgrader is an Upgrader that holds no options and is used by Upgrade +// function. +var DefaultUpgrader Upgrader + +// Upgrade is like Upgrader{}.Upgrade(). +func Upgrade(conn io.ReadWriter) (Handshake, error) { + return DefaultUpgrader.Upgrade(conn) +} + +// HTTPUpgrader contains options for upgrading connection to websocket from +// net/http Handler arguments. +type HTTPUpgrader struct { + // Timeout is the maximum amount of time an Upgrade() will spent while + // writing handshake response. + // + // The default is no timeout. + Timeout time.Duration + + // Header is an optional http.Header mapping that could be used to + // write additional headers to the handshake response. + // + // Note that if present, it will be written in any result of handshake. + Header http.Header + + // Protocol is the select function that is used to select subprotocol from + // list requested by client. If this field is set, then the first matched + // protocol is sent to a client as negotiated. + Protocol func(string) bool + + // Extension is the select function that is used to select extensions from + // list requested by client. If this field is set, then the all matched + // extensions are sent to a client as negotiated. + Extension func(httphead.Option) bool +} + +// Upgrade upgrades http connection to the websocket connection. +// +// It hijacks net.Conn from w and returns received net.Conn and +// bufio.ReadWriter. On successful handshake it returns Handshake struct +// describing handshake info. +func (u HTTPUpgrader) Upgrade(r *http.Request, w http.ResponseWriter) (conn net.Conn, rw *bufio.ReadWriter, hs Handshake, err error) { + // Hijack connection first to get the ability to write rejection errors the + // same way as in Upgrader. + hj, ok := w.(http.Hijacker) + if ok { + conn, rw, err = hj.Hijack() + } else { + err = ErrNotHijacker + } + if err != nil { + httpError(w, err.Error(), http.StatusInternalServerError) + return + } + + // See https://tools.ietf.org/html/rfc6455#section-4.1 + // The method of the request MUST be GET, and the HTTP version MUST be at least 1.1. + var nonce string + if r.Method != http.MethodGet { + err = ErrHandshakeBadMethod + } else if r.ProtoMajor < 1 || (r.ProtoMajor == 1 && r.ProtoMinor < 1) { + err = ErrHandshakeBadProtocol + } else if r.Host == "" { + err = ErrHandshakeBadHost + } else if u := httpGetHeader(r.Header, headerUpgradeCanonical); u != "websocket" && !strings.EqualFold(u, "websocket") { + err = ErrHandshakeBadUpgrade + } else if c := httpGetHeader(r.Header, headerConnectionCanonical); c != "Upgrade" && !strHasToken(c, "upgrade") { + err = ErrHandshakeBadConnection + } else if nonce = httpGetHeader(r.Header, headerSecKeyCanonical); len(nonce) != nonceSize { + err = ErrHandshakeBadSecKey + } else if v := httpGetHeader(r.Header, headerSecVersionCanonical); v != "13" { + // According to RFC6455: + // + // If this version does not match a version understood by the server, + // the server MUST abort the WebSocket handshake described in this + // section and instead send an appropriate HTTP error code (such as 426 + // Upgrade Required) and a |Sec-WebSocket-Version| header field + // indicating the version(s) the server is capable of understanding. + // + // So we branching here cause empty or not present version does not + // meet the ABNF rules of RFC6455: + // + // version = DIGIT | (NZDIGIT DIGIT) | + // ("1" DIGIT DIGIT) | ("2" DIGIT DIGIT) + // ; Limited to 0-255 range, with no leading zeros + // + // That is, if version is really invalid – we sent 426 status, if it + // not present or empty – it is 400. + if v != "" { + err = ErrHandshakeUpgradeRequired + } else { + err = ErrHandshakeBadSecVersion + } + } + if check := u.Protocol; err == nil && check != nil { + ps := r.Header[headerSecProtocolCanonical] + for i := 0; i < len(ps) && err == nil && hs.Protocol == ""; i++ { + var ok bool + hs.Protocol, ok = strSelectProtocol(ps[i], check) + if !ok { + err = ErrMalformedRequest + } + } + } + if check := u.Extension; err == nil && check != nil { + xs := r.Header[headerSecExtensionsCanonical] + for i := 0; i < len(xs) && err == nil; i++ { + var ok bool + hs.Extensions, ok = strSelectExtensions(xs[i], hs.Extensions, check) + if !ok { + err = ErrMalformedRequest + } + } + } + + // Clear deadlines set by server. + conn.SetDeadline(noDeadline) + if t := u.Timeout; t != 0 { + conn.SetWriteDeadline(time.Now().Add(t)) + defer conn.SetWriteDeadline(noDeadline) + } + + var header handshakeHeader + if h := u.Header; h != nil { + header[0] = HandshakeHeaderHTTP(h) + } + if err == nil { + httpWriteResponseUpgrade(rw.Writer, strToBytes(nonce), hs, header.WriteTo) + err = rw.Writer.Flush() + } else { + var code int + if rej, ok := err.(*rejectConnectionError); ok { + code = rej.code + header[1] = rej.header + } + if code == 0 { + code = http.StatusInternalServerError + } + httpWriteResponseError(rw.Writer, err, code, header.WriteTo) + // Do not store Flush() error to not override already existing one. + rw.Writer.Flush() + } + return +} + +// Upgrader contains options for upgrading connection to websocket. +type Upgrader struct { + // ReadBufferSize and WriteBufferSize is an I/O buffer sizes. + // They used to read and write http data while upgrading to WebSocket. + // Allocated buffers are pooled with sync.Pool to avoid extra allocations. + // + // If a size is zero then default value is used. + // + // Usually it is useful to set read buffer size bigger than write buffer + // size because incoming request could contain long header values, such as + // Cookie. Response, in other way, could be big only if user write multiple + // custom headers. Usually response takes less than 256 bytes. + ReadBufferSize, WriteBufferSize int + + // Protocol is a select function that is used to select subprotocol + // from list requested by client. If this field is set, then the first matched + // protocol is sent to a client as negotiated. + // + // The argument is only valid until the callback returns. + Protocol func([]byte) bool + + // ProtocolCustrom allow user to parse Sec-WebSocket-Protocol header manually. + // Note that returned bytes must be valid until Upgrade returns. + // If ProtocolCustom is set, it used instead of Protocol function. + ProtocolCustom func([]byte) (string, bool) + + // Extension is a select function that is used to select extensions + // from list requested by client. If this field is set, then the all matched + // extensions are sent to a client as negotiated. + // + // The argument is only valid until the callback returns. + // + // According to the RFC6455 order of extensions passed by a client is + // significant. That is, returning true from this function means that no + // other extension with the same name should be checked because server + // accepted the most preferable extension right now: + // "Note that the order of extensions is significant. Any interactions between + // multiple extensions MAY be defined in the documents defining the extensions. + // In the absence of such definitions, the interpretation is that the header + // fields listed by the client in its request represent a preference of the + // header fields it wishes to use, with the first options listed being most + // preferable." + Extension func(httphead.Option) bool + + // ExtensionCustorm allow user to parse Sec-WebSocket-Extensions header manually. + // Note that returned options should be valid until Upgrade returns. + // If ExtensionCustom is set, it used instead of Extension function. + ExtensionCustom func([]byte, []httphead.Option) ([]httphead.Option, bool) + + // Header is an optional HandshakeHeader instance that could be used to + // write additional headers to the handshake response. + // + // It used instead of any key-value mappings to avoid allocations in user + // land. + // + // Note that if present, it will be written in any result of handshake. + Header HandshakeHeader + + // OnRequest is a callback that will be called after request line + // successful parsing. + // + // The arguments are only valid until the callback returns. + // + // If returned error is non-nil then connection is rejected and response is + // sent with appropriate HTTP error code and body set to error message. + // + // RejectConnectionError could be used to get more control on response. + OnRequest func(uri []byte) error + + // OnHost is a callback that will be called after "Host" header successful + // parsing. + // + // It is separated from OnHeader callback because the Host header must be + // present in each request since HTTP/1.1. Thus Host header is non-optional + // and required for every WebSocket handshake. + // + // The arguments are only valid until the callback returns. + // + // If returned error is non-nil then connection is rejected and response is + // sent with appropriate HTTP error code and body set to error message. + // + // RejectConnectionError could be used to get more control on response. + OnHost func(host []byte) error + + // OnHeader is a callback that will be called after successful parsing of + // header, that is not used during WebSocket handshake procedure. That is, + // it will be called with non-websocket headers, which could be relevant + // for application-level logic. + // + // The arguments are only valid until the callback returns. + // + // If returned error is non-nil then connection is rejected and response is + // sent with appropriate HTTP error code and body set to error message. + // + // RejectConnectionError could be used to get more control on response. + OnHeader func(key, value []byte) error + + // OnBeforeUpgrade is a callback that will be called before sending + // successful upgrade response. + // + // Setting OnBeforeUpgrade allows user to make final application-level + // checks and decide whether this connection is allowed to successfully + // upgrade to WebSocket. + // + // It must return non-nil either HandshakeHeader or error and never both. + // + // If returned error is non-nil then connection is rejected and response is + // sent with appropriate HTTP error code and body set to error message. + // + // RejectConnectionError could be used to get more control on response. + OnBeforeUpgrade func() (header HandshakeHeader, err error) +} + +// Upgrade zero-copy upgrades connection to WebSocket. It interprets given conn +// as connection with incoming HTTP Upgrade request. +// +// It is a caller responsibility to manage i/o timeouts on conn. +// +// Non-nil error means that request for the WebSocket upgrade is invalid or +// malformed and usually connection should be closed. +// Even when error is non-nil Upgrade will write appropriate response into +// connection in compliance with RFC. +func (u Upgrader) Upgrade(conn io.ReadWriter) (hs Handshake, err error) { + // headerSeen constants helps to report whether or not some header was seen + // during reading request bytes. + const ( + headerSeenHost = 1 << iota + headerSeenUpgrade + headerSeenConnection + headerSeenSecVersion + headerSeenSecKey + + // headerSeenAll is the value that we expect to receive at the end of + // headers read/parse loop. + headerSeenAll = 0 | + headerSeenHost | + headerSeenUpgrade | + headerSeenConnection | + headerSeenSecVersion | + headerSeenSecKey + ) + + // Prepare I/O buffers. + // TODO(gobwas): make it configurable. + br := pbufio.GetReader(conn, + nonZero(u.ReadBufferSize, DefaultServerReadBufferSize), + ) + bw := pbufio.GetWriter(conn, + nonZero(u.WriteBufferSize, DefaultServerWriteBufferSize), + ) + defer func() { + pbufio.PutReader(br) + pbufio.PutWriter(bw) + }() + + // Read HTTP request line like "GET /ws HTTP/1.1". + rl, err := readLine(br) + if err != nil { + return + } + // Parse request line data like HTTP version, uri and method. + req, err := httpParseRequestLine(rl) + if err != nil { + return + } + + // Prepare stack-based handshake header list. + header := handshakeHeader{ + 0: u.Header, + } + + // Parse and check HTTP request. + // As RFC6455 says: + // The client's opening handshake consists of the following parts. If the + // server, while reading the handshake, finds that the client did not + // send a handshake that matches the description below (note that as per + // [RFC2616], the order of the header fields is not important), including + // but not limited to any violations of the ABNF grammar specified for + // the components of the handshake, the server MUST stop processing the + // client's handshake and return an HTTP response with an appropriate + // error code (such as 400 Bad Request). + // + // See https://tools.ietf.org/html/rfc6455#section-4.2.1 + + // An HTTP/1.1 or higher GET request, including a "Request-URI". + // + // Even if RFC says "1.1 or higher" without mentioning the part of the + // version, we apply it only to minor part. + switch { + case req.major != 1 || req.minor < 1: + // Abort processing the whole request because we do not even know how + // to actually parse it. + err = ErrHandshakeBadProtocol + + case btsToString(req.method) != http.MethodGet: + err = ErrHandshakeBadMethod + + default: + if onRequest := u.OnRequest; onRequest != nil { + err = onRequest(req.uri) + } + } + // Start headers read/parse loop. + var ( + // headerSeen reports which header was seen by setting corresponding + // bit on. + headerSeen byte + + nonce = make([]byte, nonceSize) + ) + for err == nil { + line, e := readLine(br) + if e != nil { + return hs, e + } + if len(line) == 0 { + // Blank line, no more lines to read. + break + } + + k, v, ok := httpParseHeaderLine(line) + if !ok { + err = ErrMalformedRequest + break + } + + switch btsToString(k) { + case headerHostCanonical: + headerSeen |= headerSeenHost + if onHost := u.OnHost; onHost != nil { + err = onHost(v) + } + + case headerUpgradeCanonical: + headerSeen |= headerSeenUpgrade + if !bytes.Equal(v, specHeaderValueUpgrade) && !bytes.EqualFold(v, specHeaderValueUpgrade) { + err = ErrHandshakeBadUpgrade + } + + case headerConnectionCanonical: + headerSeen |= headerSeenConnection + if !bytes.Equal(v, specHeaderValueConnection) && !btsHasToken(v, specHeaderValueConnectionLower) { + err = ErrHandshakeBadConnection + } + + case headerSecVersionCanonical: + headerSeen |= headerSeenSecVersion + if !bytes.Equal(v, specHeaderValueSecVersion) { + err = ErrHandshakeUpgradeRequired + } + + case headerSecKeyCanonical: + headerSeen |= headerSeenSecKey + if len(v) != nonceSize { + err = ErrHandshakeBadSecKey + } else { + copy(nonce[:], v) + } + + case headerSecProtocolCanonical: + if custom, check := u.ProtocolCustom, u.Protocol; hs.Protocol == "" && (custom != nil || check != nil) { + var ok bool + if custom != nil { + hs.Protocol, ok = custom(v) + } else { + hs.Protocol, ok = btsSelectProtocol(v, check) + } + if !ok { + err = ErrMalformedRequest + } + } + + case headerSecExtensionsCanonical: + if custom, check := u.ExtensionCustom, u.Extension; custom != nil || check != nil { + var ok bool + if custom != nil { + hs.Extensions, ok = custom(v, hs.Extensions) + } else { + hs.Extensions, ok = btsSelectExtensions(v, hs.Extensions, check) + } + if !ok { + err = ErrMalformedRequest + } + } + + default: + if onHeader := u.OnHeader; onHeader != nil { + err = onHeader(k, v) + } + } + } + switch { + case err == nil && headerSeen != headerSeenAll: + switch { + case headerSeen&headerSeenHost == 0: + // As RFC2616 says: + // A client MUST include a Host header field in all HTTP/1.1 + // request messages. If the requested URI does not include an + // Internet host name for the service being requested, then the + // Host header field MUST be given with an empty value. An + // HTTP/1.1 proxy MUST ensure that any request message it + // forwards does contain an appropriate Host header field that + // identifies the service being requested by the proxy. All + // Internet-based HTTP/1.1 servers MUST respond with a 400 (Bad + // Request) status code to any HTTP/1.1 request message which + // lacks a Host header field. + err = ErrHandshakeBadHost + case headerSeen&headerSeenUpgrade == 0: + err = ErrHandshakeBadUpgrade + case headerSeen&headerSeenConnection == 0: + err = ErrHandshakeBadConnection + case headerSeen&headerSeenSecVersion == 0: + // In case of empty or not present version we do not send 426 status, + // because it does not meet the ABNF rules of RFC6455: + // + // version = DIGIT | (NZDIGIT DIGIT) | + // ("1" DIGIT DIGIT) | ("2" DIGIT DIGIT) + // ; Limited to 0-255 range, with no leading zeros + // + // That is, if version is really invalid – we sent 426 status as above, if it + // not present – it is 400. + err = ErrHandshakeBadSecVersion + case headerSeen&headerSeenSecKey == 0: + err = ErrHandshakeBadSecKey + default: + panic("unknown headers state") + } + + case err == nil && u.OnBeforeUpgrade != nil: + header[1], err = u.OnBeforeUpgrade() + } + if err != nil { + var code int + if rej, ok := err.(*rejectConnectionError); ok { + code = rej.code + header[1] = rej.header + } + if code == 0 { + code = http.StatusInternalServerError + } + httpWriteResponseError(bw, err, code, header.WriteTo) + // Do not store Flush() error to not override already existing one. + bw.Flush() + return + } + + httpWriteResponseUpgrade(bw, nonce, hs, header.WriteTo) + err = bw.Flush() + + return +} + +type handshakeHeader [2]HandshakeHeader + +func (hs handshakeHeader) WriteTo(w io.Writer) (n int64, err error) { + for i := 0; i < len(hs) && err == nil; i++ { + if h := hs[i]; h != nil { + var m int64 + m, err = h.WriteTo(w) + n += m + } + } + return n, err +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/server_test.s b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/server_test.s new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/stub.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/stub.go deleted file mode 100644 index 0d00bc949fb6..000000000000 --- a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/stub.go +++ /dev/null @@ -1,54 +0,0 @@ -// Code generated by depstubber. DO NOT EDIT. -// This is a simple stub for github.com/gobwas/ws, strictly for use in testing. - -// See the LICENSE file for information about the licensing of the original library. -// Source: github.com/gobwas/ws (exports: Dialer; functions: Dial) - -// Package ws is a stub of github.com/gobwas/ws, generated by depstubber. -package ws - -import ( - bufio "bufio" - context "context" - tls "crypto/tls" - io "io" - net "net" - url "net/url" - time "time" -) - -func Dial(_ context.Context, _ string) (net.Conn, *bufio.Reader, Handshake, error) { - return nil, nil, Handshake{}, nil -} - -type Dialer struct { - ReadBufferSize int - WriteBufferSize int - Timeout time.Duration - Protocols []string - Extensions []interface{} - Header HandshakeHeader - OnStatusError func(int, []byte, io.Reader) - OnHeader func([]byte, []byte) error - NetDial func(context.Context, string, string) (net.Conn, error) - TLSClient func(net.Conn, string) net.Conn - TLSConfig *tls.Config - WrapConn func(net.Conn) net.Conn -} - -func (_ Dialer) Dial(_ context.Context, _ string) (net.Conn, *bufio.Reader, Handshake, error) { - return nil, nil, Handshake{}, nil -} - -func (_ Dialer) Upgrade(_ io.ReadWriter, _ *url.URL) (*bufio.Reader, Handshake, error) { - return nil, Handshake{}, nil -} - -type Handshake struct { - Protocol string - Extensions []interface{} -} - -type HandshakeHeader interface { - WriteTo(_ io.Writer) (int64, error) -} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/util.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/util.go new file mode 100644 index 000000000000..67ad906e5d25 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/util.go @@ -0,0 +1,214 @@ +package ws + +import ( + "bufio" + "bytes" + "fmt" + "reflect" + "unsafe" + + "github.com/gobwas/httphead" +) + +// SelectFromSlice creates accept function that could be used as Protocol/Extension +// select during upgrade. +func SelectFromSlice(accept []string) func(string) bool { + if len(accept) > 16 { + mp := make(map[string]struct{}, len(accept)) + for _, p := range accept { + mp[p] = struct{}{} + } + return func(p string) bool { + _, ok := mp[p] + return ok + } + } + return func(p string) bool { + for _, ok := range accept { + if p == ok { + return true + } + } + return false + } +} + +// SelectEqual creates accept function that could be used as Protocol/Extension +// select during upgrade. +func SelectEqual(v string) func(string) bool { + return func(p string) bool { + return v == p + } +} + +func strToBytes(str string) (bts []byte) { + s := (*reflect.StringHeader)(unsafe.Pointer(&str)) + b := (*reflect.SliceHeader)(unsafe.Pointer(&bts)) + b.Data = s.Data + b.Len = s.Len + b.Cap = s.Len + return +} + +func btsToString(bts []byte) (str string) { + return *(*string)(unsafe.Pointer(&bts)) +} + +// asciiToInt converts bytes to int. +func asciiToInt(bts []byte) (ret int, err error) { + // ASCII numbers all start with the high-order bits 0011. + // If you see that, and the next bits are 0-9 (0000 - 1001) you can grab those + // bits and interpret them directly as an integer. + var n int + if n = len(bts); n < 1 { + return 0, fmt.Errorf("converting empty bytes to int") + } + for i := 0; i < n; i++ { + if bts[i]&0xf0 != 0x30 { + return 0, fmt.Errorf("%s is not a numeric character", string(bts[i])) + } + ret += int(bts[i]&0xf) * pow(10, n-i-1) + } + return ret, nil +} + +// pow for integers implementation. +// See Donald Knuth, The Art of Computer Programming, Volume 2, Section 4.6.3 +func pow(a, b int) int { + p := 1 + for b > 0 { + if b&1 != 0 { + p *= a + } + b >>= 1 + a *= a + } + return p +} + +func bsplit3(bts []byte, sep byte) (b1, b2, b3 []byte) { + a := bytes.IndexByte(bts, sep) + b := bytes.IndexByte(bts[a+1:], sep) + if a == -1 || b == -1 { + return bts, nil, nil + } + b += a + 1 + return bts[:a], bts[a+1 : b], bts[b+1:] +} + +func btrim(bts []byte) []byte { + var i, j int + for i = 0; i < len(bts) && (bts[i] == ' ' || bts[i] == '\t'); { + i++ + } + for j = len(bts); j > i && (bts[j-1] == ' ' || bts[j-1] == '\t'); { + j-- + } + return bts[i:j] +} + +func strHasToken(header, token string) (has bool) { + return btsHasToken(strToBytes(header), strToBytes(token)) +} + +func btsHasToken(header, token []byte) (has bool) { + httphead.ScanTokens(header, func(v []byte) bool { + has = bytes.EqualFold(v, token) + return !has + }) + return +} + +const ( + toLower = 'a' - 'A' // for use with OR. + toUpper = ^byte(toLower) // for use with AND. + toLower8 = uint64(toLower) | + uint64(toLower)<<8 | + uint64(toLower)<<16 | + uint64(toLower)<<24 | + uint64(toLower)<<32 | + uint64(toLower)<<40 | + uint64(toLower)<<48 | + uint64(toLower)<<56 +) + +// Algorithm below is like standard textproto/CanonicalMIMEHeaderKey, except +// that it operates with slice of bytes and modifies it inplace without copying. +func canonicalizeHeaderKey(k []byte) { + upper := true + for i, c := range k { + if upper && 'a' <= c && c <= 'z' { + k[i] &= toUpper + } else if !upper && 'A' <= c && c <= 'Z' { + k[i] |= toLower + } + upper = c == '-' + } +} + +// readLine reads line from br. It reads until '\n' and returns bytes without +// '\n' or '\r\n' at the end. +// It returns err if and only if line does not end in '\n'. Note that read +// bytes returned in any case of error. +// +// It is much like the textproto/Reader.ReadLine() except the thing that it +// returns raw bytes, instead of string. That is, it avoids copying bytes read +// from br. +// +// textproto/Reader.ReadLineBytes() is also makes copy of resulting bytes to be +// safe with future I/O operations on br. +// +// We could control I/O operations on br and do not need to make additional +// copy for safety. +// +// NOTE: it may return copied flag to notify that returned buffer is safe to +// use. +func readLine(br *bufio.Reader) ([]byte, error) { + var line []byte + for { + bts, err := br.ReadSlice('\n') + if err == bufio.ErrBufferFull { + // Copy bytes because next read will discard them. + line = append(line, bts...) + continue + } + + // Avoid copy of single read. + if line == nil { + line = bts + } else { + line = append(line, bts...) + } + + if err != nil { + return line, err + } + + // Size of line is at least 1. + // In other case bufio.ReadSlice() returns error. + n := len(line) + + // Cut '\n' or '\r\n'. + if n > 1 && line[n-2] == '\r' { + line = line[:n-2] + } else { + line = line[:n-1] + } + + return line, nil + } +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func nonZero(a, b int) int { + if a != 0 { + return a + } + return b +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/write.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/write.go new file mode 100644 index 000000000000..94557c696394 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gobwas/ws/write.go @@ -0,0 +1,104 @@ +package ws + +import ( + "encoding/binary" + "io" +) + +// Header size length bounds in bytes. +const ( + MaxHeaderSize = 14 + MinHeaderSize = 2 +) + +const ( + bit0 = 0x80 + bit1 = 0x40 + bit2 = 0x20 + bit3 = 0x10 + bit4 = 0x08 + bit5 = 0x04 + bit6 = 0x02 + bit7 = 0x01 + + len7 = int64(125) + len16 = int64(^(uint16(0))) + len64 = int64(^(uint64(0)) >> 1) +) + +// HeaderSize returns number of bytes that are needed to encode given header. +// It returns -1 if header is malformed. +func HeaderSize(h Header) (n int) { + switch { + case h.Length < 126: + n = 2 + case h.Length <= len16: + n = 4 + case h.Length <= len64: + n = 10 + default: + return -1 + } + if h.Masked { + n += len(h.Mask) + } + return n +} + +// WriteHeader writes header binary representation into w. +func WriteHeader(w io.Writer, h Header) error { + // Make slice of bytes with capacity 14 that could hold any header. + bts := make([]byte, MaxHeaderSize) + + if h.Fin { + bts[0] |= bit0 + } + bts[0] |= h.Rsv << 4 + bts[0] |= byte(h.OpCode) + + var n int + switch { + case h.Length <= len7: + bts[1] = byte(h.Length) + n = 2 + + case h.Length <= len16: + bts[1] = 126 + binary.BigEndian.PutUint16(bts[2:4], uint16(h.Length)) + n = 4 + + case h.Length <= len64: + bts[1] = 127 + binary.BigEndian.PutUint64(bts[2:10], uint64(h.Length)) + n = 10 + + default: + return ErrHeaderLengthUnexpected + } + + if h.Masked { + bts[1] |= bit0 + n += copy(bts[n:], h.Mask[:]) + } + + _, err := w.Write(bts[:n]) + + return err +} + +// WriteFrame writes frame binary representation into w. +func WriteFrame(w io.Writer, f Frame) error { + err := WriteHeader(w, f.Header) + if err != nil { + return err + } + _, err = w.Write(f.Payload) + return err +} + +// MustWriteFrame is like WriteFrame but panics if frame can not be read. +func MustWriteFrame(w io.Writer, f Frame) { + if err := WriteFrame(w, f); err != nil { + panic(err) + } +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/.gitignore b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/.gitignore new file mode 100644 index 000000000000..cd3fcd1ef72a --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/.gitignore @@ -0,0 +1,25 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe + +.idea/ +*.iml diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/AUTHORS b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/AUTHORS new file mode 100644 index 000000000000..1931f400682c --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/AUTHORS @@ -0,0 +1,9 @@ +# This is the official list of Gorilla WebSocket authors for copyright +# purposes. +# +# Please keep the list sorted. + +Gary Burd +Google LLC (https://opensource.google.com/) +Joachim Bauch + diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/README.md b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/README.md new file mode 100644 index 000000000000..19aa2e75c824 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/README.md @@ -0,0 +1,64 @@ +# Gorilla WebSocket + +[![GoDoc](https://godoc.org/github.com/gorilla/websocket?status.svg)](https://godoc.org/github.com/gorilla/websocket) +[![CircleCI](https://circleci.com/gh/gorilla/websocket.svg?style=svg)](https://circleci.com/gh/gorilla/websocket) + +Gorilla WebSocket is a [Go](http://golang.org/) implementation of the +[WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol. + +### Documentation + +* [API Reference](https://pkg.go.dev/github.com/gorilla/websocket?tab=doc) +* [Chat example](https://github.com/gorilla/websocket/tree/master/examples/chat) +* [Command example](https://github.com/gorilla/websocket/tree/master/examples/command) +* [Client and server example](https://github.com/gorilla/websocket/tree/master/examples/echo) +* [File watch example](https://github.com/gorilla/websocket/tree/master/examples/filewatch) + +### Status + +The Gorilla WebSocket package provides a complete and tested implementation of +the [WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol. The +package API is stable. + +### Installation + + go get github.com/gorilla/websocket + +### Protocol Compliance + +The Gorilla WebSocket package passes the server tests in the [Autobahn Test +Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn +subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn). + +### Gorilla WebSocket compared with other packages + + + + + + + + + + + + + + + + + + +
github.com/gorillagolang.org/x/net
RFC 6455 Features
Passes Autobahn Test SuiteYesNo
Receive fragmented messageYesNo, see note 1
Send close messageYesNo
Send pings and receive pongsYesNo
Get the type of a received data messageYesYes, see note 2
Other Features
Compression ExtensionsExperimentalNo
Read message using io.ReaderYesNo, see note 3
Write message using io.WriteCloserYesNo, see note 3
+ +Notes: + +1. Large messages are fragmented in [Chrome's new WebSocket implementation](http://www.ietf.org/mail-archive/web/hybi/current/msg10503.html). +2. The application can get the type of a received data message by implementing + a [Codec marshal](http://godoc.org/golang.org/x/net/websocket#Codec.Marshal) + function. +3. The go.net io.Reader and io.Writer operate across WebSocket frame boundaries. + Read returns when the input buffer is full or a frame boundary is + encountered. Each call to Write sends a single frame message. The Gorilla + io.Reader and io.WriteCloser operate on a single WebSocket message. + diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/client.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/client.go new file mode 100644 index 000000000000..962c06a391c2 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/client.go @@ -0,0 +1,395 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bytes" + "context" + "crypto/tls" + "errors" + "io" + "io/ioutil" + "net" + "net/http" + "net/http/httptrace" + "net/url" + "strings" + "time" +) + +// ErrBadHandshake is returned when the server response to opening handshake is +// invalid. +var ErrBadHandshake = errors.New("websocket: bad handshake") + +var errInvalidCompression = errors.New("websocket: invalid compression negotiation") + +// NewClient creates a new client connection using the given net connection. +// The URL u specifies the host and request URI. Use requestHeader to specify +// the origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies +// (Cookie). Use the response.Header to get the selected subprotocol +// (Sec-WebSocket-Protocol) and cookies (Set-Cookie). +// +// If the WebSocket handshake fails, ErrBadHandshake is returned along with a +// non-nil *http.Response so that callers can handle redirects, authentication, +// etc. +// +// Deprecated: Use Dialer instead. +func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufSize, writeBufSize int) (c *Conn, response *http.Response, err error) { + d := Dialer{ + ReadBufferSize: readBufSize, + WriteBufferSize: writeBufSize, + NetDial: func(net, addr string) (net.Conn, error) { + return netConn, nil + }, + } + return d.Dial(u.String(), requestHeader) +} + +// A Dialer contains options for connecting to WebSocket server. +type Dialer struct { + // NetDial specifies the dial function for creating TCP connections. If + // NetDial is nil, net.Dial is used. + NetDial func(network, addr string) (net.Conn, error) + + // NetDialContext specifies the dial function for creating TCP connections. If + // NetDialContext is nil, net.DialContext is used. + NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error) + + // Proxy specifies a function to return a proxy for a given + // Request. If the function returns a non-nil error, the + // request is aborted with the provided error. + // If Proxy is nil or returns a nil *URL, no proxy is used. + Proxy func(*http.Request) (*url.URL, error) + + // TLSClientConfig specifies the TLS configuration to use with tls.Client. + // If nil, the default configuration is used. + TLSClientConfig *tls.Config + + // HandshakeTimeout specifies the duration for the handshake to complete. + HandshakeTimeout time.Duration + + // ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer + // size is zero, then a useful default size is used. The I/O buffer sizes + // do not limit the size of the messages that can be sent or received. + ReadBufferSize, WriteBufferSize int + + // WriteBufferPool is a pool of buffers for write operations. If the value + // is not set, then write buffers are allocated to the connection for the + // lifetime of the connection. + // + // A pool is most useful when the application has a modest volume of writes + // across a large number of connections. + // + // Applications should use a single pool for each unique value of + // WriteBufferSize. + WriteBufferPool BufferPool + + // Subprotocols specifies the client's requested subprotocols. + Subprotocols []string + + // EnableCompression specifies if the client should attempt to negotiate + // per message compression (RFC 7692). Setting this value to true does not + // guarantee that compression will be supported. Currently only "no context + // takeover" modes are supported. + EnableCompression bool + + // Jar specifies the cookie jar. + // If Jar is nil, cookies are not sent in requests and ignored + // in responses. + Jar http.CookieJar +} + +// Dial creates a new client connection by calling DialContext with a background context. +func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { + return d.DialContext(context.Background(), urlStr, requestHeader) +} + +var errMalformedURL = errors.New("malformed ws or wss URL") + +func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) { + hostPort = u.Host + hostNoPort = u.Host + if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") { + hostNoPort = hostNoPort[:i] + } else { + switch u.Scheme { + case "wss": + hostPort += ":443" + case "https": + hostPort += ":443" + default: + hostPort += ":80" + } + } + return hostPort, hostNoPort +} + +// DefaultDialer is a dialer with all fields set to the default values. +var DefaultDialer = &Dialer{ + Proxy: http.ProxyFromEnvironment, + HandshakeTimeout: 45 * time.Second, +} + +// nilDialer is dialer to use when receiver is nil. +var nilDialer = *DefaultDialer + +// DialContext creates a new client connection. Use requestHeader to specify the +// origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie). +// Use the response.Header to get the selected subprotocol +// (Sec-WebSocket-Protocol) and cookies (Set-Cookie). +// +// The context will be used in the request and in the Dialer. +// +// If the WebSocket handshake fails, ErrBadHandshake is returned along with a +// non-nil *http.Response so that callers can handle redirects, authentication, +// etcetera. The response body may not contain the entire response and does not +// need to be closed by the application. +func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { + if d == nil { + d = &nilDialer + } + + challengeKey, err := generateChallengeKey() + if err != nil { + return nil, nil, err + } + + u, err := url.Parse(urlStr) + if err != nil { + return nil, nil, err + } + + switch u.Scheme { + case "ws": + u.Scheme = "http" + case "wss": + u.Scheme = "https" + default: + return nil, nil, errMalformedURL + } + + if u.User != nil { + // User name and password are not allowed in websocket URIs. + return nil, nil, errMalformedURL + } + + req := &http.Request{ + Method: "GET", + URL: u, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + Host: u.Host, + } + req = req.WithContext(ctx) + + // Set the cookies present in the cookie jar of the dialer + if d.Jar != nil { + for _, cookie := range d.Jar.Cookies(u) { + req.AddCookie(cookie) + } + } + + // Set the request headers using the capitalization for names and values in + // RFC examples. Although the capitalization shouldn't matter, there are + // servers that depend on it. The Header.Set method is not used because the + // method canonicalizes the header names. + req.Header["Upgrade"] = []string{"websocket"} + req.Header["Connection"] = []string{"Upgrade"} + req.Header["Sec-WebSocket-Key"] = []string{challengeKey} + req.Header["Sec-WebSocket-Version"] = []string{"13"} + if len(d.Subprotocols) > 0 { + req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")} + } + for k, vs := range requestHeader { + switch { + case k == "Host": + if len(vs) > 0 { + req.Host = vs[0] + } + case k == "Upgrade" || + k == "Connection" || + k == "Sec-Websocket-Key" || + k == "Sec-Websocket-Version" || + k == "Sec-Websocket-Extensions" || + (k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0): + return nil, nil, errors.New("websocket: duplicate header not allowed: " + k) + case k == "Sec-Websocket-Protocol": + req.Header["Sec-WebSocket-Protocol"] = vs + default: + req.Header[k] = vs + } + } + + if d.EnableCompression { + req.Header["Sec-WebSocket-Extensions"] = []string{"permessage-deflate; server_no_context_takeover; client_no_context_takeover"} + } + + if d.HandshakeTimeout != 0 { + var cancel func() + ctx, cancel = context.WithTimeout(ctx, d.HandshakeTimeout) + defer cancel() + } + + // Get network dial function. + var netDial func(network, add string) (net.Conn, error) + + if d.NetDialContext != nil { + netDial = func(network, addr string) (net.Conn, error) { + return d.NetDialContext(ctx, network, addr) + } + } else if d.NetDial != nil { + netDial = d.NetDial + } else { + netDialer := &net.Dialer{} + netDial = func(network, addr string) (net.Conn, error) { + return netDialer.DialContext(ctx, network, addr) + } + } + + // If needed, wrap the dial function to set the connection deadline. + if deadline, ok := ctx.Deadline(); ok { + forwardDial := netDial + netDial = func(network, addr string) (net.Conn, error) { + c, err := forwardDial(network, addr) + if err != nil { + return nil, err + } + err = c.SetDeadline(deadline) + if err != nil { + c.Close() + return nil, err + } + return c, nil + } + } + + // If needed, wrap the dial function to connect through a proxy. + if d.Proxy != nil { + proxyURL, err := d.Proxy(req) + if err != nil { + return nil, nil, err + } + if proxyURL != nil { + dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial)) + if err != nil { + return nil, nil, err + } + netDial = dialer.Dial + } + } + + hostPort, hostNoPort := hostPortNoPort(u) + trace := httptrace.ContextClientTrace(ctx) + if trace != nil && trace.GetConn != nil { + trace.GetConn(hostPort) + } + + netConn, err := netDial("tcp", hostPort) + if trace != nil && trace.GotConn != nil { + trace.GotConn(httptrace.GotConnInfo{ + Conn: netConn, + }) + } + if err != nil { + return nil, nil, err + } + + defer func() { + if netConn != nil { + netConn.Close() + } + }() + + if u.Scheme == "https" { + cfg := cloneTLSConfig(d.TLSClientConfig) + if cfg.ServerName == "" { + cfg.ServerName = hostNoPort + } + tlsConn := tls.Client(netConn, cfg) + netConn = tlsConn + + var err error + if trace != nil { + err = doHandshakeWithTrace(trace, tlsConn, cfg) + } else { + err = doHandshake(tlsConn, cfg) + } + + if err != nil { + return nil, nil, err + } + } + + conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize, d.WriteBufferPool, nil, nil) + + if err := req.Write(netConn); err != nil { + return nil, nil, err + } + + if trace != nil && trace.GotFirstResponseByte != nil { + if peek, err := conn.br.Peek(1); err == nil && len(peek) == 1 { + trace.GotFirstResponseByte() + } + } + + resp, err := http.ReadResponse(conn.br, req) + if err != nil { + return nil, nil, err + } + + if d.Jar != nil { + if rc := resp.Cookies(); len(rc) > 0 { + d.Jar.SetCookies(u, rc) + } + } + + if resp.StatusCode != 101 || + !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || + !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || + resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) { + // Before closing the network connection on return from this + // function, slurp up some of the response to aid application + // debugging. + buf := make([]byte, 1024) + n, _ := io.ReadFull(resp.Body, buf) + resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n])) + return nil, resp, ErrBadHandshake + } + + for _, ext := range parseExtensions(resp.Header) { + if ext[""] != "permessage-deflate" { + continue + } + _, snct := ext["server_no_context_takeover"] + _, cnct := ext["client_no_context_takeover"] + if !snct || !cnct { + return nil, resp, errInvalidCompression + } + conn.newCompressionWriter = compressNoContextTakeover + conn.newDecompressionReader = decompressNoContextTakeover + break + } + + resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{})) + conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol") + + netConn.SetDeadline(time.Time{}) + netConn = nil // to avoid close in defer. + return conn, resp, nil +} + +func doHandshake(tlsConn *tls.Conn, cfg *tls.Config) error { + if err := tlsConn.Handshake(); err != nil { + return err + } + if !cfg.InsecureSkipVerify { + if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { + return err + } + } + return nil +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/client_clone.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/client_clone.go new file mode 100644 index 000000000000..4f0d943723a9 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/client_clone.go @@ -0,0 +1,16 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.8 + +package websocket + +import "crypto/tls" + +func cloneTLSConfig(cfg *tls.Config) *tls.Config { + if cfg == nil { + return &tls.Config{} + } + return cfg.Clone() +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/client_clone_legacy.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/client_clone_legacy.go new file mode 100644 index 000000000000..babb007fb414 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/client_clone_legacy.go @@ -0,0 +1,38 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !go1.8 + +package websocket + +import "crypto/tls" + +// cloneTLSConfig clones all public fields except the fields +// SessionTicketsDisabled and SessionTicketKey. This avoids copying the +// sync.Mutex in the sync.Once and makes it safe to call cloneTLSConfig on a +// config in active use. +func cloneTLSConfig(cfg *tls.Config) *tls.Config { + if cfg == nil { + return &tls.Config{} + } + return &tls.Config{ + Rand: cfg.Rand, + Time: cfg.Time, + Certificates: cfg.Certificates, + NameToCertificate: cfg.NameToCertificate, + GetCertificate: cfg.GetCertificate, + RootCAs: cfg.RootCAs, + NextProtos: cfg.NextProtos, + ServerName: cfg.ServerName, + ClientAuth: cfg.ClientAuth, + ClientCAs: cfg.ClientCAs, + InsecureSkipVerify: cfg.InsecureSkipVerify, + CipherSuites: cfg.CipherSuites, + PreferServerCipherSuites: cfg.PreferServerCipherSuites, + ClientSessionCache: cfg.ClientSessionCache, + MinVersion: cfg.MinVersion, + MaxVersion: cfg.MaxVersion, + CurvePreferences: cfg.CurvePreferences, + } +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/compression.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/compression.go new file mode 100644 index 000000000000..813ffb1e8433 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/compression.go @@ -0,0 +1,148 @@ +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "compress/flate" + "errors" + "io" + "strings" + "sync" +) + +const ( + minCompressionLevel = -2 // flate.HuffmanOnly not defined in Go < 1.6 + maxCompressionLevel = flate.BestCompression + defaultCompressionLevel = 1 +) + +var ( + flateWriterPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool + flateReaderPool = sync.Pool{New: func() interface{} { + return flate.NewReader(nil) + }} +) + +func decompressNoContextTakeover(r io.Reader) io.ReadCloser { + const tail = + // Add four bytes as specified in RFC + "\x00\x00\xff\xff" + + // Add final block to squelch unexpected EOF error from flate reader. + "\x01\x00\x00\xff\xff" + + fr, _ := flateReaderPool.Get().(io.ReadCloser) + fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil) + return &flateReadWrapper{fr} +} + +func isValidCompressionLevel(level int) bool { + return minCompressionLevel <= level && level <= maxCompressionLevel +} + +func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser { + p := &flateWriterPools[level-minCompressionLevel] + tw := &truncWriter{w: w} + fw, _ := p.Get().(*flate.Writer) + if fw == nil { + fw, _ = flate.NewWriter(tw, level) + } else { + fw.Reset(tw) + } + return &flateWriteWrapper{fw: fw, tw: tw, p: p} +} + +// truncWriter is an io.Writer that writes all but the last four bytes of the +// stream to another io.Writer. +type truncWriter struct { + w io.WriteCloser + n int + p [4]byte +} + +func (w *truncWriter) Write(p []byte) (int, error) { + n := 0 + + // fill buffer first for simplicity. + if w.n < len(w.p) { + n = copy(w.p[w.n:], p) + p = p[n:] + w.n += n + if len(p) == 0 { + return n, nil + } + } + + m := len(p) + if m > len(w.p) { + m = len(w.p) + } + + if nn, err := w.w.Write(w.p[:m]); err != nil { + return n + nn, err + } + + copy(w.p[:], w.p[m:]) + copy(w.p[len(w.p)-m:], p[len(p)-m:]) + nn, err := w.w.Write(p[:len(p)-m]) + return n + nn, err +} + +type flateWriteWrapper struct { + fw *flate.Writer + tw *truncWriter + p *sync.Pool +} + +func (w *flateWriteWrapper) Write(p []byte) (int, error) { + if w.fw == nil { + return 0, errWriteClosed + } + return w.fw.Write(p) +} + +func (w *flateWriteWrapper) Close() error { + if w.fw == nil { + return errWriteClosed + } + err1 := w.fw.Flush() + w.p.Put(w.fw) + w.fw = nil + if w.tw.p != [4]byte{0, 0, 0xff, 0xff} { + return errors.New("websocket: internal error, unexpected bytes at end of flate stream") + } + err2 := w.tw.w.Close() + if err1 != nil { + return err1 + } + return err2 +} + +type flateReadWrapper struct { + fr io.ReadCloser +} + +func (r *flateReadWrapper) Read(p []byte) (int, error) { + if r.fr == nil { + return 0, io.ErrClosedPipe + } + n, err := r.fr.Read(p) + if err == io.EOF { + // Preemptively place the reader back in the pool. This helps with + // scenarios where the application does not call NextReader() soon after + // this final read. + r.Close() + } + return n, err +} + +func (r *flateReadWrapper) Close() error { + if r.fr == nil { + return io.ErrClosedPipe + } + err := r.fr.Close() + flateReaderPool.Put(r.fr) + r.fr = nil + return err +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/conn.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/conn.go new file mode 100644 index 000000000000..ca46d2f793c2 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/conn.go @@ -0,0 +1,1201 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bufio" + "encoding/binary" + "errors" + "io" + "io/ioutil" + "math/rand" + "net" + "strconv" + "sync" + "time" + "unicode/utf8" +) + +const ( + // Frame header byte 0 bits from Section 5.2 of RFC 6455 + finalBit = 1 << 7 + rsv1Bit = 1 << 6 + rsv2Bit = 1 << 5 + rsv3Bit = 1 << 4 + + // Frame header byte 1 bits from Section 5.2 of RFC 6455 + maskBit = 1 << 7 + + maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask + maxControlFramePayloadSize = 125 + + writeWait = time.Second + + defaultReadBufferSize = 4096 + defaultWriteBufferSize = 4096 + + continuationFrame = 0 + noFrame = -1 +) + +// Close codes defined in RFC 6455, section 11.7. +const ( + CloseNormalClosure = 1000 + CloseGoingAway = 1001 + CloseProtocolError = 1002 + CloseUnsupportedData = 1003 + CloseNoStatusReceived = 1005 + CloseAbnormalClosure = 1006 + CloseInvalidFramePayloadData = 1007 + ClosePolicyViolation = 1008 + CloseMessageTooBig = 1009 + CloseMandatoryExtension = 1010 + CloseInternalServerErr = 1011 + CloseServiceRestart = 1012 + CloseTryAgainLater = 1013 + CloseTLSHandshake = 1015 +) + +// The message types are defined in RFC 6455, section 11.8. +const ( + // TextMessage denotes a text data message. The text message payload is + // interpreted as UTF-8 encoded text data. + TextMessage = 1 + + // BinaryMessage denotes a binary data message. + BinaryMessage = 2 + + // CloseMessage denotes a close control message. The optional message + // payload contains a numeric code and text. Use the FormatCloseMessage + // function to format a close message payload. + CloseMessage = 8 + + // PingMessage denotes a ping control message. The optional message payload + // is UTF-8 encoded text. + PingMessage = 9 + + // PongMessage denotes a pong control message. The optional message payload + // is UTF-8 encoded text. + PongMessage = 10 +) + +// ErrCloseSent is returned when the application writes a message to the +// connection after sending a close message. +var ErrCloseSent = errors.New("websocket: close sent") + +// ErrReadLimit is returned when reading a message that is larger than the +// read limit set for the connection. +var ErrReadLimit = errors.New("websocket: read limit exceeded") + +// netError satisfies the net Error interface. +type netError struct { + msg string + temporary bool + timeout bool +} + +func (e *netError) Error() string { return e.msg } +func (e *netError) Temporary() bool { return e.temporary } +func (e *netError) Timeout() bool { return e.timeout } + +// CloseError represents a close message. +type CloseError struct { + // Code is defined in RFC 6455, section 11.7. + Code int + + // Text is the optional text payload. + Text string +} + +func (e *CloseError) Error() string { + s := []byte("websocket: close ") + s = strconv.AppendInt(s, int64(e.Code), 10) + switch e.Code { + case CloseNormalClosure: + s = append(s, " (normal)"...) + case CloseGoingAway: + s = append(s, " (going away)"...) + case CloseProtocolError: + s = append(s, " (protocol error)"...) + case CloseUnsupportedData: + s = append(s, " (unsupported data)"...) + case CloseNoStatusReceived: + s = append(s, " (no status)"...) + case CloseAbnormalClosure: + s = append(s, " (abnormal closure)"...) + case CloseInvalidFramePayloadData: + s = append(s, " (invalid payload data)"...) + case ClosePolicyViolation: + s = append(s, " (policy violation)"...) + case CloseMessageTooBig: + s = append(s, " (message too big)"...) + case CloseMandatoryExtension: + s = append(s, " (mandatory extension missing)"...) + case CloseInternalServerErr: + s = append(s, " (internal server error)"...) + case CloseTLSHandshake: + s = append(s, " (TLS handshake error)"...) + } + if e.Text != "" { + s = append(s, ": "...) + s = append(s, e.Text...) + } + return string(s) +} + +// IsCloseError returns boolean indicating whether the error is a *CloseError +// with one of the specified codes. +func IsCloseError(err error, codes ...int) bool { + if e, ok := err.(*CloseError); ok { + for _, code := range codes { + if e.Code == code { + return true + } + } + } + return false +} + +// IsUnexpectedCloseError returns boolean indicating whether the error is a +// *CloseError with a code not in the list of expected codes. +func IsUnexpectedCloseError(err error, expectedCodes ...int) bool { + if e, ok := err.(*CloseError); ok { + for _, code := range expectedCodes { + if e.Code == code { + return false + } + } + return true + } + return false +} + +var ( + errWriteTimeout = &netError{msg: "websocket: write timeout", timeout: true, temporary: true} + errUnexpectedEOF = &CloseError{Code: CloseAbnormalClosure, Text: io.ErrUnexpectedEOF.Error()} + errBadWriteOpCode = errors.New("websocket: bad write message type") + errWriteClosed = errors.New("websocket: write closed") + errInvalidControlFrame = errors.New("websocket: invalid control frame") +) + +func newMaskKey() [4]byte { + n := rand.Uint32() + return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)} +} + +func hideTempErr(err error) error { + if e, ok := err.(net.Error); ok && e.Temporary() { + err = &netError{msg: e.Error(), timeout: e.Timeout()} + } + return err +} + +func isControl(frameType int) bool { + return frameType == CloseMessage || frameType == PingMessage || frameType == PongMessage +} + +func isData(frameType int) bool { + return frameType == TextMessage || frameType == BinaryMessage +} + +var validReceivedCloseCodes = map[int]bool{ + // see http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number + + CloseNormalClosure: true, + CloseGoingAway: true, + CloseProtocolError: true, + CloseUnsupportedData: true, + CloseNoStatusReceived: false, + CloseAbnormalClosure: false, + CloseInvalidFramePayloadData: true, + ClosePolicyViolation: true, + CloseMessageTooBig: true, + CloseMandatoryExtension: true, + CloseInternalServerErr: true, + CloseServiceRestart: true, + CloseTryAgainLater: true, + CloseTLSHandshake: false, +} + +func isValidReceivedCloseCode(code int) bool { + return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999) +} + +// BufferPool represents a pool of buffers. The *sync.Pool type satisfies this +// interface. The type of the value stored in a pool is not specified. +type BufferPool interface { + // Get gets a value from the pool or returns nil if the pool is empty. + Get() interface{} + // Put adds a value to the pool. + Put(interface{}) +} + +// writePoolData is the type added to the write buffer pool. This wrapper is +// used to prevent applications from peeking at and depending on the values +// added to the pool. +type writePoolData struct{ buf []byte } + +// The Conn type represents a WebSocket connection. +type Conn struct { + conn net.Conn + isServer bool + subprotocol string + + // Write fields + mu chan struct{} // used as mutex to protect write to conn + writeBuf []byte // frame is constructed in this buffer. + writePool BufferPool + writeBufSize int + writeDeadline time.Time + writer io.WriteCloser // the current writer returned to the application + isWriting bool // for best-effort concurrent write detection + + writeErrMu sync.Mutex + writeErr error + + enableWriteCompression bool + compressionLevel int + newCompressionWriter func(io.WriteCloser, int) io.WriteCloser + + // Read fields + reader io.ReadCloser // the current reader returned to the application + readErr error + br *bufio.Reader + // bytes remaining in current frame. + // set setReadRemaining to safely update this value and prevent overflow + readRemaining int64 + readFinal bool // true the current message has more frames. + readLength int64 // Message size. + readLimit int64 // Maximum message size. + readMaskPos int + readMaskKey [4]byte + handlePong func(string) error + handlePing func(string) error + handleClose func(int, string) error + readErrCount int + messageReader *messageReader // the current low-level reader + + readDecompress bool // whether last read frame had RSV1 set + newDecompressionReader func(io.Reader) io.ReadCloser +} + +func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, writeBufferPool BufferPool, br *bufio.Reader, writeBuf []byte) *Conn { + + if br == nil { + if readBufferSize == 0 { + readBufferSize = defaultReadBufferSize + } else if readBufferSize < maxControlFramePayloadSize { + // must be large enough for control frame + readBufferSize = maxControlFramePayloadSize + } + br = bufio.NewReaderSize(conn, readBufferSize) + } + + if writeBufferSize <= 0 { + writeBufferSize = defaultWriteBufferSize + } + writeBufferSize += maxFrameHeaderSize + + if writeBuf == nil && writeBufferPool == nil { + writeBuf = make([]byte, writeBufferSize) + } + + mu := make(chan struct{}, 1) + mu <- struct{}{} + c := &Conn{ + isServer: isServer, + br: br, + conn: conn, + mu: mu, + readFinal: true, + writeBuf: writeBuf, + writePool: writeBufferPool, + writeBufSize: writeBufferSize, + enableWriteCompression: true, + compressionLevel: defaultCompressionLevel, + } + c.SetCloseHandler(nil) + c.SetPingHandler(nil) + c.SetPongHandler(nil) + return c +} + +// setReadRemaining tracks the number of bytes remaining on the connection. If n +// overflows, an ErrReadLimit is returned. +func (c *Conn) setReadRemaining(n int64) error { + if n < 0 { + return ErrReadLimit + } + + c.readRemaining = n + return nil +} + +// Subprotocol returns the negotiated protocol for the connection. +func (c *Conn) Subprotocol() string { + return c.subprotocol +} + +// Close closes the underlying network connection without sending or waiting +// for a close message. +func (c *Conn) Close() error { + return c.conn.Close() +} + +// LocalAddr returns the local network address. +func (c *Conn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +// RemoteAddr returns the remote network address. +func (c *Conn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +// Write methods + +func (c *Conn) writeFatal(err error) error { + err = hideTempErr(err) + c.writeErrMu.Lock() + if c.writeErr == nil { + c.writeErr = err + } + c.writeErrMu.Unlock() + return err +} + +func (c *Conn) read(n int) ([]byte, error) { + p, err := c.br.Peek(n) + if err == io.EOF { + err = errUnexpectedEOF + } + c.br.Discard(len(p)) + return p, err +} + +func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error { + <-c.mu + defer func() { c.mu <- struct{}{} }() + + c.writeErrMu.Lock() + err := c.writeErr + c.writeErrMu.Unlock() + if err != nil { + return err + } + + c.conn.SetWriteDeadline(deadline) + if len(buf1) == 0 { + _, err = c.conn.Write(buf0) + } else { + err = c.writeBufs(buf0, buf1) + } + if err != nil { + return c.writeFatal(err) + } + if frameType == CloseMessage { + c.writeFatal(ErrCloseSent) + } + return nil +} + +// WriteControl writes a control message with the given deadline. The allowed +// message types are CloseMessage, PingMessage and PongMessage. +func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error { + if !isControl(messageType) { + return errBadWriteOpCode + } + if len(data) > maxControlFramePayloadSize { + return errInvalidControlFrame + } + + b0 := byte(messageType) | finalBit + b1 := byte(len(data)) + if !c.isServer { + b1 |= maskBit + } + + buf := make([]byte, 0, maxFrameHeaderSize+maxControlFramePayloadSize) + buf = append(buf, b0, b1) + + if c.isServer { + buf = append(buf, data...) + } else { + key := newMaskKey() + buf = append(buf, key[:]...) + buf = append(buf, data...) + maskBytes(key, 0, buf[6:]) + } + + d := 1000 * time.Hour + if !deadline.IsZero() { + d = deadline.Sub(time.Now()) + if d < 0 { + return errWriteTimeout + } + } + + timer := time.NewTimer(d) + select { + case <-c.mu: + timer.Stop() + case <-timer.C: + return errWriteTimeout + } + defer func() { c.mu <- struct{}{} }() + + c.writeErrMu.Lock() + err := c.writeErr + c.writeErrMu.Unlock() + if err != nil { + return err + } + + c.conn.SetWriteDeadline(deadline) + _, err = c.conn.Write(buf) + if err != nil { + return c.writeFatal(err) + } + if messageType == CloseMessage { + c.writeFatal(ErrCloseSent) + } + return err +} + +// beginMessage prepares a connection and message writer for a new message. +func (c *Conn) beginMessage(mw *messageWriter, messageType int) error { + // Close previous writer if not already closed by the application. It's + // probably better to return an error in this situation, but we cannot + // change this without breaking existing applications. + if c.writer != nil { + c.writer.Close() + c.writer = nil + } + + if !isControl(messageType) && !isData(messageType) { + return errBadWriteOpCode + } + + c.writeErrMu.Lock() + err := c.writeErr + c.writeErrMu.Unlock() + if err != nil { + return err + } + + mw.c = c + mw.frameType = messageType + mw.pos = maxFrameHeaderSize + + if c.writeBuf == nil { + wpd, ok := c.writePool.Get().(writePoolData) + if ok { + c.writeBuf = wpd.buf + } else { + c.writeBuf = make([]byte, c.writeBufSize) + } + } + return nil +} + +// NextWriter returns a writer for the next message to send. The writer's Close +// method flushes the complete message to the network. +// +// There can be at most one open writer on a connection. NextWriter closes the +// previous writer if the application has not already done so. +// +// All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and +// PongMessage) are supported. +func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { + var mw messageWriter + if err := c.beginMessage(&mw, messageType); err != nil { + return nil, err + } + c.writer = &mw + if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) { + w := c.newCompressionWriter(c.writer, c.compressionLevel) + mw.compress = true + c.writer = w + } + return c.writer, nil +} + +type messageWriter struct { + c *Conn + compress bool // whether next call to flushFrame should set RSV1 + pos int // end of data in writeBuf. + frameType int // type of the current frame. + err error +} + +func (w *messageWriter) endMessage(err error) error { + if w.err != nil { + return err + } + c := w.c + w.err = err + c.writer = nil + if c.writePool != nil { + c.writePool.Put(writePoolData{buf: c.writeBuf}) + c.writeBuf = nil + } + return err +} + +// flushFrame writes buffered data and extra as a frame to the network. The +// final argument indicates that this is the last frame in the message. +func (w *messageWriter) flushFrame(final bool, extra []byte) error { + c := w.c + length := w.pos - maxFrameHeaderSize + len(extra) + + // Check for invalid control frames. + if isControl(w.frameType) && + (!final || length > maxControlFramePayloadSize) { + return w.endMessage(errInvalidControlFrame) + } + + b0 := byte(w.frameType) + if final { + b0 |= finalBit + } + if w.compress { + b0 |= rsv1Bit + } + w.compress = false + + b1 := byte(0) + if !c.isServer { + b1 |= maskBit + } + + // Assume that the frame starts at beginning of c.writeBuf. + framePos := 0 + if c.isServer { + // Adjust up if mask not included in the header. + framePos = 4 + } + + switch { + case length >= 65536: + c.writeBuf[framePos] = b0 + c.writeBuf[framePos+1] = b1 | 127 + binary.BigEndian.PutUint64(c.writeBuf[framePos+2:], uint64(length)) + case length > 125: + framePos += 6 + c.writeBuf[framePos] = b0 + c.writeBuf[framePos+1] = b1 | 126 + binary.BigEndian.PutUint16(c.writeBuf[framePos+2:], uint16(length)) + default: + framePos += 8 + c.writeBuf[framePos] = b0 + c.writeBuf[framePos+1] = b1 | byte(length) + } + + if !c.isServer { + key := newMaskKey() + copy(c.writeBuf[maxFrameHeaderSize-4:], key[:]) + maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos]) + if len(extra) > 0 { + return w.endMessage(c.writeFatal(errors.New("websocket: internal error, extra used in client mode"))) + } + } + + // Write the buffers to the connection with best-effort detection of + // concurrent writes. See the concurrency section in the package + // documentation for more info. + + if c.isWriting { + panic("concurrent write to websocket connection") + } + c.isWriting = true + + err := c.write(w.frameType, c.writeDeadline, c.writeBuf[framePos:w.pos], extra) + + if !c.isWriting { + panic("concurrent write to websocket connection") + } + c.isWriting = false + + if err != nil { + return w.endMessage(err) + } + + if final { + w.endMessage(errWriteClosed) + return nil + } + + // Setup for next frame. + w.pos = maxFrameHeaderSize + w.frameType = continuationFrame + return nil +} + +func (w *messageWriter) ncopy(max int) (int, error) { + n := len(w.c.writeBuf) - w.pos + if n <= 0 { + if err := w.flushFrame(false, nil); err != nil { + return 0, err + } + n = len(w.c.writeBuf) - w.pos + } + if n > max { + n = max + } + return n, nil +} + +func (w *messageWriter) Write(p []byte) (int, error) { + if w.err != nil { + return 0, w.err + } + + if len(p) > 2*len(w.c.writeBuf) && w.c.isServer { + // Don't buffer large messages. + err := w.flushFrame(false, p) + if err != nil { + return 0, err + } + return len(p), nil + } + + nn := len(p) + for len(p) > 0 { + n, err := w.ncopy(len(p)) + if err != nil { + return 0, err + } + copy(w.c.writeBuf[w.pos:], p[:n]) + w.pos += n + p = p[n:] + } + return nn, nil +} + +func (w *messageWriter) WriteString(p string) (int, error) { + if w.err != nil { + return 0, w.err + } + + nn := len(p) + for len(p) > 0 { + n, err := w.ncopy(len(p)) + if err != nil { + return 0, err + } + copy(w.c.writeBuf[w.pos:], p[:n]) + w.pos += n + p = p[n:] + } + return nn, nil +} + +func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) { + if w.err != nil { + return 0, w.err + } + for { + if w.pos == len(w.c.writeBuf) { + err = w.flushFrame(false, nil) + if err != nil { + break + } + } + var n int + n, err = r.Read(w.c.writeBuf[w.pos:]) + w.pos += n + nn += int64(n) + if err != nil { + if err == io.EOF { + err = nil + } + break + } + } + return nn, err +} + +func (w *messageWriter) Close() error { + if w.err != nil { + return w.err + } + return w.flushFrame(true, nil) +} + +// WritePreparedMessage writes prepared message into connection. +func (c *Conn) WritePreparedMessage(pm *PreparedMessage) error { + frameType, frameData, err := pm.frame(prepareKey{ + isServer: c.isServer, + compress: c.newCompressionWriter != nil && c.enableWriteCompression && isData(pm.messageType), + compressionLevel: c.compressionLevel, + }) + if err != nil { + return err + } + if c.isWriting { + panic("concurrent write to websocket connection") + } + c.isWriting = true + err = c.write(frameType, c.writeDeadline, frameData, nil) + if !c.isWriting { + panic("concurrent write to websocket connection") + } + c.isWriting = false + return err +} + +// WriteMessage is a helper method for getting a writer using NextWriter, +// writing the message and closing the writer. +func (c *Conn) WriteMessage(messageType int, data []byte) error { + + if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) { + // Fast path with no allocations and single frame. + + var mw messageWriter + if err := c.beginMessage(&mw, messageType); err != nil { + return err + } + n := copy(c.writeBuf[mw.pos:], data) + mw.pos += n + data = data[n:] + return mw.flushFrame(true, data) + } + + w, err := c.NextWriter(messageType) + if err != nil { + return err + } + if _, err = w.Write(data); err != nil { + return err + } + return w.Close() +} + +// SetWriteDeadline sets the write deadline on the underlying network +// connection. After a write has timed out, the websocket state is corrupt and +// all future writes will return an error. A zero value for t means writes will +// not time out. +func (c *Conn) SetWriteDeadline(t time.Time) error { + c.writeDeadline = t + return nil +} + +// Read methods + +func (c *Conn) advanceFrame() (int, error) { + // 1. Skip remainder of previous frame. + + if c.readRemaining > 0 { + if _, err := io.CopyN(ioutil.Discard, c.br, c.readRemaining); err != nil { + return noFrame, err + } + } + + // 2. Read and parse first two bytes of frame header. + + p, err := c.read(2) + if err != nil { + return noFrame, err + } + + final := p[0]&finalBit != 0 + frameType := int(p[0] & 0xf) + mask := p[1]&maskBit != 0 + c.setReadRemaining(int64(p[1] & 0x7f)) + + c.readDecompress = false + if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 { + c.readDecompress = true + p[0] &^= rsv1Bit + } + + if rsv := p[0] & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 { + return noFrame, c.handleProtocolError("unexpected reserved bits 0x" + strconv.FormatInt(int64(rsv), 16)) + } + + switch frameType { + case CloseMessage, PingMessage, PongMessage: + if c.readRemaining > maxControlFramePayloadSize { + return noFrame, c.handleProtocolError("control frame length > 125") + } + if !final { + return noFrame, c.handleProtocolError("control frame not final") + } + case TextMessage, BinaryMessage: + if !c.readFinal { + return noFrame, c.handleProtocolError("message start before final message frame") + } + c.readFinal = final + case continuationFrame: + if c.readFinal { + return noFrame, c.handleProtocolError("continuation after final message frame") + } + c.readFinal = final + default: + return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType)) + } + + // 3. Read and parse frame length as per + // https://tools.ietf.org/html/rfc6455#section-5.2 + // + // The length of the "Payload data", in bytes: if 0-125, that is the payload + // length. + // - If 126, the following 2 bytes interpreted as a 16-bit unsigned + // integer are the payload length. + // - If 127, the following 8 bytes interpreted as + // a 64-bit unsigned integer (the most significant bit MUST be 0) are the + // payload length. Multibyte length quantities are expressed in network byte + // order. + + switch c.readRemaining { + case 126: + p, err := c.read(2) + if err != nil { + return noFrame, err + } + + if err := c.setReadRemaining(int64(binary.BigEndian.Uint16(p))); err != nil { + return noFrame, err + } + case 127: + p, err := c.read(8) + if err != nil { + return noFrame, err + } + + if err := c.setReadRemaining(int64(binary.BigEndian.Uint64(p))); err != nil { + return noFrame, err + } + } + + // 4. Handle frame masking. + + if mask != c.isServer { + return noFrame, c.handleProtocolError("incorrect mask flag") + } + + if mask { + c.readMaskPos = 0 + p, err := c.read(len(c.readMaskKey)) + if err != nil { + return noFrame, err + } + copy(c.readMaskKey[:], p) + } + + // 5. For text and binary messages, enforce read limit and return. + + if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage { + + c.readLength += c.readRemaining + // Don't allow readLength to overflow in the presence of a large readRemaining + // counter. + if c.readLength < 0 { + return noFrame, ErrReadLimit + } + + if c.readLimit > 0 && c.readLength > c.readLimit { + c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait)) + return noFrame, ErrReadLimit + } + + return frameType, nil + } + + // 6. Read control frame payload. + + var payload []byte + if c.readRemaining > 0 { + payload, err = c.read(int(c.readRemaining)) + c.setReadRemaining(0) + if err != nil { + return noFrame, err + } + if c.isServer { + maskBytes(c.readMaskKey, 0, payload) + } + } + + // 7. Process control frame payload. + + switch frameType { + case PongMessage: + if err := c.handlePong(string(payload)); err != nil { + return noFrame, err + } + case PingMessage: + if err := c.handlePing(string(payload)); err != nil { + return noFrame, err + } + case CloseMessage: + closeCode := CloseNoStatusReceived + closeText := "" + if len(payload) >= 2 { + closeCode = int(binary.BigEndian.Uint16(payload)) + if !isValidReceivedCloseCode(closeCode) { + return noFrame, c.handleProtocolError("invalid close code") + } + closeText = string(payload[2:]) + if !utf8.ValidString(closeText) { + return noFrame, c.handleProtocolError("invalid utf8 payload in close frame") + } + } + if err := c.handleClose(closeCode, closeText); err != nil { + return noFrame, err + } + return noFrame, &CloseError{Code: closeCode, Text: closeText} + } + + return frameType, nil +} + +func (c *Conn) handleProtocolError(message string) error { + c.WriteControl(CloseMessage, FormatCloseMessage(CloseProtocolError, message), time.Now().Add(writeWait)) + return errors.New("websocket: " + message) +} + +// NextReader returns the next data message received from the peer. The +// returned messageType is either TextMessage or BinaryMessage. +// +// There can be at most one open reader on a connection. NextReader discards +// the previous message if the application has not already consumed it. +// +// Applications must break out of the application's read loop when this method +// returns a non-nil error value. Errors returned from this method are +// permanent. Once this method returns a non-nil error, all subsequent calls to +// this method return the same error. +func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { + // Close previous reader, only relevant for decompression. + if c.reader != nil { + c.reader.Close() + c.reader = nil + } + + c.messageReader = nil + c.readLength = 0 + + for c.readErr == nil { + frameType, err := c.advanceFrame() + if err != nil { + c.readErr = hideTempErr(err) + break + } + + if frameType == TextMessage || frameType == BinaryMessage { + c.messageReader = &messageReader{c} + c.reader = c.messageReader + if c.readDecompress { + c.reader = c.newDecompressionReader(c.reader) + } + return frameType, c.reader, nil + } + } + + // Applications that do handle the error returned from this method spin in + // tight loop on connection failure. To help application developers detect + // this error, panic on repeated reads to the failed connection. + c.readErrCount++ + if c.readErrCount >= 1000 { + panic("repeated read on failed websocket connection") + } + + return noFrame, nil, c.readErr +} + +type messageReader struct{ c *Conn } + +func (r *messageReader) Read(b []byte) (int, error) { + c := r.c + if c.messageReader != r { + return 0, io.EOF + } + + for c.readErr == nil { + + if c.readRemaining > 0 { + if int64(len(b)) > c.readRemaining { + b = b[:c.readRemaining] + } + n, err := c.br.Read(b) + c.readErr = hideTempErr(err) + if c.isServer { + c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n]) + } + rem := c.readRemaining + rem -= int64(n) + c.setReadRemaining(rem) + if c.readRemaining > 0 && c.readErr == io.EOF { + c.readErr = errUnexpectedEOF + } + return n, c.readErr + } + + if c.readFinal { + c.messageReader = nil + return 0, io.EOF + } + + frameType, err := c.advanceFrame() + switch { + case err != nil: + c.readErr = hideTempErr(err) + case frameType == TextMessage || frameType == BinaryMessage: + c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader") + } + } + + err := c.readErr + if err == io.EOF && c.messageReader == r { + err = errUnexpectedEOF + } + return 0, err +} + +func (r *messageReader) Close() error { + return nil +} + +// ReadMessage is a helper method for getting a reader using NextReader and +// reading from that reader to a buffer. +func (c *Conn) ReadMessage() (messageType int, p []byte, err error) { + var r io.Reader + messageType, r, err = c.NextReader() + if err != nil { + return messageType, nil, err + } + p, err = ioutil.ReadAll(r) + return messageType, p, err +} + +// SetReadDeadline sets the read deadline on the underlying network connection. +// After a read has timed out, the websocket connection state is corrupt and +// all future reads will return an error. A zero value for t means reads will +// not time out. +func (c *Conn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +// SetReadLimit sets the maximum size in bytes for a message read from the peer. If a +// message exceeds the limit, the connection sends a close message to the peer +// and returns ErrReadLimit to the application. +func (c *Conn) SetReadLimit(limit int64) { + c.readLimit = limit +} + +// CloseHandler returns the current close handler +func (c *Conn) CloseHandler() func(code int, text string) error { + return c.handleClose +} + +// SetCloseHandler sets the handler for close messages received from the peer. +// The code argument to h is the received close code or CloseNoStatusReceived +// if the close message is empty. The default close handler sends a close +// message back to the peer. +// +// The handler function is called from the NextReader, ReadMessage and message +// reader Read methods. The application must read the connection to process +// close messages as described in the section on Control Messages above. +// +// The connection read methods return a CloseError when a close message is +// received. Most applications should handle close messages as part of their +// normal error handling. Applications should only set a close handler when the +// application must perform some action before sending a close message back to +// the peer. +func (c *Conn) SetCloseHandler(h func(code int, text string) error) { + if h == nil { + h = func(code int, text string) error { + message := FormatCloseMessage(code, "") + c.WriteControl(CloseMessage, message, time.Now().Add(writeWait)) + return nil + } + } + c.handleClose = h +} + +// PingHandler returns the current ping handler +func (c *Conn) PingHandler() func(appData string) error { + return c.handlePing +} + +// SetPingHandler sets the handler for ping messages received from the peer. +// The appData argument to h is the PING message application data. The default +// ping handler sends a pong to the peer. +// +// The handler function is called from the NextReader, ReadMessage and message +// reader Read methods. The application must read the connection to process +// ping messages as described in the section on Control Messages above. +func (c *Conn) SetPingHandler(h func(appData string) error) { + if h == nil { + h = func(message string) error { + err := c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait)) + if err == ErrCloseSent { + return nil + } else if e, ok := err.(net.Error); ok && e.Temporary() { + return nil + } + return err + } + } + c.handlePing = h +} + +// PongHandler returns the current pong handler +func (c *Conn) PongHandler() func(appData string) error { + return c.handlePong +} + +// SetPongHandler sets the handler for pong messages received from the peer. +// The appData argument to h is the PONG message application data. The default +// pong handler does nothing. +// +// The handler function is called from the NextReader, ReadMessage and message +// reader Read methods. The application must read the connection to process +// pong messages as described in the section on Control Messages above. +func (c *Conn) SetPongHandler(h func(appData string) error) { + if h == nil { + h = func(string) error { return nil } + } + c.handlePong = h +} + +// UnderlyingConn returns the internal net.Conn. This can be used to further +// modifications to connection specific flags. +func (c *Conn) UnderlyingConn() net.Conn { + return c.conn +} + +// EnableWriteCompression enables and disables write compression of +// subsequent text and binary messages. This function is a noop if +// compression was not negotiated with the peer. +func (c *Conn) EnableWriteCompression(enable bool) { + c.enableWriteCompression = enable +} + +// SetCompressionLevel sets the flate compression level for subsequent text and +// binary messages. This function is a noop if compression was not negotiated +// with the peer. See the compress/flate package for a description of +// compression levels. +func (c *Conn) SetCompressionLevel(level int) error { + if !isValidCompressionLevel(level) { + return errors.New("websocket: invalid compression level") + } + c.compressionLevel = level + return nil +} + +// FormatCloseMessage formats closeCode and text as a WebSocket close message. +// An empty message is returned for code CloseNoStatusReceived. +func FormatCloseMessage(closeCode int, text string) []byte { + if closeCode == CloseNoStatusReceived { + // Return empty message because it's illegal to send + // CloseNoStatusReceived. Return non-nil value in case application + // checks for nil. + return []byte{} + } + buf := make([]byte, 2+len(text)) + binary.BigEndian.PutUint16(buf, uint16(closeCode)) + copy(buf[2:], text) + return buf +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/conn_write.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/conn_write.go new file mode 100644 index 000000000000..a509a21f87af --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/conn_write.go @@ -0,0 +1,15 @@ +// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.8 + +package websocket + +import "net" + +func (c *Conn) writeBufs(bufs ...[]byte) error { + b := net.Buffers(bufs) + _, err := b.WriteTo(c.conn) + return err +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/conn_write_legacy.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/conn_write_legacy.go new file mode 100644 index 000000000000..37edaff5a578 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/conn_write_legacy.go @@ -0,0 +1,18 @@ +// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !go1.8 + +package websocket + +func (c *Conn) writeBufs(bufs ...[]byte) error { + for _, buf := range bufs { + if len(buf) > 0 { + if _, err := c.conn.Write(buf); err != nil { + return err + } + } + } + return nil +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/doc.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/doc.go new file mode 100644 index 000000000000..8db0cef95a29 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/doc.go @@ -0,0 +1,227 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package websocket implements the WebSocket protocol defined in RFC 6455. +// +// Overview +// +// The Conn type represents a WebSocket connection. A server application calls +// the Upgrader.Upgrade method from an HTTP request handler to get a *Conn: +// +// var upgrader = websocket.Upgrader{ +// ReadBufferSize: 1024, +// WriteBufferSize: 1024, +// } +// +// func handler(w http.ResponseWriter, r *http.Request) { +// conn, err := upgrader.Upgrade(w, r, nil) +// if err != nil { +// log.Println(err) +// return +// } +// ... Use conn to send and receive messages. +// } +// +// Call the connection's WriteMessage and ReadMessage methods to send and +// receive messages as a slice of bytes. This snippet of code shows how to echo +// messages using these methods: +// +// for { +// messageType, p, err := conn.ReadMessage() +// if err != nil { +// log.Println(err) +// return +// } +// if err := conn.WriteMessage(messageType, p); err != nil { +// log.Println(err) +// return +// } +// } +// +// In above snippet of code, p is a []byte and messageType is an int with value +// websocket.BinaryMessage or websocket.TextMessage. +// +// An application can also send and receive messages using the io.WriteCloser +// and io.Reader interfaces. To send a message, call the connection NextWriter +// method to get an io.WriteCloser, write the message to the writer and close +// the writer when done. To receive a message, call the connection NextReader +// method to get an io.Reader and read until io.EOF is returned. This snippet +// shows how to echo messages using the NextWriter and NextReader methods: +// +// for { +// messageType, r, err := conn.NextReader() +// if err != nil { +// return +// } +// w, err := conn.NextWriter(messageType) +// if err != nil { +// return err +// } +// if _, err := io.Copy(w, r); err != nil { +// return err +// } +// if err := w.Close(); err != nil { +// return err +// } +// } +// +// Data Messages +// +// The WebSocket protocol distinguishes between text and binary data messages. +// Text messages are interpreted as UTF-8 encoded text. The interpretation of +// binary messages is left to the application. +// +// This package uses the TextMessage and BinaryMessage integer constants to +// identify the two data message types. The ReadMessage and NextReader methods +// return the type of the received message. The messageType argument to the +// WriteMessage and NextWriter methods specifies the type of a sent message. +// +// It is the application's responsibility to ensure that text messages are +// valid UTF-8 encoded text. +// +// Control Messages +// +// The WebSocket protocol defines three types of control messages: close, ping +// and pong. Call the connection WriteControl, WriteMessage or NextWriter +// methods to send a control message to the peer. +// +// Connections handle received close messages by calling the handler function +// set with the SetCloseHandler method and by returning a *CloseError from the +// NextReader, ReadMessage or the message Read method. The default close +// handler sends a close message to the peer. +// +// Connections handle received ping messages by calling the handler function +// set with the SetPingHandler method. The default ping handler sends a pong +// message to the peer. +// +// Connections handle received pong messages by calling the handler function +// set with the SetPongHandler method. The default pong handler does nothing. +// If an application sends ping messages, then the application should set a +// pong handler to receive the corresponding pong. +// +// The control message handler functions are called from the NextReader, +// ReadMessage and message reader Read methods. The default close and ping +// handlers can block these methods for a short time when the handler writes to +// the connection. +// +// The application must read the connection to process close, ping and pong +// messages sent from the peer. If the application is not otherwise interested +// in messages from the peer, then the application should start a goroutine to +// read and discard messages from the peer. A simple example is: +// +// func readLoop(c *websocket.Conn) { +// for { +// if _, _, err := c.NextReader(); err != nil { +// c.Close() +// break +// } +// } +// } +// +// Concurrency +// +// Connections support one concurrent reader and one concurrent writer. +// +// Applications are responsible for ensuring that no more than one goroutine +// calls the write methods (NextWriter, SetWriteDeadline, WriteMessage, +// WriteJSON, EnableWriteCompression, SetCompressionLevel) concurrently and +// that no more than one goroutine calls the read methods (NextReader, +// SetReadDeadline, ReadMessage, ReadJSON, SetPongHandler, SetPingHandler) +// concurrently. +// +// The Close and WriteControl methods can be called concurrently with all other +// methods. +// +// Origin Considerations +// +// Web browsers allow Javascript applications to open a WebSocket connection to +// any host. It's up to the server to enforce an origin policy using the Origin +// request header sent by the browser. +// +// The Upgrader calls the function specified in the CheckOrigin field to check +// the origin. If the CheckOrigin function returns false, then the Upgrade +// method fails the WebSocket handshake with HTTP status 403. +// +// If the CheckOrigin field is nil, then the Upgrader uses a safe default: fail +// the handshake if the Origin request header is present and the Origin host is +// not equal to the Host request header. +// +// The deprecated package-level Upgrade function does not perform origin +// checking. The application is responsible for checking the Origin header +// before calling the Upgrade function. +// +// Buffers +// +// Connections buffer network input and output to reduce the number +// of system calls when reading or writing messages. +// +// Write buffers are also used for constructing WebSocket frames. See RFC 6455, +// Section 5 for a discussion of message framing. A WebSocket frame header is +// written to the network each time a write buffer is flushed to the network. +// Decreasing the size of the write buffer can increase the amount of framing +// overhead on the connection. +// +// The buffer sizes in bytes are specified by the ReadBufferSize and +// WriteBufferSize fields in the Dialer and Upgrader. The Dialer uses a default +// size of 4096 when a buffer size field is set to zero. The Upgrader reuses +// buffers created by the HTTP server when a buffer size field is set to zero. +// The HTTP server buffers have a size of 4096 at the time of this writing. +// +// The buffer sizes do not limit the size of a message that can be read or +// written by a connection. +// +// Buffers are held for the lifetime of the connection by default. If the +// Dialer or Upgrader WriteBufferPool field is set, then a connection holds the +// write buffer only when writing a message. +// +// Applications should tune the buffer sizes to balance memory use and +// performance. Increasing the buffer size uses more memory, but can reduce the +// number of system calls to read or write the network. In the case of writing, +// increasing the buffer size can reduce the number of frame headers written to +// the network. +// +// Some guidelines for setting buffer parameters are: +// +// Limit the buffer sizes to the maximum expected message size. Buffers larger +// than the largest message do not provide any benefit. +// +// Depending on the distribution of message sizes, setting the buffer size to +// a value less than the maximum expected message size can greatly reduce memory +// use with a small impact on performance. Here's an example: If 99% of the +// messages are smaller than 256 bytes and the maximum message size is 512 +// bytes, then a buffer size of 256 bytes will result in 1.01 more system calls +// than a buffer size of 512 bytes. The memory savings is 50%. +// +// A write buffer pool is useful when the application has a modest number +// writes over a large number of connections. when buffers are pooled, a larger +// buffer size has a reduced impact on total memory use and has the benefit of +// reducing system calls and frame overhead. +// +// Compression EXPERIMENTAL +// +// Per message compression extensions (RFC 7692) are experimentally supported +// by this package in a limited capacity. Setting the EnableCompression option +// to true in Dialer or Upgrader will attempt to negotiate per message deflate +// support. +// +// var upgrader = websocket.Upgrader{ +// EnableCompression: true, +// } +// +// If compression was successfully negotiated with the connection's peer, any +// message received in compressed form will be automatically decompressed. +// All Read methods will return uncompressed bytes. +// +// Per message compression of messages written to a connection can be enabled +// or disabled by calling the corresponding Conn method: +// +// conn.EnableWriteCompression(false) +// +// Currently this package does not support compression with "context takeover". +// This means that messages must be compressed and decompressed in isolation, +// without retaining sliding window or dictionary state across messages. For +// more details refer to RFC 7692. +// +// Use of compression is experimental and may result in decreased performance. +package websocket diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/go.mod b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/go.mod new file mode 100644 index 000000000000..1a7afd5028a7 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/go.mod @@ -0,0 +1,3 @@ +module github.com/gorilla/websocket + +go 1.12 diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/join.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/join.go new file mode 100644 index 000000000000..c64f8c82901a --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/join.go @@ -0,0 +1,42 @@ +// Copyright 2019 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "io" + "strings" +) + +// JoinMessages concatenates received messages to create a single io.Reader. +// The string term is appended to each message. The returned reader does not +// support concurrent calls to the Read method. +func JoinMessages(c *Conn, term string) io.Reader { + return &joinReader{c: c, term: term} +} + +type joinReader struct { + c *Conn + term string + r io.Reader +} + +func (r *joinReader) Read(p []byte) (int, error) { + if r.r == nil { + var err error + _, r.r, err = r.c.NextReader() + if err != nil { + return 0, err + } + if r.term != "" { + r.r = io.MultiReader(r.r, strings.NewReader(r.term)) + } + } + n, err := r.r.Read(p) + if err == io.EOF { + err = nil + r.r = nil + } + return n, err +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/json.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/json.go new file mode 100644 index 000000000000..dc2c1f6415ff --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/json.go @@ -0,0 +1,60 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "encoding/json" + "io" +) + +// WriteJSON writes the JSON encoding of v as a message. +// +// Deprecated: Use c.WriteJSON instead. +func WriteJSON(c *Conn, v interface{}) error { + return c.WriteJSON(v) +} + +// WriteJSON writes the JSON encoding of v as a message. +// +// See the documentation for encoding/json Marshal for details about the +// conversion of Go values to JSON. +func (c *Conn) WriteJSON(v interface{}) error { + w, err := c.NextWriter(TextMessage) + if err != nil { + return err + } + err1 := json.NewEncoder(w).Encode(v) + err2 := w.Close() + if err1 != nil { + return err1 + } + return err2 +} + +// ReadJSON reads the next JSON-encoded message from the connection and stores +// it in the value pointed to by v. +// +// Deprecated: Use c.ReadJSON instead. +func ReadJSON(c *Conn, v interface{}) error { + return c.ReadJSON(v) +} + +// ReadJSON reads the next JSON-encoded message from the connection and stores +// it in the value pointed to by v. +// +// See the documentation for the encoding/json Unmarshal function for details +// about the conversion of JSON to a Go value. +func (c *Conn) ReadJSON(v interface{}) error { + _, r, err := c.NextReader() + if err != nil { + return err + } + err = json.NewDecoder(r).Decode(v) + if err == io.EOF { + // One value is expected in the message. + err = io.ErrUnexpectedEOF + } + return err +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/mask.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/mask.go new file mode 100644 index 000000000000..577fce9efd72 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/mask.go @@ -0,0 +1,54 @@ +// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of +// this source code is governed by a BSD-style license that can be found in the +// LICENSE file. + +// +build !appengine + +package websocket + +import "unsafe" + +const wordSize = int(unsafe.Sizeof(uintptr(0))) + +func maskBytes(key [4]byte, pos int, b []byte) int { + // Mask one byte at a time for small buffers. + if len(b) < 2*wordSize { + for i := range b { + b[i] ^= key[pos&3] + pos++ + } + return pos & 3 + } + + // Mask one byte at a time to word boundary. + if n := int(uintptr(unsafe.Pointer(&b[0]))) % wordSize; n != 0 { + n = wordSize - n + for i := range b[:n] { + b[i] ^= key[pos&3] + pos++ + } + b = b[n:] + } + + // Create aligned word size key. + var k [wordSize]byte + for i := range k { + k[i] = key[(pos+i)&3] + } + kw := *(*uintptr)(unsafe.Pointer(&k)) + + // Mask one word at a time. + n := (len(b) / wordSize) * wordSize + for i := 0; i < n; i += wordSize { + *(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&b[0])) + uintptr(i))) ^= kw + } + + // Mask one byte at a time for remaining bytes. + b = b[n:] + for i := range b { + b[i] ^= key[pos&3] + pos++ + } + + return pos & 3 +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/mask_safe.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/mask_safe.go new file mode 100644 index 000000000000..2aac060e52e7 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/mask_safe.go @@ -0,0 +1,15 @@ +// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of +// this source code is governed by a BSD-style license that can be found in the +// LICENSE file. + +// +build appengine + +package websocket + +func maskBytes(key [4]byte, pos int, b []byte) int { + for i := range b { + b[i] ^= key[pos&3] + pos++ + } + return pos & 3 +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/prepared.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/prepared.go new file mode 100644 index 000000000000..c854225e9676 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/prepared.go @@ -0,0 +1,102 @@ +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bytes" + "net" + "sync" + "time" +) + +// PreparedMessage caches on the wire representations of a message payload. +// Use PreparedMessage to efficiently send a message payload to multiple +// connections. PreparedMessage is especially useful when compression is used +// because the CPU and memory expensive compression operation can be executed +// once for a given set of compression options. +type PreparedMessage struct { + messageType int + data []byte + mu sync.Mutex + frames map[prepareKey]*preparedFrame +} + +// prepareKey defines a unique set of options to cache prepared frames in PreparedMessage. +type prepareKey struct { + isServer bool + compress bool + compressionLevel int +} + +// preparedFrame contains data in wire representation. +type preparedFrame struct { + once sync.Once + data []byte +} + +// NewPreparedMessage returns an initialized PreparedMessage. You can then send +// it to connection using WritePreparedMessage method. Valid wire +// representation will be calculated lazily only once for a set of current +// connection options. +func NewPreparedMessage(messageType int, data []byte) (*PreparedMessage, error) { + pm := &PreparedMessage{ + messageType: messageType, + frames: make(map[prepareKey]*preparedFrame), + data: data, + } + + // Prepare a plain server frame. + _, frameData, err := pm.frame(prepareKey{isServer: true, compress: false}) + if err != nil { + return nil, err + } + + // To protect against caller modifying the data argument, remember the data + // copied to the plain server frame. + pm.data = frameData[len(frameData)-len(data):] + return pm, nil +} + +func (pm *PreparedMessage) frame(key prepareKey) (int, []byte, error) { + pm.mu.Lock() + frame, ok := pm.frames[key] + if !ok { + frame = &preparedFrame{} + pm.frames[key] = frame + } + pm.mu.Unlock() + + var err error + frame.once.Do(func() { + // Prepare a frame using a 'fake' connection. + // TODO: Refactor code in conn.go to allow more direct construction of + // the frame. + mu := make(chan struct{}, 1) + mu <- struct{}{} + var nc prepareConn + c := &Conn{ + conn: &nc, + mu: mu, + isServer: key.isServer, + compressionLevel: key.compressionLevel, + enableWriteCompression: true, + writeBuf: make([]byte, defaultWriteBufferSize+maxFrameHeaderSize), + } + if key.compress { + c.newCompressionWriter = compressNoContextTakeover + } + err = c.WriteMessage(pm.messageType, pm.data) + frame.data = nc.buf.Bytes() + }) + return pm.messageType, frame.data, err +} + +type prepareConn struct { + buf bytes.Buffer + net.Conn +} + +func (pc *prepareConn) Write(p []byte) (int, error) { return pc.buf.Write(p) } +func (pc *prepareConn) SetWriteDeadline(t time.Time) error { return nil } diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/proxy.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/proxy.go new file mode 100644 index 000000000000..e87a8c9f0c96 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/proxy.go @@ -0,0 +1,77 @@ +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bufio" + "encoding/base64" + "errors" + "net" + "net/http" + "net/url" + "strings" +) + +type netDialerFunc func(network, addr string) (net.Conn, error) + +func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) { + return fn(network, addr) +} + +func init() { + proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) { + return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil + }) +} + +type httpProxyDialer struct { + proxyURL *url.URL + forwardDial func(network, addr string) (net.Conn, error) +} + +func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) { + hostPort, _ := hostPortNoPort(hpd.proxyURL) + conn, err := hpd.forwardDial(network, hostPort) + if err != nil { + return nil, err + } + + connectHeader := make(http.Header) + if user := hpd.proxyURL.User; user != nil { + proxyUser := user.Username() + if proxyPassword, passwordSet := user.Password(); passwordSet { + credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword)) + connectHeader.Set("Proxy-Authorization", "Basic "+credential) + } + } + + connectReq := &http.Request{ + Method: "CONNECT", + URL: &url.URL{Opaque: addr}, + Host: addr, + Header: connectHeader, + } + + if err := connectReq.Write(conn); err != nil { + conn.Close() + return nil, err + } + + // Read response. It's OK to use and discard buffered reader here becaue + // the remote server does not speak until spoken to. + br := bufio.NewReader(conn) + resp, err := http.ReadResponse(br, connectReq) + if err != nil { + conn.Close() + return nil, err + } + + if resp.StatusCode != 200 { + conn.Close() + f := strings.SplitN(resp.Status, " ", 2) + return nil, errors.New(f[1]) + } + return conn, nil +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/server.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/server.go new file mode 100644 index 000000000000..887d558918c7 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/server.go @@ -0,0 +1,363 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bufio" + "errors" + "io" + "net/http" + "net/url" + "strings" + "time" +) + +// HandshakeError describes an error with the handshake from the peer. +type HandshakeError struct { + message string +} + +func (e HandshakeError) Error() string { return e.message } + +// Upgrader specifies parameters for upgrading an HTTP connection to a +// WebSocket connection. +type Upgrader struct { + // HandshakeTimeout specifies the duration for the handshake to complete. + HandshakeTimeout time.Duration + + // ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer + // size is zero, then buffers allocated by the HTTP server are used. The + // I/O buffer sizes do not limit the size of the messages that can be sent + // or received. + ReadBufferSize, WriteBufferSize int + + // WriteBufferPool is a pool of buffers for write operations. If the value + // is not set, then write buffers are allocated to the connection for the + // lifetime of the connection. + // + // A pool is most useful when the application has a modest volume of writes + // across a large number of connections. + // + // Applications should use a single pool for each unique value of + // WriteBufferSize. + WriteBufferPool BufferPool + + // Subprotocols specifies the server's supported protocols in order of + // preference. If this field is not nil, then the Upgrade method negotiates a + // subprotocol by selecting the first match in this list with a protocol + // requested by the client. If there's no match, then no protocol is + // negotiated (the Sec-Websocket-Protocol header is not included in the + // handshake response). + Subprotocols []string + + // Error specifies the function for generating HTTP error responses. If Error + // is nil, then http.Error is used to generate the HTTP response. + Error func(w http.ResponseWriter, r *http.Request, status int, reason error) + + // CheckOrigin returns true if the request Origin header is acceptable. If + // CheckOrigin is nil, then a safe default is used: return false if the + // Origin request header is present and the origin host is not equal to + // request Host header. + // + // A CheckOrigin function should carefully validate the request origin to + // prevent cross-site request forgery. + CheckOrigin func(r *http.Request) bool + + // EnableCompression specify if the server should attempt to negotiate per + // message compression (RFC 7692). Setting this value to true does not + // guarantee that compression will be supported. Currently only "no context + // takeover" modes are supported. + EnableCompression bool +} + +func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) { + err := HandshakeError{reason} + if u.Error != nil { + u.Error(w, r, status, err) + } else { + w.Header().Set("Sec-Websocket-Version", "13") + http.Error(w, http.StatusText(status), status) + } + return nil, err +} + +// checkSameOrigin returns true if the origin is not set or is equal to the request host. +func checkSameOrigin(r *http.Request) bool { + origin := r.Header["Origin"] + if len(origin) == 0 { + return true + } + u, err := url.Parse(origin[0]) + if err != nil { + return false + } + return equalASCIIFold(u.Host, r.Host) +} + +func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string { + if u.Subprotocols != nil { + clientProtocols := Subprotocols(r) + for _, serverProtocol := range u.Subprotocols { + for _, clientProtocol := range clientProtocols { + if clientProtocol == serverProtocol { + return clientProtocol + } + } + } + } else if responseHeader != nil { + return responseHeader.Get("Sec-Websocket-Protocol") + } + return "" +} + +// Upgrade upgrades the HTTP server connection to the WebSocket protocol. +// +// The responseHeader is included in the response to the client's upgrade +// request. Use the responseHeader to specify cookies (Set-Cookie) and the +// application negotiated subprotocol (Sec-WebSocket-Protocol). +// +// If the upgrade fails, then Upgrade replies to the client with an HTTP error +// response. +func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) { + const badHandshake = "websocket: the client is not using the websocket protocol: " + + if !tokenListContainsValue(r.Header, "Connection", "upgrade") { + return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'upgrade' token not found in 'Connection' header") + } + + if !tokenListContainsValue(r.Header, "Upgrade", "websocket") { + return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'websocket' token not found in 'Upgrade' header") + } + + if r.Method != "GET" { + return u.returnError(w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET") + } + + if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") { + return u.returnError(w, r, http.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header") + } + + if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok { + return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific 'Sec-WebSocket-Extensions' headers are unsupported") + } + + checkOrigin := u.CheckOrigin + if checkOrigin == nil { + checkOrigin = checkSameOrigin + } + if !checkOrigin(r) { + return u.returnError(w, r, http.StatusForbidden, "websocket: request origin not allowed by Upgrader.CheckOrigin") + } + + challengeKey := r.Header.Get("Sec-Websocket-Key") + if challengeKey == "" { + return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header is missing or blank") + } + + subprotocol := u.selectSubprotocol(r, responseHeader) + + // Negotiate PMCE + var compress bool + if u.EnableCompression { + for _, ext := range parseExtensions(r.Header) { + if ext[""] != "permessage-deflate" { + continue + } + compress = true + break + } + } + + h, ok := w.(http.Hijacker) + if !ok { + return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker") + } + var brw *bufio.ReadWriter + netConn, brw, err := h.Hijack() + if err != nil { + return u.returnError(w, r, http.StatusInternalServerError, err.Error()) + } + + if brw.Reader.Buffered() > 0 { + netConn.Close() + return nil, errors.New("websocket: client sent data before handshake is complete") + } + + var br *bufio.Reader + if u.ReadBufferSize == 0 && bufioReaderSize(netConn, brw.Reader) > 256 { + // Reuse hijacked buffered reader as connection reader. + br = brw.Reader + } + + buf := bufioWriterBuffer(netConn, brw.Writer) + + var writeBuf []byte + if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 { + // Reuse hijacked write buffer as connection buffer. + writeBuf = buf + } + + c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, br, writeBuf) + c.subprotocol = subprotocol + + if compress { + c.newCompressionWriter = compressNoContextTakeover + c.newDecompressionReader = decompressNoContextTakeover + } + + // Use larger of hijacked buffer and connection write buffer for header. + p := buf + if len(c.writeBuf) > len(p) { + p = c.writeBuf + } + p = p[:0] + + p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...) + p = append(p, computeAcceptKey(challengeKey)...) + p = append(p, "\r\n"...) + if c.subprotocol != "" { + p = append(p, "Sec-WebSocket-Protocol: "...) + p = append(p, c.subprotocol...) + p = append(p, "\r\n"...) + } + if compress { + p = append(p, "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...) + } + for k, vs := range responseHeader { + if k == "Sec-Websocket-Protocol" { + continue + } + for _, v := range vs { + p = append(p, k...) + p = append(p, ": "...) + for i := 0; i < len(v); i++ { + b := v[i] + if b <= 31 { + // prevent response splitting. + b = ' ' + } + p = append(p, b) + } + p = append(p, "\r\n"...) + } + } + p = append(p, "\r\n"...) + + // Clear deadlines set by HTTP server. + netConn.SetDeadline(time.Time{}) + + if u.HandshakeTimeout > 0 { + netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)) + } + if _, err = netConn.Write(p); err != nil { + netConn.Close() + return nil, err + } + if u.HandshakeTimeout > 0 { + netConn.SetWriteDeadline(time.Time{}) + } + + return c, nil +} + +// Upgrade upgrades the HTTP server connection to the WebSocket protocol. +// +// Deprecated: Use websocket.Upgrader instead. +// +// Upgrade does not perform origin checking. The application is responsible for +// checking the Origin header before calling Upgrade. An example implementation +// of the same origin policy check is: +// +// if req.Header.Get("Origin") != "http://"+req.Host { +// http.Error(w, "Origin not allowed", http.StatusForbidden) +// return +// } +// +// If the endpoint supports subprotocols, then the application is responsible +// for negotiating the protocol used on the connection. Use the Subprotocols() +// function to get the subprotocols requested by the client. Use the +// Sec-Websocket-Protocol response header to specify the subprotocol selected +// by the application. +// +// The responseHeader is included in the response to the client's upgrade +// request. Use the responseHeader to specify cookies (Set-Cookie) and the +// negotiated subprotocol (Sec-Websocket-Protocol). +// +// The connection buffers IO to the underlying network connection. The +// readBufSize and writeBufSize parameters specify the size of the buffers to +// use. Messages can be larger than the buffers. +// +// If the request is not a valid WebSocket handshake, then Upgrade returns an +// error of type HandshakeError. Applications should handle this error by +// replying to the client with an HTTP error response. +func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header, readBufSize, writeBufSize int) (*Conn, error) { + u := Upgrader{ReadBufferSize: readBufSize, WriteBufferSize: writeBufSize} + u.Error = func(w http.ResponseWriter, r *http.Request, status int, reason error) { + // don't return errors to maintain backwards compatibility + } + u.CheckOrigin = func(r *http.Request) bool { + // allow all connections by default + return true + } + return u.Upgrade(w, r, responseHeader) +} + +// Subprotocols returns the subprotocols requested by the client in the +// Sec-Websocket-Protocol header. +func Subprotocols(r *http.Request) []string { + h := strings.TrimSpace(r.Header.Get("Sec-Websocket-Protocol")) + if h == "" { + return nil + } + protocols := strings.Split(h, ",") + for i := range protocols { + protocols[i] = strings.TrimSpace(protocols[i]) + } + return protocols +} + +// IsWebSocketUpgrade returns true if the client requested upgrade to the +// WebSocket protocol. +func IsWebSocketUpgrade(r *http.Request) bool { + return tokenListContainsValue(r.Header, "Connection", "upgrade") && + tokenListContainsValue(r.Header, "Upgrade", "websocket") +} + +// bufioReaderSize size returns the size of a bufio.Reader. +func bufioReaderSize(originalReader io.Reader, br *bufio.Reader) int { + // This code assumes that peek on a reset reader returns + // bufio.Reader.buf[:0]. + // TODO: Use bufio.Reader.Size() after Go 1.10 + br.Reset(originalReader) + if p, err := br.Peek(0); err == nil { + return cap(p) + } + return 0 +} + +// writeHook is an io.Writer that records the last slice passed to it vio +// io.Writer.Write. +type writeHook struct { + p []byte +} + +func (wh *writeHook) Write(p []byte) (int, error) { + wh.p = p + return len(p), nil +} + +// bufioWriterBuffer grabs the buffer from a bufio.Writer. +func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte { + // This code assumes that bufio.Writer.buf[:1] is passed to the + // bufio.Writer's underlying writer. + var wh writeHook + bw.Reset(&wh) + bw.WriteByte(0) + bw.Flush() + + bw.Reset(originalWriter) + + return wh.p[:cap(wh.p)] +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/stub.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/stub.go deleted file mode 100644 index 0be1589cca96..000000000000 --- a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/stub.go +++ /dev/null @@ -1,135 +0,0 @@ -// Code generated by depstubber. DO NOT EDIT. -// This is a simple stub for github.com/gorilla/websocket, strictly for use in testing. - -// See the LICENSE file for information about the licensing of the original library. -// Source: github.com/gorilla/websocket (exports: Dialer; functions: ) - -// Package websocket is a stub of github.com/gorilla/websocket, generated by depstubber. -package websocket - -import ( - context "context" - tls "crypto/tls" - io "io" - net "net" - http "net/http" - url "net/url" - time "time" -) - -type BufferPool interface { - Get() interface{} - Put(_ interface{}) -} - -type Conn struct{} - -func (_ *Conn) Close() error { - return nil -} - -func (_ *Conn) CloseHandler() func(int, string) error { - return nil -} - -func (_ *Conn) EnableWriteCompression(_ bool) {} - -func (_ *Conn) LocalAddr() net.Addr { - return nil -} - -func (_ *Conn) NextReader() (int, io.Reader, error) { - return 0, nil, nil -} - -func (_ *Conn) NextWriter(_ int) (io.WriteCloser, error) { - return nil, nil -} - -func (_ *Conn) PingHandler() func(string) error { - return nil -} - -func (_ *Conn) PongHandler() func(string) error { - return nil -} - -func (_ *Conn) ReadJSON(_ interface{}) error { - return nil -} - -func (_ *Conn) ReadMessage() (int, []byte, error) { - return 0, nil, nil -} - -func (_ *Conn) RemoteAddr() net.Addr { - return nil -} - -func (_ *Conn) SetCloseHandler(_ func(int, string) error) {} - -func (_ *Conn) SetCompressionLevel(_ int) error { - return nil -} - -func (_ *Conn) SetPingHandler(_ func(string) error) {} - -func (_ *Conn) SetPongHandler(_ func(string) error) {} - -func (_ *Conn) SetReadDeadline(_ time.Time) error { - return nil -} - -func (_ *Conn) SetReadLimit(_ int64) {} - -func (_ *Conn) SetWriteDeadline(_ time.Time) error { - return nil -} - -func (_ *Conn) Subprotocol() string { - return "" -} - -func (_ *Conn) UnderlyingConn() net.Conn { - return nil -} - -func (_ *Conn) WriteControl(_ int, _ []byte, _ time.Time) error { - return nil -} - -func (_ *Conn) WriteJSON(_ interface{}) error { - return nil -} - -func (_ *Conn) WriteMessage(_ int, _ []byte) error { - return nil -} - -func (_ *Conn) WritePreparedMessage(_ *PreparedMessage) error { - return nil -} - -type Dialer struct { - NetDial func(string, string) (net.Conn, error) - NetDialContext func(context.Context, string, string) (net.Conn, error) - Proxy func(*http.Request) (*url.URL, error) - TLSClientConfig *tls.Config - HandshakeTimeout time.Duration - ReadBufferSize int - WriteBufferSize int - WriteBufferPool BufferPool - Subprotocols []string - EnableCompression bool - Jar http.CookieJar -} - -func (_ *Dialer) Dial(_ string, _ http.Header) (*Conn, *http.Response, error) { - return nil, nil, nil -} - -func (_ *Dialer) DialContext(_ context.Context, _ string, _ http.Header) (*Conn, *http.Response, error) { - return nil, nil, nil -} - -type PreparedMessage struct{} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/trace.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/trace.go new file mode 100644 index 000000000000..834f122a00db --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/trace.go @@ -0,0 +1,19 @@ +// +build go1.8 + +package websocket + +import ( + "crypto/tls" + "net/http/httptrace" +) + +func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error { + if trace.TLSHandshakeStart != nil { + trace.TLSHandshakeStart() + } + err := doHandshake(tlsConn, cfg) + if trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(tlsConn.ConnectionState(), err) + } + return err +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/trace_17.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/trace_17.go new file mode 100644 index 000000000000..77d05a0b5748 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/trace_17.go @@ -0,0 +1,12 @@ +// +build !go1.8 + +package websocket + +import ( + "crypto/tls" + "net/http/httptrace" +) + +func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error { + return doHandshake(tlsConn, cfg) +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/util.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/util.go new file mode 100644 index 000000000000..7bf2f66c6747 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/util.go @@ -0,0 +1,283 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "crypto/rand" + "crypto/sha1" + "encoding/base64" + "io" + "net/http" + "strings" + "unicode/utf8" +) + +var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") + +func computeAcceptKey(challengeKey string) string { + h := sha1.New() + h.Write([]byte(challengeKey)) + h.Write(keyGUID) + return base64.StdEncoding.EncodeToString(h.Sum(nil)) +} + +func generateChallengeKey() (string, error) { + p := make([]byte, 16) + if _, err := io.ReadFull(rand.Reader, p); err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(p), nil +} + +// Token octets per RFC 2616. +var isTokenOctet = [256]bool{ + '!': true, + '#': true, + '$': true, + '%': true, + '&': true, + '\'': true, + '*': true, + '+': true, + '-': true, + '.': true, + '0': true, + '1': true, + '2': true, + '3': true, + '4': true, + '5': true, + '6': true, + '7': true, + '8': true, + '9': true, + 'A': true, + 'B': true, + 'C': true, + 'D': true, + 'E': true, + 'F': true, + 'G': true, + 'H': true, + 'I': true, + 'J': true, + 'K': true, + 'L': true, + 'M': true, + 'N': true, + 'O': true, + 'P': true, + 'Q': true, + 'R': true, + 'S': true, + 'T': true, + 'U': true, + 'W': true, + 'V': true, + 'X': true, + 'Y': true, + 'Z': true, + '^': true, + '_': true, + '`': true, + 'a': true, + 'b': true, + 'c': true, + 'd': true, + 'e': true, + 'f': true, + 'g': true, + 'h': true, + 'i': true, + 'j': true, + 'k': true, + 'l': true, + 'm': true, + 'n': true, + 'o': true, + 'p': true, + 'q': true, + 'r': true, + 's': true, + 't': true, + 'u': true, + 'v': true, + 'w': true, + 'x': true, + 'y': true, + 'z': true, + '|': true, + '~': true, +} + +// skipSpace returns a slice of the string s with all leading RFC 2616 linear +// whitespace removed. +func skipSpace(s string) (rest string) { + i := 0 + for ; i < len(s); i++ { + if b := s[i]; b != ' ' && b != '\t' { + break + } + } + return s[i:] +} + +// nextToken returns the leading RFC 2616 token of s and the string following +// the token. +func nextToken(s string) (token, rest string) { + i := 0 + for ; i < len(s); i++ { + if !isTokenOctet[s[i]] { + break + } + } + return s[:i], s[i:] +} + +// nextTokenOrQuoted returns the leading token or quoted string per RFC 2616 +// and the string following the token or quoted string. +func nextTokenOrQuoted(s string) (value string, rest string) { + if !strings.HasPrefix(s, "\"") { + return nextToken(s) + } + s = s[1:] + for i := 0; i < len(s); i++ { + switch s[i] { + case '"': + return s[:i], s[i+1:] + case '\\': + p := make([]byte, len(s)-1) + j := copy(p, s[:i]) + escape := true + for i = i + 1; i < len(s); i++ { + b := s[i] + switch { + case escape: + escape = false + p[j] = b + j++ + case b == '\\': + escape = true + case b == '"': + return string(p[:j]), s[i+1:] + default: + p[j] = b + j++ + } + } + return "", "" + } + } + return "", "" +} + +// equalASCIIFold returns true if s is equal to t with ASCII case folding as +// defined in RFC 4790. +func equalASCIIFold(s, t string) bool { + for s != "" && t != "" { + sr, size := utf8.DecodeRuneInString(s) + s = s[size:] + tr, size := utf8.DecodeRuneInString(t) + t = t[size:] + if sr == tr { + continue + } + if 'A' <= sr && sr <= 'Z' { + sr = sr + 'a' - 'A' + } + if 'A' <= tr && tr <= 'Z' { + tr = tr + 'a' - 'A' + } + if sr != tr { + return false + } + } + return s == t +} + +// tokenListContainsValue returns true if the 1#token header with the given +// name contains a token equal to value with ASCII case folding. +func tokenListContainsValue(header http.Header, name string, value string) bool { +headers: + for _, s := range header[name] { + for { + var t string + t, s = nextToken(skipSpace(s)) + if t == "" { + continue headers + } + s = skipSpace(s) + if s != "" && s[0] != ',' { + continue headers + } + if equalASCIIFold(t, value) { + return true + } + if s == "" { + continue headers + } + s = s[1:] + } + } + return false +} + +// parseExtensions parses WebSocket extensions from a header. +func parseExtensions(header http.Header) []map[string]string { + // From RFC 6455: + // + // Sec-WebSocket-Extensions = extension-list + // extension-list = 1#extension + // extension = extension-token *( ";" extension-param ) + // extension-token = registered-token + // registered-token = token + // extension-param = token [ "=" (token | quoted-string) ] + // ;When using the quoted-string syntax variant, the value + // ;after quoted-string unescaping MUST conform to the + // ;'token' ABNF. + + var result []map[string]string +headers: + for _, s := range header["Sec-Websocket-Extensions"] { + for { + var t string + t, s = nextToken(skipSpace(s)) + if t == "" { + continue headers + } + ext := map[string]string{"": t} + for { + s = skipSpace(s) + if !strings.HasPrefix(s, ";") { + break + } + var k string + k, s = nextToken(skipSpace(s[1:])) + if k == "" { + continue headers + } + s = skipSpace(s) + var v string + if strings.HasPrefix(s, "=") { + v, s = nextTokenOrQuoted(skipSpace(s[1:])) + s = skipSpace(s) + } + if s != "" && s[0] != ',' && s[0] != ';' { + continue headers + } + ext[k] = v + } + if s != "" && s[0] != ',' { + continue headers + } + result = append(result, ext) + if s == "" { + continue headers + } + s = s[1:] + } + } + return result +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/x_net_proxy.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/x_net_proxy.go new file mode 100644 index 000000000000..2e668f6b8821 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/gorilla/websocket/x_net_proxy.go @@ -0,0 +1,473 @@ +// Code generated by golang.org/x/tools/cmd/bundle. DO NOT EDIT. +//go:generate bundle -o x_net_proxy.go golang.org/x/net/proxy + +// Package proxy provides support for a variety of protocols to proxy network +// data. +// + +package websocket + +import ( + "errors" + "io" + "net" + "net/url" + "os" + "strconv" + "strings" + "sync" +) + +type proxy_direct struct{} + +// Direct is a direct proxy: one that makes network connections directly. +var proxy_Direct = proxy_direct{} + +func (proxy_direct) Dial(network, addr string) (net.Conn, error) { + return net.Dial(network, addr) +} + +// A PerHost directs connections to a default Dialer unless the host name +// requested matches one of a number of exceptions. +type proxy_PerHost struct { + def, bypass proxy_Dialer + + bypassNetworks []*net.IPNet + bypassIPs []net.IP + bypassZones []string + bypassHosts []string +} + +// NewPerHost returns a PerHost Dialer that directs connections to either +// defaultDialer or bypass, depending on whether the connection matches one of +// the configured rules. +func proxy_NewPerHost(defaultDialer, bypass proxy_Dialer) *proxy_PerHost { + return &proxy_PerHost{ + def: defaultDialer, + bypass: bypass, + } +} + +// Dial connects to the address addr on the given network through either +// defaultDialer or bypass. +func (p *proxy_PerHost) Dial(network, addr string) (c net.Conn, err error) { + host, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + + return p.dialerForRequest(host).Dial(network, addr) +} + +func (p *proxy_PerHost) dialerForRequest(host string) proxy_Dialer { + if ip := net.ParseIP(host); ip != nil { + for _, net := range p.bypassNetworks { + if net.Contains(ip) { + return p.bypass + } + } + for _, bypassIP := range p.bypassIPs { + if bypassIP.Equal(ip) { + return p.bypass + } + } + return p.def + } + + for _, zone := range p.bypassZones { + if strings.HasSuffix(host, zone) { + return p.bypass + } + if host == zone[1:] { + // For a zone ".example.com", we match "example.com" + // too. + return p.bypass + } + } + for _, bypassHost := range p.bypassHosts { + if bypassHost == host { + return p.bypass + } + } + return p.def +} + +// AddFromString parses a string that contains comma-separated values +// specifying hosts that should use the bypass proxy. Each value is either an +// IP address, a CIDR range, a zone (*.example.com) or a host name +// (localhost). A best effort is made to parse the string and errors are +// ignored. +func (p *proxy_PerHost) AddFromString(s string) { + hosts := strings.Split(s, ",") + for _, host := range hosts { + host = strings.TrimSpace(host) + if len(host) == 0 { + continue + } + if strings.Contains(host, "/") { + // We assume that it's a CIDR address like 127.0.0.0/8 + if _, net, err := net.ParseCIDR(host); err == nil { + p.AddNetwork(net) + } + continue + } + if ip := net.ParseIP(host); ip != nil { + p.AddIP(ip) + continue + } + if strings.HasPrefix(host, "*.") { + p.AddZone(host[1:]) + continue + } + p.AddHost(host) + } +} + +// AddIP specifies an IP address that will use the bypass proxy. Note that +// this will only take effect if a literal IP address is dialed. A connection +// to a named host will never match an IP. +func (p *proxy_PerHost) AddIP(ip net.IP) { + p.bypassIPs = append(p.bypassIPs, ip) +} + +// AddNetwork specifies an IP range that will use the bypass proxy. Note that +// this will only take effect if a literal IP address is dialed. A connection +// to a named host will never match. +func (p *proxy_PerHost) AddNetwork(net *net.IPNet) { + p.bypassNetworks = append(p.bypassNetworks, net) +} + +// AddZone specifies a DNS suffix that will use the bypass proxy. A zone of +// "example.com" matches "example.com" and all of its subdomains. +func (p *proxy_PerHost) AddZone(zone string) { + if strings.HasSuffix(zone, ".") { + zone = zone[:len(zone)-1] + } + if !strings.HasPrefix(zone, ".") { + zone = "." + zone + } + p.bypassZones = append(p.bypassZones, zone) +} + +// AddHost specifies a host name that will use the bypass proxy. +func (p *proxy_PerHost) AddHost(host string) { + if strings.HasSuffix(host, ".") { + host = host[:len(host)-1] + } + p.bypassHosts = append(p.bypassHosts, host) +} + +// A Dialer is a means to establish a connection. +type proxy_Dialer interface { + // Dial connects to the given address via the proxy. + Dial(network, addr string) (c net.Conn, err error) +} + +// Auth contains authentication parameters that specific Dialers may require. +type proxy_Auth struct { + User, Password string +} + +// FromEnvironment returns the dialer specified by the proxy related variables in +// the environment. +func proxy_FromEnvironment() proxy_Dialer { + allProxy := proxy_allProxyEnv.Get() + if len(allProxy) == 0 { + return proxy_Direct + } + + proxyURL, err := url.Parse(allProxy) + if err != nil { + return proxy_Direct + } + proxy, err := proxy_FromURL(proxyURL, proxy_Direct) + if err != nil { + return proxy_Direct + } + + noProxy := proxy_noProxyEnv.Get() + if len(noProxy) == 0 { + return proxy + } + + perHost := proxy_NewPerHost(proxy, proxy_Direct) + perHost.AddFromString(noProxy) + return perHost +} + +// proxySchemes is a map from URL schemes to a function that creates a Dialer +// from a URL with such a scheme. +var proxy_proxySchemes map[string]func(*url.URL, proxy_Dialer) (proxy_Dialer, error) + +// RegisterDialerType takes a URL scheme and a function to generate Dialers from +// a URL with that scheme and a forwarding Dialer. Registered schemes are used +// by FromURL. +func proxy_RegisterDialerType(scheme string, f func(*url.URL, proxy_Dialer) (proxy_Dialer, error)) { + if proxy_proxySchemes == nil { + proxy_proxySchemes = make(map[string]func(*url.URL, proxy_Dialer) (proxy_Dialer, error)) + } + proxy_proxySchemes[scheme] = f +} + +// FromURL returns a Dialer given a URL specification and an underlying +// Dialer for it to make network requests. +func proxy_FromURL(u *url.URL, forward proxy_Dialer) (proxy_Dialer, error) { + var auth *proxy_Auth + if u.User != nil { + auth = new(proxy_Auth) + auth.User = u.User.Username() + if p, ok := u.User.Password(); ok { + auth.Password = p + } + } + + switch u.Scheme { + case "socks5": + return proxy_SOCKS5("tcp", u.Host, auth, forward) + } + + // If the scheme doesn't match any of the built-in schemes, see if it + // was registered by another package. + if proxy_proxySchemes != nil { + if f, ok := proxy_proxySchemes[u.Scheme]; ok { + return f(u, forward) + } + } + + return nil, errors.New("proxy: unknown scheme: " + u.Scheme) +} + +var ( + proxy_allProxyEnv = &proxy_envOnce{ + names: []string{"ALL_PROXY", "all_proxy"}, + } + proxy_noProxyEnv = &proxy_envOnce{ + names: []string{"NO_PROXY", "no_proxy"}, + } +) + +// envOnce looks up an environment variable (optionally by multiple +// names) once. It mitigates expensive lookups on some platforms +// (e.g. Windows). +// (Borrowed from net/http/transport.go) +type proxy_envOnce struct { + names []string + once sync.Once + val string +} + +func (e *proxy_envOnce) Get() string { + e.once.Do(e.init) + return e.val +} + +func (e *proxy_envOnce) init() { + for _, n := range e.names { + e.val = os.Getenv(n) + if e.val != "" { + return + } + } +} + +// SOCKS5 returns a Dialer that makes SOCKSv5 connections to the given address +// with an optional username and password. See RFC 1928 and RFC 1929. +func proxy_SOCKS5(network, addr string, auth *proxy_Auth, forward proxy_Dialer) (proxy_Dialer, error) { + s := &proxy_socks5{ + network: network, + addr: addr, + forward: forward, + } + if auth != nil { + s.user = auth.User + s.password = auth.Password + } + + return s, nil +} + +type proxy_socks5 struct { + user, password string + network, addr string + forward proxy_Dialer +} + +const proxy_socks5Version = 5 + +const ( + proxy_socks5AuthNone = 0 + proxy_socks5AuthPassword = 2 +) + +const proxy_socks5Connect = 1 + +const ( + proxy_socks5IP4 = 1 + proxy_socks5Domain = 3 + proxy_socks5IP6 = 4 +) + +var proxy_socks5Errors = []string{ + "", + "general failure", + "connection forbidden", + "network unreachable", + "host unreachable", + "connection refused", + "TTL expired", + "command not supported", + "address type not supported", +} + +// Dial connects to the address addr on the given network via the SOCKS5 proxy. +func (s *proxy_socks5) Dial(network, addr string) (net.Conn, error) { + switch network { + case "tcp", "tcp6", "tcp4": + default: + return nil, errors.New("proxy: no support for SOCKS5 proxy connections of type " + network) + } + + conn, err := s.forward.Dial(s.network, s.addr) + if err != nil { + return nil, err + } + if err := s.connect(conn, addr); err != nil { + conn.Close() + return nil, err + } + return conn, nil +} + +// connect takes an existing connection to a socks5 proxy server, +// and commands the server to extend that connection to target, +// which must be a canonical address with a host and port. +func (s *proxy_socks5) connect(conn net.Conn, target string) error { + host, portStr, err := net.SplitHostPort(target) + if err != nil { + return err + } + + port, err := strconv.Atoi(portStr) + if err != nil { + return errors.New("proxy: failed to parse port number: " + portStr) + } + if port < 1 || port > 0xffff { + return errors.New("proxy: port number out of range: " + portStr) + } + + // the size here is just an estimate + buf := make([]byte, 0, 6+len(host)) + + buf = append(buf, proxy_socks5Version) + if len(s.user) > 0 && len(s.user) < 256 && len(s.password) < 256 { + buf = append(buf, 2 /* num auth methods */, proxy_socks5AuthNone, proxy_socks5AuthPassword) + } else { + buf = append(buf, 1 /* num auth methods */, proxy_socks5AuthNone) + } + + if _, err := conn.Write(buf); err != nil { + return errors.New("proxy: failed to write greeting to SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + + if _, err := io.ReadFull(conn, buf[:2]); err != nil { + return errors.New("proxy: failed to read greeting from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + if buf[0] != 5 { + return errors.New("proxy: SOCKS5 proxy at " + s.addr + " has unexpected version " + strconv.Itoa(int(buf[0]))) + } + if buf[1] == 0xff { + return errors.New("proxy: SOCKS5 proxy at " + s.addr + " requires authentication") + } + + // See RFC 1929 + if buf[1] == proxy_socks5AuthPassword { + buf = buf[:0] + buf = append(buf, 1 /* password protocol version */) + buf = append(buf, uint8(len(s.user))) + buf = append(buf, s.user...) + buf = append(buf, uint8(len(s.password))) + buf = append(buf, s.password...) + + if _, err := conn.Write(buf); err != nil { + return errors.New("proxy: failed to write authentication request to SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + + if _, err := io.ReadFull(conn, buf[:2]); err != nil { + return errors.New("proxy: failed to read authentication reply from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + + if buf[1] != 0 { + return errors.New("proxy: SOCKS5 proxy at " + s.addr + " rejected username/password") + } + } + + buf = buf[:0] + buf = append(buf, proxy_socks5Version, proxy_socks5Connect, 0 /* reserved */) + + if ip := net.ParseIP(host); ip != nil { + if ip4 := ip.To4(); ip4 != nil { + buf = append(buf, proxy_socks5IP4) + ip = ip4 + } else { + buf = append(buf, proxy_socks5IP6) + } + buf = append(buf, ip...) + } else { + if len(host) > 255 { + return errors.New("proxy: destination host name too long: " + host) + } + buf = append(buf, proxy_socks5Domain) + buf = append(buf, byte(len(host))) + buf = append(buf, host...) + } + buf = append(buf, byte(port>>8), byte(port)) + + if _, err := conn.Write(buf); err != nil { + return errors.New("proxy: failed to write connect request to SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + + if _, err := io.ReadFull(conn, buf[:4]); err != nil { + return errors.New("proxy: failed to read connect reply from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + + failure := "unknown error" + if int(buf[1]) < len(proxy_socks5Errors) { + failure = proxy_socks5Errors[buf[1]] + } + + if len(failure) > 0 { + return errors.New("proxy: SOCKS5 proxy at " + s.addr + " failed to connect: " + failure) + } + + bytesToDiscard := 0 + switch buf[3] { + case proxy_socks5IP4: + bytesToDiscard = net.IPv4len + case proxy_socks5IP6: + bytesToDiscard = net.IPv6len + case proxy_socks5Domain: + _, err := io.ReadFull(conn, buf[:1]) + if err != nil { + return errors.New("proxy: failed to read domain length from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + bytesToDiscard = int(buf[0]) + default: + return errors.New("proxy: got unknown address type " + strconv.Itoa(int(buf[3])) + " from SOCKS5 proxy at " + s.addr) + } + + if cap(buf) < bytesToDiscard { + buf = make([]byte, bytesToDiscard) + } else { + buf = buf[:bytesToDiscard] + } + if _, err := io.ReadFull(conn, buf); err != nil { + return errors.New("proxy: failed to read address from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + + // Also need to discard the port number + if _, err := io.ReadFull(conn, buf[:2]); err != nil { + return errors.New("proxy: failed to read port from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + + return nil +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/LICENSE b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/LICENSE new file mode 100644 index 000000000000..1eb75ef68e44 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/LICENSE @@ -0,0 +1,28 @@ +Copyright (c) 2012 The Go Authors. All rights reserved. +Copyright (c) 2019 Klaus Post. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/deflate.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/deflate.go new file mode 100644 index 000000000000..2b101d26b25a --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/deflate.go @@ -0,0 +1,819 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Copyright (c) 2015 Klaus Post +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package flate + +import ( + "fmt" + "io" + "math" +) + +const ( + NoCompression = 0 + BestSpeed = 1 + BestCompression = 9 + DefaultCompression = -1 + + // HuffmanOnly disables Lempel-Ziv match searching and only performs Huffman + // entropy encoding. This mode is useful in compressing data that has + // already been compressed with an LZ style algorithm (e.g. Snappy or LZ4) + // that lacks an entropy encoder. Compression gains are achieved when + // certain bytes in the input stream occur more frequently than others. + // + // Note that HuffmanOnly produces a compressed output that is + // RFC 1951 compliant. That is, any valid DEFLATE decompressor will + // continue to be able to decompress this output. + HuffmanOnly = -2 + ConstantCompression = HuffmanOnly // compatibility alias. + + logWindowSize = 15 + windowSize = 1 << logWindowSize + windowMask = windowSize - 1 + logMaxOffsetSize = 15 // Standard DEFLATE + minMatchLength = 4 // The smallest match that the compressor looks for + maxMatchLength = 258 // The longest match for the compressor + minOffsetSize = 1 // The shortest offset that makes any sense + + // The maximum number of tokens we put into a single flat block, just too + // stop things from getting too large. + maxFlateBlockTokens = 1 << 14 + maxStoreBlockSize = 65535 + hashBits = 17 // After 17 performance degrades + hashSize = 1 << hashBits + hashMask = (1 << hashBits) - 1 + hashShift = (hashBits + minMatchLength - 1) / minMatchLength + maxHashOffset = 1 << 24 + + skipNever = math.MaxInt32 + + debugDeflate = false +) + +type compressionLevel struct { + good, lazy, nice, chain, fastSkipHashing, level int +} + +// Compression levels have been rebalanced from zlib deflate defaults +// to give a bigger spread in speed and compression. +// See https://blog.klauspost.com/rebalancing-deflate-compression-levels/ +var levels = []compressionLevel{ + {}, // 0 + // Level 1-6 uses specialized algorithm - values not used + {0, 0, 0, 0, 0, 1}, + {0, 0, 0, 0, 0, 2}, + {0, 0, 0, 0, 0, 3}, + {0, 0, 0, 0, 0, 4}, + {0, 0, 0, 0, 0, 5}, + {0, 0, 0, 0, 0, 6}, + // Levels 7-9 use increasingly more lazy matching + // and increasingly stringent conditions for "good enough". + {8, 8, 24, 16, skipNever, 7}, + {10, 16, 24, 64, skipNever, 8}, + {32, 258, 258, 4096, skipNever, 9}, +} + +// advancedState contains state for the advanced levels, with bigger hash tables, etc. +type advancedState struct { + // deflate state + length int + offset int + hash uint32 + maxInsertIndex int + ii uint16 // position of last match, intended to overflow to reset. + + // Input hash chains + // hashHead[hashValue] contains the largest inputIndex with the specified hash value + // If hashHead[hashValue] is within the current window, then + // hashPrev[hashHead[hashValue] & windowMask] contains the previous index + // with the same hash value. + chainHead int + hashHead [hashSize]uint32 + hashPrev [windowSize]uint32 + hashOffset int + + // input window: unprocessed data is window[index:windowEnd] + index int + hashMatch [maxMatchLength + minMatchLength]uint32 +} + +type compressor struct { + compressionLevel + + w *huffmanBitWriter + + // compression algorithm + fill func(*compressor, []byte) int // copy data to window + step func(*compressor) // process window + sync bool // requesting flush + + window []byte + windowEnd int + blockStart int // window index where current tokens start + byteAvailable bool // if true, still need to process window[index-1]. + err error + + // queued output tokens + tokens tokens + fast fastEnc + state *advancedState +} + +func (d *compressor) fillDeflate(b []byte) int { + s := d.state + if s.index >= 2*windowSize-(minMatchLength+maxMatchLength) { + // shift the window by windowSize + copy(d.window[:], d.window[windowSize:2*windowSize]) + s.index -= windowSize + d.windowEnd -= windowSize + if d.blockStart >= windowSize { + d.blockStart -= windowSize + } else { + d.blockStart = math.MaxInt32 + } + s.hashOffset += windowSize + if s.hashOffset > maxHashOffset { + delta := s.hashOffset - 1 + s.hashOffset -= delta + s.chainHead -= delta + // Iterate over slices instead of arrays to avoid copying + // the entire table onto the stack (Issue #18625). + for i, v := range s.hashPrev[:] { + if int(v) > delta { + s.hashPrev[i] = uint32(int(v) - delta) + } else { + s.hashPrev[i] = 0 + } + } + for i, v := range s.hashHead[:] { + if int(v) > delta { + s.hashHead[i] = uint32(int(v) - delta) + } else { + s.hashHead[i] = 0 + } + } + } + } + n := copy(d.window[d.windowEnd:], b) + d.windowEnd += n + return n +} + +func (d *compressor) writeBlock(tok *tokens, index int, eof bool) error { + if index > 0 || eof { + var window []byte + if d.blockStart <= index { + window = d.window[d.blockStart:index] + } + d.blockStart = index + d.w.writeBlock(tok, eof, window) + return d.w.err + } + return nil +} + +// writeBlockSkip writes the current block and uses the number of tokens +// to determine if the block should be stored on no matches, or +// only huffman encoded. +func (d *compressor) writeBlockSkip(tok *tokens, index int, eof bool) error { + if index > 0 || eof { + if d.blockStart <= index { + window := d.window[d.blockStart:index] + // If we removed less than a 64th of all literals + // we huffman compress the block. + if int(tok.n) > len(window)-int(tok.n>>6) { + d.w.writeBlockHuff(eof, window, d.sync) + } else { + // Write a dynamic huffman block. + d.w.writeBlockDynamic(tok, eof, window, d.sync) + } + } else { + d.w.writeBlock(tok, eof, nil) + } + d.blockStart = index + return d.w.err + } + return nil +} + +// fillWindow will fill the current window with the supplied +// dictionary and calculate all hashes. +// This is much faster than doing a full encode. +// Should only be used after a start/reset. +func (d *compressor) fillWindow(b []byte) { + // Do not fill window if we are in store-only or huffman mode. + if d.level <= 0 { + return + } + if d.fast != nil { + // encode the last data, but discard the result + if len(b) > maxMatchOffset { + b = b[len(b)-maxMatchOffset:] + } + d.fast.Encode(&d.tokens, b) + d.tokens.Reset() + return + } + s := d.state + // If we are given too much, cut it. + if len(b) > windowSize { + b = b[len(b)-windowSize:] + } + // Add all to window. + n := copy(d.window[d.windowEnd:], b) + + // Calculate 256 hashes at the time (more L1 cache hits) + loops := (n + 256 - minMatchLength) / 256 + for j := 0; j < loops; j++ { + startindex := j * 256 + end := startindex + 256 + minMatchLength - 1 + if end > n { + end = n + } + tocheck := d.window[startindex:end] + dstSize := len(tocheck) - minMatchLength + 1 + + if dstSize <= 0 { + continue + } + + dst := s.hashMatch[:dstSize] + bulkHash4(tocheck, dst) + var newH uint32 + for i, val := range dst { + di := i + startindex + newH = val & hashMask + // Get previous value with the same hash. + // Our chain should point to the previous value. + s.hashPrev[di&windowMask] = s.hashHead[newH] + // Set the head of the hash chain to us. + s.hashHead[newH] = uint32(di + s.hashOffset) + } + s.hash = newH + } + // Update window information. + d.windowEnd += n + s.index = n +} + +// Try to find a match starting at index whose length is greater than prevSize. +// We only look at chainCount possibilities before giving up. +// pos = s.index, prevHead = s.chainHead-s.hashOffset, prevLength=minMatchLength-1, lookahead +func (d *compressor) findMatch(pos int, prevHead int, prevLength int, lookahead int) (length, offset int, ok bool) { + minMatchLook := maxMatchLength + if lookahead < minMatchLook { + minMatchLook = lookahead + } + + win := d.window[0 : pos+minMatchLook] + + // We quit when we get a match that's at least nice long + nice := len(win) - pos + if d.nice < nice { + nice = d.nice + } + + // If we've got a match that's good enough, only look in 1/4 the chain. + tries := d.chain + length = prevLength + if length >= d.good { + tries >>= 2 + } + + wEnd := win[pos+length] + wPos := win[pos:] + minIndex := pos - windowSize + + for i := prevHead; tries > 0; tries-- { + if wEnd == win[i+length] { + n := matchLen(win[i:i+minMatchLook], wPos) + + if n > length && (n > minMatchLength || pos-i <= 4096) { + length = n + offset = pos - i + ok = true + if n >= nice { + // The match is good enough that we don't try to find a better one. + break + } + wEnd = win[pos+n] + } + } + if i == minIndex { + // hashPrev[i & windowMask] has already been overwritten, so stop now. + break + } + i = int(d.state.hashPrev[i&windowMask]) - d.state.hashOffset + if i < minIndex || i < 0 { + break + } + } + return +} + +func (d *compressor) writeStoredBlock(buf []byte) error { + if d.w.writeStoredHeader(len(buf), false); d.w.err != nil { + return d.w.err + } + d.w.writeBytes(buf) + return d.w.err +} + +// hash4 returns a hash representation of the first 4 bytes +// of the supplied slice. +// The caller must ensure that len(b) >= 4. +func hash4(b []byte) uint32 { + b = b[:4] + return hash4u(uint32(b[3])|uint32(b[2])<<8|uint32(b[1])<<16|uint32(b[0])<<24, hashBits) +} + +// bulkHash4 will compute hashes using the same +// algorithm as hash4 +func bulkHash4(b []byte, dst []uint32) { + if len(b) < 4 { + return + } + hb := uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24 + dst[0] = hash4u(hb, hashBits) + end := len(b) - 4 + 1 + for i := 1; i < end; i++ { + hb = (hb << 8) | uint32(b[i+3]) + dst[i] = hash4u(hb, hashBits) + } +} + +func (d *compressor) initDeflate() { + d.window = make([]byte, 2*windowSize) + d.byteAvailable = false + d.err = nil + if d.state == nil { + return + } + s := d.state + s.index = 0 + s.hashOffset = 1 + s.length = minMatchLength - 1 + s.offset = 0 + s.hash = 0 + s.chainHead = -1 +} + +// deflateLazy is the same as deflate, but with d.fastSkipHashing == skipNever, +// meaning it always has lazy matching on. +func (d *compressor) deflateLazy() { + s := d.state + // Sanity enables additional runtime tests. + // It's intended to be used during development + // to supplement the currently ad-hoc unit tests. + const sanity = debugDeflate + + if d.windowEnd-s.index < minMatchLength+maxMatchLength && !d.sync { + return + } + + s.maxInsertIndex = d.windowEnd - (minMatchLength - 1) + if s.index < s.maxInsertIndex { + s.hash = hash4(d.window[s.index : s.index+minMatchLength]) + } + + for { + if sanity && s.index > d.windowEnd { + panic("index > windowEnd") + } + lookahead := d.windowEnd - s.index + if lookahead < minMatchLength+maxMatchLength { + if !d.sync { + return + } + if sanity && s.index > d.windowEnd { + panic("index > windowEnd") + } + if lookahead == 0 { + // Flush current output block if any. + if d.byteAvailable { + // There is still one pending token that needs to be flushed + d.tokens.AddLiteral(d.window[s.index-1]) + d.byteAvailable = false + } + if d.tokens.n > 0 { + if d.err = d.writeBlock(&d.tokens, s.index, false); d.err != nil { + return + } + d.tokens.Reset() + } + return + } + } + if s.index < s.maxInsertIndex { + // Update the hash + s.hash = hash4(d.window[s.index : s.index+minMatchLength]) + ch := s.hashHead[s.hash&hashMask] + s.chainHead = int(ch) + s.hashPrev[s.index&windowMask] = ch + s.hashHead[s.hash&hashMask] = uint32(s.index + s.hashOffset) + } + prevLength := s.length + prevOffset := s.offset + s.length = minMatchLength - 1 + s.offset = 0 + minIndex := s.index - windowSize + if minIndex < 0 { + minIndex = 0 + } + + if s.chainHead-s.hashOffset >= minIndex && lookahead > prevLength && prevLength < d.lazy { + if newLength, newOffset, ok := d.findMatch(s.index, s.chainHead-s.hashOffset, minMatchLength-1, lookahead); ok { + s.length = newLength + s.offset = newOffset + } + } + if prevLength >= minMatchLength && s.length <= prevLength { + // There was a match at the previous step, and the current match is + // not better. Output the previous match. + d.tokens.AddMatch(uint32(prevLength-3), uint32(prevOffset-minOffsetSize)) + + // Insert in the hash table all strings up to the end of the match. + // index and index-1 are already inserted. If there is not enough + // lookahead, the last two strings are not inserted into the hash + // table. + var newIndex int + newIndex = s.index + prevLength - 1 + // Calculate missing hashes + end := newIndex + if end > s.maxInsertIndex { + end = s.maxInsertIndex + } + end += minMatchLength - 1 + startindex := s.index + 1 + if startindex > s.maxInsertIndex { + startindex = s.maxInsertIndex + } + tocheck := d.window[startindex:end] + dstSize := len(tocheck) - minMatchLength + 1 + if dstSize > 0 { + dst := s.hashMatch[:dstSize] + bulkHash4(tocheck, dst) + var newH uint32 + for i, val := range dst { + di := i + startindex + newH = val & hashMask + // Get previous value with the same hash. + // Our chain should point to the previous value. + s.hashPrev[di&windowMask] = s.hashHead[newH] + // Set the head of the hash chain to us. + s.hashHead[newH] = uint32(di + s.hashOffset) + } + s.hash = newH + } + + s.index = newIndex + d.byteAvailable = false + s.length = minMatchLength - 1 + if d.tokens.n == maxFlateBlockTokens { + // The block includes the current character + if d.err = d.writeBlock(&d.tokens, s.index, false); d.err != nil { + return + } + d.tokens.Reset() + } + } else { + // Reset, if we got a match this run. + if s.length >= minMatchLength { + s.ii = 0 + } + // We have a byte waiting. Emit it. + if d.byteAvailable { + s.ii++ + d.tokens.AddLiteral(d.window[s.index-1]) + if d.tokens.n == maxFlateBlockTokens { + if d.err = d.writeBlock(&d.tokens, s.index, false); d.err != nil { + return + } + d.tokens.Reset() + } + s.index++ + + // If we have a long run of no matches, skip additional bytes + // Resets when s.ii overflows after 64KB. + if s.ii > 31 { + n := int(s.ii >> 5) + for j := 0; j < n; j++ { + if s.index >= d.windowEnd-1 { + break + } + + d.tokens.AddLiteral(d.window[s.index-1]) + if d.tokens.n == maxFlateBlockTokens { + if d.err = d.writeBlock(&d.tokens, s.index, false); d.err != nil { + return + } + d.tokens.Reset() + } + s.index++ + } + // Flush last byte + d.tokens.AddLiteral(d.window[s.index-1]) + d.byteAvailable = false + // s.length = minMatchLength - 1 // not needed, since s.ii is reset above, so it should never be > minMatchLength + if d.tokens.n == maxFlateBlockTokens { + if d.err = d.writeBlock(&d.tokens, s.index, false); d.err != nil { + return + } + d.tokens.Reset() + } + } + } else { + s.index++ + d.byteAvailable = true + } + } + } +} + +func (d *compressor) store() { + if d.windowEnd > 0 && (d.windowEnd == maxStoreBlockSize || d.sync) { + d.err = d.writeStoredBlock(d.window[:d.windowEnd]) + d.windowEnd = 0 + } +} + +// fillWindow will fill the buffer with data for huffman-only compression. +// The number of bytes copied is returned. +func (d *compressor) fillBlock(b []byte) int { + n := copy(d.window[d.windowEnd:], b) + d.windowEnd += n + return n +} + +// storeHuff will compress and store the currently added data, +// if enough has been accumulated or we at the end of the stream. +// Any error that occurred will be in d.err +func (d *compressor) storeHuff() { + if d.windowEnd < len(d.window) && !d.sync || d.windowEnd == 0 { + return + } + d.w.writeBlockHuff(false, d.window[:d.windowEnd], d.sync) + d.err = d.w.err + d.windowEnd = 0 +} + +// storeFast will compress and store the currently added data, +// if enough has been accumulated or we at the end of the stream. +// Any error that occurred will be in d.err +func (d *compressor) storeFast() { + // We only compress if we have maxStoreBlockSize. + if d.windowEnd < len(d.window) { + if !d.sync { + return + } + // Handle extremely small sizes. + if d.windowEnd < 128 { + if d.windowEnd == 0 { + return + } + if d.windowEnd <= 32 { + d.err = d.writeStoredBlock(d.window[:d.windowEnd]) + } else { + d.w.writeBlockHuff(false, d.window[:d.windowEnd], true) + d.err = d.w.err + } + d.tokens.Reset() + d.windowEnd = 0 + d.fast.Reset() + return + } + } + + d.fast.Encode(&d.tokens, d.window[:d.windowEnd]) + // If we made zero matches, store the block as is. + if d.tokens.n == 0 { + d.err = d.writeStoredBlock(d.window[:d.windowEnd]) + // If we removed less than 1/16th, huffman compress the block. + } else if int(d.tokens.n) > d.windowEnd-(d.windowEnd>>4) { + d.w.writeBlockHuff(false, d.window[:d.windowEnd], d.sync) + d.err = d.w.err + } else { + d.w.writeBlockDynamic(&d.tokens, false, d.window[:d.windowEnd], d.sync) + d.err = d.w.err + } + d.tokens.Reset() + d.windowEnd = 0 +} + +// write will add input byte to the stream. +// Unless an error occurs all bytes will be consumed. +func (d *compressor) write(b []byte) (n int, err error) { + if d.err != nil { + return 0, d.err + } + n = len(b) + for len(b) > 0 { + d.step(d) + b = b[d.fill(d, b):] + if d.err != nil { + return 0, d.err + } + } + return n, d.err +} + +func (d *compressor) syncFlush() error { + d.sync = true + if d.err != nil { + return d.err + } + d.step(d) + if d.err == nil { + d.w.writeStoredHeader(0, false) + d.w.flush() + d.err = d.w.err + } + d.sync = false + return d.err +} + +func (d *compressor) init(w io.Writer, level int) (err error) { + d.w = newHuffmanBitWriter(w) + + switch { + case level == NoCompression: + d.window = make([]byte, maxStoreBlockSize) + d.fill = (*compressor).fillBlock + d.step = (*compressor).store + case level == ConstantCompression: + d.w.logNewTablePenalty = 4 + d.window = make([]byte, maxStoreBlockSize) + d.fill = (*compressor).fillBlock + d.step = (*compressor).storeHuff + case level == DefaultCompression: + level = 5 + fallthrough + case level >= 1 && level <= 6: + d.w.logNewTablePenalty = 6 + d.fast = newFastEnc(level) + d.window = make([]byte, maxStoreBlockSize) + d.fill = (*compressor).fillBlock + d.step = (*compressor).storeFast + case 7 <= level && level <= 9: + d.w.logNewTablePenalty = 10 + d.state = &advancedState{} + d.compressionLevel = levels[level] + d.initDeflate() + d.fill = (*compressor).fillDeflate + d.step = (*compressor).deflateLazy + default: + return fmt.Errorf("flate: invalid compression level %d: want value in range [-2, 9]", level) + } + d.level = level + return nil +} + +// reset the state of the compressor. +func (d *compressor) reset(w io.Writer) { + d.w.reset(w) + d.sync = false + d.err = nil + // We only need to reset a few things for Snappy. + if d.fast != nil { + d.fast.Reset() + d.windowEnd = 0 + d.tokens.Reset() + return + } + switch d.compressionLevel.chain { + case 0: + // level was NoCompression or ConstantCompresssion. + d.windowEnd = 0 + default: + s := d.state + s.chainHead = -1 + for i := range s.hashHead { + s.hashHead[i] = 0 + } + for i := range s.hashPrev { + s.hashPrev[i] = 0 + } + s.hashOffset = 1 + s.index, d.windowEnd = 0, 0 + d.blockStart, d.byteAvailable = 0, false + d.tokens.Reset() + s.length = minMatchLength - 1 + s.offset = 0 + s.hash = 0 + s.ii = 0 + s.maxInsertIndex = 0 + } +} + +func (d *compressor) close() error { + if d.err != nil { + return d.err + } + d.sync = true + d.step(d) + if d.err != nil { + return d.err + } + if d.w.writeStoredHeader(0, true); d.w.err != nil { + return d.w.err + } + d.w.flush() + d.w.reset(nil) + return d.w.err +} + +// NewWriter returns a new Writer compressing data at the given level. +// Following zlib, levels range from 1 (BestSpeed) to 9 (BestCompression); +// higher levels typically run slower but compress more. +// Level 0 (NoCompression) does not attempt any compression; it only adds the +// necessary DEFLATE framing. +// Level -1 (DefaultCompression) uses the default compression level. +// Level -2 (ConstantCompression) will use Huffman compression only, giving +// a very fast compression for all types of input, but sacrificing considerable +// compression efficiency. +// +// If level is in the range [-2, 9] then the error returned will be nil. +// Otherwise the error returned will be non-nil. +func NewWriter(w io.Writer, level int) (*Writer, error) { + var dw Writer + if err := dw.d.init(w, level); err != nil { + return nil, err + } + return &dw, nil +} + +// NewWriterDict is like NewWriter but initializes the new +// Writer with a preset dictionary. The returned Writer behaves +// as if the dictionary had been written to it without producing +// any compressed output. The compressed data written to w +// can only be decompressed by a Reader initialized with the +// same dictionary. +func NewWriterDict(w io.Writer, level int, dict []byte) (*Writer, error) { + zw, err := NewWriter(w, level) + if err != nil { + return nil, err + } + zw.d.fillWindow(dict) + zw.dict = append(zw.dict, dict...) // duplicate dictionary for Reset method. + return zw, err +} + +// A Writer takes data written to it and writes the compressed +// form of that data to an underlying writer (see NewWriter). +type Writer struct { + d compressor + dict []byte +} + +// Write writes data to w, which will eventually write the +// compressed form of data to its underlying writer. +func (w *Writer) Write(data []byte) (n int, err error) { + return w.d.write(data) +} + +// Flush flushes any pending data to the underlying writer. +// It is useful mainly in compressed network protocols, to ensure that +// a remote reader has enough data to reconstruct a packet. +// Flush does not return until the data has been written. +// Calling Flush when there is no pending data still causes the Writer +// to emit a sync marker of at least 4 bytes. +// If the underlying writer returns an error, Flush returns that error. +// +// In the terminology of the zlib library, Flush is equivalent to Z_SYNC_FLUSH. +func (w *Writer) Flush() error { + // For more about flushing: + // http://www.bolet.org/~pornin/deflate-flush.html + return w.d.syncFlush() +} + +// Close flushes and closes the writer. +func (w *Writer) Close() error { + return w.d.close() +} + +// Reset discards the writer's state and makes it equivalent to +// the result of NewWriter or NewWriterDict called with dst +// and w's level and dictionary. +func (w *Writer) Reset(dst io.Writer) { + if len(w.dict) > 0 { + // w was created with NewWriterDict + w.d.reset(dst) + if dst != nil { + w.d.fillWindow(w.dict) + } + } else { + // w was created with NewWriter + w.d.reset(dst) + } +} + +// ResetDict discards the writer's state and makes it equivalent to +// the result of NewWriter or NewWriterDict called with dst +// and w's level, but sets a specific dictionary. +func (w *Writer) ResetDict(dst io.Writer, dict []byte) { + w.dict = dict + w.d.reset(dst) + w.d.fillWindow(w.dict) +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/dict_decoder.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/dict_decoder.go new file mode 100644 index 000000000000..71c75a065ea7 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/dict_decoder.go @@ -0,0 +1,184 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package flate + +// dictDecoder implements the LZ77 sliding dictionary as used in decompression. +// LZ77 decompresses data through sequences of two forms of commands: +// +// * Literal insertions: Runs of one or more symbols are inserted into the data +// stream as is. This is accomplished through the writeByte method for a +// single symbol, or combinations of writeSlice/writeMark for multiple symbols. +// Any valid stream must start with a literal insertion if no preset dictionary +// is used. +// +// * Backward copies: Runs of one or more symbols are copied from previously +// emitted data. Backward copies come as the tuple (dist, length) where dist +// determines how far back in the stream to copy from and length determines how +// many bytes to copy. Note that it is valid for the length to be greater than +// the distance. Since LZ77 uses forward copies, that situation is used to +// perform a form of run-length encoding on repeated runs of symbols. +// The writeCopy and tryWriteCopy are used to implement this command. +// +// For performance reasons, this implementation performs little to no sanity +// checks about the arguments. As such, the invariants documented for each +// method call must be respected. +type dictDecoder struct { + hist []byte // Sliding window history + + // Invariant: 0 <= rdPos <= wrPos <= len(hist) + wrPos int // Current output position in buffer + rdPos int // Have emitted hist[:rdPos] already + full bool // Has a full window length been written yet? +} + +// init initializes dictDecoder to have a sliding window dictionary of the given +// size. If a preset dict is provided, it will initialize the dictionary with +// the contents of dict. +func (dd *dictDecoder) init(size int, dict []byte) { + *dd = dictDecoder{hist: dd.hist} + + if cap(dd.hist) < size { + dd.hist = make([]byte, size) + } + dd.hist = dd.hist[:size] + + if len(dict) > len(dd.hist) { + dict = dict[len(dict)-len(dd.hist):] + } + dd.wrPos = copy(dd.hist, dict) + if dd.wrPos == len(dd.hist) { + dd.wrPos = 0 + dd.full = true + } + dd.rdPos = dd.wrPos +} + +// histSize reports the total amount of historical data in the dictionary. +func (dd *dictDecoder) histSize() int { + if dd.full { + return len(dd.hist) + } + return dd.wrPos +} + +// availRead reports the number of bytes that can be flushed by readFlush. +func (dd *dictDecoder) availRead() int { + return dd.wrPos - dd.rdPos +} + +// availWrite reports the available amount of output buffer space. +func (dd *dictDecoder) availWrite() int { + return len(dd.hist) - dd.wrPos +} + +// writeSlice returns a slice of the available buffer to write data to. +// +// This invariant will be kept: len(s) <= availWrite() +func (dd *dictDecoder) writeSlice() []byte { + return dd.hist[dd.wrPos:] +} + +// writeMark advances the writer pointer by cnt. +// +// This invariant must be kept: 0 <= cnt <= availWrite() +func (dd *dictDecoder) writeMark(cnt int) { + dd.wrPos += cnt +} + +// writeByte writes a single byte to the dictionary. +// +// This invariant must be kept: 0 < availWrite() +func (dd *dictDecoder) writeByte(c byte) { + dd.hist[dd.wrPos] = c + dd.wrPos++ +} + +// writeCopy copies a string at a given (dist, length) to the output. +// This returns the number of bytes copied and may be less than the requested +// length if the available space in the output buffer is too small. +// +// This invariant must be kept: 0 < dist <= histSize() +func (dd *dictDecoder) writeCopy(dist, length int) int { + dstBase := dd.wrPos + dstPos := dstBase + srcPos := dstPos - dist + endPos := dstPos + length + if endPos > len(dd.hist) { + endPos = len(dd.hist) + } + + // Copy non-overlapping section after destination position. + // + // This section is non-overlapping in that the copy length for this section + // is always less than or equal to the backwards distance. This can occur + // if a distance refers to data that wraps-around in the buffer. + // Thus, a backwards copy is performed here; that is, the exact bytes in + // the source prior to the copy is placed in the destination. + if srcPos < 0 { + srcPos += len(dd.hist) + dstPos += copy(dd.hist[dstPos:endPos], dd.hist[srcPos:]) + srcPos = 0 + } + + // Copy possibly overlapping section before destination position. + // + // This section can overlap if the copy length for this section is larger + // than the backwards distance. This is allowed by LZ77 so that repeated + // strings can be succinctly represented using (dist, length) pairs. + // Thus, a forwards copy is performed here; that is, the bytes copied is + // possibly dependent on the resulting bytes in the destination as the copy + // progresses along. This is functionally equivalent to the following: + // + // for i := 0; i < endPos-dstPos; i++ { + // dd.hist[dstPos+i] = dd.hist[srcPos+i] + // } + // dstPos = endPos + // + for dstPos < endPos { + dstPos += copy(dd.hist[dstPos:endPos], dd.hist[srcPos:dstPos]) + } + + dd.wrPos = dstPos + return dstPos - dstBase +} + +// tryWriteCopy tries to copy a string at a given (distance, length) to the +// output. This specialized version is optimized for short distances. +// +// This method is designed to be inlined for performance reasons. +// +// This invariant must be kept: 0 < dist <= histSize() +func (dd *dictDecoder) tryWriteCopy(dist, length int) int { + dstPos := dd.wrPos + endPos := dstPos + length + if dstPos < dist || endPos > len(dd.hist) { + return 0 + } + dstBase := dstPos + srcPos := dstPos - dist + + // Copy possibly overlapping section before destination position. +loop: + dstPos += copy(dd.hist[dstPos:endPos], dd.hist[srcPos:dstPos]) + if dstPos < endPos { + goto loop // Avoid for-loop so that this function can be inlined + } + + dd.wrPos = dstPos + return dstPos - dstBase +} + +// readFlush returns a slice of the historical buffer that is ready to be +// emitted to the user. The data returned by readFlush must be fully consumed +// before calling any other dictDecoder methods. +func (dd *dictDecoder) readFlush() []byte { + toRead := dd.hist[dd.rdPos:dd.wrPos] + dd.rdPos = dd.wrPos + if dd.wrPos == len(dd.hist) { + dd.wrPos, dd.rdPos = 0, 0 + dd.full = true + } + return toRead +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/fast_encoder.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/fast_encoder.go new file mode 100644 index 000000000000..6d4c1e98bc5f --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/fast_encoder.go @@ -0,0 +1,254 @@ +// Copyright 2011 The Snappy-Go Authors. All rights reserved. +// Modified for deflate by Klaus Post (c) 2015. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package flate + +import ( + "fmt" + "math/bits" +) + +type fastEnc interface { + Encode(dst *tokens, src []byte) + Reset() +} + +func newFastEnc(level int) fastEnc { + switch level { + case 1: + return &fastEncL1{fastGen: fastGen{cur: maxStoreBlockSize}} + case 2: + return &fastEncL2{fastGen: fastGen{cur: maxStoreBlockSize}} + case 3: + return &fastEncL3{fastGen: fastGen{cur: maxStoreBlockSize}} + case 4: + return &fastEncL4{fastGen: fastGen{cur: maxStoreBlockSize}} + case 5: + return &fastEncL5{fastGen: fastGen{cur: maxStoreBlockSize}} + case 6: + return &fastEncL6{fastGen: fastGen{cur: maxStoreBlockSize}} + default: + panic("invalid level specified") + } +} + +const ( + tableBits = 15 // Bits used in the table + tableSize = 1 << tableBits // Size of the table + tableShift = 32 - tableBits // Right-shift to get the tableBits most significant bits of a uint32. + baseMatchOffset = 1 // The smallest match offset + baseMatchLength = 3 // The smallest match length per the RFC section 3.2.5 + maxMatchOffset = 1 << 15 // The largest match offset + + bTableBits = 17 // Bits used in the big tables + bTableSize = 1 << bTableBits // Size of the table + allocHistory = maxStoreBlockSize * 10 // Size to preallocate for history. + bufferReset = (1 << 31) - allocHistory - maxStoreBlockSize - 1 // Reset the buffer offset when reaching this. +) + +const ( + prime3bytes = 506832829 + prime4bytes = 2654435761 + prime5bytes = 889523592379 + prime6bytes = 227718039650203 + prime7bytes = 58295818150454627 + prime8bytes = 0xcf1bbcdcb7a56463 +) + +func load32(b []byte, i int) uint32 { + // Help the compiler eliminate bounds checks on the read so it can be done in a single read. + b = b[i:] + b = b[:4] + return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24 +} + +func load64(b []byte, i int) uint64 { + // Help the compiler eliminate bounds checks on the read so it can be done in a single read. + b = b[i:] + b = b[:8] + return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 | + uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56 +} + +func load3232(b []byte, i int32) uint32 { + // Help the compiler eliminate bounds checks on the read so it can be done in a single read. + b = b[i:] + b = b[:4] + return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24 +} + +func load6432(b []byte, i int32) uint64 { + // Help the compiler eliminate bounds checks on the read so it can be done in a single read. + b = b[i:] + b = b[:8] + return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 | + uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56 +} + +func hash(u uint32) uint32 { + return (u * 0x1e35a7bd) >> tableShift +} + +type tableEntry struct { + offset int32 +} + +// fastGen maintains the table for matches, +// and the previous byte block for level 2. +// This is the generic implementation. +type fastGen struct { + hist []byte + cur int32 +} + +func (e *fastGen) addBlock(src []byte) int32 { + // check if we have space already + if len(e.hist)+len(src) > cap(e.hist) { + if cap(e.hist) == 0 { + e.hist = make([]byte, 0, allocHistory) + } else { + if cap(e.hist) < maxMatchOffset*2 { + panic("unexpected buffer size") + } + // Move down + offset := int32(len(e.hist)) - maxMatchOffset + copy(e.hist[0:maxMatchOffset], e.hist[offset:]) + e.cur += offset + e.hist = e.hist[:maxMatchOffset] + } + } + s := int32(len(e.hist)) + e.hist = append(e.hist, src...) + return s +} + +// hash4 returns the hash of u to fit in a hash table with h bits. +// Preferably h should be a constant and should always be <32. +func hash4u(u uint32, h uint8) uint32 { + return (u * prime4bytes) >> ((32 - h) & 31) +} + +type tableEntryPrev struct { + Cur tableEntry + Prev tableEntry +} + +// hash4x64 returns the hash of the lowest 4 bytes of u to fit in a hash table with h bits. +// Preferably h should be a constant and should always be <32. +func hash4x64(u uint64, h uint8) uint32 { + return (uint32(u) * prime4bytes) >> ((32 - h) & 31) +} + +// hash7 returns the hash of the lowest 7 bytes of u to fit in a hash table with h bits. +// Preferably h should be a constant and should always be <64. +func hash7(u uint64, h uint8) uint32 { + return uint32(((u << (64 - 56)) * prime7bytes) >> ((64 - h) & 63)) +} + +// hash8 returns the hash of u to fit in a hash table with h bits. +// Preferably h should be a constant and should always be <64. +func hash8(u uint64, h uint8) uint32 { + return uint32((u * prime8bytes) >> ((64 - h) & 63)) +} + +// hash6 returns the hash of the lowest 6 bytes of u to fit in a hash table with h bits. +// Preferably h should be a constant and should always be <64. +func hash6(u uint64, h uint8) uint32 { + return uint32(((u << (64 - 48)) * prime6bytes) >> ((64 - h) & 63)) +} + +// matchlen will return the match length between offsets and t in src. +// The maximum length returned is maxMatchLength - 4. +// It is assumed that s > t, that t >=0 and s < len(src). +func (e *fastGen) matchlen(s, t int32, src []byte) int32 { + if debugDecode { + if t >= s { + panic(fmt.Sprint("t >=s:", t, s)) + } + if int(s) >= len(src) { + panic(fmt.Sprint("s >= len(src):", s, len(src))) + } + if t < 0 { + panic(fmt.Sprint("t < 0:", t)) + } + if s-t > maxMatchOffset { + panic(fmt.Sprint(s, "-", t, "(", s-t, ") > maxMatchLength (", maxMatchOffset, ")")) + } + } + s1 := int(s) + maxMatchLength - 4 + if s1 > len(src) { + s1 = len(src) + } + + // Extend the match to be as long as possible. + return int32(matchLen(src[s:s1], src[t:])) +} + +// matchlenLong will return the match length between offsets and t in src. +// It is assumed that s > t, that t >=0 and s < len(src). +func (e *fastGen) matchlenLong(s, t int32, src []byte) int32 { + if debugDecode { + if t >= s { + panic(fmt.Sprint("t >=s:", t, s)) + } + if int(s) >= len(src) { + panic(fmt.Sprint("s >= len(src):", s, len(src))) + } + if t < 0 { + panic(fmt.Sprint("t < 0:", t)) + } + if s-t > maxMatchOffset { + panic(fmt.Sprint(s, "-", t, "(", s-t, ") > maxMatchLength (", maxMatchOffset, ")")) + } + } + // Extend the match to be as long as possible. + return int32(matchLen(src[s:], src[t:])) +} + +// Reset the encoding table. +func (e *fastGen) Reset() { + if cap(e.hist) < allocHistory { + e.hist = make([]byte, 0, allocHistory) + } + // We offset current position so everything will be out of reach. + // If we are above the buffer reset it will be cleared anyway since len(hist) == 0. + if e.cur <= bufferReset { + e.cur += maxMatchOffset + int32(len(e.hist)) + } + e.hist = e.hist[:0] +} + +// matchLen returns the maximum length. +// 'a' must be the shortest of the two. +func matchLen(a, b []byte) int { + b = b[:len(a)] + var checked int + if len(a) > 4 { + // Try 4 bytes first + if diff := load32(a, 0) ^ load32(b, 0); diff != 0 { + return bits.TrailingZeros32(diff) >> 3 + } + // Switch to 8 byte matching. + checked = 4 + a = a[4:] + b = b[4:] + for len(a) >= 8 { + b = b[:len(a)] + if diff := load64(a, 0) ^ load64(b, 0); diff != 0 { + return checked + (bits.TrailingZeros64(diff) >> 3) + } + checked += 8 + a = a[8:] + b = b[8:] + } + } + b = b[:len(a)] + for i := range a { + if a[i] != b[i] { + return int(i) + checked + } + } + return len(a) + checked +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/gen_inflate.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/gen_inflate.go new file mode 100644 index 000000000000..c74a95fe7f6f --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/gen_inflate.go @@ -0,0 +1,274 @@ +// +build generate + +//go:generate go run $GOFILE && gofmt -w inflate_gen.go + +package main + +import ( + "os" + "strings" +) + +func main() { + f, err := os.Create("inflate_gen.go") + if err != nil { + panic(err) + } + defer f.Close() + types := []string{"*bytes.Buffer", "*bytes.Reader", "*bufio.Reader", "*strings.Reader"} + names := []string{"BytesBuffer", "BytesReader", "BufioReader", "StringsReader"} + imports := []string{"bytes", "bufio", "io", "strings", "math/bits"} + f.WriteString(`// Code generated by go generate gen_inflate.go. DO NOT EDIT. + +package flate + +import ( +`) + + for _, imp := range imports { + f.WriteString("\t\"" + imp + "\"\n") + } + f.WriteString(")\n\n") + + template := ` + +// Decode a single Huffman block from f. +// hl and hd are the Huffman states for the lit/length values +// and the distance values, respectively. If hd == nil, using the +// fixed distance encoding associated with fixed Huffman blocks. +func (f *decompressor) $FUNCNAME$() { + const ( + stateInit = iota // Zero value must be stateInit + stateDict + ) + fr := f.r.($TYPE$) + moreBits := func() error { + c, err := fr.ReadByte() + if err != nil { + return noEOF(err) + } + f.roffset++ + f.b |= uint32(c) << f.nb + f.nb += 8 + return nil + } + + switch f.stepState { + case stateInit: + goto readLiteral + case stateDict: + goto copyHistory + } + +readLiteral: + // Read literal and/or (length, distance) according to RFC section 3.2.3. + { + var v int + { + // Inlined v, err := f.huffSym(f.hl) + // Since a huffmanDecoder can be empty or be composed of a degenerate tree + // with single element, huffSym must error on these two edge cases. In both + // cases, the chunks slice will be 0 for the invalid sequence, leading it + // satisfy the n == 0 check below. + n := uint(f.hl.maxRead) + // Optimization. Compiler isn't smart enough to keep f.b,f.nb in registers, + // but is smart enough to keep local variables in registers, so use nb and b, + // inline call to moreBits and reassign b,nb back to f on return. + nb, b := f.nb, f.b + for { + for nb < n { + c, err := fr.ReadByte() + if err != nil { + f.b = b + f.nb = nb + f.err = noEOF(err) + return + } + f.roffset++ + b |= uint32(c) << (nb & 31) + nb += 8 + } + chunk := f.hl.chunks[b&(huffmanNumChunks-1)] + n = uint(chunk & huffmanCountMask) + if n > huffmanChunkBits { + chunk = f.hl.links[chunk>>huffmanValueShift][(b>>huffmanChunkBits)&f.hl.linkMask] + n = uint(chunk & huffmanCountMask) + } + if n <= nb { + if n == 0 { + f.b = b + f.nb = nb + if debugDecode { + fmt.Println("huffsym: n==0") + } + f.err = CorruptInputError(f.roffset) + return + } + f.b = b >> (n & 31) + f.nb = nb - n + v = int(chunk >> huffmanValueShift) + break + } + } + } + + var n uint // number of bits extra + var length int + var err error + switch { + case v < 256: + f.dict.writeByte(byte(v)) + if f.dict.availWrite() == 0 { + f.toRead = f.dict.readFlush() + f.step = (*decompressor).$FUNCNAME$ + f.stepState = stateInit + return + } + goto readLiteral + case v == 256: + f.finishBlock() + return + // otherwise, reference to older data + case v < 265: + length = v - (257 - 3) + n = 0 + case v < 269: + length = v*2 - (265*2 - 11) + n = 1 + case v < 273: + length = v*4 - (269*4 - 19) + n = 2 + case v < 277: + length = v*8 - (273*8 - 35) + n = 3 + case v < 281: + length = v*16 - (277*16 - 67) + n = 4 + case v < 285: + length = v*32 - (281*32 - 131) + n = 5 + case v < maxNumLit: + length = 258 + n = 0 + default: + if debugDecode { + fmt.Println(v, ">= maxNumLit") + } + f.err = CorruptInputError(f.roffset) + return + } + if n > 0 { + for f.nb < n { + if err = moreBits(); err != nil { + if debugDecode { + fmt.Println("morebits n>0:", err) + } + f.err = err + return + } + } + length += int(f.b & uint32(1<>= n + f.nb -= n + } + + var dist int + if f.hd == nil { + for f.nb < 5 { + if err = moreBits(); err != nil { + if debugDecode { + fmt.Println("morebits f.nb<5:", err) + } + f.err = err + return + } + } + dist = int(bits.Reverse8(uint8(f.b & 0x1F << 3))) + f.b >>= 5 + f.nb -= 5 + } else { + if dist, err = f.huffSym(f.hd); err != nil { + if debugDecode { + fmt.Println("huffsym:", err) + } + f.err = err + return + } + } + + switch { + case dist < 4: + dist++ + case dist < maxNumDist: + nb := uint(dist-2) >> 1 + // have 1 bit in bottom of dist, need nb more. + extra := (dist & 1) << nb + for f.nb < nb { + if err = moreBits(); err != nil { + if debugDecode { + fmt.Println("morebits f.nb>= nb + f.nb -= nb + dist = 1<<(nb+1) + 1 + extra + default: + if debugDecode { + fmt.Println("dist too big:", dist, maxNumDist) + } + f.err = CorruptInputError(f.roffset) + return + } + + // No check on length; encoding can be prescient. + if dist > f.dict.histSize() { + if debugDecode { + fmt.Println("dist > f.dict.histSize():", dist, f.dict.histSize()) + } + f.err = CorruptInputError(f.roffset) + return + } + + f.copyLen, f.copyDist = length, dist + goto copyHistory + } + +copyHistory: + // Perform a backwards copy according to RFC section 3.2.3. + { + cnt := f.dict.tryWriteCopy(f.copyDist, f.copyLen) + if cnt == 0 { + cnt = f.dict.writeCopy(f.copyDist, f.copyLen) + } + f.copyLen -= cnt + + if f.dict.availWrite() == 0 || f.copyLen > 0 { + f.toRead = f.dict.readFlush() + f.step = (*decompressor).$FUNCNAME$ // We need to continue this work + f.stepState = stateDict + return + } + goto readLiteral + } +} + +` + for i, t := range types { + s := strings.Replace(template, "$FUNCNAME$", "huffman"+names[i], -1) + s = strings.Replace(s, "$TYPE$", t, -1) + f.WriteString(s) + } + f.WriteString("func (f *decompressor) huffmanBlockDecoder() func() {\n") + f.WriteString("\tswitch f.r.(type) {\n") + for i, t := range types { + f.WriteString("\t\tcase " + t + ":\n") + f.WriteString("\t\t\treturn f.huffman" + names[i] + "\n") + } + f.WriteString("\t\tdefault:\n") + f.WriteString("\t\t\treturn f.huffmanBlockGeneric") + f.WriteString("\t}\n}\n") +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/huffman_bit_writer.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/huffman_bit_writer.go new file mode 100644 index 000000000000..53fe1d06e25a --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/huffman_bit_writer.go @@ -0,0 +1,911 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package flate + +import ( + "io" +) + +const ( + // The largest offset code. + offsetCodeCount = 30 + + // The special code used to mark the end of a block. + endBlockMarker = 256 + + // The first length code. + lengthCodesStart = 257 + + // The number of codegen codes. + codegenCodeCount = 19 + badCode = 255 + + // bufferFlushSize indicates the buffer size + // after which bytes are flushed to the writer. + // Should preferably be a multiple of 6, since + // we accumulate 6 bytes between writes to the buffer. + bufferFlushSize = 240 + + // bufferSize is the actual output byte buffer size. + // It must have additional headroom for a flush + // which can contain up to 8 bytes. + bufferSize = bufferFlushSize + 8 +) + +// The number of extra bits needed by length code X - LENGTH_CODES_START. +var lengthExtraBits = [32]int8{ + /* 257 */ 0, 0, 0, + /* 260 */ 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, + /* 270 */ 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, + /* 280 */ 4, 5, 5, 5, 5, 0, +} + +// The length indicated by length code X - LENGTH_CODES_START. +var lengthBase = [32]uint8{ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 10, + 12, 14, 16, 20, 24, 28, 32, 40, 48, 56, + 64, 80, 96, 112, 128, 160, 192, 224, 255, +} + +// offset code word extra bits. +var offsetExtraBits = [64]int8{ + 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, + 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, + 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, + /* extended window */ + 14, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 20, 20, +} + +var offsetBase = [64]uint32{ + /* normal deflate */ + 0x000000, 0x000001, 0x000002, 0x000003, 0x000004, + 0x000006, 0x000008, 0x00000c, 0x000010, 0x000018, + 0x000020, 0x000030, 0x000040, 0x000060, 0x000080, + 0x0000c0, 0x000100, 0x000180, 0x000200, 0x000300, + 0x000400, 0x000600, 0x000800, 0x000c00, 0x001000, + 0x001800, 0x002000, 0x003000, 0x004000, 0x006000, + + /* extended window */ + 0x008000, 0x00c000, 0x010000, 0x018000, 0x020000, + 0x030000, 0x040000, 0x060000, 0x080000, 0x0c0000, + 0x100000, 0x180000, 0x200000, 0x300000, +} + +// The odd order in which the codegen code sizes are written. +var codegenOrder = []uint32{16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15} + +type huffmanBitWriter struct { + // writer is the underlying writer. + // Do not use it directly; use the write method, which ensures + // that Write errors are sticky. + writer io.Writer + + // Data waiting to be written is bytes[0:nbytes] + // and then the low nbits of bits. + bits uint64 + nbits uint16 + nbytes uint8 + literalEncoding *huffmanEncoder + offsetEncoding *huffmanEncoder + codegenEncoding *huffmanEncoder + err error + lastHeader int + // Set between 0 (reused block can be up to 2x the size) + logNewTablePenalty uint + lastHuffMan bool + bytes [256]byte + literalFreq [lengthCodesStart + 32]uint16 + offsetFreq [32]uint16 + codegenFreq [codegenCodeCount]uint16 + + // codegen must have an extra space for the final symbol. + codegen [literalCount + offsetCodeCount + 1]uint8 +} + +// Huffman reuse. +// +// The huffmanBitWriter supports reusing huffman tables and thereby combining block sections. +// +// This is controlled by several variables: +// +// If lastHeader is non-zero the Huffman table can be reused. +// This also indicates that a Huffman table has been generated that can output all +// possible symbols. +// It also indicates that an EOB has not yet been emitted, so if a new tabel is generated +// an EOB with the previous table must be written. +// +// If lastHuffMan is set, a table for outputting literals has been generated and offsets are invalid. +// +// An incoming block estimates the output size of a new table using a 'fresh' by calculating the +// optimal size and adding a penalty in 'logNewTablePenalty'. +// A Huffman table is not optimal, which is why we add a penalty, and generating a new table +// is slower both for compression and decompression. + +func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter { + return &huffmanBitWriter{ + writer: w, + literalEncoding: newHuffmanEncoder(literalCount), + codegenEncoding: newHuffmanEncoder(codegenCodeCount), + offsetEncoding: newHuffmanEncoder(offsetCodeCount), + } +} + +func (w *huffmanBitWriter) reset(writer io.Writer) { + w.writer = writer + w.bits, w.nbits, w.nbytes, w.err = 0, 0, 0, nil + w.lastHeader = 0 + w.lastHuffMan = false +} + +func (w *huffmanBitWriter) canReuse(t *tokens) (offsets, lits bool) { + offsets, lits = true, true + a := t.offHist[:offsetCodeCount] + b := w.offsetFreq[:len(a)] + for i := range a { + if b[i] == 0 && a[i] != 0 { + offsets = false + break + } + } + + a = t.extraHist[:literalCount-256] + b = w.literalFreq[256:literalCount] + b = b[:len(a)] + for i := range a { + if b[i] == 0 && a[i] != 0 { + lits = false + break + } + } + if lits { + a = t.litHist[:] + b = w.literalFreq[:len(a)] + for i := range a { + if b[i] == 0 && a[i] != 0 { + lits = false + break + } + } + } + return +} + +func (w *huffmanBitWriter) flush() { + if w.err != nil { + w.nbits = 0 + return + } + if w.lastHeader > 0 { + // We owe an EOB + w.writeCode(w.literalEncoding.codes[endBlockMarker]) + w.lastHeader = 0 + } + n := w.nbytes + for w.nbits != 0 { + w.bytes[n] = byte(w.bits) + w.bits >>= 8 + if w.nbits > 8 { // Avoid underflow + w.nbits -= 8 + } else { + w.nbits = 0 + } + n++ + } + w.bits = 0 + w.write(w.bytes[:n]) + w.nbytes = 0 +} + +func (w *huffmanBitWriter) write(b []byte) { + if w.err != nil { + return + } + _, w.err = w.writer.Write(b) +} + +func (w *huffmanBitWriter) writeBits(b int32, nb uint16) { + w.bits |= uint64(b) << (w.nbits & 63) + w.nbits += nb + if w.nbits >= 48 { + w.writeOutBits() + } +} + +func (w *huffmanBitWriter) writeBytes(bytes []byte) { + if w.err != nil { + return + } + n := w.nbytes + if w.nbits&7 != 0 { + w.err = InternalError("writeBytes with unfinished bits") + return + } + for w.nbits != 0 { + w.bytes[n] = byte(w.bits) + w.bits >>= 8 + w.nbits -= 8 + n++ + } + if n != 0 { + w.write(w.bytes[:n]) + } + w.nbytes = 0 + w.write(bytes) +} + +// RFC 1951 3.2.7 specifies a special run-length encoding for specifying +// the literal and offset lengths arrays (which are concatenated into a single +// array). This method generates that run-length encoding. +// +// The result is written into the codegen array, and the frequencies +// of each code is written into the codegenFreq array. +// Codes 0-15 are single byte codes. Codes 16-18 are followed by additional +// information. Code badCode is an end marker +// +// numLiterals The number of literals in literalEncoding +// numOffsets The number of offsets in offsetEncoding +// litenc, offenc The literal and offset encoder to use +func (w *huffmanBitWriter) generateCodegen(numLiterals int, numOffsets int, litEnc, offEnc *huffmanEncoder) { + for i := range w.codegenFreq { + w.codegenFreq[i] = 0 + } + // Note that we are using codegen both as a temporary variable for holding + // a copy of the frequencies, and as the place where we put the result. + // This is fine because the output is always shorter than the input used + // so far. + codegen := w.codegen[:] // cache + // Copy the concatenated code sizes to codegen. Put a marker at the end. + cgnl := codegen[:numLiterals] + for i := range cgnl { + cgnl[i] = uint8(litEnc.codes[i].len) + } + + cgnl = codegen[numLiterals : numLiterals+numOffsets] + for i := range cgnl { + cgnl[i] = uint8(offEnc.codes[i].len) + } + codegen[numLiterals+numOffsets] = badCode + + size := codegen[0] + count := 1 + outIndex := 0 + for inIndex := 1; size != badCode; inIndex++ { + // INVARIANT: We have seen "count" copies of size that have not yet + // had output generated for them. + nextSize := codegen[inIndex] + if nextSize == size { + count++ + continue + } + // We need to generate codegen indicating "count" of size. + if size != 0 { + codegen[outIndex] = size + outIndex++ + w.codegenFreq[size]++ + count-- + for count >= 3 { + n := 6 + if n > count { + n = count + } + codegen[outIndex] = 16 + outIndex++ + codegen[outIndex] = uint8(n - 3) + outIndex++ + w.codegenFreq[16]++ + count -= n + } + } else { + for count >= 11 { + n := 138 + if n > count { + n = count + } + codegen[outIndex] = 18 + outIndex++ + codegen[outIndex] = uint8(n - 11) + outIndex++ + w.codegenFreq[18]++ + count -= n + } + if count >= 3 { + // count >= 3 && count <= 10 + codegen[outIndex] = 17 + outIndex++ + codegen[outIndex] = uint8(count - 3) + outIndex++ + w.codegenFreq[17]++ + count = 0 + } + } + count-- + for ; count >= 0; count-- { + codegen[outIndex] = size + outIndex++ + w.codegenFreq[size]++ + } + // Set up invariant for next time through the loop. + size = nextSize + count = 1 + } + // Marker indicating the end of the codegen. + codegen[outIndex] = badCode +} + +func (w *huffmanBitWriter) codegens() int { + numCodegens := len(w.codegenFreq) + for numCodegens > 4 && w.codegenFreq[codegenOrder[numCodegens-1]] == 0 { + numCodegens-- + } + return numCodegens +} + +func (w *huffmanBitWriter) headerSize() (size, numCodegens int) { + numCodegens = len(w.codegenFreq) + for numCodegens > 4 && w.codegenFreq[codegenOrder[numCodegens-1]] == 0 { + numCodegens-- + } + return 3 + 5 + 5 + 4 + (3 * numCodegens) + + w.codegenEncoding.bitLength(w.codegenFreq[:]) + + int(w.codegenFreq[16])*2 + + int(w.codegenFreq[17])*3 + + int(w.codegenFreq[18])*7, numCodegens +} + +// dynamicSize returns the size of dynamically encoded data in bits. +func (w *huffmanBitWriter) dynamicReuseSize(litEnc, offEnc *huffmanEncoder) (size int) { + size = litEnc.bitLength(w.literalFreq[:]) + + offEnc.bitLength(w.offsetFreq[:]) + return size +} + +// dynamicSize returns the size of dynamically encoded data in bits. +func (w *huffmanBitWriter) dynamicSize(litEnc, offEnc *huffmanEncoder, extraBits int) (size, numCodegens int) { + header, numCodegens := w.headerSize() + size = header + + litEnc.bitLength(w.literalFreq[:]) + + offEnc.bitLength(w.offsetFreq[:]) + + extraBits + return size, numCodegens +} + +// extraBitSize will return the number of bits that will be written +// as "extra" bits on matches. +func (w *huffmanBitWriter) extraBitSize() int { + total := 0 + for i, n := range w.literalFreq[257:literalCount] { + total += int(n) * int(lengthExtraBits[i&31]) + } + for i, n := range w.offsetFreq[:offsetCodeCount] { + total += int(n) * int(offsetExtraBits[i&31]) + } + return total +} + +// fixedSize returns the size of dynamically encoded data in bits. +func (w *huffmanBitWriter) fixedSize(extraBits int) int { + return 3 + + fixedLiteralEncoding.bitLength(w.literalFreq[:]) + + fixedOffsetEncoding.bitLength(w.offsetFreq[:]) + + extraBits +} + +// storedSize calculates the stored size, including header. +// The function returns the size in bits and whether the block +// fits inside a single block. +func (w *huffmanBitWriter) storedSize(in []byte) (int, bool) { + if in == nil { + return 0, false + } + if len(in) <= maxStoreBlockSize { + return (len(in) + 5) * 8, true + } + return 0, false +} + +func (w *huffmanBitWriter) writeCode(c hcode) { + // The function does not get inlined if we "& 63" the shift. + w.bits |= uint64(c.code) << w.nbits + w.nbits += c.len + if w.nbits >= 48 { + w.writeOutBits() + } +} + +// writeOutBits will write bits to the buffer. +func (w *huffmanBitWriter) writeOutBits() { + bits := w.bits + w.bits >>= 48 + w.nbits -= 48 + n := w.nbytes + w.bytes[n] = byte(bits) + w.bytes[n+1] = byte(bits >> 8) + w.bytes[n+2] = byte(bits >> 16) + w.bytes[n+3] = byte(bits >> 24) + w.bytes[n+4] = byte(bits >> 32) + w.bytes[n+5] = byte(bits >> 40) + n += 6 + if n >= bufferFlushSize { + if w.err != nil { + n = 0 + return + } + w.write(w.bytes[:n]) + n = 0 + } + w.nbytes = n +} + +// Write the header of a dynamic Huffman block to the output stream. +// +// numLiterals The number of literals specified in codegen +// numOffsets The number of offsets specified in codegen +// numCodegens The number of codegens used in codegen +func (w *huffmanBitWriter) writeDynamicHeader(numLiterals int, numOffsets int, numCodegens int, isEof bool) { + if w.err != nil { + return + } + var firstBits int32 = 4 + if isEof { + firstBits = 5 + } + w.writeBits(firstBits, 3) + w.writeBits(int32(numLiterals-257), 5) + w.writeBits(int32(numOffsets-1), 5) + w.writeBits(int32(numCodegens-4), 4) + + for i := 0; i < numCodegens; i++ { + value := uint(w.codegenEncoding.codes[codegenOrder[i]].len) + w.writeBits(int32(value), 3) + } + + i := 0 + for { + var codeWord = uint32(w.codegen[i]) + i++ + if codeWord == badCode { + break + } + w.writeCode(w.codegenEncoding.codes[codeWord]) + + switch codeWord { + case 16: + w.writeBits(int32(w.codegen[i]), 2) + i++ + case 17: + w.writeBits(int32(w.codegen[i]), 3) + i++ + case 18: + w.writeBits(int32(w.codegen[i]), 7) + i++ + } + } +} + +// writeStoredHeader will write a stored header. +// If the stored block is only used for EOF, +// it is replaced with a fixed huffman block. +func (w *huffmanBitWriter) writeStoredHeader(length int, isEof bool) { + if w.err != nil { + return + } + if w.lastHeader > 0 { + // We owe an EOB + w.writeCode(w.literalEncoding.codes[endBlockMarker]) + w.lastHeader = 0 + } + + // To write EOF, use a fixed encoding block. 10 bits instead of 5 bytes. + if length == 0 && isEof { + w.writeFixedHeader(isEof) + // EOB: 7 bits, value: 0 + w.writeBits(0, 7) + w.flush() + return + } + + var flag int32 + if isEof { + flag = 1 + } + w.writeBits(flag, 3) + w.flush() + w.writeBits(int32(length), 16) + w.writeBits(int32(^uint16(length)), 16) +} + +func (w *huffmanBitWriter) writeFixedHeader(isEof bool) { + if w.err != nil { + return + } + if w.lastHeader > 0 { + // We owe an EOB + w.writeCode(w.literalEncoding.codes[endBlockMarker]) + w.lastHeader = 0 + } + + // Indicate that we are a fixed Huffman block + var value int32 = 2 + if isEof { + value = 3 + } + w.writeBits(value, 3) +} + +// writeBlock will write a block of tokens with the smallest encoding. +// The original input can be supplied, and if the huffman encoded data +// is larger than the original bytes, the data will be written as a +// stored block. +// If the input is nil, the tokens will always be Huffman encoded. +func (w *huffmanBitWriter) writeBlock(tokens *tokens, eof bool, input []byte) { + if w.err != nil { + return + } + + tokens.AddEOB() + if w.lastHeader > 0 { + // We owe an EOB + w.writeCode(w.literalEncoding.codes[endBlockMarker]) + w.lastHeader = 0 + } + numLiterals, numOffsets := w.indexTokens(tokens, false) + w.generate(tokens) + var extraBits int + storedSize, storable := w.storedSize(input) + if storable { + extraBits = w.extraBitSize() + } + + // Figure out smallest code. + // Fixed Huffman baseline. + var literalEncoding = fixedLiteralEncoding + var offsetEncoding = fixedOffsetEncoding + var size = w.fixedSize(extraBits) + + // Dynamic Huffman? + var numCodegens int + + // Generate codegen and codegenFrequencies, which indicates how to encode + // the literalEncoding and the offsetEncoding. + w.generateCodegen(numLiterals, numOffsets, w.literalEncoding, w.offsetEncoding) + w.codegenEncoding.generate(w.codegenFreq[:], 7) + dynamicSize, numCodegens := w.dynamicSize(w.literalEncoding, w.offsetEncoding, extraBits) + + if dynamicSize < size { + size = dynamicSize + literalEncoding = w.literalEncoding + offsetEncoding = w.offsetEncoding + } + + // Stored bytes? + if storable && storedSize < size { + w.writeStoredHeader(len(input), eof) + w.writeBytes(input) + return + } + + // Huffman. + if literalEncoding == fixedLiteralEncoding { + w.writeFixedHeader(eof) + } else { + w.writeDynamicHeader(numLiterals, numOffsets, numCodegens, eof) + } + + // Write the tokens. + w.writeTokens(tokens.Slice(), literalEncoding.codes, offsetEncoding.codes) +} + +// writeBlockDynamic encodes a block using a dynamic Huffman table. +// This should be used if the symbols used have a disproportionate +// histogram distribution. +// If input is supplied and the compression savings are below 1/16th of the +// input size the block is stored. +func (w *huffmanBitWriter) writeBlockDynamic(tokens *tokens, eof bool, input []byte, sync bool) { + if w.err != nil { + return + } + + sync = sync || eof + if sync { + tokens.AddEOB() + } + + // We cannot reuse pure huffman table, and must mark as EOF. + if (w.lastHuffMan || eof) && w.lastHeader > 0 { + // We will not try to reuse. + w.writeCode(w.literalEncoding.codes[endBlockMarker]) + w.lastHeader = 0 + w.lastHuffMan = false + } + if !sync { + tokens.Fill() + } + numLiterals, numOffsets := w.indexTokens(tokens, !sync) + + var size int + // Check if we should reuse. + if w.lastHeader > 0 { + // Estimate size for using a new table. + // Use the previous header size as the best estimate. + newSize := w.lastHeader + tokens.EstimatedBits() + newSize += newSize >> w.logNewTablePenalty + + // The estimated size is calculated as an optimal table. + // We add a penalty to make it more realistic and re-use a bit more. + reuseSize := w.dynamicReuseSize(w.literalEncoding, w.offsetEncoding) + w.extraBitSize() + + // Check if a new table is better. + if newSize < reuseSize { + // Write the EOB we owe. + w.writeCode(w.literalEncoding.codes[endBlockMarker]) + size = newSize + w.lastHeader = 0 + } else { + size = reuseSize + } + // Check if we get a reasonable size decrease. + if ssize, storable := w.storedSize(input); storable && ssize < (size+size>>4) { + w.writeStoredHeader(len(input), eof) + w.writeBytes(input) + w.lastHeader = 0 + return + } + } + + // We want a new block/table + if w.lastHeader == 0 { + w.generate(tokens) + // Generate codegen and codegenFrequencies, which indicates how to encode + // the literalEncoding and the offsetEncoding. + w.generateCodegen(numLiterals, numOffsets, w.literalEncoding, w.offsetEncoding) + w.codegenEncoding.generate(w.codegenFreq[:], 7) + var numCodegens int + size, numCodegens = w.dynamicSize(w.literalEncoding, w.offsetEncoding, w.extraBitSize()) + // Store bytes, if we don't get a reasonable improvement. + if ssize, storable := w.storedSize(input); storable && ssize < (size+size>>4) { + w.writeStoredHeader(len(input), eof) + w.writeBytes(input) + w.lastHeader = 0 + return + } + + // Write Huffman table. + w.writeDynamicHeader(numLiterals, numOffsets, numCodegens, eof) + w.lastHeader, _ = w.headerSize() + w.lastHuffMan = false + } + + if sync { + w.lastHeader = 0 + } + // Write the tokens. + w.writeTokens(tokens.Slice(), w.literalEncoding.codes, w.offsetEncoding.codes) +} + +// indexTokens indexes a slice of tokens, and updates +// literalFreq and offsetFreq, and generates literalEncoding +// and offsetEncoding. +// The number of literal and offset tokens is returned. +func (w *huffmanBitWriter) indexTokens(t *tokens, filled bool) (numLiterals, numOffsets int) { + copy(w.literalFreq[:], t.litHist[:]) + copy(w.literalFreq[256:], t.extraHist[:]) + copy(w.offsetFreq[:], t.offHist[:offsetCodeCount]) + + if t.n == 0 { + return + } + if filled { + return maxNumLit, maxNumDist + } + // get the number of literals + numLiterals = len(w.literalFreq) + for w.literalFreq[numLiterals-1] == 0 { + numLiterals-- + } + // get the number of offsets + numOffsets = len(w.offsetFreq) + for numOffsets > 0 && w.offsetFreq[numOffsets-1] == 0 { + numOffsets-- + } + if numOffsets == 0 { + // We haven't found a single match. If we want to go with the dynamic encoding, + // we should count at least one offset to be sure that the offset huffman tree could be encoded. + w.offsetFreq[0] = 1 + numOffsets = 1 + } + return +} + +func (w *huffmanBitWriter) generate(t *tokens) { + w.literalEncoding.generate(w.literalFreq[:literalCount], 15) + w.offsetEncoding.generate(w.offsetFreq[:offsetCodeCount], 15) +} + +// writeTokens writes a slice of tokens to the output. +// codes for literal and offset encoding must be supplied. +func (w *huffmanBitWriter) writeTokens(tokens []token, leCodes, oeCodes []hcode) { + if w.err != nil { + return + } + if len(tokens) == 0 { + return + } + + // Only last token should be endBlockMarker. + var deferEOB bool + if tokens[len(tokens)-1] == endBlockMarker { + tokens = tokens[:len(tokens)-1] + deferEOB = true + } + + // Create slices up to the next power of two to avoid bounds checks. + lits := leCodes[:256] + offs := oeCodes[:32] + lengths := leCodes[lengthCodesStart:] + lengths = lengths[:32] + for _, t := range tokens { + if t < matchType { + w.writeCode(lits[t.literal()]) + continue + } + + // Write the length + length := t.length() + lengthCode := lengthCode(length) + if false { + w.writeCode(lengths[lengthCode&31]) + } else { + // inlined + c := lengths[lengthCode&31] + w.bits |= uint64(c.code) << (w.nbits & 63) + w.nbits += c.len + if w.nbits >= 48 { + w.writeOutBits() + } + } + + extraLengthBits := uint16(lengthExtraBits[lengthCode&31]) + if extraLengthBits > 0 { + extraLength := int32(length - lengthBase[lengthCode&31]) + w.writeBits(extraLength, extraLengthBits) + } + // Write the offset + offset := t.offset() + offsetCode := offsetCode(offset) + if false { + w.writeCode(offs[offsetCode&31]) + } else { + // inlined + c := offs[offsetCode&31] + w.bits |= uint64(c.code) << (w.nbits & 63) + w.nbits += c.len + if w.nbits >= 48 { + w.writeOutBits() + } + } + extraOffsetBits := uint16(offsetExtraBits[offsetCode&63]) + if extraOffsetBits > 0 { + extraOffset := int32(offset - offsetBase[offsetCode&63]) + w.writeBits(extraOffset, extraOffsetBits) + } + } + if deferEOB { + w.writeCode(leCodes[endBlockMarker]) + } +} + +// huffOffset is a static offset encoder used for huffman only encoding. +// It can be reused since we will not be encoding offset values. +var huffOffset *huffmanEncoder + +func init() { + w := newHuffmanBitWriter(nil) + w.offsetFreq[0] = 1 + huffOffset = newHuffmanEncoder(offsetCodeCount) + huffOffset.generate(w.offsetFreq[:offsetCodeCount], 15) +} + +// writeBlockHuff encodes a block of bytes as either +// Huffman encoded literals or uncompressed bytes if the +// results only gains very little from compression. +func (w *huffmanBitWriter) writeBlockHuff(eof bool, input []byte, sync bool) { + if w.err != nil { + return + } + + // Clear histogram + for i := range w.literalFreq[:] { + w.literalFreq[i] = 0 + } + if !w.lastHuffMan { + for i := range w.offsetFreq[:] { + w.offsetFreq[i] = 0 + } + } + + // Add everything as literals + // We have to estimate the header size. + // Assume header is around 70 bytes: + // https://stackoverflow.com/a/25454430 + const guessHeaderSizeBits = 70 * 8 + estBits, estExtra := histogramSize(input, w.literalFreq[:], !eof && !sync) + estBits += w.lastHeader + 15 + if w.lastHeader == 0 { + estBits += guessHeaderSizeBits + } + estBits += estBits >> w.logNewTablePenalty + + // Store bytes, if we don't get a reasonable improvement. + ssize, storable := w.storedSize(input) + if storable && ssize < estBits { + w.writeStoredHeader(len(input), eof) + w.writeBytes(input) + return + } + + if w.lastHeader > 0 { + reuseSize := w.literalEncoding.bitLength(w.literalFreq[:256]) + estBits += estExtra + + if estBits < reuseSize { + // We owe an EOB + w.writeCode(w.literalEncoding.codes[endBlockMarker]) + w.lastHeader = 0 + } + } + + const numLiterals = endBlockMarker + 1 + const numOffsets = 1 + if w.lastHeader == 0 { + w.literalFreq[endBlockMarker] = 1 + w.literalEncoding.generate(w.literalFreq[:numLiterals], 15) + + // Generate codegen and codegenFrequencies, which indicates how to encode + // the literalEncoding and the offsetEncoding. + w.generateCodegen(numLiterals, numOffsets, w.literalEncoding, huffOffset) + w.codegenEncoding.generate(w.codegenFreq[:], 7) + numCodegens := w.codegens() + + // Huffman. + w.writeDynamicHeader(numLiterals, numOffsets, numCodegens, eof) + w.lastHuffMan = true + w.lastHeader, _ = w.headerSize() + } + + encoding := w.literalEncoding.codes[:257] + for _, t := range input { + // Bitwriting inlined, ~30% speedup + c := encoding[t] + w.bits |= uint64(c.code) << ((w.nbits) & 63) + w.nbits += c.len + if w.nbits >= 48 { + bits := w.bits + w.bits >>= 48 + w.nbits -= 48 + n := w.nbytes + w.bytes[n] = byte(bits) + w.bytes[n+1] = byte(bits >> 8) + w.bytes[n+2] = byte(bits >> 16) + w.bytes[n+3] = byte(bits >> 24) + w.bytes[n+4] = byte(bits >> 32) + w.bytes[n+5] = byte(bits >> 40) + n += 6 + if n >= bufferFlushSize { + if w.err != nil { + n = 0 + return + } + w.write(w.bytes[:n]) + n = 0 + } + w.nbytes = n + } + } + if eof || sync { + w.writeCode(encoding[endBlockMarker]) + w.lastHeader = 0 + w.lastHuffMan = false + } +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/huffman_code.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/huffman_code.go new file mode 100644 index 000000000000..4c39a3018711 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/huffman_code.go @@ -0,0 +1,363 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package flate + +import ( + "math" + "math/bits" +) + +const ( + maxBitsLimit = 16 + // number of valid literals + literalCount = 286 +) + +// hcode is a huffman code with a bit code and bit length. +type hcode struct { + code, len uint16 +} + +type huffmanEncoder struct { + codes []hcode + freqcache []literalNode + bitCount [17]int32 +} + +type literalNode struct { + literal uint16 + freq uint16 +} + +// A levelInfo describes the state of the constructed tree for a given depth. +type levelInfo struct { + // Our level. for better printing + level int32 + + // The frequency of the last node at this level + lastFreq int32 + + // The frequency of the next character to add to this level + nextCharFreq int32 + + // The frequency of the next pair (from level below) to add to this level. + // Only valid if the "needed" value of the next lower level is 0. + nextPairFreq int32 + + // The number of chains remaining to generate for this level before moving + // up to the next level + needed int32 +} + +// set sets the code and length of an hcode. +func (h *hcode) set(code uint16, length uint16) { + h.len = length + h.code = code +} + +func reverseBits(number uint16, bitLength byte) uint16 { + return bits.Reverse16(number << ((16 - bitLength) & 15)) +} + +func maxNode() literalNode { return literalNode{math.MaxUint16, math.MaxUint16} } + +func newHuffmanEncoder(size int) *huffmanEncoder { + // Make capacity to next power of two. + c := uint(bits.Len32(uint32(size - 1))) + return &huffmanEncoder{codes: make([]hcode, size, 1<= 3 +// The cases of 0, 1, and 2 literals are handled by special case code. +// +// list An array of the literals with non-zero frequencies +// and their associated frequencies. The array is in order of increasing +// frequency, and has as its last element a special element with frequency +// MaxInt32 +// maxBits The maximum number of bits that should be used to encode any literal. +// Must be less than 16. +// return An integer array in which array[i] indicates the number of literals +// that should be encoded in i bits. +func (h *huffmanEncoder) bitCounts(list []literalNode, maxBits int32) []int32 { + if maxBits >= maxBitsLimit { + panic("flate: maxBits too large") + } + n := int32(len(list)) + list = list[0 : n+1] + list[n] = maxNode() + + // The tree can't have greater depth than n - 1, no matter what. This + // saves a little bit of work in some small cases + if maxBits > n-1 { + maxBits = n - 1 + } + + // Create information about each of the levels. + // A bogus "Level 0" whose sole purpose is so that + // level1.prev.needed==0. This makes level1.nextPairFreq + // be a legitimate value that never gets chosen. + var levels [maxBitsLimit]levelInfo + // leafCounts[i] counts the number of literals at the left + // of ancestors of the rightmost node at level i. + // leafCounts[i][j] is the number of literals at the left + // of the level j ancestor. + var leafCounts [maxBitsLimit][maxBitsLimit]int32 + + for level := int32(1); level <= maxBits; level++ { + // For every level, the first two items are the first two characters. + // We initialize the levels as if we had already figured this out. + levels[level] = levelInfo{ + level: level, + lastFreq: int32(list[1].freq), + nextCharFreq: int32(list[2].freq), + nextPairFreq: int32(list[0].freq) + int32(list[1].freq), + } + leafCounts[level][level] = 2 + if level == 1 { + levels[level].nextPairFreq = math.MaxInt32 + } + } + + // We need a total of 2*n - 2 items at top level and have already generated 2. + levels[maxBits].needed = 2*n - 4 + + level := maxBits + for { + l := &levels[level] + if l.nextPairFreq == math.MaxInt32 && l.nextCharFreq == math.MaxInt32 { + // We've run out of both leafs and pairs. + // End all calculations for this level. + // To make sure we never come back to this level or any lower level, + // set nextPairFreq impossibly large. + l.needed = 0 + levels[level+1].nextPairFreq = math.MaxInt32 + level++ + continue + } + + prevFreq := l.lastFreq + if l.nextCharFreq < l.nextPairFreq { + // The next item on this row is a leaf node. + n := leafCounts[level][level] + 1 + l.lastFreq = l.nextCharFreq + // Lower leafCounts are the same of the previous node. + leafCounts[level][level] = n + e := list[n] + if e.literal < math.MaxUint16 { + l.nextCharFreq = int32(e.freq) + } else { + l.nextCharFreq = math.MaxInt32 + } + } else { + // The next item on this row is a pair from the previous row. + // nextPairFreq isn't valid until we generate two + // more values in the level below + l.lastFreq = l.nextPairFreq + // Take leaf counts from the lower level, except counts[level] remains the same. + copy(leafCounts[level][:level], leafCounts[level-1][:level]) + levels[l.level-1].needed = 2 + } + + if l.needed--; l.needed == 0 { + // We've done everything we need to do for this level. + // Continue calculating one level up. Fill in nextPairFreq + // of that level with the sum of the two nodes we've just calculated on + // this level. + if l.level == maxBits { + // All done! + break + } + levels[l.level+1].nextPairFreq = prevFreq + l.lastFreq + level++ + } else { + // If we stole from below, move down temporarily to replenish it. + for levels[level-1].needed > 0 { + level-- + } + } + } + + // Somethings is wrong if at the end, the top level is null or hasn't used + // all of the leaves. + if leafCounts[maxBits][maxBits] != n { + panic("leafCounts[maxBits][maxBits] != n") + } + + bitCount := h.bitCount[:maxBits+1] + bits := 1 + counts := &leafCounts[maxBits] + for level := maxBits; level > 0; level-- { + // chain.leafCount gives the number of literals requiring at least "bits" + // bits to encode. + bitCount[bits] = counts[level] - counts[level-1] + bits++ + } + return bitCount +} + +// Look at the leaves and assign them a bit count and an encoding as specified +// in RFC 1951 3.2.2 +func (h *huffmanEncoder) assignEncodingAndSize(bitCount []int32, list []literalNode) { + code := uint16(0) + for n, bits := range bitCount { + code <<= 1 + if n == 0 || bits == 0 { + continue + } + // The literals list[len(list)-bits] .. list[len(list)-bits] + // are encoded using "bits" bits, and get the values + // code, code + 1, .... The code values are + // assigned in literal order (not frequency order). + chunk := list[len(list)-int(bits):] + + sortByLiteral(chunk) + for _, node := range chunk { + h.codes[node.literal] = hcode{code: reverseBits(code, uint8(n)), len: uint16(n)} + code++ + } + list = list[0 : len(list)-int(bits)] + } +} + +// Update this Huffman Code object to be the minimum code for the specified frequency count. +// +// freq An array of frequencies, in which frequency[i] gives the frequency of literal i. +// maxBits The maximum number of bits to use for any literal. +func (h *huffmanEncoder) generate(freq []uint16, maxBits int32) { + if h.freqcache == nil { + // Allocate a reusable buffer with the longest possible frequency table. + // Possible lengths are codegenCodeCount, offsetCodeCount and literalCount. + // The largest of these is literalCount, so we allocate for that case. + h.freqcache = make([]literalNode, literalCount+1) + } + list := h.freqcache[:len(freq)+1] + // Number of non-zero literals + count := 0 + // Set list to be the set of all non-zero literals and their frequencies + for i, f := range freq { + if f != 0 { + list[count] = literalNode{uint16(i), f} + count++ + } else { + list[count] = literalNode{} + h.codes[i].len = 0 + } + } + list[len(freq)] = literalNode{} + + list = list[:count] + if count <= 2 { + // Handle the small cases here, because they are awkward for the general case code. With + // two or fewer literals, everything has bit length 1. + for i, node := range list { + // "list" is in order of increasing literal value. + h.codes[node.literal].set(uint16(i), 1) + } + return + } + sortByFreq(list) + + // Get the number of literals for each bit count + bitCount := h.bitCounts(list, maxBits) + // And do the assignment + h.assignEncodingAndSize(bitCount, list) +} + +func atLeastOne(v float32) float32 { + if v < 1 { + return 1 + } + return v +} + +// histogramSize accumulates a histogram of b in h. +// An estimated size in bits is returned. +// Unassigned values are assigned '1' in the histogram. +// len(h) must be >= 256, and h's elements must be all zeroes. +func histogramSize(b []byte, h []uint16, fill bool) (int, int) { + h = h[:256] + for _, t := range b { + h[t]++ + } + invTotal := 1.0 / float32(len(b)) + shannon := float32(0.0) + var extra float32 + if fill { + oneBits := atLeastOne(-mFastLog2(invTotal)) + for i, v := range h[:] { + if v > 0 { + n := float32(v) + shannon += atLeastOne(-mFastLog2(n*invTotal)) * n + } else { + h[i] = 1 + extra += oneBits + } + } + } else { + for _, v := range h[:] { + if v > 0 { + n := float32(v) + shannon += atLeastOne(-mFastLog2(n*invTotal)) * n + } + } + } + + return int(shannon + 0.99), int(extra + 0.99) +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/huffman_sortByFreq.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/huffman_sortByFreq.go new file mode 100644 index 000000000000..207780299007 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/huffman_sortByFreq.go @@ -0,0 +1,178 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package flate + +// Sort sorts data. +// It makes one call to data.Len to determine n, and O(n*log(n)) calls to +// data.Less and data.Swap. The sort is not guaranteed to be stable. +func sortByFreq(data []literalNode) { + n := len(data) + quickSortByFreq(data, 0, n, maxDepth(n)) +} + +func quickSortByFreq(data []literalNode, a, b, maxDepth int) { + for b-a > 12 { // Use ShellSort for slices <= 12 elements + if maxDepth == 0 { + heapSort(data, a, b) + return + } + maxDepth-- + mlo, mhi := doPivotByFreq(data, a, b) + // Avoiding recursion on the larger subproblem guarantees + // a stack depth of at most lg(b-a). + if mlo-a < b-mhi { + quickSortByFreq(data, a, mlo, maxDepth) + a = mhi // i.e., quickSortByFreq(data, mhi, b) + } else { + quickSortByFreq(data, mhi, b, maxDepth) + b = mlo // i.e., quickSortByFreq(data, a, mlo) + } + } + if b-a > 1 { + // Do ShellSort pass with gap 6 + // It could be written in this simplified form cause b-a <= 12 + for i := a + 6; i < b; i++ { + if data[i].freq == data[i-6].freq && data[i].literal < data[i-6].literal || data[i].freq < data[i-6].freq { + data[i], data[i-6] = data[i-6], data[i] + } + } + insertionSortByFreq(data, a, b) + } +} + +// siftDownByFreq implements the heap property on data[lo, hi). +// first is an offset into the array where the root of the heap lies. +func siftDownByFreq(data []literalNode, lo, hi, first int) { + root := lo + for { + child := 2*root + 1 + if child >= hi { + break + } + if child+1 < hi && (data[first+child].freq == data[first+child+1].freq && data[first+child].literal < data[first+child+1].literal || data[first+child].freq < data[first+child+1].freq) { + child++ + } + if data[first+root].freq == data[first+child].freq && data[first+root].literal > data[first+child].literal || data[first+root].freq > data[first+child].freq { + return + } + data[first+root], data[first+child] = data[first+child], data[first+root] + root = child + } +} +func doPivotByFreq(data []literalNode, lo, hi int) (midlo, midhi int) { + m := int(uint(lo+hi) >> 1) // Written like this to avoid integer overflow. + if hi-lo > 40 { + // Tukey's ``Ninther,'' median of three medians of three. + s := (hi - lo) / 8 + medianOfThreeSortByFreq(data, lo, lo+s, lo+2*s) + medianOfThreeSortByFreq(data, m, m-s, m+s) + medianOfThreeSortByFreq(data, hi-1, hi-1-s, hi-1-2*s) + } + medianOfThreeSortByFreq(data, lo, m, hi-1) + + // Invariants are: + // data[lo] = pivot (set up by ChoosePivot) + // data[lo < i < a] < pivot + // data[a <= i < b] <= pivot + // data[b <= i < c] unexamined + // data[c <= i < hi-1] > pivot + // data[hi-1] >= pivot + pivot := lo + a, c := lo+1, hi-1 + + for ; a < c && (data[a].freq == data[pivot].freq && data[a].literal < data[pivot].literal || data[a].freq < data[pivot].freq); a++ { + } + b := a + for { + for ; b < c && (data[pivot].freq == data[b].freq && data[pivot].literal > data[b].literal || data[pivot].freq > data[b].freq); b++ { // data[b] <= pivot + } + for ; b < c && (data[pivot].freq == data[c-1].freq && data[pivot].literal < data[c-1].literal || data[pivot].freq < data[c-1].freq); c-- { // data[c-1] > pivot + } + if b >= c { + break + } + // data[b] > pivot; data[c-1] <= pivot + data[b], data[c-1] = data[c-1], data[b] + b++ + c-- + } + // If hi-c<3 then there are duplicates (by property of median of nine). + // Let's be a bit more conservative, and set border to 5. + protect := hi-c < 5 + if !protect && hi-c < (hi-lo)/4 { + // Lets test some points for equality to pivot + dups := 0 + if data[pivot].freq == data[hi-1].freq && data[pivot].literal > data[hi-1].literal || data[pivot].freq > data[hi-1].freq { // data[hi-1] = pivot + data[c], data[hi-1] = data[hi-1], data[c] + c++ + dups++ + } + if data[b-1].freq == data[pivot].freq && data[b-1].literal > data[pivot].literal || data[b-1].freq > data[pivot].freq { // data[b-1] = pivot + b-- + dups++ + } + // m-lo = (hi-lo)/2 > 6 + // b-lo > (hi-lo)*3/4-1 > 8 + // ==> m < b ==> data[m] <= pivot + if data[m].freq == data[pivot].freq && data[m].literal > data[pivot].literal || data[m].freq > data[pivot].freq { // data[m] = pivot + data[m], data[b-1] = data[b-1], data[m] + b-- + dups++ + } + // if at least 2 points are equal to pivot, assume skewed distribution + protect = dups > 1 + } + if protect { + // Protect against a lot of duplicates + // Add invariant: + // data[a <= i < b] unexamined + // data[b <= i < c] = pivot + for { + for ; a < b && (data[b-1].freq == data[pivot].freq && data[b-1].literal > data[pivot].literal || data[b-1].freq > data[pivot].freq); b-- { // data[b] == pivot + } + for ; a < b && (data[a].freq == data[pivot].freq && data[a].literal < data[pivot].literal || data[a].freq < data[pivot].freq); a++ { // data[a] < pivot + } + if a >= b { + break + } + // data[a] == pivot; data[b-1] < pivot + data[a], data[b-1] = data[b-1], data[a] + a++ + b-- + } + } + // Swap pivot into middle + data[pivot], data[b-1] = data[b-1], data[pivot] + return b - 1, c +} + +// Insertion sort +func insertionSortByFreq(data []literalNode, a, b int) { + for i := a + 1; i < b; i++ { + for j := i; j > a && (data[j].freq == data[j-1].freq && data[j].literal < data[j-1].literal || data[j].freq < data[j-1].freq); j-- { + data[j], data[j-1] = data[j-1], data[j] + } + } +} + +// quickSortByFreq, loosely following Bentley and McIlroy, +// ``Engineering a Sort Function,'' SP&E November 1993. + +// medianOfThreeSortByFreq moves the median of the three values data[m0], data[m1], data[m2] into data[m1]. +func medianOfThreeSortByFreq(data []literalNode, m1, m0, m2 int) { + // sort 3 elements + if data[m1].freq == data[m0].freq && data[m1].literal < data[m0].literal || data[m1].freq < data[m0].freq { + data[m1], data[m0] = data[m0], data[m1] + } + // data[m0] <= data[m1] + if data[m2].freq == data[m1].freq && data[m2].literal < data[m1].literal || data[m2].freq < data[m1].freq { + data[m2], data[m1] = data[m1], data[m2] + // data[m0] <= data[m2] && data[m1] < data[m2] + if data[m1].freq == data[m0].freq && data[m1].literal < data[m0].literal || data[m1].freq < data[m0].freq { + data[m1], data[m0] = data[m0], data[m1] + } + } + // now data[m0] <= data[m1] <= data[m2] +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/huffman_sortByLiteral.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/huffman_sortByLiteral.go new file mode 100644 index 000000000000..93f1aea109e1 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/huffman_sortByLiteral.go @@ -0,0 +1,201 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package flate + +// Sort sorts data. +// It makes one call to data.Len to determine n, and O(n*log(n)) calls to +// data.Less and data.Swap. The sort is not guaranteed to be stable. +func sortByLiteral(data []literalNode) { + n := len(data) + quickSort(data, 0, n, maxDepth(n)) +} + +func quickSort(data []literalNode, a, b, maxDepth int) { + for b-a > 12 { // Use ShellSort for slices <= 12 elements + if maxDepth == 0 { + heapSort(data, a, b) + return + } + maxDepth-- + mlo, mhi := doPivot(data, a, b) + // Avoiding recursion on the larger subproblem guarantees + // a stack depth of at most lg(b-a). + if mlo-a < b-mhi { + quickSort(data, a, mlo, maxDepth) + a = mhi // i.e., quickSort(data, mhi, b) + } else { + quickSort(data, mhi, b, maxDepth) + b = mlo // i.e., quickSort(data, a, mlo) + } + } + if b-a > 1 { + // Do ShellSort pass with gap 6 + // It could be written in this simplified form cause b-a <= 12 + for i := a + 6; i < b; i++ { + if data[i].literal < data[i-6].literal { + data[i], data[i-6] = data[i-6], data[i] + } + } + insertionSort(data, a, b) + } +} +func heapSort(data []literalNode, a, b int) { + first := a + lo := 0 + hi := b - a + + // Build heap with greatest element at top. + for i := (hi - 1) / 2; i >= 0; i-- { + siftDown(data, i, hi, first) + } + + // Pop elements, largest first, into end of data. + for i := hi - 1; i >= 0; i-- { + data[first], data[first+i] = data[first+i], data[first] + siftDown(data, lo, i, first) + } +} + +// siftDown implements the heap property on data[lo, hi). +// first is an offset into the array where the root of the heap lies. +func siftDown(data []literalNode, lo, hi, first int) { + root := lo + for { + child := 2*root + 1 + if child >= hi { + break + } + if child+1 < hi && data[first+child].literal < data[first+child+1].literal { + child++ + } + if data[first+root].literal > data[first+child].literal { + return + } + data[first+root], data[first+child] = data[first+child], data[first+root] + root = child + } +} +func doPivot(data []literalNode, lo, hi int) (midlo, midhi int) { + m := int(uint(lo+hi) >> 1) // Written like this to avoid integer overflow. + if hi-lo > 40 { + // Tukey's ``Ninther,'' median of three medians of three. + s := (hi - lo) / 8 + medianOfThree(data, lo, lo+s, lo+2*s) + medianOfThree(data, m, m-s, m+s) + medianOfThree(data, hi-1, hi-1-s, hi-1-2*s) + } + medianOfThree(data, lo, m, hi-1) + + // Invariants are: + // data[lo] = pivot (set up by ChoosePivot) + // data[lo < i < a] < pivot + // data[a <= i < b] <= pivot + // data[b <= i < c] unexamined + // data[c <= i < hi-1] > pivot + // data[hi-1] >= pivot + pivot := lo + a, c := lo+1, hi-1 + + for ; a < c && data[a].literal < data[pivot].literal; a++ { + } + b := a + for { + for ; b < c && data[pivot].literal > data[b].literal; b++ { // data[b] <= pivot + } + for ; b < c && data[pivot].literal < data[c-1].literal; c-- { // data[c-1] > pivot + } + if b >= c { + break + } + // data[b] > pivot; data[c-1] <= pivot + data[b], data[c-1] = data[c-1], data[b] + b++ + c-- + } + // If hi-c<3 then there are duplicates (by property of median of nine). + // Let's be a bit more conservative, and set border to 5. + protect := hi-c < 5 + if !protect && hi-c < (hi-lo)/4 { + // Lets test some points for equality to pivot + dups := 0 + if data[pivot].literal > data[hi-1].literal { // data[hi-1] = pivot + data[c], data[hi-1] = data[hi-1], data[c] + c++ + dups++ + } + if data[b-1].literal > data[pivot].literal { // data[b-1] = pivot + b-- + dups++ + } + // m-lo = (hi-lo)/2 > 6 + // b-lo > (hi-lo)*3/4-1 > 8 + // ==> m < b ==> data[m] <= pivot + if data[m].literal > data[pivot].literal { // data[m] = pivot + data[m], data[b-1] = data[b-1], data[m] + b-- + dups++ + } + // if at least 2 points are equal to pivot, assume skewed distribution + protect = dups > 1 + } + if protect { + // Protect against a lot of duplicates + // Add invariant: + // data[a <= i < b] unexamined + // data[b <= i < c] = pivot + for { + for ; a < b && data[b-1].literal > data[pivot].literal; b-- { // data[b] == pivot + } + for ; a < b && data[a].literal < data[pivot].literal; a++ { // data[a] < pivot + } + if a >= b { + break + } + // data[a] == pivot; data[b-1] < pivot + data[a], data[b-1] = data[b-1], data[a] + a++ + b-- + } + } + // Swap pivot into middle + data[pivot], data[b-1] = data[b-1], data[pivot] + return b - 1, c +} + +// Insertion sort +func insertionSort(data []literalNode, a, b int) { + for i := a + 1; i < b; i++ { + for j := i; j > a && data[j].literal < data[j-1].literal; j-- { + data[j], data[j-1] = data[j-1], data[j] + } + } +} + +// maxDepth returns a threshold at which quicksort should switch +// to heapsort. It returns 2*ceil(lg(n+1)). +func maxDepth(n int) int { + var depth int + for i := n; i > 0; i >>= 1 { + depth++ + } + return depth * 2 +} + +// medianOfThree moves the median of the three values data[m0], data[m1], data[m2] into data[m1]. +func medianOfThree(data []literalNode, m1, m0, m2 int) { + // sort 3 elements + if data[m1].literal < data[m0].literal { + data[m1], data[m0] = data[m0], data[m1] + } + // data[m0] <= data[m1] + if data[m2].literal < data[m1].literal { + data[m2], data[m1] = data[m1], data[m2] + // data[m0] <= data[m2] && data[m1] < data[m2] + if data[m1].literal < data[m0].literal { + data[m1], data[m0] = data[m0], data[m1] + } + } + // now data[m0] <= data[m1] <= data[m2] +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/inflate.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/inflate.go new file mode 100644 index 000000000000..7f175a4ec26e --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/inflate.go @@ -0,0 +1,1000 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package flate implements the DEFLATE compressed data format, described in +// RFC 1951. The gzip and zlib packages implement access to DEFLATE-based file +// formats. +package flate + +import ( + "bufio" + "fmt" + "io" + "math/bits" + "strconv" + "sync" +) + +const ( + maxCodeLen = 16 // max length of Huffman code + maxCodeLenMask = 15 // mask for max length of Huffman code + // The next three numbers come from the RFC section 3.2.7, with the + // additional proviso in section 3.2.5 which implies that distance codes + // 30 and 31 should never occur in compressed data. + maxNumLit = 286 + maxNumDist = 30 + numCodes = 19 // number of codes in Huffman meta-code + + debugDecode = false +) + +// Initialize the fixedHuffmanDecoder only once upon first use. +var fixedOnce sync.Once +var fixedHuffmanDecoder huffmanDecoder + +// A CorruptInputError reports the presence of corrupt input at a given offset. +type CorruptInputError int64 + +func (e CorruptInputError) Error() string { + return "flate: corrupt input before offset " + strconv.FormatInt(int64(e), 10) +} + +// An InternalError reports an error in the flate code itself. +type InternalError string + +func (e InternalError) Error() string { return "flate: internal error: " + string(e) } + +// A ReadError reports an error encountered while reading input. +// +// Deprecated: No longer returned. +type ReadError struct { + Offset int64 // byte offset where error occurred + Err error // error returned by underlying Read +} + +func (e *ReadError) Error() string { + return "flate: read error at offset " + strconv.FormatInt(e.Offset, 10) + ": " + e.Err.Error() +} + +// A WriteError reports an error encountered while writing output. +// +// Deprecated: No longer returned. +type WriteError struct { + Offset int64 // byte offset where error occurred + Err error // error returned by underlying Write +} + +func (e *WriteError) Error() string { + return "flate: write error at offset " + strconv.FormatInt(e.Offset, 10) + ": " + e.Err.Error() +} + +// Resetter resets a ReadCloser returned by NewReader or NewReaderDict to +// to switch to a new underlying Reader. This permits reusing a ReadCloser +// instead of allocating a new one. +type Resetter interface { + // Reset discards any buffered data and resets the Resetter as if it was + // newly initialized with the given reader. + Reset(r io.Reader, dict []byte) error +} + +// The data structure for decoding Huffman tables is based on that of +// zlib. There is a lookup table of a fixed bit width (huffmanChunkBits), +// For codes smaller than the table width, there are multiple entries +// (each combination of trailing bits has the same value). For codes +// larger than the table width, the table contains a link to an overflow +// table. The width of each entry in the link table is the maximum code +// size minus the chunk width. +// +// Note that you can do a lookup in the table even without all bits +// filled. Since the extra bits are zero, and the DEFLATE Huffman codes +// have the property that shorter codes come before longer ones, the +// bit length estimate in the result is a lower bound on the actual +// number of bits. +// +// See the following: +// http://www.gzip.org/algorithm.txt + +// chunk & 15 is number of bits +// chunk >> 4 is value, including table link + +const ( + huffmanChunkBits = 9 + huffmanNumChunks = 1 << huffmanChunkBits + huffmanCountMask = 15 + huffmanValueShift = 4 +) + +type huffmanDecoder struct { + maxRead int // the maximum number of bits we can read and not overread + chunks *[huffmanNumChunks]uint16 // chunks as described above + links [][]uint16 // overflow links + linkMask uint32 // mask the width of the link table +} + +// Initialize Huffman decoding tables from array of code lengths. +// Following this function, h is guaranteed to be initialized into a complete +// tree (i.e., neither over-subscribed nor under-subscribed). The exception is a +// degenerate case where the tree has only a single symbol with length 1. Empty +// trees are permitted. +func (h *huffmanDecoder) init(lengths []int) bool { + // Sanity enables additional runtime tests during Huffman + // table construction. It's intended to be used during + // development to supplement the currently ad-hoc unit tests. + const sanity = false + + if h.chunks == nil { + h.chunks = &[huffmanNumChunks]uint16{} + } + if h.maxRead != 0 { + *h = huffmanDecoder{chunks: h.chunks, links: h.links} + } + + // Count number of codes of each length, + // compute maxRead and max length. + var count [maxCodeLen]int + var min, max int + for _, n := range lengths { + if n == 0 { + continue + } + if min == 0 || n < min { + min = n + } + if n > max { + max = n + } + count[n&maxCodeLenMask]++ + } + + // Empty tree. The decompressor.huffSym function will fail later if the tree + // is used. Technically, an empty tree is only valid for the HDIST tree and + // not the HCLEN and HLIT tree. However, a stream with an empty HCLEN tree + // is guaranteed to fail since it will attempt to use the tree to decode the + // codes for the HLIT and HDIST trees. Similarly, an empty HLIT tree is + // guaranteed to fail later since the compressed data section must be + // composed of at least one symbol (the end-of-block marker). + if max == 0 { + return true + } + + code := 0 + var nextcode [maxCodeLen]int + for i := min; i <= max; i++ { + code <<= 1 + nextcode[i&maxCodeLenMask] = code + code += count[i&maxCodeLenMask] + } + + // Check that the coding is complete (i.e., that we've + // assigned all 2-to-the-max possible bit sequences). + // Exception: To be compatible with zlib, we also need to + // accept degenerate single-code codings. See also + // TestDegenerateHuffmanCoding. + if code != 1< huffmanChunkBits { + numLinks := 1 << (uint(max) - huffmanChunkBits) + h.linkMask = uint32(numLinks - 1) + + // create link tables + link := nextcode[huffmanChunkBits+1] >> 1 + if cap(h.links) < huffmanNumChunks-link { + h.links = make([][]uint16, huffmanNumChunks-link) + } else { + h.links = h.links[:huffmanNumChunks-link] + } + for j := uint(link); j < huffmanNumChunks; j++ { + reverse := int(bits.Reverse16(uint16(j))) + reverse >>= uint(16 - huffmanChunkBits) + off := j - uint(link) + if sanity && h.chunks[reverse] != 0 { + panic("impossible: overwriting existing chunk") + } + h.chunks[reverse] = uint16(off<>= uint(16 - n) + if n <= huffmanChunkBits { + for off := reverse; off < len(h.chunks); off += 1 << uint(n) { + // We should never need to overwrite + // an existing chunk. Also, 0 is + // never a valid chunk, because the + // lower 4 "count" bits should be + // between 1 and 15. + if sanity && h.chunks[off] != 0 { + panic("impossible: overwriting existing chunk") + } + h.chunks[off] = chunk + } + } else { + j := reverse & (huffmanNumChunks - 1) + if sanity && h.chunks[j]&huffmanCountMask != huffmanChunkBits+1 { + // Longer codes should have been + // associated with a link table above. + panic("impossible: not an indirect chunk") + } + value := h.chunks[j] >> huffmanValueShift + linktab := h.links[value] + reverse >>= huffmanChunkBits + for off := reverse; off < len(linktab); off += 1 << uint(n-huffmanChunkBits) { + if sanity && linktab[off] != 0 { + panic("impossible: overwriting existing chunk") + } + linktab[off] = chunk + } + } + } + + if sanity { + // Above we've sanity checked that we never overwrote + // an existing entry. Here we additionally check that + // we filled the tables completely. + for i, chunk := range h.chunks { + if chunk == 0 { + // As an exception, in the degenerate + // single-code case, we allow odd + // chunks to be missing. + if code == 1 && i%2 == 1 { + continue + } + panic("impossible: missing chunk") + } + } + for _, linktab := range h.links { + for _, chunk := range linktab { + if chunk == 0 { + panic("impossible: missing chunk") + } + } + } + } + + return true +} + +// The actual read interface needed by NewReader. +// If the passed in io.Reader does not also have ReadByte, +// the NewReader will introduce its own buffering. +type Reader interface { + io.Reader + io.ByteReader +} + +// Decompress state. +type decompressor struct { + // Input source. + r Reader + roffset int64 + + // Input bits, in top of b. + b uint32 + nb uint + + // Huffman decoders for literal/length, distance. + h1, h2 huffmanDecoder + + // Length arrays used to define Huffman codes. + bits *[maxNumLit + maxNumDist]int + codebits *[numCodes]int + + // Output history, buffer. + dict dictDecoder + + // Temporary buffer (avoids repeated allocation). + buf [4]byte + + // Next step in the decompression, + // and decompression state. + step func(*decompressor) + stepState int + final bool + err error + toRead []byte + hl, hd *huffmanDecoder + copyLen int + copyDist int +} + +func (f *decompressor) nextBlock() { + for f.nb < 1+2 { + if f.err = f.moreBits(); f.err != nil { + return + } + } + f.final = f.b&1 == 1 + f.b >>= 1 + typ := f.b & 3 + f.b >>= 2 + f.nb -= 1 + 2 + switch typ { + case 0: + f.dataBlock() + case 1: + // compressed, fixed Huffman tables + f.hl = &fixedHuffmanDecoder + f.hd = nil + f.huffmanBlockDecoder()() + case 2: + // compressed, dynamic Huffman tables + if f.err = f.readHuffman(); f.err != nil { + break + } + f.hl = &f.h1 + f.hd = &f.h2 + f.huffmanBlockDecoder()() + default: + // 3 is reserved. + if debugDecode { + fmt.Println("reserved data block encountered") + } + f.err = CorruptInputError(f.roffset) + } +} + +func (f *decompressor) Read(b []byte) (int, error) { + for { + if len(f.toRead) > 0 { + n := copy(b, f.toRead) + f.toRead = f.toRead[n:] + if len(f.toRead) == 0 { + return n, f.err + } + return n, nil + } + if f.err != nil { + return 0, f.err + } + f.step(f) + if f.err != nil && len(f.toRead) == 0 { + f.toRead = f.dict.readFlush() // Flush what's left in case of error + } + } +} + +// Support the io.WriteTo interface for io.Copy and friends. +func (f *decompressor) WriteTo(w io.Writer) (int64, error) { + total := int64(0) + flushed := false + for { + if len(f.toRead) > 0 { + n, err := w.Write(f.toRead) + total += int64(n) + if err != nil { + f.err = err + return total, err + } + if n != len(f.toRead) { + return total, io.ErrShortWrite + } + f.toRead = f.toRead[:0] + } + if f.err != nil && flushed { + if f.err == io.EOF { + return total, nil + } + return total, f.err + } + if f.err == nil { + f.step(f) + } + if len(f.toRead) == 0 && f.err != nil && !flushed { + f.toRead = f.dict.readFlush() // Flush what's left in case of error + flushed = true + } + } +} + +func (f *decompressor) Close() error { + if f.err == io.EOF { + return nil + } + return f.err +} + +// RFC 1951 section 3.2.7. +// Compression with dynamic Huffman codes + +var codeOrder = [...]int{16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15} + +func (f *decompressor) readHuffman() error { + // HLIT[5], HDIST[5], HCLEN[4]. + for f.nb < 5+5+4 { + if err := f.moreBits(); err != nil { + return err + } + } + nlit := int(f.b&0x1F) + 257 + if nlit > maxNumLit { + if debugDecode { + fmt.Println("nlit > maxNumLit", nlit) + } + return CorruptInputError(f.roffset) + } + f.b >>= 5 + ndist := int(f.b&0x1F) + 1 + if ndist > maxNumDist { + if debugDecode { + fmt.Println("ndist > maxNumDist", ndist) + } + return CorruptInputError(f.roffset) + } + f.b >>= 5 + nclen := int(f.b&0xF) + 4 + // numCodes is 19, so nclen is always valid. + f.b >>= 4 + f.nb -= 5 + 5 + 4 + + // (HCLEN+4)*3 bits: code lengths in the magic codeOrder order. + for i := 0; i < nclen; i++ { + for f.nb < 3 { + if err := f.moreBits(); err != nil { + return err + } + } + f.codebits[codeOrder[i]] = int(f.b & 0x7) + f.b >>= 3 + f.nb -= 3 + } + for i := nclen; i < len(codeOrder); i++ { + f.codebits[codeOrder[i]] = 0 + } + if !f.h1.init(f.codebits[0:]) { + if debugDecode { + fmt.Println("init codebits failed") + } + return CorruptInputError(f.roffset) + } + + // HLIT + 257 code lengths, HDIST + 1 code lengths, + // using the code length Huffman code. + for i, n := 0, nlit+ndist; i < n; { + x, err := f.huffSym(&f.h1) + if err != nil { + return err + } + if x < 16 { + // Actual length. + f.bits[i] = x + i++ + continue + } + // Repeat previous length or zero. + var rep int + var nb uint + var b int + switch x { + default: + return InternalError("unexpected length code") + case 16: + rep = 3 + nb = 2 + if i == 0 { + if debugDecode { + fmt.Println("i==0") + } + return CorruptInputError(f.roffset) + } + b = f.bits[i-1] + case 17: + rep = 3 + nb = 3 + b = 0 + case 18: + rep = 11 + nb = 7 + b = 0 + } + for f.nb < nb { + if err := f.moreBits(); err != nil { + if debugDecode { + fmt.Println("morebits:", err) + } + return err + } + } + rep += int(f.b & uint32(1<>= nb + f.nb -= nb + if i+rep > n { + if debugDecode { + fmt.Println("i+rep > n", i, rep, n) + } + return CorruptInputError(f.roffset) + } + for j := 0; j < rep; j++ { + f.bits[i] = b + i++ + } + } + + if !f.h1.init(f.bits[0:nlit]) || !f.h2.init(f.bits[nlit:nlit+ndist]) { + if debugDecode { + fmt.Println("init2 failed") + } + return CorruptInputError(f.roffset) + } + + // As an optimization, we can initialize the maxRead bits to read at a time + // for the HLIT tree to the length of the EOB marker since we know that + // every block must terminate with one. This preserves the property that + // we never read any extra bytes after the end of the DEFLATE stream. + if f.h1.maxRead < f.bits[endBlockMarker] { + f.h1.maxRead = f.bits[endBlockMarker] + } + if !f.final { + // If not the final block, the smallest block possible is + // a predefined table, BTYPE=01, with a single EOB marker. + // This will take up 3 + 7 bits. + f.h1.maxRead += 10 + } + + return nil +} + +// Decode a single Huffman block from f. +// hl and hd are the Huffman states for the lit/length values +// and the distance values, respectively. If hd == nil, using the +// fixed distance encoding associated with fixed Huffman blocks. +func (f *decompressor) huffmanBlockGeneric() { + const ( + stateInit = iota // Zero value must be stateInit + stateDict + ) + + switch f.stepState { + case stateInit: + goto readLiteral + case stateDict: + goto copyHistory + } + +readLiteral: + // Read literal and/or (length, distance) according to RFC section 3.2.3. + { + var v int + { + // Inlined v, err := f.huffSym(f.hl) + // Since a huffmanDecoder can be empty or be composed of a degenerate tree + // with single element, huffSym must error on these two edge cases. In both + // cases, the chunks slice will be 0 for the invalid sequence, leading it + // satisfy the n == 0 check below. + n := uint(f.hl.maxRead) + // Optimization. Compiler isn't smart enough to keep f.b,f.nb in registers, + // but is smart enough to keep local variables in registers, so use nb and b, + // inline call to moreBits and reassign b,nb back to f on return. + nb, b := f.nb, f.b + for { + for nb < n { + c, err := f.r.ReadByte() + if err != nil { + f.b = b + f.nb = nb + f.err = noEOF(err) + return + } + f.roffset++ + b |= uint32(c) << (nb & 31) + nb += 8 + } + chunk := f.hl.chunks[b&(huffmanNumChunks-1)] + n = uint(chunk & huffmanCountMask) + if n > huffmanChunkBits { + chunk = f.hl.links[chunk>>huffmanValueShift][(b>>huffmanChunkBits)&f.hl.linkMask] + n = uint(chunk & huffmanCountMask) + } + if n <= nb { + if n == 0 { + f.b = b + f.nb = nb + if debugDecode { + fmt.Println("huffsym: n==0") + } + f.err = CorruptInputError(f.roffset) + return + } + f.b = b >> (n & 31) + f.nb = nb - n + v = int(chunk >> huffmanValueShift) + break + } + } + } + + var n uint // number of bits extra + var length int + var err error + switch { + case v < 256: + f.dict.writeByte(byte(v)) + if f.dict.availWrite() == 0 { + f.toRead = f.dict.readFlush() + f.step = (*decompressor).huffmanBlockGeneric + f.stepState = stateInit + return + } + goto readLiteral + case v == 256: + f.finishBlock() + return + // otherwise, reference to older data + case v < 265: + length = v - (257 - 3) + n = 0 + case v < 269: + length = v*2 - (265*2 - 11) + n = 1 + case v < 273: + length = v*4 - (269*4 - 19) + n = 2 + case v < 277: + length = v*8 - (273*8 - 35) + n = 3 + case v < 281: + length = v*16 - (277*16 - 67) + n = 4 + case v < 285: + length = v*32 - (281*32 - 131) + n = 5 + case v < maxNumLit: + length = 258 + n = 0 + default: + if debugDecode { + fmt.Println(v, ">= maxNumLit") + } + f.err = CorruptInputError(f.roffset) + return + } + if n > 0 { + for f.nb < n { + if err = f.moreBits(); err != nil { + if debugDecode { + fmt.Println("morebits n>0:", err) + } + f.err = err + return + } + } + length += int(f.b & uint32(1<>= n + f.nb -= n + } + + var dist int + if f.hd == nil { + for f.nb < 5 { + if err = f.moreBits(); err != nil { + if debugDecode { + fmt.Println("morebits f.nb<5:", err) + } + f.err = err + return + } + } + dist = int(bits.Reverse8(uint8(f.b & 0x1F << 3))) + f.b >>= 5 + f.nb -= 5 + } else { + if dist, err = f.huffSym(f.hd); err != nil { + if debugDecode { + fmt.Println("huffsym:", err) + } + f.err = err + return + } + } + + switch { + case dist < 4: + dist++ + case dist < maxNumDist: + nb := uint(dist-2) >> 1 + // have 1 bit in bottom of dist, need nb more. + extra := (dist & 1) << nb + for f.nb < nb { + if err = f.moreBits(); err != nil { + if debugDecode { + fmt.Println("morebits f.nb>= nb + f.nb -= nb + dist = 1<<(nb+1) + 1 + extra + default: + if debugDecode { + fmt.Println("dist too big:", dist, maxNumDist) + } + f.err = CorruptInputError(f.roffset) + return + } + + // No check on length; encoding can be prescient. + if dist > f.dict.histSize() { + if debugDecode { + fmt.Println("dist > f.dict.histSize():", dist, f.dict.histSize()) + } + f.err = CorruptInputError(f.roffset) + return + } + + f.copyLen, f.copyDist = length, dist + goto copyHistory + } + +copyHistory: + // Perform a backwards copy according to RFC section 3.2.3. + { + cnt := f.dict.tryWriteCopy(f.copyDist, f.copyLen) + if cnt == 0 { + cnt = f.dict.writeCopy(f.copyDist, f.copyLen) + } + f.copyLen -= cnt + + if f.dict.availWrite() == 0 || f.copyLen > 0 { + f.toRead = f.dict.readFlush() + f.step = (*decompressor).huffmanBlockGeneric // We need to continue this work + f.stepState = stateDict + return + } + goto readLiteral + } +} + +// Copy a single uncompressed data block from input to output. +func (f *decompressor) dataBlock() { + // Uncompressed. + // Discard current half-byte. + left := (f.nb) & 7 + f.nb -= left + f.b >>= left + + offBytes := f.nb >> 3 + // Unfilled values will be overwritten. + f.buf[0] = uint8(f.b) + f.buf[1] = uint8(f.b >> 8) + f.buf[2] = uint8(f.b >> 16) + f.buf[3] = uint8(f.b >> 24) + + f.roffset += int64(offBytes) + f.nb, f.b = 0, 0 + + // Length then ones-complement of length. + nr, err := io.ReadFull(f.r, f.buf[offBytes:4]) + f.roffset += int64(nr) + if err != nil { + f.err = noEOF(err) + return + } + n := uint16(f.buf[0]) | uint16(f.buf[1])<<8 + nn := uint16(f.buf[2]) | uint16(f.buf[3])<<8 + if nn != ^n { + if debugDecode { + ncomp := ^n + fmt.Println("uint16(nn) != uint16(^n)", nn, ncomp) + } + f.err = CorruptInputError(f.roffset) + return + } + + if n == 0 { + f.toRead = f.dict.readFlush() + f.finishBlock() + return + } + + f.copyLen = int(n) + f.copyData() +} + +// copyData copies f.copyLen bytes from the underlying reader into f.hist. +// It pauses for reads when f.hist is full. +func (f *decompressor) copyData() { + buf := f.dict.writeSlice() + if len(buf) > f.copyLen { + buf = buf[:f.copyLen] + } + + cnt, err := io.ReadFull(f.r, buf) + f.roffset += int64(cnt) + f.copyLen -= cnt + f.dict.writeMark(cnt) + if err != nil { + f.err = noEOF(err) + return + } + + if f.dict.availWrite() == 0 || f.copyLen > 0 { + f.toRead = f.dict.readFlush() + f.step = (*decompressor).copyData + return + } + f.finishBlock() +} + +func (f *decompressor) finishBlock() { + if f.final { + if f.dict.availRead() > 0 { + f.toRead = f.dict.readFlush() + } + f.err = io.EOF + } + f.step = (*decompressor).nextBlock +} + +// noEOF returns err, unless err == io.EOF, in which case it returns io.ErrUnexpectedEOF. +func noEOF(e error) error { + if e == io.EOF { + return io.ErrUnexpectedEOF + } + return e +} + +func (f *decompressor) moreBits() error { + c, err := f.r.ReadByte() + if err != nil { + return noEOF(err) + } + f.roffset++ + f.b |= uint32(c) << f.nb + f.nb += 8 + return nil +} + +// Read the next Huffman-encoded symbol from f according to h. +func (f *decompressor) huffSym(h *huffmanDecoder) (int, error) { + // Since a huffmanDecoder can be empty or be composed of a degenerate tree + // with single element, huffSym must error on these two edge cases. In both + // cases, the chunks slice will be 0 for the invalid sequence, leading it + // satisfy the n == 0 check below. + n := uint(h.maxRead) + // Optimization. Compiler isn't smart enough to keep f.b,f.nb in registers, + // but is smart enough to keep local variables in registers, so use nb and b, + // inline call to moreBits and reassign b,nb back to f on return. + nb, b := f.nb, f.b + for { + for nb < n { + c, err := f.r.ReadByte() + if err != nil { + f.b = b + f.nb = nb + return 0, noEOF(err) + } + f.roffset++ + b |= uint32(c) << (nb & 31) + nb += 8 + } + chunk := h.chunks[b&(huffmanNumChunks-1)] + n = uint(chunk & huffmanCountMask) + if n > huffmanChunkBits { + chunk = h.links[chunk>>huffmanValueShift][(b>>huffmanChunkBits)&h.linkMask] + n = uint(chunk & huffmanCountMask) + } + if n <= nb { + if n == 0 { + f.b = b + f.nb = nb + if debugDecode { + fmt.Println("huffsym: n==0") + } + f.err = CorruptInputError(f.roffset) + return 0, f.err + } + f.b = b >> (n & 31) + f.nb = nb - n + return int(chunk >> huffmanValueShift), nil + } + } +} + +func makeReader(r io.Reader) Reader { + if rr, ok := r.(Reader); ok { + return rr + } + return bufio.NewReader(r) +} + +func fixedHuffmanDecoderInit() { + fixedOnce.Do(func() { + // These come from the RFC section 3.2.6. + var bits [288]int + for i := 0; i < 144; i++ { + bits[i] = 8 + } + for i := 144; i < 256; i++ { + bits[i] = 9 + } + for i := 256; i < 280; i++ { + bits[i] = 7 + } + for i := 280; i < 288; i++ { + bits[i] = 8 + } + fixedHuffmanDecoder.init(bits[:]) + }) +} + +func (f *decompressor) Reset(r io.Reader, dict []byte) error { + *f = decompressor{ + r: makeReader(r), + bits: f.bits, + codebits: f.codebits, + h1: f.h1, + h2: f.h2, + dict: f.dict, + step: (*decompressor).nextBlock, + } + f.dict.init(maxMatchOffset, dict) + return nil +} + +// NewReader returns a new ReadCloser that can be used +// to read the uncompressed version of r. +// If r does not also implement io.ByteReader, +// the decompressor may read more data than necessary from r. +// It is the caller's responsibility to call Close on the ReadCloser +// when finished reading. +// +// The ReadCloser returned by NewReader also implements Resetter. +func NewReader(r io.Reader) io.ReadCloser { + fixedHuffmanDecoderInit() + + var f decompressor + f.r = makeReader(r) + f.bits = new([maxNumLit + maxNumDist]int) + f.codebits = new([numCodes]int) + f.step = (*decompressor).nextBlock + f.dict.init(maxMatchOffset, nil) + return &f +} + +// NewReaderDict is like NewReader but initializes the reader +// with a preset dictionary. The returned Reader behaves as if +// the uncompressed data stream started with the given dictionary, +// which has already been read. NewReaderDict is typically used +// to read data compressed by NewWriterDict. +// +// The ReadCloser returned by NewReader also implements Resetter. +func NewReaderDict(r io.Reader, dict []byte) io.ReadCloser { + fixedHuffmanDecoderInit() + + var f decompressor + f.r = makeReader(r) + f.bits = new([maxNumLit + maxNumDist]int) + f.codebits = new([numCodes]int) + f.step = (*decompressor).nextBlock + f.dict.init(maxMatchOffset, dict) + return &f +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/inflate_gen.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/inflate_gen.go new file mode 100644 index 000000000000..397dc1b1a134 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/inflate_gen.go @@ -0,0 +1,922 @@ +// Code generated by go generate gen_inflate.go. DO NOT EDIT. + +package flate + +import ( + "bufio" + "bytes" + "fmt" + "math/bits" + "strings" +) + +// Decode a single Huffman block from f. +// hl and hd are the Huffman states for the lit/length values +// and the distance values, respectively. If hd == nil, using the +// fixed distance encoding associated with fixed Huffman blocks. +func (f *decompressor) huffmanBytesBuffer() { + const ( + stateInit = iota // Zero value must be stateInit + stateDict + ) + fr := f.r.(*bytes.Buffer) + moreBits := func() error { + c, err := fr.ReadByte() + if err != nil { + return noEOF(err) + } + f.roffset++ + f.b |= uint32(c) << f.nb + f.nb += 8 + return nil + } + + switch f.stepState { + case stateInit: + goto readLiteral + case stateDict: + goto copyHistory + } + +readLiteral: + // Read literal and/or (length, distance) according to RFC section 3.2.3. + { + var v int + { + // Inlined v, err := f.huffSym(f.hl) + // Since a huffmanDecoder can be empty or be composed of a degenerate tree + // with single element, huffSym must error on these two edge cases. In both + // cases, the chunks slice will be 0 for the invalid sequence, leading it + // satisfy the n == 0 check below. + n := uint(f.hl.maxRead) + // Optimization. Compiler isn't smart enough to keep f.b,f.nb in registers, + // but is smart enough to keep local variables in registers, so use nb and b, + // inline call to moreBits and reassign b,nb back to f on return. + nb, b := f.nb, f.b + for { + for nb < n { + c, err := fr.ReadByte() + if err != nil { + f.b = b + f.nb = nb + f.err = noEOF(err) + return + } + f.roffset++ + b |= uint32(c) << (nb & 31) + nb += 8 + } + chunk := f.hl.chunks[b&(huffmanNumChunks-1)] + n = uint(chunk & huffmanCountMask) + if n > huffmanChunkBits { + chunk = f.hl.links[chunk>>huffmanValueShift][(b>>huffmanChunkBits)&f.hl.linkMask] + n = uint(chunk & huffmanCountMask) + } + if n <= nb { + if n == 0 { + f.b = b + f.nb = nb + if debugDecode { + fmt.Println("huffsym: n==0") + } + f.err = CorruptInputError(f.roffset) + return + } + f.b = b >> (n & 31) + f.nb = nb - n + v = int(chunk >> huffmanValueShift) + break + } + } + } + + var n uint // number of bits extra + var length int + var err error + switch { + case v < 256: + f.dict.writeByte(byte(v)) + if f.dict.availWrite() == 0 { + f.toRead = f.dict.readFlush() + f.step = (*decompressor).huffmanBytesBuffer + f.stepState = stateInit + return + } + goto readLiteral + case v == 256: + f.finishBlock() + return + // otherwise, reference to older data + case v < 265: + length = v - (257 - 3) + n = 0 + case v < 269: + length = v*2 - (265*2 - 11) + n = 1 + case v < 273: + length = v*4 - (269*4 - 19) + n = 2 + case v < 277: + length = v*8 - (273*8 - 35) + n = 3 + case v < 281: + length = v*16 - (277*16 - 67) + n = 4 + case v < 285: + length = v*32 - (281*32 - 131) + n = 5 + case v < maxNumLit: + length = 258 + n = 0 + default: + if debugDecode { + fmt.Println(v, ">= maxNumLit") + } + f.err = CorruptInputError(f.roffset) + return + } + if n > 0 { + for f.nb < n { + if err = moreBits(); err != nil { + if debugDecode { + fmt.Println("morebits n>0:", err) + } + f.err = err + return + } + } + length += int(f.b & uint32(1<>= n + f.nb -= n + } + + var dist int + if f.hd == nil { + for f.nb < 5 { + if err = moreBits(); err != nil { + if debugDecode { + fmt.Println("morebits f.nb<5:", err) + } + f.err = err + return + } + } + dist = int(bits.Reverse8(uint8(f.b & 0x1F << 3))) + f.b >>= 5 + f.nb -= 5 + } else { + if dist, err = f.huffSym(f.hd); err != nil { + if debugDecode { + fmt.Println("huffsym:", err) + } + f.err = err + return + } + } + + switch { + case dist < 4: + dist++ + case dist < maxNumDist: + nb := uint(dist-2) >> 1 + // have 1 bit in bottom of dist, need nb more. + extra := (dist & 1) << nb + for f.nb < nb { + if err = moreBits(); err != nil { + if debugDecode { + fmt.Println("morebits f.nb>= nb + f.nb -= nb + dist = 1<<(nb+1) + 1 + extra + default: + if debugDecode { + fmt.Println("dist too big:", dist, maxNumDist) + } + f.err = CorruptInputError(f.roffset) + return + } + + // No check on length; encoding can be prescient. + if dist > f.dict.histSize() { + if debugDecode { + fmt.Println("dist > f.dict.histSize():", dist, f.dict.histSize()) + } + f.err = CorruptInputError(f.roffset) + return + } + + f.copyLen, f.copyDist = length, dist + goto copyHistory + } + +copyHistory: + // Perform a backwards copy according to RFC section 3.2.3. + { + cnt := f.dict.tryWriteCopy(f.copyDist, f.copyLen) + if cnt == 0 { + cnt = f.dict.writeCopy(f.copyDist, f.copyLen) + } + f.copyLen -= cnt + + if f.dict.availWrite() == 0 || f.copyLen > 0 { + f.toRead = f.dict.readFlush() + f.step = (*decompressor).huffmanBytesBuffer // We need to continue this work + f.stepState = stateDict + return + } + goto readLiteral + } +} + +// Decode a single Huffman block from f. +// hl and hd are the Huffman states for the lit/length values +// and the distance values, respectively. If hd == nil, using the +// fixed distance encoding associated with fixed Huffman blocks. +func (f *decompressor) huffmanBytesReader() { + const ( + stateInit = iota // Zero value must be stateInit + stateDict + ) + fr := f.r.(*bytes.Reader) + moreBits := func() error { + c, err := fr.ReadByte() + if err != nil { + return noEOF(err) + } + f.roffset++ + f.b |= uint32(c) << f.nb + f.nb += 8 + return nil + } + + switch f.stepState { + case stateInit: + goto readLiteral + case stateDict: + goto copyHistory + } + +readLiteral: + // Read literal and/or (length, distance) according to RFC section 3.2.3. + { + var v int + { + // Inlined v, err := f.huffSym(f.hl) + // Since a huffmanDecoder can be empty or be composed of a degenerate tree + // with single element, huffSym must error on these two edge cases. In both + // cases, the chunks slice will be 0 for the invalid sequence, leading it + // satisfy the n == 0 check below. + n := uint(f.hl.maxRead) + // Optimization. Compiler isn't smart enough to keep f.b,f.nb in registers, + // but is smart enough to keep local variables in registers, so use nb and b, + // inline call to moreBits and reassign b,nb back to f on return. + nb, b := f.nb, f.b + for { + for nb < n { + c, err := fr.ReadByte() + if err != nil { + f.b = b + f.nb = nb + f.err = noEOF(err) + return + } + f.roffset++ + b |= uint32(c) << (nb & 31) + nb += 8 + } + chunk := f.hl.chunks[b&(huffmanNumChunks-1)] + n = uint(chunk & huffmanCountMask) + if n > huffmanChunkBits { + chunk = f.hl.links[chunk>>huffmanValueShift][(b>>huffmanChunkBits)&f.hl.linkMask] + n = uint(chunk & huffmanCountMask) + } + if n <= nb { + if n == 0 { + f.b = b + f.nb = nb + if debugDecode { + fmt.Println("huffsym: n==0") + } + f.err = CorruptInputError(f.roffset) + return + } + f.b = b >> (n & 31) + f.nb = nb - n + v = int(chunk >> huffmanValueShift) + break + } + } + } + + var n uint // number of bits extra + var length int + var err error + switch { + case v < 256: + f.dict.writeByte(byte(v)) + if f.dict.availWrite() == 0 { + f.toRead = f.dict.readFlush() + f.step = (*decompressor).huffmanBytesReader + f.stepState = stateInit + return + } + goto readLiteral + case v == 256: + f.finishBlock() + return + // otherwise, reference to older data + case v < 265: + length = v - (257 - 3) + n = 0 + case v < 269: + length = v*2 - (265*2 - 11) + n = 1 + case v < 273: + length = v*4 - (269*4 - 19) + n = 2 + case v < 277: + length = v*8 - (273*8 - 35) + n = 3 + case v < 281: + length = v*16 - (277*16 - 67) + n = 4 + case v < 285: + length = v*32 - (281*32 - 131) + n = 5 + case v < maxNumLit: + length = 258 + n = 0 + default: + if debugDecode { + fmt.Println(v, ">= maxNumLit") + } + f.err = CorruptInputError(f.roffset) + return + } + if n > 0 { + for f.nb < n { + if err = moreBits(); err != nil { + if debugDecode { + fmt.Println("morebits n>0:", err) + } + f.err = err + return + } + } + length += int(f.b & uint32(1<>= n + f.nb -= n + } + + var dist int + if f.hd == nil { + for f.nb < 5 { + if err = moreBits(); err != nil { + if debugDecode { + fmt.Println("morebits f.nb<5:", err) + } + f.err = err + return + } + } + dist = int(bits.Reverse8(uint8(f.b & 0x1F << 3))) + f.b >>= 5 + f.nb -= 5 + } else { + if dist, err = f.huffSym(f.hd); err != nil { + if debugDecode { + fmt.Println("huffsym:", err) + } + f.err = err + return + } + } + + switch { + case dist < 4: + dist++ + case dist < maxNumDist: + nb := uint(dist-2) >> 1 + // have 1 bit in bottom of dist, need nb more. + extra := (dist & 1) << nb + for f.nb < nb { + if err = moreBits(); err != nil { + if debugDecode { + fmt.Println("morebits f.nb>= nb + f.nb -= nb + dist = 1<<(nb+1) + 1 + extra + default: + if debugDecode { + fmt.Println("dist too big:", dist, maxNumDist) + } + f.err = CorruptInputError(f.roffset) + return + } + + // No check on length; encoding can be prescient. + if dist > f.dict.histSize() { + if debugDecode { + fmt.Println("dist > f.dict.histSize():", dist, f.dict.histSize()) + } + f.err = CorruptInputError(f.roffset) + return + } + + f.copyLen, f.copyDist = length, dist + goto copyHistory + } + +copyHistory: + // Perform a backwards copy according to RFC section 3.2.3. + { + cnt := f.dict.tryWriteCopy(f.copyDist, f.copyLen) + if cnt == 0 { + cnt = f.dict.writeCopy(f.copyDist, f.copyLen) + } + f.copyLen -= cnt + + if f.dict.availWrite() == 0 || f.copyLen > 0 { + f.toRead = f.dict.readFlush() + f.step = (*decompressor).huffmanBytesReader // We need to continue this work + f.stepState = stateDict + return + } + goto readLiteral + } +} + +// Decode a single Huffman block from f. +// hl and hd are the Huffman states for the lit/length values +// and the distance values, respectively. If hd == nil, using the +// fixed distance encoding associated with fixed Huffman blocks. +func (f *decompressor) huffmanBufioReader() { + const ( + stateInit = iota // Zero value must be stateInit + stateDict + ) + fr := f.r.(*bufio.Reader) + moreBits := func() error { + c, err := fr.ReadByte() + if err != nil { + return noEOF(err) + } + f.roffset++ + f.b |= uint32(c) << f.nb + f.nb += 8 + return nil + } + + switch f.stepState { + case stateInit: + goto readLiteral + case stateDict: + goto copyHistory + } + +readLiteral: + // Read literal and/or (length, distance) according to RFC section 3.2.3. + { + var v int + { + // Inlined v, err := f.huffSym(f.hl) + // Since a huffmanDecoder can be empty or be composed of a degenerate tree + // with single element, huffSym must error on these two edge cases. In both + // cases, the chunks slice will be 0 for the invalid sequence, leading it + // satisfy the n == 0 check below. + n := uint(f.hl.maxRead) + // Optimization. Compiler isn't smart enough to keep f.b,f.nb in registers, + // but is smart enough to keep local variables in registers, so use nb and b, + // inline call to moreBits and reassign b,nb back to f on return. + nb, b := f.nb, f.b + for { + for nb < n { + c, err := fr.ReadByte() + if err != nil { + f.b = b + f.nb = nb + f.err = noEOF(err) + return + } + f.roffset++ + b |= uint32(c) << (nb & 31) + nb += 8 + } + chunk := f.hl.chunks[b&(huffmanNumChunks-1)] + n = uint(chunk & huffmanCountMask) + if n > huffmanChunkBits { + chunk = f.hl.links[chunk>>huffmanValueShift][(b>>huffmanChunkBits)&f.hl.linkMask] + n = uint(chunk & huffmanCountMask) + } + if n <= nb { + if n == 0 { + f.b = b + f.nb = nb + if debugDecode { + fmt.Println("huffsym: n==0") + } + f.err = CorruptInputError(f.roffset) + return + } + f.b = b >> (n & 31) + f.nb = nb - n + v = int(chunk >> huffmanValueShift) + break + } + } + } + + var n uint // number of bits extra + var length int + var err error + switch { + case v < 256: + f.dict.writeByte(byte(v)) + if f.dict.availWrite() == 0 { + f.toRead = f.dict.readFlush() + f.step = (*decompressor).huffmanBufioReader + f.stepState = stateInit + return + } + goto readLiteral + case v == 256: + f.finishBlock() + return + // otherwise, reference to older data + case v < 265: + length = v - (257 - 3) + n = 0 + case v < 269: + length = v*2 - (265*2 - 11) + n = 1 + case v < 273: + length = v*4 - (269*4 - 19) + n = 2 + case v < 277: + length = v*8 - (273*8 - 35) + n = 3 + case v < 281: + length = v*16 - (277*16 - 67) + n = 4 + case v < 285: + length = v*32 - (281*32 - 131) + n = 5 + case v < maxNumLit: + length = 258 + n = 0 + default: + if debugDecode { + fmt.Println(v, ">= maxNumLit") + } + f.err = CorruptInputError(f.roffset) + return + } + if n > 0 { + for f.nb < n { + if err = moreBits(); err != nil { + if debugDecode { + fmt.Println("morebits n>0:", err) + } + f.err = err + return + } + } + length += int(f.b & uint32(1<>= n + f.nb -= n + } + + var dist int + if f.hd == nil { + for f.nb < 5 { + if err = moreBits(); err != nil { + if debugDecode { + fmt.Println("morebits f.nb<5:", err) + } + f.err = err + return + } + } + dist = int(bits.Reverse8(uint8(f.b & 0x1F << 3))) + f.b >>= 5 + f.nb -= 5 + } else { + if dist, err = f.huffSym(f.hd); err != nil { + if debugDecode { + fmt.Println("huffsym:", err) + } + f.err = err + return + } + } + + switch { + case dist < 4: + dist++ + case dist < maxNumDist: + nb := uint(dist-2) >> 1 + // have 1 bit in bottom of dist, need nb more. + extra := (dist & 1) << nb + for f.nb < nb { + if err = moreBits(); err != nil { + if debugDecode { + fmt.Println("morebits f.nb>= nb + f.nb -= nb + dist = 1<<(nb+1) + 1 + extra + default: + if debugDecode { + fmt.Println("dist too big:", dist, maxNumDist) + } + f.err = CorruptInputError(f.roffset) + return + } + + // No check on length; encoding can be prescient. + if dist > f.dict.histSize() { + if debugDecode { + fmt.Println("dist > f.dict.histSize():", dist, f.dict.histSize()) + } + f.err = CorruptInputError(f.roffset) + return + } + + f.copyLen, f.copyDist = length, dist + goto copyHistory + } + +copyHistory: + // Perform a backwards copy according to RFC section 3.2.3. + { + cnt := f.dict.tryWriteCopy(f.copyDist, f.copyLen) + if cnt == 0 { + cnt = f.dict.writeCopy(f.copyDist, f.copyLen) + } + f.copyLen -= cnt + + if f.dict.availWrite() == 0 || f.copyLen > 0 { + f.toRead = f.dict.readFlush() + f.step = (*decompressor).huffmanBufioReader // We need to continue this work + f.stepState = stateDict + return + } + goto readLiteral + } +} + +// Decode a single Huffman block from f. +// hl and hd are the Huffman states for the lit/length values +// and the distance values, respectively. If hd == nil, using the +// fixed distance encoding associated with fixed Huffman blocks. +func (f *decompressor) huffmanStringsReader() { + const ( + stateInit = iota // Zero value must be stateInit + stateDict + ) + fr := f.r.(*strings.Reader) + moreBits := func() error { + c, err := fr.ReadByte() + if err != nil { + return noEOF(err) + } + f.roffset++ + f.b |= uint32(c) << f.nb + f.nb += 8 + return nil + } + + switch f.stepState { + case stateInit: + goto readLiteral + case stateDict: + goto copyHistory + } + +readLiteral: + // Read literal and/or (length, distance) according to RFC section 3.2.3. + { + var v int + { + // Inlined v, err := f.huffSym(f.hl) + // Since a huffmanDecoder can be empty or be composed of a degenerate tree + // with single element, huffSym must error on these two edge cases. In both + // cases, the chunks slice will be 0 for the invalid sequence, leading it + // satisfy the n == 0 check below. + n := uint(f.hl.maxRead) + // Optimization. Compiler isn't smart enough to keep f.b,f.nb in registers, + // but is smart enough to keep local variables in registers, so use nb and b, + // inline call to moreBits and reassign b,nb back to f on return. + nb, b := f.nb, f.b + for { + for nb < n { + c, err := fr.ReadByte() + if err != nil { + f.b = b + f.nb = nb + f.err = noEOF(err) + return + } + f.roffset++ + b |= uint32(c) << (nb & 31) + nb += 8 + } + chunk := f.hl.chunks[b&(huffmanNumChunks-1)] + n = uint(chunk & huffmanCountMask) + if n > huffmanChunkBits { + chunk = f.hl.links[chunk>>huffmanValueShift][(b>>huffmanChunkBits)&f.hl.linkMask] + n = uint(chunk & huffmanCountMask) + } + if n <= nb { + if n == 0 { + f.b = b + f.nb = nb + if debugDecode { + fmt.Println("huffsym: n==0") + } + f.err = CorruptInputError(f.roffset) + return + } + f.b = b >> (n & 31) + f.nb = nb - n + v = int(chunk >> huffmanValueShift) + break + } + } + } + + var n uint // number of bits extra + var length int + var err error + switch { + case v < 256: + f.dict.writeByte(byte(v)) + if f.dict.availWrite() == 0 { + f.toRead = f.dict.readFlush() + f.step = (*decompressor).huffmanStringsReader + f.stepState = stateInit + return + } + goto readLiteral + case v == 256: + f.finishBlock() + return + // otherwise, reference to older data + case v < 265: + length = v - (257 - 3) + n = 0 + case v < 269: + length = v*2 - (265*2 - 11) + n = 1 + case v < 273: + length = v*4 - (269*4 - 19) + n = 2 + case v < 277: + length = v*8 - (273*8 - 35) + n = 3 + case v < 281: + length = v*16 - (277*16 - 67) + n = 4 + case v < 285: + length = v*32 - (281*32 - 131) + n = 5 + case v < maxNumLit: + length = 258 + n = 0 + default: + if debugDecode { + fmt.Println(v, ">= maxNumLit") + } + f.err = CorruptInputError(f.roffset) + return + } + if n > 0 { + for f.nb < n { + if err = moreBits(); err != nil { + if debugDecode { + fmt.Println("morebits n>0:", err) + } + f.err = err + return + } + } + length += int(f.b & uint32(1<>= n + f.nb -= n + } + + var dist int + if f.hd == nil { + for f.nb < 5 { + if err = moreBits(); err != nil { + if debugDecode { + fmt.Println("morebits f.nb<5:", err) + } + f.err = err + return + } + } + dist = int(bits.Reverse8(uint8(f.b & 0x1F << 3))) + f.b >>= 5 + f.nb -= 5 + } else { + if dist, err = f.huffSym(f.hd); err != nil { + if debugDecode { + fmt.Println("huffsym:", err) + } + f.err = err + return + } + } + + switch { + case dist < 4: + dist++ + case dist < maxNumDist: + nb := uint(dist-2) >> 1 + // have 1 bit in bottom of dist, need nb more. + extra := (dist & 1) << nb + for f.nb < nb { + if err = moreBits(); err != nil { + if debugDecode { + fmt.Println("morebits f.nb>= nb + f.nb -= nb + dist = 1<<(nb+1) + 1 + extra + default: + if debugDecode { + fmt.Println("dist too big:", dist, maxNumDist) + } + f.err = CorruptInputError(f.roffset) + return + } + + // No check on length; encoding can be prescient. + if dist > f.dict.histSize() { + if debugDecode { + fmt.Println("dist > f.dict.histSize():", dist, f.dict.histSize()) + } + f.err = CorruptInputError(f.roffset) + return + } + + f.copyLen, f.copyDist = length, dist + goto copyHistory + } + +copyHistory: + // Perform a backwards copy according to RFC section 3.2.3. + { + cnt := f.dict.tryWriteCopy(f.copyDist, f.copyLen) + if cnt == 0 { + cnt = f.dict.writeCopy(f.copyDist, f.copyLen) + } + f.copyLen -= cnt + + if f.dict.availWrite() == 0 || f.copyLen > 0 { + f.toRead = f.dict.readFlush() + f.step = (*decompressor).huffmanStringsReader // We need to continue this work + f.stepState = stateDict + return + } + goto readLiteral + } +} + +func (f *decompressor) huffmanBlockDecoder() func() { + switch f.r.(type) { + case *bytes.Buffer: + return f.huffmanBytesBuffer + case *bytes.Reader: + return f.huffmanBytesReader + case *bufio.Reader: + return f.huffmanBufioReader + case *strings.Reader: + return f.huffmanStringsReader + default: + return f.huffmanBlockGeneric + } +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/level1.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/level1.go new file mode 100644 index 000000000000..1e5eea3968aa --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/level1.go @@ -0,0 +1,179 @@ +package flate + +import "fmt" + +// fastGen maintains the table for matches, +// and the previous byte block for level 2. +// This is the generic implementation. +type fastEncL1 struct { + fastGen + table [tableSize]tableEntry +} + +// EncodeL1 uses a similar algorithm to level 1 +func (e *fastEncL1) Encode(dst *tokens, src []byte) { + const ( + inputMargin = 12 - 1 + minNonLiteralBlockSize = 1 + 1 + inputMargin + ) + if debugDeflate && e.cur < 0 { + panic(fmt.Sprint("e.cur < 0: ", e.cur)) + } + + // Protect against e.cur wraparound. + for e.cur >= bufferReset { + if len(e.hist) == 0 { + for i := range e.table[:] { + e.table[i] = tableEntry{} + } + e.cur = maxMatchOffset + break + } + // Shift down everything in the table that isn't already too far away. + minOff := e.cur + int32(len(e.hist)) - maxMatchOffset + for i := range e.table[:] { + v := e.table[i].offset + if v <= minOff { + v = 0 + } else { + v = v - e.cur + maxMatchOffset + } + e.table[i].offset = v + } + e.cur = maxMatchOffset + } + + s := e.addBlock(src) + + // This check isn't in the Snappy implementation, but there, the caller + // instead of the callee handles this case. + if len(src) < minNonLiteralBlockSize { + // We do not fill the token table. + // This will be picked up by caller. + dst.n = uint16(len(src)) + return + } + + // Override src + src = e.hist + nextEmit := s + + // sLimit is when to stop looking for offset/length copies. The inputMargin + // lets us use a fast path for emitLiteral in the main loop, while we are + // looking for copies. + sLimit := int32(len(src) - inputMargin) + + // nextEmit is where in src the next emitLiteral should start from. + cv := load3232(src, s) + + for { + const skipLog = 5 + const doEvery = 2 + + nextS := s + var candidate tableEntry + for { + nextHash := hash(cv) + candidate = e.table[nextHash] + nextS = s + doEvery + (s-nextEmit)>>skipLog + if nextS > sLimit { + goto emitRemainder + } + + now := load6432(src, nextS) + e.table[nextHash] = tableEntry{offset: s + e.cur} + nextHash = hash(uint32(now)) + + offset := s - (candidate.offset - e.cur) + if offset < maxMatchOffset && cv == load3232(src, candidate.offset-e.cur) { + e.table[nextHash] = tableEntry{offset: nextS + e.cur} + break + } + + // Do one right away... + cv = uint32(now) + s = nextS + nextS++ + candidate = e.table[nextHash] + now >>= 8 + e.table[nextHash] = tableEntry{offset: s + e.cur} + + offset = s - (candidate.offset - e.cur) + if offset < maxMatchOffset && cv == load3232(src, candidate.offset-e.cur) { + e.table[nextHash] = tableEntry{offset: nextS + e.cur} + break + } + cv = uint32(now) + s = nextS + } + + // A 4-byte match has been found. We'll later see if more than 4 bytes + // match. But, prior to the match, src[nextEmit:s] are unmatched. Emit + // them as literal bytes. + for { + // Invariant: we have a 4-byte match at s, and no need to emit any + // literal bytes prior to s. + + // Extend the 4-byte match as long as possible. + t := candidate.offset - e.cur + l := e.matchlenLong(s+4, t+4, src) + 4 + + // Extend backwards + for t > 0 && s > nextEmit && src[t-1] == src[s-1] { + s-- + t-- + l++ + } + if nextEmit < s { + emitLiteral(dst, src[nextEmit:s]) + } + + // Save the match found + dst.AddMatchLong(l, uint32(s-t-baseMatchOffset)) + s += l + nextEmit = s + if nextS >= s { + s = nextS + 1 + } + if s >= sLimit { + // Index first pair after match end. + if int(s+l+4) < len(src) { + cv := load3232(src, s) + e.table[hash(cv)] = tableEntry{offset: s + e.cur} + } + goto emitRemainder + } + + // We could immediately start working at s now, but to improve + // compression we first update the hash table at s-2 and at s. If + // another emitCopy is not our next move, also calculate nextHash + // at s+1. At least on GOARCH=amd64, these three hash calculations + // are faster as one load64 call (with some shifts) instead of + // three load32 calls. + x := load6432(src, s-2) + o := e.cur + s - 2 + prevHash := hash(uint32(x)) + e.table[prevHash] = tableEntry{offset: o} + x >>= 16 + currHash := hash(uint32(x)) + candidate = e.table[currHash] + e.table[currHash] = tableEntry{offset: o + 2} + + offset := s - (candidate.offset - e.cur) + if offset > maxMatchOffset || uint32(x) != load3232(src, candidate.offset-e.cur) { + cv = uint32(x >> 8) + s++ + break + } + } + } + +emitRemainder: + if int(nextEmit) < len(src) { + // If nothing was added, don't encode literals. + if dst.n == 0 { + return + } + emitLiteral(dst, src[nextEmit:]) + } +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/level2.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/level2.go new file mode 100644 index 000000000000..5b986a1944ea --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/level2.go @@ -0,0 +1,205 @@ +package flate + +import "fmt" + +// fastGen maintains the table for matches, +// and the previous byte block for level 2. +// This is the generic implementation. +type fastEncL2 struct { + fastGen + table [bTableSize]tableEntry +} + +// EncodeL2 uses a similar algorithm to level 1, but is capable +// of matching across blocks giving better compression at a small slowdown. +func (e *fastEncL2) Encode(dst *tokens, src []byte) { + const ( + inputMargin = 12 - 1 + minNonLiteralBlockSize = 1 + 1 + inputMargin + ) + + if debugDeflate && e.cur < 0 { + panic(fmt.Sprint("e.cur < 0: ", e.cur)) + } + + // Protect against e.cur wraparound. + for e.cur >= bufferReset { + if len(e.hist) == 0 { + for i := range e.table[:] { + e.table[i] = tableEntry{} + } + e.cur = maxMatchOffset + break + } + // Shift down everything in the table that isn't already too far away. + minOff := e.cur + int32(len(e.hist)) - maxMatchOffset + for i := range e.table[:] { + v := e.table[i].offset + if v <= minOff { + v = 0 + } else { + v = v - e.cur + maxMatchOffset + } + e.table[i].offset = v + } + e.cur = maxMatchOffset + } + + s := e.addBlock(src) + + // This check isn't in the Snappy implementation, but there, the caller + // instead of the callee handles this case. + if len(src) < minNonLiteralBlockSize { + // We do not fill the token table. + // This will be picked up by caller. + dst.n = uint16(len(src)) + return + } + + // Override src + src = e.hist + nextEmit := s + + // sLimit is when to stop looking for offset/length copies. The inputMargin + // lets us use a fast path for emitLiteral in the main loop, while we are + // looking for copies. + sLimit := int32(len(src) - inputMargin) + + // nextEmit is where in src the next emitLiteral should start from. + cv := load3232(src, s) + for { + // When should we start skipping if we haven't found matches in a long while. + const skipLog = 5 + const doEvery = 2 + + nextS := s + var candidate tableEntry + for { + nextHash := hash4u(cv, bTableBits) + s = nextS + nextS = s + doEvery + (s-nextEmit)>>skipLog + if nextS > sLimit { + goto emitRemainder + } + candidate = e.table[nextHash] + now := load6432(src, nextS) + e.table[nextHash] = tableEntry{offset: s + e.cur} + nextHash = hash4u(uint32(now), bTableBits) + + offset := s - (candidate.offset - e.cur) + if offset < maxMatchOffset && cv == load3232(src, candidate.offset-e.cur) { + e.table[nextHash] = tableEntry{offset: nextS + e.cur} + break + } + + // Do one right away... + cv = uint32(now) + s = nextS + nextS++ + candidate = e.table[nextHash] + now >>= 8 + e.table[nextHash] = tableEntry{offset: s + e.cur} + + offset = s - (candidate.offset - e.cur) + if offset < maxMatchOffset && cv == load3232(src, candidate.offset-e.cur) { + break + } + cv = uint32(now) + } + + // A 4-byte match has been found. We'll later see if more than 4 bytes + // match. But, prior to the match, src[nextEmit:s] are unmatched. Emit + // them as literal bytes. + + // Call emitCopy, and then see if another emitCopy could be our next + // move. Repeat until we find no match for the input immediately after + // what was consumed by the last emitCopy call. + // + // If we exit this loop normally then we need to call emitLiteral next, + // though we don't yet know how big the literal will be. We handle that + // by proceeding to the next iteration of the main loop. We also can + // exit this loop via goto if we get close to exhausting the input. + for { + // Invariant: we have a 4-byte match at s, and no need to emit any + // literal bytes prior to s. + + // Extend the 4-byte match as long as possible. + t := candidate.offset - e.cur + l := e.matchlenLong(s+4, t+4, src) + 4 + + // Extend backwards + for t > 0 && s > nextEmit && src[t-1] == src[s-1] { + s-- + t-- + l++ + } + if nextEmit < s { + emitLiteral(dst, src[nextEmit:s]) + } + + dst.AddMatchLong(l, uint32(s-t-baseMatchOffset)) + s += l + nextEmit = s + if nextS >= s { + s = nextS + 1 + } + + if s >= sLimit { + // Index first pair after match end. + if int(s+l+4) < len(src) { + cv := load3232(src, s) + e.table[hash4u(cv, bTableBits)] = tableEntry{offset: s + e.cur} + } + goto emitRemainder + } + + // Store every second hash in-between, but offset by 1. + for i := s - l + 2; i < s-5; i += 7 { + x := load6432(src, int32(i)) + nextHash := hash4u(uint32(x), bTableBits) + e.table[nextHash] = tableEntry{offset: e.cur + i} + // Skip one + x >>= 16 + nextHash = hash4u(uint32(x), bTableBits) + e.table[nextHash] = tableEntry{offset: e.cur + i + 2} + // Skip one + x >>= 16 + nextHash = hash4u(uint32(x), bTableBits) + e.table[nextHash] = tableEntry{offset: e.cur + i + 4} + } + + // We could immediately start working at s now, but to improve + // compression we first update the hash table at s-2 to s. If + // another emitCopy is not our next move, also calculate nextHash + // at s+1. At least on GOARCH=amd64, these three hash calculations + // are faster as one load64 call (with some shifts) instead of + // three load32 calls. + x := load6432(src, s-2) + o := e.cur + s - 2 + prevHash := hash4u(uint32(x), bTableBits) + prevHash2 := hash4u(uint32(x>>8), bTableBits) + e.table[prevHash] = tableEntry{offset: o} + e.table[prevHash2] = tableEntry{offset: o + 1} + currHash := hash4u(uint32(x>>16), bTableBits) + candidate = e.table[currHash] + e.table[currHash] = tableEntry{offset: o + 2} + + offset := s - (candidate.offset - e.cur) + if offset > maxMatchOffset || uint32(x>>16) != load3232(src, candidate.offset-e.cur) { + cv = uint32(x >> 24) + s++ + break + } + } + } + +emitRemainder: + if int(nextEmit) < len(src) { + // If nothing was added, don't encode literals. + if dst.n == 0 { + return + } + + emitLiteral(dst, src[nextEmit:]) + } +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/level3.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/level3.go new file mode 100644 index 000000000000..c22b4244a5c0 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/level3.go @@ -0,0 +1,229 @@ +package flate + +import "fmt" + +// fastEncL3 +type fastEncL3 struct { + fastGen + table [tableSize]tableEntryPrev +} + +// Encode uses a similar algorithm to level 2, will check up to two candidates. +func (e *fastEncL3) Encode(dst *tokens, src []byte) { + const ( + inputMargin = 8 - 1 + minNonLiteralBlockSize = 1 + 1 + inputMargin + ) + + if debugDeflate && e.cur < 0 { + panic(fmt.Sprint("e.cur < 0: ", e.cur)) + } + + // Protect against e.cur wraparound. + for e.cur >= bufferReset { + if len(e.hist) == 0 { + for i := range e.table[:] { + e.table[i] = tableEntryPrev{} + } + e.cur = maxMatchOffset + break + } + // Shift down everything in the table that isn't already too far away. + minOff := e.cur + int32(len(e.hist)) - maxMatchOffset + for i := range e.table[:] { + v := e.table[i] + if v.Cur.offset <= minOff { + v.Cur.offset = 0 + } else { + v.Cur.offset = v.Cur.offset - e.cur + maxMatchOffset + } + if v.Prev.offset <= minOff { + v.Prev.offset = 0 + } else { + v.Prev.offset = v.Prev.offset - e.cur + maxMatchOffset + } + e.table[i] = v + } + e.cur = maxMatchOffset + } + + s := e.addBlock(src) + + // Skip if too small. + if len(src) < minNonLiteralBlockSize { + // We do not fill the token table. + // This will be picked up by caller. + dst.n = uint16(len(src)) + return + } + + // Override src + src = e.hist + nextEmit := s + + // sLimit is when to stop looking for offset/length copies. The inputMargin + // lets us use a fast path for emitLiteral in the main loop, while we are + // looking for copies. + sLimit := int32(len(src) - inputMargin) + + // nextEmit is where in src the next emitLiteral should start from. + cv := load3232(src, s) + for { + const skipLog = 6 + nextS := s + var candidate tableEntry + for { + nextHash := hash(cv) + s = nextS + nextS = s + 1 + (s-nextEmit)>>skipLog + if nextS > sLimit { + goto emitRemainder + } + candidates := e.table[nextHash] + now := load3232(src, nextS) + + // Safe offset distance until s + 4... + minOffset := e.cur + s - (maxMatchOffset - 4) + e.table[nextHash] = tableEntryPrev{Prev: candidates.Cur, Cur: tableEntry{offset: s + e.cur}} + + // Check both candidates + candidate = candidates.Cur + if candidate.offset < minOffset { + cv = now + // Previous will also be invalid, we have nothing. + continue + } + + if cv == load3232(src, candidate.offset-e.cur) { + if candidates.Prev.offset < minOffset || cv != load3232(src, candidates.Prev.offset-e.cur) { + break + } + // Both match and are valid, pick longest. + offset := s - (candidate.offset - e.cur) + o2 := s - (candidates.Prev.offset - e.cur) + l1, l2 := matchLen(src[s+4:], src[s-offset+4:]), matchLen(src[s+4:], src[s-o2+4:]) + if l2 > l1 { + candidate = candidates.Prev + } + break + } else { + // We only check if value mismatches. + // Offset will always be invalid in other cases. + candidate = candidates.Prev + if candidate.offset > minOffset && cv == load3232(src, candidate.offset-e.cur) { + break + } + } + cv = now + } + + // Call emitCopy, and then see if another emitCopy could be our next + // move. Repeat until we find no match for the input immediately after + // what was consumed by the last emitCopy call. + // + // If we exit this loop normally then we need to call emitLiteral next, + // though we don't yet know how big the literal will be. We handle that + // by proceeding to the next iteration of the main loop. We also can + // exit this loop via goto if we get close to exhausting the input. + for { + // Invariant: we have a 4-byte match at s, and no need to emit any + // literal bytes prior to s. + + // Extend the 4-byte match as long as possible. + // + t := candidate.offset - e.cur + l := e.matchlenLong(s+4, t+4, src) + 4 + + // Extend backwards + for t > 0 && s > nextEmit && src[t-1] == src[s-1] { + s-- + t-- + l++ + } + if nextEmit < s { + emitLiteral(dst, src[nextEmit:s]) + } + + dst.AddMatchLong(l, uint32(s-t-baseMatchOffset)) + s += l + nextEmit = s + if nextS >= s { + s = nextS + 1 + } + + if s >= sLimit { + t += l + // Index first pair after match end. + if int(t+4) < len(src) && t > 0 { + cv := load3232(src, t) + nextHash := hash(cv) + e.table[nextHash] = tableEntryPrev{ + Prev: e.table[nextHash].Cur, + Cur: tableEntry{offset: e.cur + t}, + } + } + goto emitRemainder + } + + // We could immediately start working at s now, but to improve + // compression we first update the hash table at s-3 to s. + x := load6432(src, s-3) + prevHash := hash(uint32(x)) + e.table[prevHash] = tableEntryPrev{ + Prev: e.table[prevHash].Cur, + Cur: tableEntry{offset: e.cur + s - 3}, + } + x >>= 8 + prevHash = hash(uint32(x)) + + e.table[prevHash] = tableEntryPrev{ + Prev: e.table[prevHash].Cur, + Cur: tableEntry{offset: e.cur + s - 2}, + } + x >>= 8 + prevHash = hash(uint32(x)) + + e.table[prevHash] = tableEntryPrev{ + Prev: e.table[prevHash].Cur, + Cur: tableEntry{offset: e.cur + s - 1}, + } + x >>= 8 + currHash := hash(uint32(x)) + candidates := e.table[currHash] + cv = uint32(x) + e.table[currHash] = tableEntryPrev{ + Prev: candidates.Cur, + Cur: tableEntry{offset: s + e.cur}, + } + + // Check both candidates + candidate = candidates.Cur + minOffset := e.cur + s - (maxMatchOffset - 4) + + if candidate.offset > minOffset && cv != load3232(src, candidate.offset-e.cur) { + // We only check if value mismatches. + // Offset will always be invalid in other cases. + candidate = candidates.Prev + if candidate.offset > minOffset && cv == load3232(src, candidate.offset-e.cur) { + offset := s - (candidate.offset - e.cur) + if offset <= maxMatchOffset { + continue + } + } + } + cv = uint32(x >> 8) + s++ + break + } + } + +emitRemainder: + if int(nextEmit) < len(src) { + // If nothing was added, don't encode literals. + if dst.n == 0 { + return + } + + emitLiteral(dst, src[nextEmit:]) + } +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/level4.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/level4.go new file mode 100644 index 000000000000..e62f0c02b1e7 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/level4.go @@ -0,0 +1,212 @@ +package flate + +import "fmt" + +type fastEncL4 struct { + fastGen + table [tableSize]tableEntry + bTable [tableSize]tableEntry +} + +func (e *fastEncL4) Encode(dst *tokens, src []byte) { + const ( + inputMargin = 12 - 1 + minNonLiteralBlockSize = 1 + 1 + inputMargin + ) + if debugDeflate && e.cur < 0 { + panic(fmt.Sprint("e.cur < 0: ", e.cur)) + } + // Protect against e.cur wraparound. + for e.cur >= bufferReset { + if len(e.hist) == 0 { + for i := range e.table[:] { + e.table[i] = tableEntry{} + } + for i := range e.bTable[:] { + e.bTable[i] = tableEntry{} + } + e.cur = maxMatchOffset + break + } + // Shift down everything in the table that isn't already too far away. + minOff := e.cur + int32(len(e.hist)) - maxMatchOffset + for i := range e.table[:] { + v := e.table[i].offset + if v <= minOff { + v = 0 + } else { + v = v - e.cur + maxMatchOffset + } + e.table[i].offset = v + } + for i := range e.bTable[:] { + v := e.bTable[i].offset + if v <= minOff { + v = 0 + } else { + v = v - e.cur + maxMatchOffset + } + e.bTable[i].offset = v + } + e.cur = maxMatchOffset + } + + s := e.addBlock(src) + + // This check isn't in the Snappy implementation, but there, the caller + // instead of the callee handles this case. + if len(src) < minNonLiteralBlockSize { + // We do not fill the token table. + // This will be picked up by caller. + dst.n = uint16(len(src)) + return + } + + // Override src + src = e.hist + nextEmit := s + + // sLimit is when to stop looking for offset/length copies. The inputMargin + // lets us use a fast path for emitLiteral in the main loop, while we are + // looking for copies. + sLimit := int32(len(src) - inputMargin) + + // nextEmit is where in src the next emitLiteral should start from. + cv := load6432(src, s) + for { + const skipLog = 6 + const doEvery = 1 + + nextS := s + var t int32 + for { + nextHashS := hash4x64(cv, tableBits) + nextHashL := hash7(cv, tableBits) + + s = nextS + nextS = s + doEvery + (s-nextEmit)>>skipLog + if nextS > sLimit { + goto emitRemainder + } + // Fetch a short+long candidate + sCandidate := e.table[nextHashS] + lCandidate := e.bTable[nextHashL] + next := load6432(src, nextS) + entry := tableEntry{offset: s + e.cur} + e.table[nextHashS] = entry + e.bTable[nextHashL] = entry + + t = lCandidate.offset - e.cur + if s-t < maxMatchOffset && uint32(cv) == load3232(src, lCandidate.offset-e.cur) { + // We got a long match. Use that. + break + } + + t = sCandidate.offset - e.cur + if s-t < maxMatchOffset && uint32(cv) == load3232(src, sCandidate.offset-e.cur) { + // Found a 4 match... + lCandidate = e.bTable[hash7(next, tableBits)] + + // If the next long is a candidate, check if we should use that instead... + lOff := nextS - (lCandidate.offset - e.cur) + if lOff < maxMatchOffset && load3232(src, lCandidate.offset-e.cur) == uint32(next) { + l1, l2 := matchLen(src[s+4:], src[t+4:]), matchLen(src[nextS+4:], src[nextS-lOff+4:]) + if l2 > l1 { + s = nextS + t = lCandidate.offset - e.cur + } + } + break + } + cv = next + } + + // A 4-byte match has been found. We'll later see if more than 4 bytes + // match. But, prior to the match, src[nextEmit:s] are unmatched. Emit + // them as literal bytes. + + // Extend the 4-byte match as long as possible. + l := e.matchlenLong(s+4, t+4, src) + 4 + + // Extend backwards + for t > 0 && s > nextEmit && src[t-1] == src[s-1] { + s-- + t-- + l++ + } + if nextEmit < s { + emitLiteral(dst, src[nextEmit:s]) + } + if debugDeflate { + if t >= s { + panic("s-t") + } + if (s - t) > maxMatchOffset { + panic(fmt.Sprintln("mmo", t)) + } + if l < baseMatchLength { + panic("bml") + } + } + + dst.AddMatchLong(l, uint32(s-t-baseMatchOffset)) + s += l + nextEmit = s + if nextS >= s { + s = nextS + 1 + } + + if s >= sLimit { + // Index first pair after match end. + if int(s+8) < len(src) { + cv := load6432(src, s) + e.table[hash4x64(cv, tableBits)] = tableEntry{offset: s + e.cur} + e.bTable[hash7(cv, tableBits)] = tableEntry{offset: s + e.cur} + } + goto emitRemainder + } + + // Store every 3rd hash in-between + if true { + i := nextS + if i < s-1 { + cv := load6432(src, i) + t := tableEntry{offset: i + e.cur} + t2 := tableEntry{offset: t.offset + 1} + e.bTable[hash7(cv, tableBits)] = t + e.bTable[hash7(cv>>8, tableBits)] = t2 + e.table[hash4u(uint32(cv>>8), tableBits)] = t2 + + i += 3 + for ; i < s-1; i += 3 { + cv := load6432(src, i) + t := tableEntry{offset: i + e.cur} + t2 := tableEntry{offset: t.offset + 1} + e.bTable[hash7(cv, tableBits)] = t + e.bTable[hash7(cv>>8, tableBits)] = t2 + e.table[hash4u(uint32(cv>>8), tableBits)] = t2 + } + } + } + + // We could immediately start working at s now, but to improve + // compression we first update the hash table at s-1 and at s. + x := load6432(src, s-1) + o := e.cur + s - 1 + prevHashS := hash4x64(x, tableBits) + prevHashL := hash7(x, tableBits) + e.table[prevHashS] = tableEntry{offset: o} + e.bTable[prevHashL] = tableEntry{offset: o} + cv = x >> 8 + } + +emitRemainder: + if int(nextEmit) < len(src) { + // If nothing was added, don't encode literals. + if dst.n == 0 { + return + } + + emitLiteral(dst, src[nextEmit:]) + } +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/level5.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/level5.go new file mode 100644 index 000000000000..d513f1ffd37c --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/level5.go @@ -0,0 +1,279 @@ +package flate + +import "fmt" + +type fastEncL5 struct { + fastGen + table [tableSize]tableEntry + bTable [tableSize]tableEntryPrev +} + +func (e *fastEncL5) Encode(dst *tokens, src []byte) { + const ( + inputMargin = 12 - 1 + minNonLiteralBlockSize = 1 + 1 + inputMargin + ) + if debugDeflate && e.cur < 0 { + panic(fmt.Sprint("e.cur < 0: ", e.cur)) + } + + // Protect against e.cur wraparound. + for e.cur >= bufferReset { + if len(e.hist) == 0 { + for i := range e.table[:] { + e.table[i] = tableEntry{} + } + for i := range e.bTable[:] { + e.bTable[i] = tableEntryPrev{} + } + e.cur = maxMatchOffset + break + } + // Shift down everything in the table that isn't already too far away. + minOff := e.cur + int32(len(e.hist)) - maxMatchOffset + for i := range e.table[:] { + v := e.table[i].offset + if v <= minOff { + v = 0 + } else { + v = v - e.cur + maxMatchOffset + } + e.table[i].offset = v + } + for i := range e.bTable[:] { + v := e.bTable[i] + if v.Cur.offset <= minOff { + v.Cur.offset = 0 + v.Prev.offset = 0 + } else { + v.Cur.offset = v.Cur.offset - e.cur + maxMatchOffset + if v.Prev.offset <= minOff { + v.Prev.offset = 0 + } else { + v.Prev.offset = v.Prev.offset - e.cur + maxMatchOffset + } + } + e.bTable[i] = v + } + e.cur = maxMatchOffset + } + + s := e.addBlock(src) + + // This check isn't in the Snappy implementation, but there, the caller + // instead of the callee handles this case. + if len(src) < minNonLiteralBlockSize { + // We do not fill the token table. + // This will be picked up by caller. + dst.n = uint16(len(src)) + return + } + + // Override src + src = e.hist + nextEmit := s + + // sLimit is when to stop looking for offset/length copies. The inputMargin + // lets us use a fast path for emitLiteral in the main loop, while we are + // looking for copies. + sLimit := int32(len(src) - inputMargin) + + // nextEmit is where in src the next emitLiteral should start from. + cv := load6432(src, s) + for { + const skipLog = 6 + const doEvery = 1 + + nextS := s + var l int32 + var t int32 + for { + nextHashS := hash4x64(cv, tableBits) + nextHashL := hash7(cv, tableBits) + + s = nextS + nextS = s + doEvery + (s-nextEmit)>>skipLog + if nextS > sLimit { + goto emitRemainder + } + // Fetch a short+long candidate + sCandidate := e.table[nextHashS] + lCandidate := e.bTable[nextHashL] + next := load6432(src, nextS) + entry := tableEntry{offset: s + e.cur} + e.table[nextHashS] = entry + eLong := &e.bTable[nextHashL] + eLong.Cur, eLong.Prev = entry, eLong.Cur + + nextHashS = hash4x64(next, tableBits) + nextHashL = hash7(next, tableBits) + + t = lCandidate.Cur.offset - e.cur + if s-t < maxMatchOffset { + if uint32(cv) == load3232(src, lCandidate.Cur.offset-e.cur) { + // Store the next match + e.table[nextHashS] = tableEntry{offset: nextS + e.cur} + eLong := &e.bTable[nextHashL] + eLong.Cur, eLong.Prev = tableEntry{offset: nextS + e.cur}, eLong.Cur + + t2 := lCandidate.Prev.offset - e.cur + if s-t2 < maxMatchOffset && uint32(cv) == load3232(src, lCandidate.Prev.offset-e.cur) { + l = e.matchlen(s+4, t+4, src) + 4 + ml1 := e.matchlen(s+4, t2+4, src) + 4 + if ml1 > l { + t = t2 + l = ml1 + break + } + } + break + } + t = lCandidate.Prev.offset - e.cur + if s-t < maxMatchOffset && uint32(cv) == load3232(src, lCandidate.Prev.offset-e.cur) { + // Store the next match + e.table[nextHashS] = tableEntry{offset: nextS + e.cur} + eLong := &e.bTable[nextHashL] + eLong.Cur, eLong.Prev = tableEntry{offset: nextS + e.cur}, eLong.Cur + break + } + } + + t = sCandidate.offset - e.cur + if s-t < maxMatchOffset && uint32(cv) == load3232(src, sCandidate.offset-e.cur) { + // Found a 4 match... + l = e.matchlen(s+4, t+4, src) + 4 + lCandidate = e.bTable[nextHashL] + // Store the next match + + e.table[nextHashS] = tableEntry{offset: nextS + e.cur} + eLong := &e.bTable[nextHashL] + eLong.Cur, eLong.Prev = tableEntry{offset: nextS + e.cur}, eLong.Cur + + // If the next long is a candidate, use that... + t2 := lCandidate.Cur.offset - e.cur + if nextS-t2 < maxMatchOffset { + if load3232(src, lCandidate.Cur.offset-e.cur) == uint32(next) { + ml := e.matchlen(nextS+4, t2+4, src) + 4 + if ml > l { + t = t2 + s = nextS + l = ml + break + } + } + // If the previous long is a candidate, use that... + t2 = lCandidate.Prev.offset - e.cur + if nextS-t2 < maxMatchOffset && load3232(src, lCandidate.Prev.offset-e.cur) == uint32(next) { + ml := e.matchlen(nextS+4, t2+4, src) + 4 + if ml > l { + t = t2 + s = nextS + l = ml + break + } + } + } + break + } + cv = next + } + + // A 4-byte match has been found. We'll later see if more than 4 bytes + // match. But, prior to the match, src[nextEmit:s] are unmatched. Emit + // them as literal bytes. + + // Extend the 4-byte match as long as possible. + if l == 0 { + l = e.matchlenLong(s+4, t+4, src) + 4 + } else if l == maxMatchLength { + l += e.matchlenLong(s+l, t+l, src) + } + // Extend backwards + for t > 0 && s > nextEmit && src[t-1] == src[s-1] { + s-- + t-- + l++ + } + if nextEmit < s { + emitLiteral(dst, src[nextEmit:s]) + } + if debugDeflate { + if t >= s { + panic(fmt.Sprintln("s-t", s, t)) + } + if (s - t) > maxMatchOffset { + panic(fmt.Sprintln("mmo", s-t)) + } + if l < baseMatchLength { + panic("bml") + } + } + + dst.AddMatchLong(l, uint32(s-t-baseMatchOffset)) + s += l + nextEmit = s + if nextS >= s { + s = nextS + 1 + } + + if s >= sLimit { + goto emitRemainder + } + + // Store every 3rd hash in-between. + if true { + const hashEvery = 3 + i := s - l + 1 + if i < s-1 { + cv := load6432(src, i) + t := tableEntry{offset: i + e.cur} + e.table[hash4x64(cv, tableBits)] = t + eLong := &e.bTable[hash7(cv, tableBits)] + eLong.Cur, eLong.Prev = t, eLong.Cur + + // Do an long at i+1 + cv >>= 8 + t = tableEntry{offset: t.offset + 1} + eLong = &e.bTable[hash7(cv, tableBits)] + eLong.Cur, eLong.Prev = t, eLong.Cur + + // We only have enough bits for a short entry at i+2 + cv >>= 8 + t = tableEntry{offset: t.offset + 1} + e.table[hash4x64(cv, tableBits)] = t + + // Skip one - otherwise we risk hitting 's' + i += 4 + for ; i < s-1; i += hashEvery { + cv := load6432(src, i) + t := tableEntry{offset: i + e.cur} + t2 := tableEntry{offset: t.offset + 1} + eLong := &e.bTable[hash7(cv, tableBits)] + eLong.Cur, eLong.Prev = t, eLong.Cur + e.table[hash4u(uint32(cv>>8), tableBits)] = t2 + } + } + } + + // We could immediately start working at s now, but to improve + // compression we first update the hash table at s-1 and at s. + x := load6432(src, s-1) + o := e.cur + s - 1 + prevHashS := hash4x64(x, tableBits) + prevHashL := hash7(x, tableBits) + e.table[prevHashS] = tableEntry{offset: o} + eLong := &e.bTable[prevHashL] + eLong.Cur, eLong.Prev = tableEntry{offset: o}, eLong.Cur + cv = x >> 8 + } + +emitRemainder: + if int(nextEmit) < len(src) { + // If nothing was added, don't encode literals. + if dst.n == 0 { + return + } + + emitLiteral(dst, src[nextEmit:]) + } +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/level6.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/level6.go new file mode 100644 index 000000000000..a52c80ea456c --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/level6.go @@ -0,0 +1,282 @@ +package flate + +import "fmt" + +type fastEncL6 struct { + fastGen + table [tableSize]tableEntry + bTable [tableSize]tableEntryPrev +} + +func (e *fastEncL6) Encode(dst *tokens, src []byte) { + const ( + inputMargin = 12 - 1 + minNonLiteralBlockSize = 1 + 1 + inputMargin + ) + if debugDeflate && e.cur < 0 { + panic(fmt.Sprint("e.cur < 0: ", e.cur)) + } + + // Protect against e.cur wraparound. + for e.cur >= bufferReset { + if len(e.hist) == 0 { + for i := range e.table[:] { + e.table[i] = tableEntry{} + } + for i := range e.bTable[:] { + e.bTable[i] = tableEntryPrev{} + } + e.cur = maxMatchOffset + break + } + // Shift down everything in the table that isn't already too far away. + minOff := e.cur + int32(len(e.hist)) - maxMatchOffset + for i := range e.table[:] { + v := e.table[i].offset + if v <= minOff { + v = 0 + } else { + v = v - e.cur + maxMatchOffset + } + e.table[i].offset = v + } + for i := range e.bTable[:] { + v := e.bTable[i] + if v.Cur.offset <= minOff { + v.Cur.offset = 0 + v.Prev.offset = 0 + } else { + v.Cur.offset = v.Cur.offset - e.cur + maxMatchOffset + if v.Prev.offset <= minOff { + v.Prev.offset = 0 + } else { + v.Prev.offset = v.Prev.offset - e.cur + maxMatchOffset + } + } + e.bTable[i] = v + } + e.cur = maxMatchOffset + } + + s := e.addBlock(src) + + // This check isn't in the Snappy implementation, but there, the caller + // instead of the callee handles this case. + if len(src) < minNonLiteralBlockSize { + // We do not fill the token table. + // This will be picked up by caller. + dst.n = uint16(len(src)) + return + } + + // Override src + src = e.hist + nextEmit := s + + // sLimit is when to stop looking for offset/length copies. The inputMargin + // lets us use a fast path for emitLiteral in the main loop, while we are + // looking for copies. + sLimit := int32(len(src) - inputMargin) + + // nextEmit is where in src the next emitLiteral should start from. + cv := load6432(src, s) + // Repeat MUST be > 1 and within range + repeat := int32(1) + for { + const skipLog = 7 + const doEvery = 1 + + nextS := s + var l int32 + var t int32 + for { + nextHashS := hash4x64(cv, tableBits) + nextHashL := hash7(cv, tableBits) + s = nextS + nextS = s + doEvery + (s-nextEmit)>>skipLog + if nextS > sLimit { + goto emitRemainder + } + // Fetch a short+long candidate + sCandidate := e.table[nextHashS] + lCandidate := e.bTable[nextHashL] + next := load6432(src, nextS) + entry := tableEntry{offset: s + e.cur} + e.table[nextHashS] = entry + eLong := &e.bTable[nextHashL] + eLong.Cur, eLong.Prev = entry, eLong.Cur + + // Calculate hashes of 'next' + nextHashS = hash4x64(next, tableBits) + nextHashL = hash7(next, tableBits) + + t = lCandidate.Cur.offset - e.cur + if s-t < maxMatchOffset { + if uint32(cv) == load3232(src, lCandidate.Cur.offset-e.cur) { + // Long candidate matches at least 4 bytes. + + // Store the next match + e.table[nextHashS] = tableEntry{offset: nextS + e.cur} + eLong := &e.bTable[nextHashL] + eLong.Cur, eLong.Prev = tableEntry{offset: nextS + e.cur}, eLong.Cur + + // Check the previous long candidate as well. + t2 := lCandidate.Prev.offset - e.cur + if s-t2 < maxMatchOffset && uint32(cv) == load3232(src, lCandidate.Prev.offset-e.cur) { + l = e.matchlen(s+4, t+4, src) + 4 + ml1 := e.matchlen(s+4, t2+4, src) + 4 + if ml1 > l { + t = t2 + l = ml1 + break + } + } + break + } + // Current value did not match, but check if previous long value does. + t = lCandidate.Prev.offset - e.cur + if s-t < maxMatchOffset && uint32(cv) == load3232(src, lCandidate.Prev.offset-e.cur) { + // Store the next match + e.table[nextHashS] = tableEntry{offset: nextS + e.cur} + eLong := &e.bTable[nextHashL] + eLong.Cur, eLong.Prev = tableEntry{offset: nextS + e.cur}, eLong.Cur + break + } + } + + t = sCandidate.offset - e.cur + if s-t < maxMatchOffset && uint32(cv) == load3232(src, sCandidate.offset-e.cur) { + // Found a 4 match... + l = e.matchlen(s+4, t+4, src) + 4 + + // Look up next long candidate (at nextS) + lCandidate = e.bTable[nextHashL] + + // Store the next match + e.table[nextHashS] = tableEntry{offset: nextS + e.cur} + eLong := &e.bTable[nextHashL] + eLong.Cur, eLong.Prev = tableEntry{offset: nextS + e.cur}, eLong.Cur + + // Check repeat at s + repOff + const repOff = 1 + t2 := s - repeat + repOff + if load3232(src, t2) == uint32(cv>>(8*repOff)) { + ml := e.matchlen(s+4+repOff, t2+4, src) + 4 + if ml > l { + t = t2 + l = ml + s += repOff + // Not worth checking more. + break + } + } + + // If the next long is a candidate, use that... + t2 = lCandidate.Cur.offset - e.cur + if nextS-t2 < maxMatchOffset { + if load3232(src, lCandidate.Cur.offset-e.cur) == uint32(next) { + ml := e.matchlen(nextS+4, t2+4, src) + 4 + if ml > l { + t = t2 + s = nextS + l = ml + // This is ok, but check previous as well. + } + } + // If the previous long is a candidate, use that... + t2 = lCandidate.Prev.offset - e.cur + if nextS-t2 < maxMatchOffset && load3232(src, lCandidate.Prev.offset-e.cur) == uint32(next) { + ml := e.matchlen(nextS+4, t2+4, src) + 4 + if ml > l { + t = t2 + s = nextS + l = ml + break + } + } + } + break + } + cv = next + } + + // A 4-byte match has been found. We'll later see if more than 4 bytes + // match. But, prior to the match, src[nextEmit:s] are unmatched. Emit + // them as literal bytes. + + // Extend the 4-byte match as long as possible. + if l == 0 { + l = e.matchlenLong(s+4, t+4, src) + 4 + } else if l == maxMatchLength { + l += e.matchlenLong(s+l, t+l, src) + } + + // Extend backwards + for t > 0 && s > nextEmit && src[t-1] == src[s-1] { + s-- + t-- + l++ + } + if nextEmit < s { + emitLiteral(dst, src[nextEmit:s]) + } + if false { + if t >= s { + panic(fmt.Sprintln("s-t", s, t)) + } + if (s - t) > maxMatchOffset { + panic(fmt.Sprintln("mmo", s-t)) + } + if l < baseMatchLength { + panic("bml") + } + } + + dst.AddMatchLong(l, uint32(s-t-baseMatchOffset)) + repeat = s - t + s += l + nextEmit = s + if nextS >= s { + s = nextS + 1 + } + + if s >= sLimit { + // Index after match end. + for i := nextS + 1; i < int32(len(src))-8; i += 2 { + cv := load6432(src, i) + e.table[hash4x64(cv, tableBits)] = tableEntry{offset: i + e.cur} + eLong := &e.bTable[hash7(cv, tableBits)] + eLong.Cur, eLong.Prev = tableEntry{offset: i + e.cur}, eLong.Cur + } + goto emitRemainder + } + + // Store every long hash in-between and every second short. + if true { + for i := nextS + 1; i < s-1; i += 2 { + cv := load6432(src, i) + t := tableEntry{offset: i + e.cur} + t2 := tableEntry{offset: t.offset + 1} + eLong := &e.bTable[hash7(cv, tableBits)] + eLong2 := &e.bTable[hash7(cv>>8, tableBits)] + e.table[hash4x64(cv, tableBits)] = t + eLong.Cur, eLong.Prev = t, eLong.Cur + eLong2.Cur, eLong2.Prev = t2, eLong2.Cur + } + } + + // We could immediately start working at s now, but to improve + // compression we first update the hash table at s-1 and at s. + cv = load6432(src, s) + } + +emitRemainder: + if int(nextEmit) < len(src) { + // If nothing was added, don't encode literals. + if dst.n == 0 { + return + } + + emitLiteral(dst, src[nextEmit:]) + } +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/stateless.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/stateless.go new file mode 100644 index 000000000000..53e899124639 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/stateless.go @@ -0,0 +1,297 @@ +package flate + +import ( + "io" + "math" + "sync" +) + +const ( + maxStatelessBlock = math.MaxInt16 + // dictionary will be taken from maxStatelessBlock, so limit it. + maxStatelessDict = 8 << 10 + + slTableBits = 13 + slTableSize = 1 << slTableBits + slTableShift = 32 - slTableBits +) + +type statelessWriter struct { + dst io.Writer + closed bool +} + +func (s *statelessWriter) Close() error { + if s.closed { + return nil + } + s.closed = true + // Emit EOF block + return StatelessDeflate(s.dst, nil, true, nil) +} + +func (s *statelessWriter) Write(p []byte) (n int, err error) { + err = StatelessDeflate(s.dst, p, false, nil) + if err != nil { + return 0, err + } + return len(p), nil +} + +func (s *statelessWriter) Reset(w io.Writer) { + s.dst = w + s.closed = false +} + +// NewStatelessWriter will do compression but without maintaining any state +// between Write calls. +// There will be no memory kept between Write calls, +// but compression and speed will be suboptimal. +// Because of this, the size of actual Write calls will affect output size. +func NewStatelessWriter(dst io.Writer) io.WriteCloser { + return &statelessWriter{dst: dst} +} + +// bitWriterPool contains bit writers that can be reused. +var bitWriterPool = sync.Pool{ + New: func() interface{} { + return newHuffmanBitWriter(nil) + }, +} + +// StatelessDeflate allows to compress directly to a Writer without retaining state. +// When returning everything will be flushed. +// Up to 8KB of an optional dictionary can be given which is presumed to presumed to precede the block. +// Longer dictionaries will be truncated and will still produce valid output. +// Sending nil dictionary is perfectly fine. +func StatelessDeflate(out io.Writer, in []byte, eof bool, dict []byte) error { + var dst tokens + bw := bitWriterPool.Get().(*huffmanBitWriter) + bw.reset(out) + defer func() { + // don't keep a reference to our output + bw.reset(nil) + bitWriterPool.Put(bw) + }() + if eof && len(in) == 0 { + // Just write an EOF block. + // Could be faster... + bw.writeStoredHeader(0, true) + bw.flush() + return bw.err + } + + // Truncate dict + if len(dict) > maxStatelessDict { + dict = dict[len(dict)-maxStatelessDict:] + } + + for len(in) > 0 { + todo := in + if len(todo) > maxStatelessBlock-len(dict) { + todo = todo[:maxStatelessBlock-len(dict)] + } + in = in[len(todo):] + uncompressed := todo + if len(dict) > 0 { + // combine dict and source + bufLen := len(todo) + len(dict) + combined := make([]byte, bufLen) + copy(combined, dict) + copy(combined[len(dict):], todo) + todo = combined + } + // Compress + statelessEnc(&dst, todo, int16(len(dict))) + isEof := eof && len(in) == 0 + + if dst.n == 0 { + bw.writeStoredHeader(len(uncompressed), isEof) + if bw.err != nil { + return bw.err + } + bw.writeBytes(uncompressed) + } else if int(dst.n) > len(uncompressed)-len(uncompressed)>>4 { + // If we removed less than 1/16th, huffman compress the block. + bw.writeBlockHuff(isEof, uncompressed, len(in) == 0) + } else { + bw.writeBlockDynamic(&dst, isEof, uncompressed, len(in) == 0) + } + if len(in) > 0 { + // Retain a dict if we have more + dict = todo[len(todo)-maxStatelessDict:] + dst.Reset() + } + if bw.err != nil { + return bw.err + } + } + if !eof { + // Align, only a stored block can do that. + bw.writeStoredHeader(0, false) + } + bw.flush() + return bw.err +} + +func hashSL(u uint32) uint32 { + return (u * 0x1e35a7bd) >> slTableShift +} + +func load3216(b []byte, i int16) uint32 { + // Help the compiler eliminate bounds checks on the read so it can be done in a single read. + b = b[i:] + b = b[:4] + return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24 +} + +func load6416(b []byte, i int16) uint64 { + // Help the compiler eliminate bounds checks on the read so it can be done in a single read. + b = b[i:] + b = b[:8] + return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 | + uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56 +} + +func statelessEnc(dst *tokens, src []byte, startAt int16) { + const ( + inputMargin = 12 - 1 + minNonLiteralBlockSize = 1 + 1 + inputMargin + ) + + type tableEntry struct { + offset int16 + } + + var table [slTableSize]tableEntry + + // This check isn't in the Snappy implementation, but there, the caller + // instead of the callee handles this case. + if len(src)-int(startAt) < minNonLiteralBlockSize { + // We do not fill the token table. + // This will be picked up by caller. + dst.n = 0 + return + } + // Index until startAt + if startAt > 0 { + cv := load3232(src, 0) + for i := int16(0); i < startAt; i++ { + table[hashSL(cv)] = tableEntry{offset: i} + cv = (cv >> 8) | (uint32(src[i+4]) << 24) + } + } + + s := startAt + 1 + nextEmit := startAt + // sLimit is when to stop looking for offset/length copies. The inputMargin + // lets us use a fast path for emitLiteral in the main loop, while we are + // looking for copies. + sLimit := int16(len(src) - inputMargin) + + // nextEmit is where in src the next emitLiteral should start from. + cv := load3216(src, s) + + for { + const skipLog = 5 + const doEvery = 2 + + nextS := s + var candidate tableEntry + for { + nextHash := hashSL(cv) + candidate = table[nextHash] + nextS = s + doEvery + (s-nextEmit)>>skipLog + if nextS > sLimit || nextS <= 0 { + goto emitRemainder + } + + now := load6416(src, nextS) + table[nextHash] = tableEntry{offset: s} + nextHash = hashSL(uint32(now)) + + if cv == load3216(src, candidate.offset) { + table[nextHash] = tableEntry{offset: nextS} + break + } + + // Do one right away... + cv = uint32(now) + s = nextS + nextS++ + candidate = table[nextHash] + now >>= 8 + table[nextHash] = tableEntry{offset: s} + + if cv == load3216(src, candidate.offset) { + table[nextHash] = tableEntry{offset: nextS} + break + } + cv = uint32(now) + s = nextS + } + + // A 4-byte match has been found. We'll later see if more than 4 bytes + // match. But, prior to the match, src[nextEmit:s] are unmatched. Emit + // them as literal bytes. + for { + // Invariant: we have a 4-byte match at s, and no need to emit any + // literal bytes prior to s. + + // Extend the 4-byte match as long as possible. + t := candidate.offset + l := int16(matchLen(src[s+4:], src[t+4:]) + 4) + + // Extend backwards + for t > 0 && s > nextEmit && src[t-1] == src[s-1] { + s-- + t-- + l++ + } + if nextEmit < s { + emitLiteral(dst, src[nextEmit:s]) + } + + // Save the match found + dst.AddMatchLong(int32(l), uint32(s-t-baseMatchOffset)) + s += l + nextEmit = s + if nextS >= s { + s = nextS + 1 + } + if s >= sLimit { + goto emitRemainder + } + + // We could immediately start working at s now, but to improve + // compression we first update the hash table at s-2 and at s. If + // another emitCopy is not our next move, also calculate nextHash + // at s+1. At least on GOARCH=amd64, these three hash calculations + // are faster as one load64 call (with some shifts) instead of + // three load32 calls. + x := load6416(src, s-2) + o := s - 2 + prevHash := hashSL(uint32(x)) + table[prevHash] = tableEntry{offset: o} + x >>= 16 + currHash := hashSL(uint32(x)) + candidate = table[currHash] + table[currHash] = tableEntry{offset: o + 2} + + if uint32(x) != load3216(src, candidate.offset) { + cv = uint32(x >> 8) + s++ + break + } + } + } + +emitRemainder: + if int(nextEmit) < len(src) { + // If nothing was added, don't encode literals. + if dst.n == 0 { + return + } + emitLiteral(dst, src[nextEmit:]) + } +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/token.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/token.go new file mode 100644 index 000000000000..f9abf606d67c --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/klauspost/compress/flate/token.go @@ -0,0 +1,375 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package flate + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "math" +) + +const ( + // 2 bits: type 0 = literal 1=EOF 2=Match 3=Unused + // 8 bits: xlength = length - MIN_MATCH_LENGTH + // 22 bits xoffset = offset - MIN_OFFSET_SIZE, or literal + lengthShift = 22 + offsetMask = 1<maxnumlit + offHist [32]uint16 // offset codes + litHist [256]uint16 // codes 0->255 + n uint16 // Must be able to contain maxStoreBlockSize + tokens [maxStoreBlockSize + 1]token +} + +func (t *tokens) Reset() { + if t.n == 0 { + return + } + t.n = 0 + t.nLits = 0 + for i := range t.litHist[:] { + t.litHist[i] = 0 + } + for i := range t.extraHist[:] { + t.extraHist[i] = 0 + } + for i := range t.offHist[:] { + t.offHist[i] = 0 + } +} + +func (t *tokens) Fill() { + if t.n == 0 { + return + } + for i, v := range t.litHist[:] { + if v == 0 { + t.litHist[i] = 1 + t.nLits++ + } + } + for i, v := range t.extraHist[:literalCount-256] { + if v == 0 { + t.nLits++ + t.extraHist[i] = 1 + } + } + for i, v := range t.offHist[:offsetCodeCount] { + if v == 0 { + t.offHist[i] = 1 + } + } +} + +func indexTokens(in []token) tokens { + var t tokens + t.indexTokens(in) + return t +} + +func (t *tokens) indexTokens(in []token) { + t.Reset() + for _, tok := range in { + if tok < matchType { + t.AddLiteral(tok.literal()) + continue + } + t.AddMatch(uint32(tok.length()), tok.offset()) + } +} + +// emitLiteral writes a literal chunk and returns the number of bytes written. +func emitLiteral(dst *tokens, lit []byte) { + ol := int(dst.n) + for i, v := range lit { + dst.tokens[(i+ol)&maxStoreBlockSize] = token(v) + dst.litHist[v]++ + } + dst.n += uint16(len(lit)) + dst.nLits += len(lit) +} + +func (t *tokens) AddLiteral(lit byte) { + t.tokens[t.n] = token(lit) + t.litHist[lit]++ + t.n++ + t.nLits++ +} + +// from https://stackoverflow.com/a/28730362 +func mFastLog2(val float32) float32 { + ux := int32(math.Float32bits(val)) + log2 := (float32)(((ux >> 23) & 255) - 128) + ux &= -0x7f800001 + ux += 127 << 23 + uval := math.Float32frombits(uint32(ux)) + log2 += ((-0.34484843)*uval+2.02466578)*uval - 0.67487759 + return log2 +} + +// EstimatedBits will return an minimum size estimated by an *optimal* +// compression of the block. +// The size of the block +func (t *tokens) EstimatedBits() int { + shannon := float32(0) + bits := int(0) + nMatches := 0 + if t.nLits > 0 { + invTotal := 1.0 / float32(t.nLits) + for _, v := range t.litHist[:] { + if v > 0 { + n := float32(v) + shannon += -mFastLog2(n*invTotal) * n + } + } + // Just add 15 for EOB + shannon += 15 + for i, v := range t.extraHist[1 : literalCount-256] { + if v > 0 { + n := float32(v) + shannon += -mFastLog2(n*invTotal) * n + bits += int(lengthExtraBits[i&31]) * int(v) + nMatches += int(v) + } + } + } + if nMatches > 0 { + invTotal := 1.0 / float32(nMatches) + for i, v := range t.offHist[:offsetCodeCount] { + if v > 0 { + n := float32(v) + shannon += -mFastLog2(n*invTotal) * n + bits += int(offsetExtraBits[i&31]) * int(v) + } + } + } + return int(shannon) + bits +} + +// AddMatch adds a match to the tokens. +// This function is very sensitive to inlining and right on the border. +func (t *tokens) AddMatch(xlength uint32, xoffset uint32) { + if debugDeflate { + if xlength >= maxMatchLength+baseMatchLength { + panic(fmt.Errorf("invalid length: %v", xlength)) + } + if xoffset >= maxMatchOffset+baseMatchOffset { + panic(fmt.Errorf("invalid offset: %v", xoffset)) + } + } + t.nLits++ + lengthCode := lengthCodes1[uint8(xlength)] & 31 + t.tokens[t.n] = token(matchType | xlength<= maxMatchOffset+baseMatchOffset { + panic(fmt.Errorf("invalid offset: %v", xoffset)) + } + } + oc := offsetCode(xoffset) & 31 + for xlength > 0 { + xl := xlength + if xl > 258 { + // We need to have at least baseMatchLength left over for next loop. + xl = 258 - baseMatchLength + } + xlength -= xl + xl -= 3 + t.nLits++ + lengthCode := lengthCodes1[uint8(xl)] & 31 + t.tokens[t.n] = token(matchType | uint32(xl)<> lengthShift) } + +// The code is never more than 8 bits, but is returned as uint32 for convenience. +func lengthCode(len uint8) uint32 { return uint32(lengthCodes[len]) } + +// Returns the offset code corresponding to a specific offset +func offsetCode(off uint32) uint32 { + if false { + if off < uint32(len(offsetCodes)) { + return offsetCodes[off&255] + } else if off>>7 < uint32(len(offsetCodes)) { + return offsetCodes[(off>>7)&255] + 14 + } else { + return offsetCodes[(off>>14)&255] + 28 + } + } + if off < uint32(len(offsetCodes)) { + return offsetCodes[uint8(off)] + } + return offsetCodes14[uint8(off>>7)] +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/sacOO7/go-logger/.gitignore b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/sacOO7/go-logger/.gitignore new file mode 100644 index 000000000000..f1c181ec9c5c --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/sacOO7/go-logger/.gitignore @@ -0,0 +1,12 @@ +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, build with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/sacOO7/go-logger/LICENSE b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/sacOO7/go-logger/LICENSE new file mode 100644 index 000000000000..7364c76bad1c --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/sacOO7/go-logger/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2018 sachin shinde + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/sacOO7/go-logger/logging.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/sacOO7/go-logger/logging.go new file mode 100644 index 000000000000..12d377d8056e --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/sacOO7/go-logger/logging.go @@ -0,0 +1,82 @@ +package logging + +import ( + "log" + "io" + "io/ioutil" + "os" +) + +type Logger struct { + Name string + Trace *log.Logger + Info *log.Logger + Warning *log.Logger + Error *log.Logger + level LoggingLevel +} + +var loggers = make(map[string]Logger) + +func GetLogger(name string) Logger { + return New(name, os.Stdout, os.Stdout, os.Stdout, os.Stderr) +} + +func (logger Logger) SetLevel(level LoggingLevel) Logger{ + switch level { + case TRACE: + logger.Trace.SetOutput(os.Stdout); + logger.Info.SetOutput(os.Stdout); + logger.Warning.SetOutput(os.Stdout); + logger.Error.SetOutput(os.Stderr); + case INFO: + logger.Trace.SetOutput(ioutil.Discard); + logger.Info.SetOutput(os.Stdout); + logger.Warning.SetOutput(os.Stdout); + logger.Error.SetOutput(os.Stderr); + case WARNING: + logger.Trace.SetOutput(ioutil.Discard); + logger.Info.SetOutput(ioutil.Discard); + logger.Warning.SetOutput(os.Stdout); + logger.Error.SetOutput(os.Stderr); + case ERROR: + logger.Trace.SetOutput(ioutil.Discard); + logger.Info.SetOutput(ioutil.Discard); + logger.Warning.SetOutput(ioutil.Discard); + logger.Error.SetOutput(os.Stderr); + case OFF: + logger.Trace.SetOutput(ioutil.Discard); + logger.Info.SetOutput(ioutil.Discard); + logger.Warning.SetOutput(ioutil.Discard); + logger.Error.SetOutput(ioutil.Discard); + } + return logger; +} + +func (logger Logger) GetLevel() LoggingLevel { + return logger.level; +} + +func New( + name string, + traceHandle io.Writer, + infoHandle io.Writer, + warningHandle io.Writer, + errorHandle io.Writer) Logger { + loggers[name] = Logger{ + Name: name, + Trace: log.New(traceHandle, + "TRACE: ", + log.Ldate|log.Ltime|log.Lshortfile), + Info: log.New(infoHandle, + "INFO: ", + log.Ldate|log.Ltime|log.Lshortfile), + Warning: log.New(warningHandle, + "WARNING: ", + log.Ldate|log.Ltime|log.Lshortfile), + Error: log.New(errorHandle, + "ERROR: ", + log.Ldate|log.Ltime|log.Lshortfile), + } + return loggers[name] +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/sacOO7/go-logger/loggingL.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/sacOO7/go-logger/loggingL.go new file mode 100644 index 000000000000..aab5a8567afa --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/sacOO7/go-logger/loggingL.go @@ -0,0 +1,13 @@ +package logging + +type LoggingLevel int + +//go:generate stringer -type=LoggingLevel + +const ( + TRACE LoggingLevel = iota + INFO + WARNING + ERROR + OFF +) diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/sacOO7/go-logger/logginglevel_string.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/sacOO7/go-logger/logginglevel_string.go new file mode 100644 index 000000000000..9f24f0acbfe0 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/sacOO7/go-logger/logginglevel_string.go @@ -0,0 +1,16 @@ +// Code generated by "stringer -type=LoggingLevel"; DO NOT EDIT. + +package logging + +import "strconv" + +const _LoggingLevel_name = "TRACEINFOWARNINGERROROFF" + +var _LoggingLevel_index = [...]uint8{0, 5, 9, 16, 21, 24} + +func (i LoggingLevel) String() string { + if i < 0 || i >= LoggingLevel(len(_LoggingLevel_index)-1) { + return "LoggingLevel(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _LoggingLevel_name[_LoggingLevel_index[i]:_LoggingLevel_index[i+1]] +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/sacOO7/gowebsocket/.gitignore b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/sacOO7/gowebsocket/.gitignore new file mode 100644 index 000000000000..9a289397844c --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/sacOO7/gowebsocket/.gitignore @@ -0,0 +1,21 @@ +# Binaries for programs and plugins +*.exe +*.dll +*.so +*.dylib + +# Test binary, build with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Project-local glide cache, RE: https://github.com/Masterminds/glide/issues/736 +.glide/ + +# ignore build under build directory +build/ +bin/ + +#ignore any IDE based files +.idea/** \ No newline at end of file diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/sacOO7/gowebsocket/README.md b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/sacOO7/gowebsocket/README.md new file mode 100644 index 000000000000..2439bc6a7e10 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/sacOO7/gowebsocket/README.md @@ -0,0 +1,157 @@ +# GoWebsocket +Gorilla websocket based simplified client implementation in GO. + +Overview +-------- +This client provides following easy to implement functionality +- Support for emitting and receiving text and binary data +- Data compression +- Concurrency control +- Proxy support +- Setting request headers +- Subprotocols support +- SSL verification enable/disable + +To install use + +```markdown + go get github.com/sacOO7/gowebsocket +``` + +Description +----------- + +Create instance of `Websocket` by passing url of websocket-server end-point + +```go + //Create a client instance + socket := gowebsocket.New("ws://echo.websocket.org/") + +``` + +**Important Note** : url to websocket server must be specified with either **ws** or **wss**. + +#### Connecting to server +- For connecting to server: + +```go + //This will send websocket handshake request to socketcluster-server + socket.Connect() +``` + +#### Registering All Listeners +```go + package main + + import ( + "log" + "github.com/sacOO7/gowebsocket" + "os" + "os/signal" + ) + + func main() { + + interrupt := make(chan os.Signal, 1) + signal.Notify(interrupt, os.Interrupt) + + socket := gowebsocket.New("ws://echo.websocket.org/"); + + socket.OnConnected = func(socket gowebsocket.Socket) { + log.Println("Connected to server"); + }; + + socket.OnConnectError = func(err error, socket gowebsocket.Socket) { + log.Println("Recieved connect error ", err) + }; + + socket.OnTextMessage = func(message string, socket gowebsocket.Socket) { + log.Println("Recieved message " + message) + }; + + socket.OnBinaryMessage = func(data [] byte, socket gowebsocket.Socket) { + log.Println("Recieved binary data ", data) + }; + + socket.OnPingReceived = func(data string, socket gowebsocket.Socket) { + log.Println("Recieved ping " + data) + }; + + socket.OnPongReceived = func(data string, socket gowebsocket.Socket) { + log.Println("Recieved pong " + data) + }; + + socket.OnDisconnected = func(err error, socket gowebsocket.Socket) { + log.Println("Disconnected from server ") + return + }; + + socket.Connect() + + for { + select { + case <-interrupt: + log.Println("interrupt") + socket.Close() + return + } + } + } + +``` + +#### Sending Text message + +```go + socket.SendText("Hi there, this is my sample test message") +``` + +#### Sending Binary data +```go + token := make([]byte, 4) + // rand.Read(token) putting some random value in token + socket.SendBinary(token) +``` + +#### Closing the connection with server +```go + socket.Close() +``` + +#### Setting request headers +```go + socket.RequestHeader.Set("Accept-Encoding","gzip, deflate, sdch") + socket.RequestHeader.Set("Accept-Language","en-US,en;q=0.8") + socket.RequestHeader.Set("Pragma","no-cache") + socket.RequestHeader.Set("User-Agent","Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/49.0.2623.87 Safari/537.36") + +``` + +#### Setting proxy server +- It can be set using connectionOptions by providing url to proxy server + +```go + socket.ConnectionOptions = gowebsocket.ConnectionOptions { + Proxy: gowebsocket.BuildProxy("http://example.com"), + } +``` + +#### Setting data compression, ssl verification and subprotocols + +- It can be set using connectionOptions inside socket + +```go + socket.ConnectionOptions = gowebsocket.ConnectionOptions { + UseSSL:true, + UseCompression:true, + Subprotocols: [] string{"chat","superchat"}, + } +``` + +- ConnectionOptions needs to be applied before connecting to server +- Please checkout [**examples/gowebsocket**](!https://github.com/sacOO7/GoWebsocket/tree/master/examples/gowebsocket) directory for detailed code.. + +License +------- +Apache License, Version 2.0 + diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/sacOO7/gowebsocket/gowebsocket.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/sacOO7/gowebsocket/gowebsocket.go new file mode 100644 index 000000000000..1ea2b0d7a711 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/sacOO7/gowebsocket/gowebsocket.go @@ -0,0 +1,186 @@ +package gowebsocket + +import ( + "github.com/gorilla/websocket" + "net/http" + "errors" + "crypto/tls" + "net/url" + "sync" + "github.com/sacOO7/go-logger" + "reflect" +) + +type Empty struct { +} + +var logger = logging.GetLogger(reflect.TypeOf(Empty{}).PkgPath()).SetLevel(logging.OFF) + +func (socket Socket) EnableLogging() { + logger.SetLevel(logging.TRACE) +} + +func (socket Socket) GetLogger() logging.Logger { + return logger; +} + +type Socket struct { + Conn *websocket.Conn + WebsocketDialer *websocket.Dialer + Url string + ConnectionOptions ConnectionOptions + RequestHeader http.Header + OnConnected func(socket Socket) + OnTextMessage func(message string, socket Socket) + OnBinaryMessage func(data [] byte, socket Socket) + OnConnectError func(err error, socket Socket) + OnDisconnected func(err error, socket Socket) + OnPingReceived func(data string, socket Socket) + OnPongReceived func(data string, socket Socket) + IsConnected bool + sendMu *sync.Mutex // Prevent "concurrent write to websocket connection" + receiveMu *sync.Mutex +} + +type ConnectionOptions struct { + UseCompression bool + UseSSL bool + Proxy func(*http.Request) (*url.URL, error) + Subprotocols [] string +} + +// todo Yet to be done +type ReconnectionOptions struct { +} + +func New(url string) Socket { + return Socket{ + Url: url, + RequestHeader: http.Header{}, + ConnectionOptions: ConnectionOptions{ + UseCompression: false, + UseSSL: true, + }, + WebsocketDialer: &websocket.Dialer{}, + sendMu: &sync.Mutex{}, + receiveMu: &sync.Mutex{}, + } +} + +func (socket *Socket) setConnectionOptions() { + socket.WebsocketDialer.EnableCompression = socket.ConnectionOptions.UseCompression + socket.WebsocketDialer.TLSClientConfig = &tls.Config{InsecureSkipVerify: socket.ConnectionOptions.UseSSL} + socket.WebsocketDialer.Proxy = socket.ConnectionOptions.Proxy + socket.WebsocketDialer.Subprotocols = socket.ConnectionOptions.Subprotocols +} + +func (socket *Socket) Connect() { + var err error; + socket.setConnectionOptions() + + socket.Conn, _, err = socket.WebsocketDialer.Dial(socket.Url, socket.RequestHeader) + + if err != nil { + logger.Error.Println("Error while connecting to server ", err) + socket.IsConnected = false + if socket.OnConnectError != nil { + socket.OnConnectError(err, *socket) + } + return + } + + logger.Info.Println("Connected to server") + + if socket.OnConnected != nil { + socket.IsConnected = true + socket.OnConnected(*socket) + } + + defaultPingHandler := socket.Conn.PingHandler() + socket.Conn.SetPingHandler(func(appData string) error { + logger.Trace.Println("Received PING from server") + if socket.OnPingReceived != nil { + socket.OnPingReceived(appData, *socket) + } + return defaultPingHandler(appData) + }) + + defaultPongHandler := socket.Conn.PongHandler() + socket.Conn.SetPongHandler(func(appData string) error { + logger.Trace.Println("Received PONG from server") + if socket.OnPongReceived != nil { + socket.OnPongReceived(appData, *socket) + } + return defaultPongHandler(appData) + }) + + defaultCloseHandler := socket.Conn.CloseHandler() + socket.Conn.SetCloseHandler(func(code int, text string) error { + result := defaultCloseHandler(code, text) + logger.Warning.Println("Disconnected from server ", result) + if socket.OnDisconnected != nil { + socket.IsConnected = false + socket.OnDisconnected(errors.New(text), *socket) + } + return result + }) + + go func() { + for { + socket.receiveMu.Lock() + messageType, message, err := socket.Conn.ReadMessage() + socket.receiveMu.Unlock() + if err != nil { + logger.Error.Println("read:", err) + return + } + logger.Info.Println("recv: %s", message) + + switch messageType { + case websocket.TextMessage: + if socket.OnTextMessage != nil { + socket.OnTextMessage(string(message), *socket) + } + case websocket.BinaryMessage: + if socket.OnBinaryMessage != nil { + socket.OnBinaryMessage(message, *socket) + } + } + } + }() +} + +func (socket *Socket) SendText(message string) { + err := socket.send(websocket.TextMessage, [] byte (message)) + if err != nil { + logger.Error.Println("write:", err) + return + } +} + +func (socket *Socket) SendBinary(data [] byte) { + err := socket.send(websocket.BinaryMessage, data) + if err != nil { + logger.Error.Println("write:", err) + return + } +} + +func (socket *Socket) send(messageType int, data [] byte) error { + socket.sendMu.Lock() + err := socket.Conn.WriteMessage(messageType, data) + socket.sendMu.Unlock() + return err +} + +func (socket *Socket) Close() { + err := socket.send(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + if err != nil { + logger.Error.Println("write close:", err) + } + socket.Conn.Close() + if socket.OnDisconnected != nil { + socket.IsConnected = false + socket.OnDisconnected(err, *socket) + } +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/sacOO7/gowebsocket/stub.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/sacOO7/gowebsocket/stub.go deleted file mode 100644 index f1f20c86857e..000000000000 --- a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/sacOO7/gowebsocket/stub.go +++ /dev/null @@ -1,58 +0,0 @@ -// Code generated by depstubber. DO NOT EDIT. -// This is a simple stub for github.com/sacOO7/gowebsocket, strictly for use in testing. - -// See the LICENSE file for information about the licensing of the original library. -// Source: github.com/sacOO7/gowebsocket (exports: ; functions: New,BuildProxy) - -// Package gowebsocket is a stub of github.com/sacOO7/gowebsocket, generated by depstubber. -package gowebsocket - -import ( - http "net/http" - url "net/url" -) - -func BuildProxy(_ string) func(*http.Request) (*url.URL, error) { - return nil -} - -type ConnectionOptions struct { - UseCompression bool - UseSSL bool - Proxy func(*http.Request) (*url.URL, error) - Subprotocols []string -} - -func New(_ string) Socket { - return Socket{} -} - -type Socket struct { - Conn interface{} - WebsocketDialer interface{} - Url string - ConnectionOptions ConnectionOptions - RequestHeader http.Header - OnConnected func(Socket) - OnTextMessage func(string, Socket) - OnBinaryMessage func([]byte, Socket) - OnConnectError func(error, Socket) - OnDisconnected func(error, Socket) - OnPingReceived func(string, Socket) - OnPongReceived func(string, Socket) - IsConnected bool -} - -func (_ Socket) EnableLogging() {} - -func (_ Socket) GetLogger() interface{} { - return nil -} - -func (_ *Socket) Close() {} - -func (_ *Socket) Connect() {} - -func (_ *Socket) SendBinary(_ []byte) {} - -func (_ *Socket) SendText(_ string) {} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/sacOO7/gowebsocket/utils.go b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/sacOO7/gowebsocket/utils.go new file mode 100644 index 000000000000..d8702ebb6dfd --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/github.com/sacOO7/gowebsocket/utils.go @@ -0,0 +1,15 @@ +package gowebsocket + +import ( + "net/http" + "net/url" + "log" +) + +func BuildProxy(Url string) func(*http.Request) (*url.URL, error) { + uProxy, err := url.Parse(Url) + if err != nil { + log.Fatal("Error while parsing url ", err) + } + return http.ProxyURL(uProxy) +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/golang.org/x/net/websocket/LICENSE b/go/ql/test/query-tests/Security/CWE-918/vendor/golang.org/x/net/LICENSE similarity index 100% rename from go/ql/test/query-tests/Security/CWE-918/vendor/golang.org/x/net/websocket/LICENSE rename to go/ql/test/query-tests/Security/CWE-918/vendor/golang.org/x/net/LICENSE diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/golang.org/x/net/PATENTS b/go/ql/test/query-tests/Security/CWE-918/vendor/golang.org/x/net/PATENTS new file mode 100644 index 000000000000..733099041f84 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/golang.org/x/net/PATENTS @@ -0,0 +1,22 @@ +Additional IP Rights Grant (Patents) + +"This implementation" means the copyrightable works distributed by +Google as part of the Go project. + +Google hereby grants to You a perpetual, worldwide, non-exclusive, +no-charge, royalty-free, irrevocable (except as stated in this section) +patent license to make, have made, use, offer to sell, sell, import, +transfer and otherwise run, modify and propagate the contents of this +implementation of Go, where such license applies only to those patent +claims, both currently owned or controlled by Google and acquired in +the future, licensable by Google that are necessarily infringed by this +implementation of Go. This grant does not include claims that would be +infringed only as a consequence of further modification of this +implementation. If you or your agent or exclusive licensee institute or +order or agree to the institution of patent litigation against any +entity (including a cross-claim or counterclaim in a lawsuit) alleging +that this implementation of Go or any code incorporated within this +implementation of Go constitutes direct or contributory patent +infringement, or inducement of patent infringement, then any patent +rights granted to you under this License for this implementation of Go +shall terminate as of the date such litigation is filed. diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/golang.org/x/net/websocket/client.go b/go/ql/test/query-tests/Security/CWE-918/vendor/golang.org/x/net/websocket/client.go new file mode 100644 index 000000000000..69a4ac7eefec --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/golang.org/x/net/websocket/client.go @@ -0,0 +1,106 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bufio" + "io" + "net" + "net/http" + "net/url" +) + +// DialError is an error that occurs while dialling a websocket server. +type DialError struct { + *Config + Err error +} + +func (e *DialError) Error() string { + return "websocket.Dial " + e.Config.Location.String() + ": " + e.Err.Error() +} + +// NewConfig creates a new WebSocket config for client connection. +func NewConfig(server, origin string) (config *Config, err error) { + config = new(Config) + config.Version = ProtocolVersionHybi13 + config.Location, err = url.ParseRequestURI(server) + if err != nil { + return + } + config.Origin, err = url.ParseRequestURI(origin) + if err != nil { + return + } + config.Header = http.Header(make(map[string][]string)) + return +} + +// NewClient creates a new WebSocket client connection over rwc. +func NewClient(config *Config, rwc io.ReadWriteCloser) (ws *Conn, err error) { + br := bufio.NewReader(rwc) + bw := bufio.NewWriter(rwc) + err = hybiClientHandshake(config, br, bw) + if err != nil { + return + } + buf := bufio.NewReadWriter(br, bw) + ws = newHybiClientConn(config, buf, rwc) + return +} + +// Dial opens a new client connection to a WebSocket. +func Dial(url_, protocol, origin string) (ws *Conn, err error) { + config, err := NewConfig(url_, origin) + if err != nil { + return nil, err + } + if protocol != "" { + config.Protocol = []string{protocol} + } + return DialConfig(config) +} + +var portMap = map[string]string{ + "ws": "80", + "wss": "443", +} + +func parseAuthority(location *url.URL) string { + if _, ok := portMap[location.Scheme]; ok { + if _, _, err := net.SplitHostPort(location.Host); err != nil { + return net.JoinHostPort(location.Host, portMap[location.Scheme]) + } + } + return location.Host +} + +// DialConfig opens a new client connection to a WebSocket with a config. +func DialConfig(config *Config) (ws *Conn, err error) { + var client net.Conn + if config.Location == nil { + return nil, &DialError{config, ErrBadWebSocketLocation} + } + if config.Origin == nil { + return nil, &DialError{config, ErrBadWebSocketOrigin} + } + dialer := config.Dialer + if dialer == nil { + dialer = &net.Dialer{} + } + client, err = dialWithDialer(dialer, config) + if err != nil { + goto Error + } + ws, err = NewClient(config, client) + if err != nil { + client.Close() + goto Error + } + return + +Error: + return nil, &DialError{config, err} +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/golang.org/x/net/websocket/dial.go b/go/ql/test/query-tests/Security/CWE-918/vendor/golang.org/x/net/websocket/dial.go new file mode 100644 index 000000000000..2dab943a489a --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/golang.org/x/net/websocket/dial.go @@ -0,0 +1,24 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "crypto/tls" + "net" +) + +func dialWithDialer(dialer *net.Dialer, config *Config) (conn net.Conn, err error) { + switch config.Location.Scheme { + case "ws": + conn, err = dialer.Dial("tcp", parseAuthority(config.Location)) + + case "wss": + conn, err = tls.DialWithDialer(dialer, "tcp", parseAuthority(config.Location), config.TlsConfig) + + default: + err = ErrBadScheme + } + return +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/golang.org/x/net/websocket/hybi.go b/go/ql/test/query-tests/Security/CWE-918/vendor/golang.org/x/net/websocket/hybi.go new file mode 100644 index 000000000000..48a069e19039 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/golang.org/x/net/websocket/hybi.go @@ -0,0 +1,583 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +// This file implements a protocol of hybi draft. +// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17 + +import ( + "bufio" + "bytes" + "crypto/rand" + "crypto/sha1" + "encoding/base64" + "encoding/binary" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/url" + "strings" +) + +const ( + websocketGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + + closeStatusNormal = 1000 + closeStatusGoingAway = 1001 + closeStatusProtocolError = 1002 + closeStatusUnsupportedData = 1003 + closeStatusFrameTooLarge = 1004 + closeStatusNoStatusRcvd = 1005 + closeStatusAbnormalClosure = 1006 + closeStatusBadMessageData = 1007 + closeStatusPolicyViolation = 1008 + closeStatusTooBigData = 1009 + closeStatusExtensionMismatch = 1010 + + maxControlFramePayloadLength = 125 +) + +var ( + ErrBadMaskingKey = &ProtocolError{"bad masking key"} + ErrBadPongMessage = &ProtocolError{"bad pong message"} + ErrBadClosingStatus = &ProtocolError{"bad closing status"} + ErrUnsupportedExtensions = &ProtocolError{"unsupported extensions"} + ErrNotImplemented = &ProtocolError{"not implemented"} + + handshakeHeader = map[string]bool{ + "Host": true, + "Upgrade": true, + "Connection": true, + "Sec-Websocket-Key": true, + "Sec-Websocket-Origin": true, + "Sec-Websocket-Version": true, + "Sec-Websocket-Protocol": true, + "Sec-Websocket-Accept": true, + } +) + +// A hybiFrameHeader is a frame header as defined in hybi draft. +type hybiFrameHeader struct { + Fin bool + Rsv [3]bool + OpCode byte + Length int64 + MaskingKey []byte + + data *bytes.Buffer +} + +// A hybiFrameReader is a reader for hybi frame. +type hybiFrameReader struct { + reader io.Reader + + header hybiFrameHeader + pos int64 + length int +} + +func (frame *hybiFrameReader) Read(msg []byte) (n int, err error) { + n, err = frame.reader.Read(msg) + if frame.header.MaskingKey != nil { + for i := 0; i < n; i++ { + msg[i] = msg[i] ^ frame.header.MaskingKey[frame.pos%4] + frame.pos++ + } + } + return n, err +} + +func (frame *hybiFrameReader) PayloadType() byte { return frame.header.OpCode } + +func (frame *hybiFrameReader) HeaderReader() io.Reader { + if frame.header.data == nil { + return nil + } + if frame.header.data.Len() == 0 { + return nil + } + return frame.header.data +} + +func (frame *hybiFrameReader) TrailerReader() io.Reader { return nil } + +func (frame *hybiFrameReader) Len() (n int) { return frame.length } + +// A hybiFrameReaderFactory creates new frame reader based on its frame type. +type hybiFrameReaderFactory struct { + *bufio.Reader +} + +// NewFrameReader reads a frame header from the connection, and creates new reader for the frame. +// See Section 5.2 Base Framing protocol for detail. +// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17#section-5.2 +func (buf hybiFrameReaderFactory) NewFrameReader() (frame frameReader, err error) { + hybiFrame := new(hybiFrameReader) + frame = hybiFrame + var header []byte + var b byte + // First byte. FIN/RSV1/RSV2/RSV3/OpCode(4bits) + b, err = buf.ReadByte() + if err != nil { + return + } + header = append(header, b) + hybiFrame.header.Fin = ((header[0] >> 7) & 1) != 0 + for i := 0; i < 3; i++ { + j := uint(6 - i) + hybiFrame.header.Rsv[i] = ((header[0] >> j) & 1) != 0 + } + hybiFrame.header.OpCode = header[0] & 0x0f + + // Second byte. Mask/Payload len(7bits) + b, err = buf.ReadByte() + if err != nil { + return + } + header = append(header, b) + mask := (b & 0x80) != 0 + b &= 0x7f + lengthFields := 0 + switch { + case b <= 125: // Payload length 7bits. + hybiFrame.header.Length = int64(b) + case b == 126: // Payload length 7+16bits + lengthFields = 2 + case b == 127: // Payload length 7+64bits + lengthFields = 8 + } + for i := 0; i < lengthFields; i++ { + b, err = buf.ReadByte() + if err != nil { + return + } + if lengthFields == 8 && i == 0 { // MSB must be zero when 7+64 bits + b &= 0x7f + } + header = append(header, b) + hybiFrame.header.Length = hybiFrame.header.Length*256 + int64(b) + } + if mask { + // Masking key. 4 bytes. + for i := 0; i < 4; i++ { + b, err = buf.ReadByte() + if err != nil { + return + } + header = append(header, b) + hybiFrame.header.MaskingKey = append(hybiFrame.header.MaskingKey, b) + } + } + hybiFrame.reader = io.LimitReader(buf.Reader, hybiFrame.header.Length) + hybiFrame.header.data = bytes.NewBuffer(header) + hybiFrame.length = len(header) + int(hybiFrame.header.Length) + return +} + +// A HybiFrameWriter is a writer for hybi frame. +type hybiFrameWriter struct { + writer *bufio.Writer + + header *hybiFrameHeader +} + +func (frame *hybiFrameWriter) Write(msg []byte) (n int, err error) { + var header []byte + var b byte + if frame.header.Fin { + b |= 0x80 + } + for i := 0; i < 3; i++ { + if frame.header.Rsv[i] { + j := uint(6 - i) + b |= 1 << j + } + } + b |= frame.header.OpCode + header = append(header, b) + if frame.header.MaskingKey != nil { + b = 0x80 + } else { + b = 0 + } + lengthFields := 0 + length := len(msg) + switch { + case length <= 125: + b |= byte(length) + case length < 65536: + b |= 126 + lengthFields = 2 + default: + b |= 127 + lengthFields = 8 + } + header = append(header, b) + for i := 0; i < lengthFields; i++ { + j := uint((lengthFields - i - 1) * 8) + b = byte((length >> j) & 0xff) + header = append(header, b) + } + if frame.header.MaskingKey != nil { + if len(frame.header.MaskingKey) != 4 { + return 0, ErrBadMaskingKey + } + header = append(header, frame.header.MaskingKey...) + frame.writer.Write(header) + data := make([]byte, length) + for i := range data { + data[i] = msg[i] ^ frame.header.MaskingKey[i%4] + } + frame.writer.Write(data) + err = frame.writer.Flush() + return length, err + } + frame.writer.Write(header) + frame.writer.Write(msg) + err = frame.writer.Flush() + return length, err +} + +func (frame *hybiFrameWriter) Close() error { return nil } + +type hybiFrameWriterFactory struct { + *bufio.Writer + needMaskingKey bool +} + +func (buf hybiFrameWriterFactory) NewFrameWriter(payloadType byte) (frame frameWriter, err error) { + frameHeader := &hybiFrameHeader{Fin: true, OpCode: payloadType} + if buf.needMaskingKey { + frameHeader.MaskingKey, err = generateMaskingKey() + if err != nil { + return nil, err + } + } + return &hybiFrameWriter{writer: buf.Writer, header: frameHeader}, nil +} + +type hybiFrameHandler struct { + conn *Conn + payloadType byte +} + +func (handler *hybiFrameHandler) HandleFrame(frame frameReader) (frameReader, error) { + if handler.conn.IsServerConn() { + // The client MUST mask all frames sent to the server. + if frame.(*hybiFrameReader).header.MaskingKey == nil { + handler.WriteClose(closeStatusProtocolError) + return nil, io.EOF + } + } else { + // The server MUST NOT mask all frames. + if frame.(*hybiFrameReader).header.MaskingKey != nil { + handler.WriteClose(closeStatusProtocolError) + return nil, io.EOF + } + } + if header := frame.HeaderReader(); header != nil { + io.Copy(ioutil.Discard, header) + } + switch frame.PayloadType() { + case ContinuationFrame: + frame.(*hybiFrameReader).header.OpCode = handler.payloadType + case TextFrame, BinaryFrame: + handler.payloadType = frame.PayloadType() + case CloseFrame: + return nil, io.EOF + case PingFrame, PongFrame: + b := make([]byte, maxControlFramePayloadLength) + n, err := io.ReadFull(frame, b) + if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { + return nil, err + } + io.Copy(ioutil.Discard, frame) + if frame.PayloadType() == PingFrame { + if _, err := handler.WritePong(b[:n]); err != nil { + return nil, err + } + } + return nil, nil + } + return frame, nil +} + +func (handler *hybiFrameHandler) WriteClose(status int) (err error) { + handler.conn.wio.Lock() + defer handler.conn.wio.Unlock() + w, err := handler.conn.frameWriterFactory.NewFrameWriter(CloseFrame) + if err != nil { + return err + } + msg := make([]byte, 2) + binary.BigEndian.PutUint16(msg, uint16(status)) + _, err = w.Write(msg) + w.Close() + return err +} + +func (handler *hybiFrameHandler) WritePong(msg []byte) (n int, err error) { + handler.conn.wio.Lock() + defer handler.conn.wio.Unlock() + w, err := handler.conn.frameWriterFactory.NewFrameWriter(PongFrame) + if err != nil { + return 0, err + } + n, err = w.Write(msg) + w.Close() + return n, err +} + +// newHybiConn creates a new WebSocket connection speaking hybi draft protocol. +func newHybiConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn { + if buf == nil { + br := bufio.NewReader(rwc) + bw := bufio.NewWriter(rwc) + buf = bufio.NewReadWriter(br, bw) + } + ws := &Conn{config: config, request: request, buf: buf, rwc: rwc, + frameReaderFactory: hybiFrameReaderFactory{buf.Reader}, + frameWriterFactory: hybiFrameWriterFactory{ + buf.Writer, request == nil}, + PayloadType: TextFrame, + defaultCloseStatus: closeStatusNormal} + ws.frameHandler = &hybiFrameHandler{conn: ws} + return ws +} + +// generateMaskingKey generates a masking key for a frame. +func generateMaskingKey() (maskingKey []byte, err error) { + maskingKey = make([]byte, 4) + if _, err = io.ReadFull(rand.Reader, maskingKey); err != nil { + return + } + return +} + +// generateNonce generates a nonce consisting of a randomly selected 16-byte +// value that has been base64-encoded. +func generateNonce() (nonce []byte) { + key := make([]byte, 16) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + panic(err) + } + nonce = make([]byte, 24) + base64.StdEncoding.Encode(nonce, key) + return +} + +// removeZone removes IPv6 zone identifier from host. +// E.g., "[fe80::1%en0]:8080" to "[fe80::1]:8080" +func removeZone(host string) string { + if !strings.HasPrefix(host, "[") { + return host + } + i := strings.LastIndex(host, "]") + if i < 0 { + return host + } + j := strings.LastIndex(host[:i], "%") + if j < 0 { + return host + } + return host[:j] + host[i:] +} + +// getNonceAccept computes the base64-encoded SHA-1 of the concatenation of +// the nonce ("Sec-WebSocket-Key" value) with the websocket GUID string. +func getNonceAccept(nonce []byte) (expected []byte, err error) { + h := sha1.New() + if _, err = h.Write(nonce); err != nil { + return + } + if _, err = h.Write([]byte(websocketGUID)); err != nil { + return + } + expected = make([]byte, 28) + base64.StdEncoding.Encode(expected, h.Sum(nil)) + return +} + +// Client handshake described in draft-ietf-hybi-thewebsocket-protocol-17 +func hybiClientHandshake(config *Config, br *bufio.Reader, bw *bufio.Writer) (err error) { + bw.WriteString("GET " + config.Location.RequestURI() + " HTTP/1.1\r\n") + + // According to RFC 6874, an HTTP client, proxy, or other + // intermediary must remove any IPv6 zone identifier attached + // to an outgoing URI. + bw.WriteString("Host: " + removeZone(config.Location.Host) + "\r\n") + bw.WriteString("Upgrade: websocket\r\n") + bw.WriteString("Connection: Upgrade\r\n") + nonce := generateNonce() + if config.handshakeData != nil { + nonce = []byte(config.handshakeData["key"]) + } + bw.WriteString("Sec-WebSocket-Key: " + string(nonce) + "\r\n") + bw.WriteString("Origin: " + strings.ToLower(config.Origin.String()) + "\r\n") + + if config.Version != ProtocolVersionHybi13 { + return ErrBadProtocolVersion + } + + bw.WriteString("Sec-WebSocket-Version: " + fmt.Sprintf("%d", config.Version) + "\r\n") + if len(config.Protocol) > 0 { + bw.WriteString("Sec-WebSocket-Protocol: " + strings.Join(config.Protocol, ", ") + "\r\n") + } + // TODO(ukai): send Sec-WebSocket-Extensions. + err = config.Header.WriteSubset(bw, handshakeHeader) + if err != nil { + return err + } + + bw.WriteString("\r\n") + if err = bw.Flush(); err != nil { + return err + } + + resp, err := http.ReadResponse(br, &http.Request{Method: "GET"}) + if err != nil { + return err + } + if resp.StatusCode != 101 { + return ErrBadStatus + } + if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" || + strings.ToLower(resp.Header.Get("Connection")) != "upgrade" { + return ErrBadUpgrade + } + expectedAccept, err := getNonceAccept(nonce) + if err != nil { + return err + } + if resp.Header.Get("Sec-WebSocket-Accept") != string(expectedAccept) { + return ErrChallengeResponse + } + if resp.Header.Get("Sec-WebSocket-Extensions") != "" { + return ErrUnsupportedExtensions + } + offeredProtocol := resp.Header.Get("Sec-WebSocket-Protocol") + if offeredProtocol != "" { + protocolMatched := false + for i := 0; i < len(config.Protocol); i++ { + if config.Protocol[i] == offeredProtocol { + protocolMatched = true + break + } + } + if !protocolMatched { + return ErrBadWebSocketProtocol + } + config.Protocol = []string{offeredProtocol} + } + + return nil +} + +// newHybiClientConn creates a client WebSocket connection after handshake. +func newHybiClientConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser) *Conn { + return newHybiConn(config, buf, rwc, nil) +} + +// A HybiServerHandshaker performs a server handshake using hybi draft protocol. +type hybiServerHandshaker struct { + *Config + accept []byte +} + +func (c *hybiServerHandshaker) ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err error) { + c.Version = ProtocolVersionHybi13 + if req.Method != "GET" { + return http.StatusMethodNotAllowed, ErrBadRequestMethod + } + // HTTP version can be safely ignored. + + if strings.ToLower(req.Header.Get("Upgrade")) != "websocket" || + !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") { + return http.StatusBadRequest, ErrNotWebSocket + } + + key := req.Header.Get("Sec-Websocket-Key") + if key == "" { + return http.StatusBadRequest, ErrChallengeResponse + } + version := req.Header.Get("Sec-Websocket-Version") + switch version { + case "13": + c.Version = ProtocolVersionHybi13 + default: + return http.StatusBadRequest, ErrBadWebSocketVersion + } + var scheme string + if req.TLS != nil { + scheme = "wss" + } else { + scheme = "ws" + } + c.Location, err = url.ParseRequestURI(scheme + "://" + req.Host + req.URL.RequestURI()) + if err != nil { + return http.StatusBadRequest, err + } + protocol := strings.TrimSpace(req.Header.Get("Sec-Websocket-Protocol")) + if protocol != "" { + protocols := strings.Split(protocol, ",") + for i := 0; i < len(protocols); i++ { + c.Protocol = append(c.Protocol, strings.TrimSpace(protocols[i])) + } + } + c.accept, err = getNonceAccept([]byte(key)) + if err != nil { + return http.StatusInternalServerError, err + } + return http.StatusSwitchingProtocols, nil +} + +// Origin parses the Origin header in req. +// If the Origin header is not set, it returns nil and nil. +func Origin(config *Config, req *http.Request) (*url.URL, error) { + var origin string + switch config.Version { + case ProtocolVersionHybi13: + origin = req.Header.Get("Origin") + } + if origin == "" { + return nil, nil + } + return url.ParseRequestURI(origin) +} + +func (c *hybiServerHandshaker) AcceptHandshake(buf *bufio.Writer) (err error) { + if len(c.Protocol) > 0 { + if len(c.Protocol) != 1 { + // You need choose a Protocol in Handshake func in Server. + return ErrBadWebSocketProtocol + } + } + buf.WriteString("HTTP/1.1 101 Switching Protocols\r\n") + buf.WriteString("Upgrade: websocket\r\n") + buf.WriteString("Connection: Upgrade\r\n") + buf.WriteString("Sec-WebSocket-Accept: " + string(c.accept) + "\r\n") + if len(c.Protocol) > 0 { + buf.WriteString("Sec-WebSocket-Protocol: " + c.Protocol[0] + "\r\n") + } + // TODO(ukai): send Sec-WebSocket-Extensions. + if c.Header != nil { + err := c.Header.WriteSubset(buf, handshakeHeader) + if err != nil { + return err + } + } + buf.WriteString("\r\n") + return buf.Flush() +} + +func (c *hybiServerHandshaker) NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn { + return newHybiServerConn(c.Config, buf, rwc, request) +} + +// newHybiServerConn returns a new WebSocket connection speaking hybi draft protocol. +func newHybiServerConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn { + return newHybiConn(config, buf, rwc, request) +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/golang.org/x/net/websocket/server.go b/go/ql/test/query-tests/Security/CWE-918/vendor/golang.org/x/net/websocket/server.go new file mode 100644 index 000000000000..0895dea1905a --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/golang.org/x/net/websocket/server.go @@ -0,0 +1,113 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bufio" + "fmt" + "io" + "net/http" +) + +func newServerConn(rwc io.ReadWriteCloser, buf *bufio.ReadWriter, req *http.Request, config *Config, handshake func(*Config, *http.Request) error) (conn *Conn, err error) { + var hs serverHandshaker = &hybiServerHandshaker{Config: config} + code, err := hs.ReadHandshake(buf.Reader, req) + if err == ErrBadWebSocketVersion { + fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code)) + fmt.Fprintf(buf, "Sec-WebSocket-Version: %s\r\n", SupportedProtocolVersion) + buf.WriteString("\r\n") + buf.WriteString(err.Error()) + buf.Flush() + return + } + if err != nil { + fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code)) + buf.WriteString("\r\n") + buf.WriteString(err.Error()) + buf.Flush() + return + } + if handshake != nil { + err = handshake(config, req) + if err != nil { + code = http.StatusForbidden + fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code)) + buf.WriteString("\r\n") + buf.Flush() + return + } + } + err = hs.AcceptHandshake(buf.Writer) + if err != nil { + code = http.StatusBadRequest + fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code)) + buf.WriteString("\r\n") + buf.Flush() + return + } + conn = hs.NewServerConn(buf, rwc, req) + return +} + +// Server represents a server of a WebSocket. +type Server struct { + // Config is a WebSocket configuration for new WebSocket connection. + Config + + // Handshake is an optional function in WebSocket handshake. + // For example, you can check, or don't check Origin header. + // Another example, you can select config.Protocol. + Handshake func(*Config, *http.Request) error + + // Handler handles a WebSocket connection. + Handler +} + +// ServeHTTP implements the http.Handler interface for a WebSocket +func (s Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { + s.serveWebSocket(w, req) +} + +func (s Server) serveWebSocket(w http.ResponseWriter, req *http.Request) { + rwc, buf, err := w.(http.Hijacker).Hijack() + if err != nil { + panic("Hijack failed: " + err.Error()) + } + // The server should abort the WebSocket connection if it finds + // the client did not send a handshake that matches with protocol + // specification. + defer rwc.Close() + conn, err := newServerConn(rwc, buf, req, &s.Config, s.Handshake) + if err != nil { + return + } + if conn == nil { + panic("unexpected nil conn") + } + s.Handler(conn) +} + +// Handler is a simple interface to a WebSocket browser client. +// It checks if Origin header is valid URL by default. +// You might want to verify websocket.Conn.Config().Origin in the func. +// If you use Server instead of Handler, you could call websocket.Origin and +// check the origin in your Handshake func. So, if you want to accept +// non-browser clients, which do not send an Origin header, set a +// Server.Handshake that does not check the origin. +type Handler func(*Conn) + +func checkOrigin(config *Config, req *http.Request) (err error) { + config.Origin, err = Origin(config, req) + if err == nil && config.Origin == nil { + return fmt.Errorf("null origin") + } + return err +} + +// ServeHTTP implements the http.Handler interface for a WebSocket +func (h Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + s := Server{Handler: h, Handshake: checkOrigin} + s.serveWebSocket(w, req) +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/golang.org/x/net/websocket/stub.go b/go/ql/test/query-tests/Security/CWE-918/vendor/golang.org/x/net/websocket/stub.go deleted file mode 100644 index b860854e6e8c..000000000000 --- a/go/ql/test/query-tests/Security/CWE-918/vendor/golang.org/x/net/websocket/stub.go +++ /dev/null @@ -1,120 +0,0 @@ -// Code generated by depstubber. DO NOT EDIT. -// This is a simple stub for golang.org/x/net/websocket, strictly for use in testing. - -// See the LICENSE file for information about the licensing of the original library. -// Source: golang.org/x/net/websocket (exports: ; functions: Dial,NewConfig,DialConfig) - -// Package websocket is a stub of golang.org/x/net/websocket, generated by depstubber. -package websocket - -import ( - tls "crypto/tls" - io "io" - net "net" - http "net/http" - url "net/url" - time "time" -) - -type Config struct { - Location *url.URL - Origin *url.URL - Protocol []string - Version int - TlsConfig *tls.Config - Header http.Header - Dialer *net.Dialer -} - -type Conn struct { - PayloadType byte - MaxPayloadBytes int -} - -func (_ Conn) HandleFrame(_ interface{}) (interface{}, error) { - return nil, nil -} - -func (_ Conn) HeaderReader() io.Reader { - return nil -} - -func (_ Conn) Len() int { - return 0 -} - -func (_ Conn) NewFrameReader() (interface{}, error) { - return nil, nil -} - -func (_ Conn) NewFrameWriter(_ byte) (interface{}, error) { - return nil, nil -} - -func (_ Conn) TrailerReader() io.Reader { - return nil -} - -func (_ Conn) WriteClose(_ int) error { - return nil -} - -func (_ *Conn) Close() error { - return nil -} - -func (_ *Conn) Config() *Config { - return nil -} - -func (_ *Conn) IsClientConn() bool { - return false -} - -func (_ *Conn) IsServerConn() bool { - return false -} - -func (_ *Conn) LocalAddr() net.Addr { - return nil -} - -func (_ *Conn) Read(_ []byte) (int, error) { - return 0, nil -} - -func (_ *Conn) RemoteAddr() net.Addr { - return nil -} - -func (_ *Conn) Request() *http.Request { - return nil -} - -func (_ *Conn) SetDeadline(_ time.Time) error { - return nil -} - -func (_ *Conn) SetReadDeadline(_ time.Time) error { - return nil -} - -func (_ *Conn) SetWriteDeadline(_ time.Time) error { - return nil -} - -func (_ *Conn) Write(_ []byte) (int, error) { - return 0, nil -} - -func Dial(_ string, _ string, _ string) (*Conn, error) { - return nil, nil -} - -func DialConfig(_ *Config) (*Conn, error) { - return nil, nil -} - -func NewConfig(_ string, _ string) (*Config, error) { - return nil, nil -} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/golang.org/x/net/websocket/websocket.go b/go/ql/test/query-tests/Security/CWE-918/vendor/golang.org/x/net/websocket/websocket.go new file mode 100644 index 000000000000..90a2257cd54e --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/golang.org/x/net/websocket/websocket.go @@ -0,0 +1,449 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package websocket implements a client and server for the WebSocket protocol +// as specified in RFC 6455. +// +// This package currently lacks some features found in an alternative +// and more actively maintained WebSocket package: +// +// https://pkg.go.dev/nhooyr.io/websocket +package websocket // import "golang.org/x/net/websocket" + +import ( + "bufio" + "crypto/tls" + "encoding/json" + "errors" + "io" + "io/ioutil" + "net" + "net/http" + "net/url" + "sync" + "time" +) + +const ( + ProtocolVersionHybi13 = 13 + ProtocolVersionHybi = ProtocolVersionHybi13 + SupportedProtocolVersion = "13" + + ContinuationFrame = 0 + TextFrame = 1 + BinaryFrame = 2 + CloseFrame = 8 + PingFrame = 9 + PongFrame = 10 + UnknownFrame = 255 + + DefaultMaxPayloadBytes = 32 << 20 // 32MB +) + +// ProtocolError represents WebSocket protocol errors. +type ProtocolError struct { + ErrorString string +} + +func (err *ProtocolError) Error() string { return err.ErrorString } + +var ( + ErrBadProtocolVersion = &ProtocolError{"bad protocol version"} + ErrBadScheme = &ProtocolError{"bad scheme"} + ErrBadStatus = &ProtocolError{"bad status"} + ErrBadUpgrade = &ProtocolError{"missing or bad upgrade"} + ErrBadWebSocketOrigin = &ProtocolError{"missing or bad WebSocket-Origin"} + ErrBadWebSocketLocation = &ProtocolError{"missing or bad WebSocket-Location"} + ErrBadWebSocketProtocol = &ProtocolError{"missing or bad WebSocket-Protocol"} + ErrBadWebSocketVersion = &ProtocolError{"missing or bad WebSocket Version"} + ErrChallengeResponse = &ProtocolError{"mismatch challenge/response"} + ErrBadFrame = &ProtocolError{"bad frame"} + ErrBadFrameBoundary = &ProtocolError{"not on frame boundary"} + ErrNotWebSocket = &ProtocolError{"not websocket protocol"} + ErrBadRequestMethod = &ProtocolError{"bad method"} + ErrNotSupported = &ProtocolError{"not supported"} +) + +// ErrFrameTooLarge is returned by Codec's Receive method if payload size +// exceeds limit set by Conn.MaxPayloadBytes +var ErrFrameTooLarge = errors.New("websocket: frame payload size exceeds limit") + +// Addr is an implementation of net.Addr for WebSocket. +type Addr struct { + *url.URL +} + +// Network returns the network type for a WebSocket, "websocket". +func (addr *Addr) Network() string { return "websocket" } + +// Config is a WebSocket configuration +type Config struct { + // A WebSocket server address. + Location *url.URL + + // A Websocket client origin. + Origin *url.URL + + // WebSocket subprotocols. + Protocol []string + + // WebSocket protocol version. + Version int + + // TLS config for secure WebSocket (wss). + TlsConfig *tls.Config + + // Additional header fields to be sent in WebSocket opening handshake. + Header http.Header + + // Dialer used when opening websocket connections. + Dialer *net.Dialer + + handshakeData map[string]string +} + +// serverHandshaker is an interface to handle WebSocket server side handshake. +type serverHandshaker interface { + // ReadHandshake reads handshake request message from client. + // Returns http response code and error if any. + ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err error) + + // AcceptHandshake accepts the client handshake request and sends + // handshake response back to client. + AcceptHandshake(buf *bufio.Writer) (err error) + + // NewServerConn creates a new WebSocket connection. + NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) (conn *Conn) +} + +// frameReader is an interface to read a WebSocket frame. +type frameReader interface { + // Reader is to read payload of the frame. + io.Reader + + // PayloadType returns payload type. + PayloadType() byte + + // HeaderReader returns a reader to read header of the frame. + HeaderReader() io.Reader + + // TrailerReader returns a reader to read trailer of the frame. + // If it returns nil, there is no trailer in the frame. + TrailerReader() io.Reader + + // Len returns total length of the frame, including header and trailer. + Len() int +} + +// frameReaderFactory is an interface to creates new frame reader. +type frameReaderFactory interface { + NewFrameReader() (r frameReader, err error) +} + +// frameWriter is an interface to write a WebSocket frame. +type frameWriter interface { + // Writer is to write payload of the frame. + io.WriteCloser +} + +// frameWriterFactory is an interface to create new frame writer. +type frameWriterFactory interface { + NewFrameWriter(payloadType byte) (w frameWriter, err error) +} + +type frameHandler interface { + HandleFrame(frame frameReader) (r frameReader, err error) + WriteClose(status int) (err error) +} + +// Conn represents a WebSocket connection. +// +// Multiple goroutines may invoke methods on a Conn simultaneously. +type Conn struct { + config *Config + request *http.Request + + buf *bufio.ReadWriter + rwc io.ReadWriteCloser + + rio sync.Mutex + frameReaderFactory + frameReader + + wio sync.Mutex + frameWriterFactory + + frameHandler + PayloadType byte + defaultCloseStatus int + + // MaxPayloadBytes limits the size of frame payload received over Conn + // by Codec's Receive method. If zero, DefaultMaxPayloadBytes is used. + MaxPayloadBytes int +} + +// Read implements the io.Reader interface: +// it reads data of a frame from the WebSocket connection. +// if msg is not large enough for the frame data, it fills the msg and next Read +// will read the rest of the frame data. +// it reads Text frame or Binary frame. +func (ws *Conn) Read(msg []byte) (n int, err error) { + ws.rio.Lock() + defer ws.rio.Unlock() +again: + if ws.frameReader == nil { + frame, err := ws.frameReaderFactory.NewFrameReader() + if err != nil { + return 0, err + } + ws.frameReader, err = ws.frameHandler.HandleFrame(frame) + if err != nil { + return 0, err + } + if ws.frameReader == nil { + goto again + } + } + n, err = ws.frameReader.Read(msg) + if err == io.EOF { + if trailer := ws.frameReader.TrailerReader(); trailer != nil { + io.Copy(ioutil.Discard, trailer) + } + ws.frameReader = nil + goto again + } + return n, err +} + +// Write implements the io.Writer interface: +// it writes data as a frame to the WebSocket connection. +func (ws *Conn) Write(msg []byte) (n int, err error) { + ws.wio.Lock() + defer ws.wio.Unlock() + w, err := ws.frameWriterFactory.NewFrameWriter(ws.PayloadType) + if err != nil { + return 0, err + } + n, err = w.Write(msg) + w.Close() + return n, err +} + +// Close implements the io.Closer interface. +func (ws *Conn) Close() error { + err := ws.frameHandler.WriteClose(ws.defaultCloseStatus) + err1 := ws.rwc.Close() + if err != nil { + return err + } + return err1 +} + +// IsClientConn reports whether ws is a client-side connection. +func (ws *Conn) IsClientConn() bool { return ws.request == nil } + +// IsServerConn reports whether ws is a server-side connection. +func (ws *Conn) IsServerConn() bool { return ws.request != nil } + +// LocalAddr returns the WebSocket Origin for the connection for client, or +// the WebSocket location for server. +func (ws *Conn) LocalAddr() net.Addr { + if ws.IsClientConn() { + return &Addr{ws.config.Origin} + } + return &Addr{ws.config.Location} +} + +// RemoteAddr returns the WebSocket location for the connection for client, or +// the Websocket Origin for server. +func (ws *Conn) RemoteAddr() net.Addr { + if ws.IsClientConn() { + return &Addr{ws.config.Location} + } + return &Addr{ws.config.Origin} +} + +var errSetDeadline = errors.New("websocket: cannot set deadline: not using a net.Conn") + +// SetDeadline sets the connection's network read & write deadlines. +func (ws *Conn) SetDeadline(t time.Time) error { + if conn, ok := ws.rwc.(net.Conn); ok { + return conn.SetDeadline(t) + } + return errSetDeadline +} + +// SetReadDeadline sets the connection's network read deadline. +func (ws *Conn) SetReadDeadline(t time.Time) error { + if conn, ok := ws.rwc.(net.Conn); ok { + return conn.SetReadDeadline(t) + } + return errSetDeadline +} + +// SetWriteDeadline sets the connection's network write deadline. +func (ws *Conn) SetWriteDeadline(t time.Time) error { + if conn, ok := ws.rwc.(net.Conn); ok { + return conn.SetWriteDeadline(t) + } + return errSetDeadline +} + +// Config returns the WebSocket config. +func (ws *Conn) Config() *Config { return ws.config } + +// Request returns the http request upgraded to the WebSocket. +// It is nil for client side. +func (ws *Conn) Request() *http.Request { return ws.request } + +// Codec represents a symmetric pair of functions that implement a codec. +type Codec struct { + Marshal func(v interface{}) (data []byte, payloadType byte, err error) + Unmarshal func(data []byte, payloadType byte, v interface{}) (err error) +} + +// Send sends v marshaled by cd.Marshal as single frame to ws. +func (cd Codec) Send(ws *Conn, v interface{}) (err error) { + data, payloadType, err := cd.Marshal(v) + if err != nil { + return err + } + ws.wio.Lock() + defer ws.wio.Unlock() + w, err := ws.frameWriterFactory.NewFrameWriter(payloadType) + if err != nil { + return err + } + _, err = w.Write(data) + w.Close() + return err +} + +// Receive receives single frame from ws, unmarshaled by cd.Unmarshal and stores +// in v. The whole frame payload is read to an in-memory buffer; max size of +// payload is defined by ws.MaxPayloadBytes. If frame payload size exceeds +// limit, ErrFrameTooLarge is returned; in this case frame is not read off wire +// completely. The next call to Receive would read and discard leftover data of +// previous oversized frame before processing next frame. +func (cd Codec) Receive(ws *Conn, v interface{}) (err error) { + ws.rio.Lock() + defer ws.rio.Unlock() + if ws.frameReader != nil { + _, err = io.Copy(ioutil.Discard, ws.frameReader) + if err != nil { + return err + } + ws.frameReader = nil + } +again: + frame, err := ws.frameReaderFactory.NewFrameReader() + if err != nil { + return err + } + frame, err = ws.frameHandler.HandleFrame(frame) + if err != nil { + return err + } + if frame == nil { + goto again + } + maxPayloadBytes := ws.MaxPayloadBytes + if maxPayloadBytes == 0 { + maxPayloadBytes = DefaultMaxPayloadBytes + } + if hf, ok := frame.(*hybiFrameReader); ok && hf.header.Length > int64(maxPayloadBytes) { + // payload size exceeds limit, no need to call Unmarshal + // + // set frameReader to current oversized frame so that + // the next call to this function can drain leftover + // data before processing the next frame + ws.frameReader = frame + return ErrFrameTooLarge + } + payloadType := frame.PayloadType() + data, err := ioutil.ReadAll(frame) + if err != nil { + return err + } + return cd.Unmarshal(data, payloadType, v) +} + +func marshal(v interface{}) (msg []byte, payloadType byte, err error) { + switch data := v.(type) { + case string: + return []byte(data), TextFrame, nil + case []byte: + return data, BinaryFrame, nil + } + return nil, UnknownFrame, ErrNotSupported +} + +func unmarshal(msg []byte, payloadType byte, v interface{}) (err error) { + switch data := v.(type) { + case *string: + *data = string(msg) + return nil + case *[]byte: + *data = msg + return nil + } + return ErrNotSupported +} + +/* +Message is a codec to send/receive text/binary data in a frame on WebSocket connection. +To send/receive text frame, use string type. +To send/receive binary frame, use []byte type. + +Trivial usage: + + import "websocket" + + // receive text frame + var message string + websocket.Message.Receive(ws, &message) + + // send text frame + message = "hello" + websocket.Message.Send(ws, message) + + // receive binary frame + var data []byte + websocket.Message.Receive(ws, &data) + + // send binary frame + data = []byte{0, 1, 2} + websocket.Message.Send(ws, data) +*/ +var Message = Codec{marshal, unmarshal} + +func jsonMarshal(v interface{}) (msg []byte, payloadType byte, err error) { + msg, err = json.Marshal(v) + return msg, TextFrame, err +} + +func jsonUnmarshal(msg []byte, payloadType byte, v interface{}) (err error) { + return json.Unmarshal(msg, v) +} + +/* +JSON is a codec to send/receive JSON data in a frame from a WebSocket connection. + +Trivial usage: + + import "websocket" + + type T struct { + Msg string + Count int + } + + // receive JSON type T + var data T + websocket.JSON.Receive(ws, &data) + + // send JSON type T + websocket.JSON.Send(ws, data) +*/ +var JSON = Codec{jsonMarshal, jsonUnmarshal} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/modules.txt b/go/ql/test/query-tests/Security/CWE-918/vendor/modules.txt index 319b30b771be..eee42f1b75b0 100644 --- a/go/ql/test/query-tests/Security/CWE-918/vendor/modules.txt +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/modules.txt @@ -1,15 +1,30 @@ +# github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee +github.com/gobwas/httphead +# github.com/gobwas/pool v0.2.0 +github.com/gobwas/pool +github.com/gobwas/pool/internal/pmath +github.com/gobwas/pool/pbufio # github.com/gobwas/ws v1.0.3 ## explicit github.com/gobwas/ws # github.com/gorilla/websocket v1.4.2 ## explicit github.com/gorilla/websocket +# github.com/klauspost/compress v1.10.3 +github.com/klauspost/compress/flate +# github.com/sacOO7/go-logger v0.0.0-20180719173527-9ac9add5a50d +## explicit +github.com/sacOO7/go-logger # github.com/sacOO7/gowebsocket v0.0.0-20180719182212-1436bb906a4e ## explicit github.com/sacOO7/gowebsocket -# golang.org/x/net v0.0.0-20200421231249-e086a090c8fd +# golang.org/x/net v0.7.0 ## explicit -golang.org/x/net +golang.org/x/net/websocket # nhooyr.io/websocket v1.8.5 ## explicit nhooyr.io/websocket +nhooyr.io/websocket/internal/bpool +nhooyr.io/websocket/internal/errd +nhooyr.io/websocket/internal/wsjs +nhooyr.io/websocket/internal/xsync diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/.gitignore b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/.gitignore new file mode 100644 index 000000000000..6961e5c894a8 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/.gitignore @@ -0,0 +1 @@ +websocket.test diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/.travis.yml b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/.travis.yml new file mode 100644 index 000000000000..41d3c201468c --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/.travis.yml @@ -0,0 +1,40 @@ +language: go +go: 1.x +dist: bionic + +env: + global: + - SHFMT_URL=https://github.com/mvdan/sh/releases/download/v3.0.1/shfmt_v3.0.1_linux_amd64 + - GOFLAGS="-mod=readonly" + +jobs: + include: + - name: Format + before_script: + - sudo apt-get install -y npm + - sudo npm install -g prettier + - sudo curl -L "$SHFMT_URL" > /usr/local/bin/shfmt && sudo chmod +x /usr/local/bin/shfmt + - go get golang.org/x/tools/cmd/stringer + - go get golang.org/x/tools/cmd/goimports + script: make -j16 fmt + - name: Lint + before_script: + - sudo apt-get install -y shellcheck + - go get golang.org/x/lint/golint + script: make -j16 lint + - name: Test + before_script: + - sudo apt-get install -y chromium-browser + - go get github.com/agnivade/wasmbrowsertest + - go get github.com/mattn/goveralls + script: make -j16 test + +addons: + apt: + update: true + +cache: + npm: true + directories: + - ~/.cache + - ~/gopath/pkg diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/LICENSE b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/LICENSE.txt similarity index 100% rename from go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/LICENSE rename to go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/LICENSE.txt diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/Makefile b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/Makefile new file mode 100644 index 000000000000..f9f31c49f1c4 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/Makefile @@ -0,0 +1,7 @@ +all: fmt lint test + +.SILENT: + +include ci/fmt.mk +include ci/lint.mk +include ci/test.mk diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/README.md b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/README.md new file mode 100644 index 000000000000..14c392935e11 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/README.md @@ -0,0 +1,132 @@ +# websocket + +[![godoc](https://godoc.org/nhooyr.io/websocket?status.svg)](https://pkg.go.dev/nhooyr.io/websocket) + +websocket is a minimal and idiomatic WebSocket library for Go. + +## Install + +```bash +go get nhooyr.io/websocket +``` + +## Features + +- Minimal and idiomatic API +- First class [context.Context](https://blog.golang.org/context) support +- Fully passes the WebSocket [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite) +- Thorough tests with [90% coverage](https://coveralls.io/github/nhooyr/websocket) +- [Single dependency](https://pkg.go.dev/nhooyr.io/websocket?tab=imports) +- JSON and protobuf helpers in the [wsjson](https://pkg.go.dev/nhooyr.io/websocket/wsjson) and [wspb](https://pkg.go.dev/nhooyr.io/websocket/wspb) subpackages +- Zero alloc reads and writes +- Concurrent writes +- [Close handshake](https://pkg.go.dev/nhooyr.io/websocket#Conn.Close) +- [net.Conn](https://pkg.go.dev/nhooyr.io/websocket#NetConn) wrapper +- [Ping pong](https://pkg.go.dev/nhooyr.io/websocket#Conn.Ping) API +- [RFC 7692](https://tools.ietf.org/html/rfc7692) permessage-deflate compression +- Compile to [Wasm](https://pkg.go.dev/nhooyr.io/websocket#hdr-Wasm) + +## Roadmap + +- [ ] HTTP/2 [#4](https://github.com/nhooyr/websocket/issues/4) + +## Examples + +For a production quality example that demonstrates the complete API, see the +[echo example](./examples/echo). + +For a full stack example, see the [chat example](./examples/chat). + +### Server + +```go +http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { + c, err := websocket.Accept(w, r, nil) + if err != nil { + // ... + } + defer c.Close(websocket.StatusInternalError, "the sky is falling") + + ctx, cancel := context.WithTimeout(r.Context(), time.Second*10) + defer cancel() + + var v interface{} + err = wsjson.Read(ctx, c, &v) + if err != nil { + // ... + } + + log.Printf("received: %v", v) + + c.Close(websocket.StatusNormalClosure, "") +}) +``` + +### Client + +```go +ctx, cancel := context.WithTimeout(context.Background(), time.Minute) +defer cancel() + +c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil) +if err != nil { + // ... +} +defer c.Close(websocket.StatusInternalError, "the sky is falling") + +err = wsjson.Write(ctx, c, "hi") +if err != nil { + // ... +} + +c.Close(websocket.StatusNormalClosure, "") +``` + +## Comparison + +### gorilla/websocket + +Advantages of [gorilla/websocket](https://github.com/gorilla/websocket): + +- Mature and widely used +- [Prepared writes](https://pkg.go.dev/github.com/gorilla/websocket#PreparedMessage) +- Configurable [buffer sizes](https://pkg.go.dev/github.com/gorilla/websocket#hdr-Buffers) + +Advantages of nhooyr.io/websocket: + +- Minimal and idiomatic API + - Compare godoc of [nhooyr.io/websocket](https://pkg.go.dev/nhooyr.io/websocket) with [gorilla/websocket](https://pkg.go.dev/github.com/gorilla/websocket) side by side. +- [net.Conn](https://pkg.go.dev/nhooyr.io/websocket#NetConn) wrapper +- Zero alloc reads and writes ([gorilla/websocket#535](https://github.com/gorilla/websocket/issues/535)) +- Full [context.Context](https://blog.golang.org/context) support +- Dial uses [net/http.Client](https://golang.org/pkg/net/http/#Client) + - Will enable easy HTTP/2 support in the future + - Gorilla writes directly to a net.Conn and so duplicates features of net/http.Client. +- Concurrent writes +- Close handshake ([gorilla/websocket#448](https://github.com/gorilla/websocket/issues/448)) +- Idiomatic [ping pong](https://pkg.go.dev/nhooyr.io/websocket#Conn.Ping) API + - Gorilla requires registering a pong callback before sending a Ping +- Can target Wasm ([gorilla/websocket#432](https://github.com/gorilla/websocket/issues/432)) +- Transparent message buffer reuse with [wsjson](https://pkg.go.dev/nhooyr.io/websocket/wsjson) and [wspb](https://pkg.go.dev/nhooyr.io/websocket/wspb) subpackages +- [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) faster WebSocket masking implementation in pure Go + - Gorilla's implementation is slower and uses [unsafe](https://golang.org/pkg/unsafe/). +- Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support + - Gorilla only supports no context takeover mode + - We use a vendored [klauspost/compress](https://github.com/klauspost/compress) for much lower memory usage ([gorilla/websocket#203](https://github.com/gorilla/websocket/issues/203)) +- [CloseRead](https://pkg.go.dev/nhooyr.io/websocket#Conn.CloseRead) helper ([gorilla/websocket#492](https://github.com/gorilla/websocket/issues/492)) +- Actively maintained ([gorilla/websocket#370](https://github.com/gorilla/websocket/issues/370)) + +#### golang.org/x/net/websocket + +[golang.org/x/net/websocket](https://pkg.go.dev/golang.org/x/net/websocket) is deprecated. +See [golang/go/issues/18152](https://github.com/golang/go/issues/18152). + +The [net.Conn](https://pkg.go.dev/nhooyr.io/websocket#NetConn) can help in transitioning +to nhooyr.io/websocket. + +#### gobwas/ws + +[gobwas/ws](https://github.com/gobwas/ws) has an extremely flexible API that allows it to be used +in an event driven style for performance. See the author's [blog post](https://medium.freecodecamp.org/million-websockets-and-go-cc58418460bb). + +However when writing idiomatic Go, nhooyr.io/websocket will be faster and easier to use. diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/accept.go b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/accept.go new file mode 100644 index 000000000000..6bed54da0286 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/accept.go @@ -0,0 +1,365 @@ +// +build !js + +package websocket + +import ( + "bytes" + "crypto/sha1" + "encoding/base64" + "errors" + "fmt" + "io" + "log" + "net/http" + "net/textproto" + "net/url" + "path/filepath" + "strings" + + "nhooyr.io/websocket/internal/errd" +) + +// AcceptOptions represents Accept's options. +type AcceptOptions struct { + // Subprotocols lists the WebSocket subprotocols that Accept will negotiate with the client. + // The empty subprotocol will always be negotiated as per RFC 6455. If you would like to + // reject it, close the connection when c.Subprotocol() == "". + Subprotocols []string + + // InsecureSkipVerify is used to disable Accept's origin verification behaviour. + // + // Deprecated: Use OriginPatterns with a match all pattern of * instead to control + // origin authorization yourself. + InsecureSkipVerify bool + + // OriginPatterns lists the host patterns for authorized origins. + // The request host is always authorized. + // Use this to enable cross origin WebSockets. + // + // i.e javascript running on example.com wants to access a WebSocket server at chat.example.com. + // In such a case, example.com is the origin and chat.example.com is the request host. + // One would set this field to []string{"example.com"} to authorize example.com to connect. + // + // Each pattern is matched case insensitively against the request origin host + // with filepath.Match. + // See https://golang.org/pkg/path/filepath/#Match + // + // Please ensure you understand the ramifications of enabling this. + // If used incorrectly your WebSocket server will be open to CSRF attacks. + OriginPatterns []string + + // CompressionMode controls the compression mode. + // Defaults to CompressionNoContextTakeover. + // + // See docs on CompressionMode for details. + CompressionMode CompressionMode + + // CompressionThreshold controls the minimum size of a message before compression is applied. + // + // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes + // for CompressionContextTakeover. + CompressionThreshold int +} + +// Accept accepts a WebSocket handshake from a client and upgrades the +// the connection to a WebSocket. +// +// Accept will not allow cross origin requests by default. +// See the InsecureSkipVerify option to allow cross origin requests. +// +// Accept will write a response to w on all errors. +func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { + return accept(w, r, opts) +} + +func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) { + defer errd.Wrap(&err, "failed to accept WebSocket connection") + + if opts == nil { + opts = &AcceptOptions{} + } + opts = &*opts + + errCode, err := verifyClientRequest(w, r) + if err != nil { + http.Error(w, err.Error(), errCode) + return nil, err + } + + if !opts.InsecureSkipVerify { + err = authenticateOrigin(r, opts.OriginPatterns) + if err != nil { + if errors.Is(err, filepath.ErrBadPattern) { + log.Printf("websocket: %v", err) + err = errors.New(http.StatusText(http.StatusForbidden)) + } + http.Error(w, err.Error(), http.StatusForbidden) + return nil, err + } + } + + hj, ok := w.(http.Hijacker) + if !ok { + err = errors.New("http.ResponseWriter does not implement http.Hijacker") + http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented) + return nil, err + } + + w.Header().Set("Upgrade", "websocket") + w.Header().Set("Connection", "Upgrade") + + key := r.Header.Get("Sec-WebSocket-Key") + w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) + + subproto := selectSubprotocol(r, opts.Subprotocols) + if subproto != "" { + w.Header().Set("Sec-WebSocket-Protocol", subproto) + } + + copts, err := acceptCompression(r, w, opts.CompressionMode) + if err != nil { + return nil, err + } + + w.WriteHeader(http.StatusSwitchingProtocols) + + netConn, brw, err := hj.Hijack() + if err != nil { + err = fmt.Errorf("failed to hijack connection: %w", err) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return nil, err + } + + // https://github.com/golang/go/issues/32314 + b, _ := brw.Reader.Peek(brw.Reader.Buffered()) + brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn)) + + return newConn(connConfig{ + subprotocol: w.Header().Get("Sec-WebSocket-Protocol"), + rwc: netConn, + client: false, + copts: copts, + flateThreshold: opts.CompressionThreshold, + + br: brw.Reader, + bw: brw.Writer, + }), nil +} + +func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) { + if !r.ProtoAtLeast(1, 1) { + return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto) + } + + if !headerContainsToken(r.Header, "Connection", "Upgrade") { + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", "websocket") + return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection")) + } + + if !headerContainsToken(r.Header, "Upgrade", "websocket") { + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", "websocket") + return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade")) + } + + if r.Method != "GET" { + return http.StatusMethodNotAllowed, fmt.Errorf("WebSocket protocol violation: handshake request method is not GET but %q", r.Method) + } + + if r.Header.Get("Sec-WebSocket-Version") != "13" { + w.Header().Set("Sec-WebSocket-Version", "13") + return http.StatusBadRequest, fmt.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version")) + } + + if r.Header.Get("Sec-WebSocket-Key") == "" { + return http.StatusBadRequest, errors.New("WebSocket protocol violation: missing Sec-WebSocket-Key") + } + + return 0, nil +} + +func authenticateOrigin(r *http.Request, originHosts []string) error { + origin := r.Header.Get("Origin") + if origin == "" { + return nil + } + + u, err := url.Parse(origin) + if err != nil { + return fmt.Errorf("failed to parse Origin header %q: %w", origin, err) + } + + if strings.EqualFold(r.Host, u.Host) { + return nil + } + + for _, hostPattern := range originHosts { + matched, err := match(hostPattern, u.Host) + if err != nil { + return fmt.Errorf("failed to parse filepath pattern %q: %w", hostPattern, err) + } + if matched { + return nil + } + } + return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host) +} + +func match(pattern, s string) (bool, error) { + return filepath.Match(strings.ToLower(pattern), strings.ToLower(s)) +} + +func selectSubprotocol(r *http.Request, subprotocols []string) string { + cps := headerTokens(r.Header, "Sec-WebSocket-Protocol") + for _, sp := range subprotocols { + for _, cp := range cps { + if strings.EqualFold(sp, cp) { + return cp + } + } + } + return "" +} + +func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionMode) (*compressionOptions, error) { + if mode == CompressionDisabled { + return nil, nil + } + + for _, ext := range websocketExtensions(r.Header) { + switch ext.name { + case "permessage-deflate": + return acceptDeflate(w, ext, mode) + // Disabled for now, see https://github.com/nhooyr/websocket/issues/218 + // case "x-webkit-deflate-frame": + // return acceptWebkitDeflate(w, ext, mode) + } + } + return nil, nil +} + +func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) { + copts := mode.opts() + + for _, p := range ext.params { + switch p { + case "client_no_context_takeover": + copts.clientNoContextTakeover = true + continue + case "server_no_context_takeover": + copts.serverNoContextTakeover = true + continue + } + + if strings.HasPrefix(p, "client_max_window_bits") { + // We cannot adjust the read sliding window so cannot make use of this. + continue + } + + err := fmt.Errorf("unsupported permessage-deflate parameter: %q", p) + http.Error(w, err.Error(), http.StatusBadRequest) + return nil, err + } + + copts.setHeader(w.Header()) + + return copts, nil +} + +func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) { + copts := mode.opts() + // The peer must explicitly request it. + copts.serverNoContextTakeover = false + + for _, p := range ext.params { + if p == "no_context_takeover" { + copts.serverNoContextTakeover = true + continue + } + + // We explicitly fail on x-webkit-deflate-frame's max_window_bits parameter instead + // of ignoring it as the draft spec is unclear. It says the server can ignore it + // but the server has no way of signalling to the client it was ignored as the parameters + // are set one way. + // Thus us ignoring it would make the client think we understood it which would cause issues. + // See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06#section-4.1 + // + // Either way, we're only implementing this for webkit which never sends the max_window_bits + // parameter so we don't need to worry about it. + err := fmt.Errorf("unsupported x-webkit-deflate-frame parameter: %q", p) + http.Error(w, err.Error(), http.StatusBadRequest) + return nil, err + } + + s := "x-webkit-deflate-frame" + if copts.clientNoContextTakeover { + s += "; no_context_takeover" + } + w.Header().Set("Sec-WebSocket-Extensions", s) + + return copts, nil +} + +func headerContainsToken(h http.Header, key, token string) bool { + token = strings.ToLower(token) + + for _, t := range headerTokens(h, key) { + if t == token { + return true + } + } + return false +} + +type websocketExtension struct { + name string + params []string +} + +func websocketExtensions(h http.Header) []websocketExtension { + var exts []websocketExtension + extStrs := headerTokens(h, "Sec-WebSocket-Extensions") + for _, extStr := range extStrs { + if extStr == "" { + continue + } + + vals := strings.Split(extStr, ";") + for i := range vals { + vals[i] = strings.TrimSpace(vals[i]) + } + + e := websocketExtension{ + name: vals[0], + params: vals[1:], + } + + exts = append(exts, e) + } + return exts +} + +func headerTokens(h http.Header, key string) []string { + key = textproto.CanonicalMIMEHeaderKey(key) + var tokens []string + for _, v := range h[key] { + v = strings.TrimSpace(v) + for _, t := range strings.Split(v, ",") { + t = strings.ToLower(t) + t = strings.TrimSpace(t) + tokens = append(tokens, t) + } + } + return tokens +} + +var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") + +func secWebSocketAccept(secWebSocketKey string) string { + h := sha1.New() + h.Write([]byte(secWebSocketKey)) + h.Write(keyGUID) + + return base64.StdEncoding.EncodeToString(h.Sum(nil)) +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/accept_js.go b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/accept_js.go new file mode 100644 index 000000000000..daad4b79fec6 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/accept_js.go @@ -0,0 +1,20 @@ +package websocket + +import ( + "errors" + "net/http" +) + +// AcceptOptions represents Accept's options. +type AcceptOptions struct { + Subprotocols []string + InsecureSkipVerify bool + OriginPatterns []string + CompressionMode CompressionMode + CompressionThreshold int +} + +// Accept is stubbed out for Wasm. +func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { + return nil, errors.New("unimplemented") +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/close.go b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/close.go new file mode 100644 index 000000000000..7cbc19e9def6 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/close.go @@ -0,0 +1,76 @@ +package websocket + +import ( + "errors" + "fmt" +) + +// StatusCode represents a WebSocket status code. +// https://tools.ietf.org/html/rfc6455#section-7.4 +type StatusCode int + +// https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number +// +// These are only the status codes defined by the protocol. +// +// You can define custom codes in the 3000-4999 range. +// The 3000-3999 range is reserved for use by libraries, frameworks and applications. +// The 4000-4999 range is reserved for private use. +const ( + StatusNormalClosure StatusCode = 1000 + StatusGoingAway StatusCode = 1001 + StatusProtocolError StatusCode = 1002 + StatusUnsupportedData StatusCode = 1003 + + // 1004 is reserved and so unexported. + statusReserved StatusCode = 1004 + + // StatusNoStatusRcvd cannot be sent in a close message. + // It is reserved for when a close message is received without + // a status code. + StatusNoStatusRcvd StatusCode = 1005 + + // StatusAbnormalClosure is exported for use only with Wasm. + // In non Wasm Go, the returned error will indicate whether the + // connection was closed abnormally. + StatusAbnormalClosure StatusCode = 1006 + + StatusInvalidFramePayloadData StatusCode = 1007 + StatusPolicyViolation StatusCode = 1008 + StatusMessageTooBig StatusCode = 1009 + StatusMandatoryExtension StatusCode = 1010 + StatusInternalError StatusCode = 1011 + StatusServiceRestart StatusCode = 1012 + StatusTryAgainLater StatusCode = 1013 + StatusBadGateway StatusCode = 1014 + + // StatusTLSHandshake is only exported for use with Wasm. + // In non Wasm Go, the returned error will indicate whether there was + // a TLS handshake failure. + StatusTLSHandshake StatusCode = 1015 +) + +// CloseError is returned when the connection is closed with a status and reason. +// +// Use Go 1.13's errors.As to check for this error. +// Also see the CloseStatus helper. +type CloseError struct { + Code StatusCode + Reason string +} + +func (ce CloseError) Error() string { + return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason) +} + +// CloseStatus is a convenience wrapper around Go 1.13's errors.As to grab +// the status code from a CloseError. +// +// -1 will be returned if the passed error is nil or not a CloseError. +func CloseStatus(err error) StatusCode { + var ce CloseError + if errors.As(err, &ce) { + return ce.Code + } + return -1 +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/close_notjs.go b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/close_notjs.go new file mode 100644 index 000000000000..4251311d2e69 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/close_notjs.go @@ -0,0 +1,211 @@ +// +build !js + +package websocket + +import ( + "context" + "encoding/binary" + "errors" + "fmt" + "log" + "time" + + "nhooyr.io/websocket/internal/errd" +) + +// Close performs the WebSocket close handshake with the given status code and reason. +// +// It will write a WebSocket close frame with a timeout of 5s and then wait 5s for +// the peer to send a close frame. +// All data messages received from the peer during the close handshake will be discarded. +// +// The connection can only be closed once. Additional calls to Close +// are no-ops. +// +// The maximum length of reason must be 125 bytes. Avoid +// sending a dynamic reason. +// +// Close will unblock all goroutines interacting with the connection once +// complete. +func (c *Conn) Close(code StatusCode, reason string) error { + return c.closeHandshake(code, reason) +} + +func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) { + defer errd.Wrap(&err, "failed to close WebSocket") + + writeErr := c.writeClose(code, reason) + closeHandshakeErr := c.waitCloseHandshake() + + if writeErr != nil { + return writeErr + } + + if CloseStatus(closeHandshakeErr) == -1 { + return closeHandshakeErr + } + + return nil +} + +var errAlreadyWroteClose = errors.New("already wrote close") + +func (c *Conn) writeClose(code StatusCode, reason string) error { + c.closeMu.Lock() + wroteClose := c.wroteClose + c.wroteClose = true + c.closeMu.Unlock() + if wroteClose { + return errAlreadyWroteClose + } + + ce := CloseError{ + Code: code, + Reason: reason, + } + + var p []byte + var marshalErr error + if ce.Code != StatusNoStatusRcvd { + p, marshalErr = ce.bytes() + if marshalErr != nil { + log.Printf("websocket: %v", marshalErr) + } + } + + writeErr := c.writeControl(context.Background(), opClose, p) + if CloseStatus(writeErr) != -1 { + // Not a real error if it's due to a close frame being received. + writeErr = nil + } + + // We do this after in case there was an error writing the close frame. + c.setCloseErr(fmt.Errorf("sent close frame: %w", ce)) + + if marshalErr != nil { + return marshalErr + } + return writeErr +} + +func (c *Conn) waitCloseHandshake() error { + defer c.close(nil) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + err := c.readMu.lock(ctx) + if err != nil { + return err + } + defer c.readMu.unlock() + + if c.readCloseFrameErr != nil { + return c.readCloseFrameErr + } + + for { + h, err := c.readLoop(ctx) + if err != nil { + return err + } + + for i := int64(0); i < h.payloadLength; i++ { + _, err := c.br.ReadByte() + if err != nil { + return err + } + } + } +} + +func parseClosePayload(p []byte) (CloseError, error) { + if len(p) == 0 { + return CloseError{ + Code: StatusNoStatusRcvd, + }, nil + } + + if len(p) < 2 { + return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p) + } + + ce := CloseError{ + Code: StatusCode(binary.BigEndian.Uint16(p)), + Reason: string(p[2:]), + } + + if !validWireCloseCode(ce.Code) { + return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code) + } + + return ce, nil +} + +// See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number +// and https://tools.ietf.org/html/rfc6455#section-7.4.1 +func validWireCloseCode(code StatusCode) bool { + switch code { + case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake: + return false + } + + if code >= StatusNormalClosure && code <= StatusBadGateway { + return true + } + if code >= 3000 && code <= 4999 { + return true + } + + return false +} + +func (ce CloseError) bytes() ([]byte, error) { + p, err := ce.bytesErr() + if err != nil { + err = fmt.Errorf("failed to marshal close frame: %w", err) + ce = CloseError{ + Code: StatusInternalError, + } + p, _ = ce.bytesErr() + } + return p, err +} + +const maxCloseReason = maxControlPayload - 2 + +func (ce CloseError) bytesErr() ([]byte, error) { + if len(ce.Reason) > maxCloseReason { + return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason)) + } + + if !validWireCloseCode(ce.Code) { + return nil, fmt.Errorf("status code %v cannot be set", ce.Code) + } + + buf := make([]byte, 2+len(ce.Reason)) + binary.BigEndian.PutUint16(buf, uint16(ce.Code)) + copy(buf[2:], ce.Reason) + return buf, nil +} + +func (c *Conn) setCloseErr(err error) { + c.closeMu.Lock() + c.setCloseErrLocked(err) + c.closeMu.Unlock() +} + +func (c *Conn) setCloseErrLocked(err error) { + if c.closeErr == nil { + c.closeErr = fmt.Errorf("WebSocket closed: %w", err) + } +} + +func (c *Conn) isClosed() bool { + select { + case <-c.closed: + return true + default: + return false + } +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/compress.go b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/compress.go new file mode 100644 index 000000000000..80b46d1c1d39 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/compress.go @@ -0,0 +1,39 @@ +package websocket + +// CompressionMode represents the modes available to the deflate extension. +// See https://tools.ietf.org/html/rfc7692 +// +// A compatibility layer is implemented for the older deflate-frame extension used +// by safari. See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06 +// It will work the same in every way except that we cannot signal to the peer we +// want to use no context takeover on our side, we can only signal that they should. +// It is however currently disabled due to Safari bugs. See https://github.com/nhooyr/websocket/issues/218 +type CompressionMode int + +const ( + // CompressionNoContextTakeover grabs a new flate.Reader and flate.Writer as needed + // for every message. This applies to both server and client side. + // + // This means less efficient compression as the sliding window from previous messages + // will not be used but the memory overhead will be lower if the connections + // are long lived and seldom used. + // + // The message will only be compressed if greater than 512 bytes. + CompressionNoContextTakeover CompressionMode = iota + + // CompressionContextTakeover uses a flate.Reader and flate.Writer per connection. + // This enables reusing the sliding window from previous messages. + // As most WebSocket protocols are repetitive, this can be very efficient. + // It carries an overhead of 8 kB for every connection compared to CompressionNoContextTakeover. + // + // If the peer negotiates NoContextTakeover on the client or server side, it will be + // used instead as this is required by the RFC. + CompressionContextTakeover + + // CompressionDisabled disables the deflate extension. + // + // Use this if you are using a predominantly binary protocol with very + // little duplication in between messages or CPU and memory are more + // important than bandwidth. + CompressionDisabled +) diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/compress_notjs.go b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/compress_notjs.go new file mode 100644 index 000000000000..809a272c3d1e --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/compress_notjs.go @@ -0,0 +1,181 @@ +// +build !js + +package websocket + +import ( + "io" + "net/http" + "sync" + + "github.com/klauspost/compress/flate" +) + +func (m CompressionMode) opts() *compressionOptions { + return &compressionOptions{ + clientNoContextTakeover: m == CompressionNoContextTakeover, + serverNoContextTakeover: m == CompressionNoContextTakeover, + } +} + +type compressionOptions struct { + clientNoContextTakeover bool + serverNoContextTakeover bool +} + +func (copts *compressionOptions) setHeader(h http.Header) { + s := "permessage-deflate" + if copts.clientNoContextTakeover { + s += "; client_no_context_takeover" + } + if copts.serverNoContextTakeover { + s += "; server_no_context_takeover" + } + h.Set("Sec-WebSocket-Extensions", s) +} + +// These bytes are required to get flate.Reader to return. +// They are removed when sending to avoid the overhead as +// WebSocket framing tell's when the message has ended but then +// we need to add them back otherwise flate.Reader keeps +// trying to return more bytes. +const deflateMessageTail = "\x00\x00\xff\xff" + +type trimLastFourBytesWriter struct { + w io.Writer + tail []byte +} + +func (tw *trimLastFourBytesWriter) reset() { + if tw != nil && tw.tail != nil { + tw.tail = tw.tail[:0] + } +} + +func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) { + if tw.tail == nil { + tw.tail = make([]byte, 0, 4) + } + + extra := len(tw.tail) + len(p) - 4 + + if extra <= 0 { + tw.tail = append(tw.tail, p...) + return len(p), nil + } + + // Now we need to write as many extra bytes as we can from the previous tail. + if extra > len(tw.tail) { + extra = len(tw.tail) + } + if extra > 0 { + _, err := tw.w.Write(tw.tail[:extra]) + if err != nil { + return 0, err + } + + // Shift remaining bytes in tail over. + n := copy(tw.tail, tw.tail[extra:]) + tw.tail = tw.tail[:n] + } + + // If p is less than or equal to 4 bytes, + // all of it is is part of the tail. + if len(p) <= 4 { + tw.tail = append(tw.tail, p...) + return len(p), nil + } + + // Otherwise, only the last 4 bytes are. + tw.tail = append(tw.tail, p[len(p)-4:]...) + + p = p[:len(p)-4] + n, err := tw.w.Write(p) + return n + 4, err +} + +var flateReaderPool sync.Pool + +func getFlateReader(r io.Reader, dict []byte) io.Reader { + fr, ok := flateReaderPool.Get().(io.Reader) + if !ok { + return flate.NewReaderDict(r, dict) + } + fr.(flate.Resetter).Reset(r, dict) + return fr +} + +func putFlateReader(fr io.Reader) { + flateReaderPool.Put(fr) +} + +type slidingWindow struct { + buf []byte +} + +var swPoolMu sync.RWMutex +var swPool = map[int]*sync.Pool{} + +func slidingWindowPool(n int) *sync.Pool { + swPoolMu.RLock() + p, ok := swPool[n] + swPoolMu.RUnlock() + if ok { + return p + } + + p = &sync.Pool{} + + swPoolMu.Lock() + swPool[n] = p + swPoolMu.Unlock() + + return p +} + +func (sw *slidingWindow) init(n int) { + if sw.buf != nil { + return + } + + if n == 0 { + n = 32768 + } + + p := slidingWindowPool(n) + buf, ok := p.Get().([]byte) + if ok { + sw.buf = buf[:0] + } else { + sw.buf = make([]byte, 0, n) + } +} + +func (sw *slidingWindow) close() { + if sw.buf == nil { + return + } + + swPoolMu.Lock() + swPool[cap(sw.buf)].Put(sw.buf) + swPoolMu.Unlock() + sw.buf = nil +} + +func (sw *slidingWindow) write(p []byte) { + if len(p) >= cap(sw.buf) { + sw.buf = sw.buf[:cap(sw.buf)] + p = p[len(p)-cap(sw.buf):] + copy(sw.buf, p) + return + } + + left := cap(sw.buf) - len(sw.buf) + if left < len(p) { + // We need to shift spaceNeeded bytes from the end to make room for p at the end. + spaceNeeded := len(p) - left + copy(sw.buf, sw.buf[spaceNeeded:]) + sw.buf = sw.buf[:len(sw.buf)-spaceNeeded] + } + + sw.buf = append(sw.buf, p...) +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/conn.go b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/conn.go new file mode 100644 index 000000000000..a41808be3fad --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/conn.go @@ -0,0 +1,13 @@ +package websocket + +// MessageType represents the type of a WebSocket message. +// See https://tools.ietf.org/html/rfc6455#section-5.6 +type MessageType int + +// MessageType constants. +const ( + // MessageText is for UTF-8 encoded text messages like JSON. + MessageText MessageType = iota + 1 + // MessageBinary is for binary messages like protobufs. + MessageBinary +) diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/conn_notjs.go b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/conn_notjs.go new file mode 100644 index 000000000000..bb2eb22f7dba --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/conn_notjs.go @@ -0,0 +1,265 @@ +// +build !js + +package websocket + +import ( + "bufio" + "context" + "errors" + "fmt" + "io" + "runtime" + "strconv" + "sync" + "sync/atomic" +) + +// Conn represents a WebSocket connection. +// All methods may be called concurrently except for Reader and Read. +// +// You must always read from the connection. Otherwise control +// frames will not be handled. See Reader and CloseRead. +// +// Be sure to call Close on the connection when you +// are finished with it to release associated resources. +// +// On any error from any method, the connection is closed +// with an appropriate reason. +type Conn struct { + subprotocol string + rwc io.ReadWriteCloser + client bool + copts *compressionOptions + flateThreshold int + br *bufio.Reader + bw *bufio.Writer + + readTimeout chan context.Context + writeTimeout chan context.Context + + // Read state. + readMu *mu + readHeaderBuf [8]byte + readControlBuf [maxControlPayload]byte + msgReader *msgReader + readCloseFrameErr error + + // Write state. + msgWriterState *msgWriterState + writeFrameMu *mu + writeBuf []byte + writeHeaderBuf [8]byte + writeHeader header + + closed chan struct{} + closeMu sync.Mutex + closeErr error + wroteClose bool + + pingCounter int32 + activePingsMu sync.Mutex + activePings map[string]chan<- struct{} +} + +type connConfig struct { + subprotocol string + rwc io.ReadWriteCloser + client bool + copts *compressionOptions + flateThreshold int + + br *bufio.Reader + bw *bufio.Writer +} + +func newConn(cfg connConfig) *Conn { + c := &Conn{ + subprotocol: cfg.subprotocol, + rwc: cfg.rwc, + client: cfg.client, + copts: cfg.copts, + flateThreshold: cfg.flateThreshold, + + br: cfg.br, + bw: cfg.bw, + + readTimeout: make(chan context.Context), + writeTimeout: make(chan context.Context), + + closed: make(chan struct{}), + activePings: make(map[string]chan<- struct{}), + } + + c.readMu = newMu(c) + c.writeFrameMu = newMu(c) + + c.msgReader = newMsgReader(c) + + c.msgWriterState = newMsgWriterState(c) + if c.client { + c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc) + } + + if c.flate() && c.flateThreshold == 0 { + c.flateThreshold = 128 + if !c.msgWriterState.flateContextTakeover() { + c.flateThreshold = 512 + } + } + + runtime.SetFinalizer(c, func(c *Conn) { + c.close(errors.New("connection garbage collected")) + }) + + go c.timeoutLoop() + + return c +} + +// Subprotocol returns the negotiated subprotocol. +// An empty string means the default protocol. +func (c *Conn) Subprotocol() string { + return c.subprotocol +} + +func (c *Conn) close(err error) { + c.closeMu.Lock() + defer c.closeMu.Unlock() + + if c.isClosed() { + return + } + c.setCloseErrLocked(err) + close(c.closed) + runtime.SetFinalizer(c, nil) + + // Have to close after c.closed is closed to ensure any goroutine that wakes up + // from the connection being closed also sees that c.closed is closed and returns + // closeErr. + c.rwc.Close() + + go func() { + c.msgWriterState.close() + + c.msgReader.close() + }() +} + +func (c *Conn) timeoutLoop() { + readCtx := context.Background() + writeCtx := context.Background() + + for { + select { + case <-c.closed: + return + + case writeCtx = <-c.writeTimeout: + case readCtx = <-c.readTimeout: + + case <-readCtx.Done(): + c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err())) + go c.writeError(StatusPolicyViolation, errors.New("timed out")) + case <-writeCtx.Done(): + c.close(fmt.Errorf("write timed out: %w", writeCtx.Err())) + return + } + } +} + +func (c *Conn) flate() bool { + return c.copts != nil +} + +// Ping sends a ping to the peer and waits for a pong. +// Use this to measure latency or ensure the peer is responsive. +// Ping must be called concurrently with Reader as it does +// not read from the connection but instead waits for a Reader call +// to read the pong. +// +// TCP Keepalives should suffice for most use cases. +func (c *Conn) Ping(ctx context.Context) error { + p := atomic.AddInt32(&c.pingCounter, 1) + + err := c.ping(ctx, strconv.Itoa(int(p))) + if err != nil { + return fmt.Errorf("failed to ping: %w", err) + } + return nil +} + +func (c *Conn) ping(ctx context.Context, p string) error { + pong := make(chan struct{}) + + c.activePingsMu.Lock() + c.activePings[p] = pong + c.activePingsMu.Unlock() + + defer func() { + c.activePingsMu.Lock() + delete(c.activePings, p) + c.activePingsMu.Unlock() + }() + + err := c.writeControl(ctx, opPing, []byte(p)) + if err != nil { + return err + } + + select { + case <-c.closed: + return c.closeErr + case <-ctx.Done(): + err := fmt.Errorf("failed to wait for pong: %w", ctx.Err()) + c.close(err) + return err + case <-pong: + return nil + } +} + +type mu struct { + c *Conn + ch chan struct{} +} + +func newMu(c *Conn) *mu { + return &mu{ + c: c, + ch: make(chan struct{}, 1), + } +} + +func (m *mu) forceLock() { + m.ch <- struct{}{} +} + +func (m *mu) lock(ctx context.Context) error { + select { + case <-m.c.closed: + return m.c.closeErr + case <-ctx.Done(): + err := fmt.Errorf("failed to acquire lock: %w", ctx.Err()) + m.c.close(err) + return err + case m.ch <- struct{}{}: + // To make sure the connection is certainly alive. + // As it's possible the send on m.ch was selected + // over the receive on closed. + select { + case <-m.c.closed: + // Make sure to release. + m.unlock() + return m.c.closeErr + default: + } + return nil + } +} + +func (m *mu) unlock() { + select { + case <-m.ch: + default: + } +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/dial.go b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/dial.go new file mode 100644 index 000000000000..2b25e3517d66 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/dial.go @@ -0,0 +1,287 @@ +// +build !js + +package websocket + +import ( + "bufio" + "bytes" + "context" + "crypto/rand" + "encoding/base64" + "errors" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "nhooyr.io/websocket/internal/errd" +) + +// DialOptions represents Dial's options. +type DialOptions struct { + // HTTPClient is used for the connection. + // Its Transport must return writable bodies for WebSocket handshakes. + // http.Transport does beginning with Go 1.12. + HTTPClient *http.Client + + // HTTPHeader specifies the HTTP headers included in the handshake request. + HTTPHeader http.Header + + // Subprotocols lists the WebSocket subprotocols to negotiate with the server. + Subprotocols []string + + // CompressionMode controls the compression mode. + // Defaults to CompressionNoContextTakeover. + // + // See docs on CompressionMode for details. + CompressionMode CompressionMode + + // CompressionThreshold controls the minimum size of a message before compression is applied. + // + // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes + // for CompressionContextTakeover. + CompressionThreshold int +} + +// Dial performs a WebSocket handshake on url. +// +// The response is the WebSocket handshake response from the server. +// You never need to close resp.Body yourself. +// +// If an error occurs, the returned response may be non nil. +// However, you can only read the first 1024 bytes of the body. +// +// This function requires at least Go 1.12 as it uses a new feature +// in net/http to perform WebSocket handshakes. +// See docs on the HTTPClient option and https://github.com/golang/go/issues/26937#issuecomment-415855861 +// +// URLs with http/https schemes will work and are interpreted as ws/wss. +func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) { + return dial(ctx, u, opts, nil) +} + +func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (_ *Conn, _ *http.Response, err error) { + defer errd.Wrap(&err, "failed to WebSocket dial") + + if opts == nil { + opts = &DialOptions{} + } + + opts = &*opts + if opts.HTTPClient == nil { + opts.HTTPClient = http.DefaultClient + } + if opts.HTTPHeader == nil { + opts.HTTPHeader = http.Header{} + } + + secWebSocketKey, err := secWebSocketKey(rand) + if err != nil { + return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err) + } + + var copts *compressionOptions + if opts.CompressionMode != CompressionDisabled { + copts = opts.CompressionMode.opts() + } + + resp, err := handshakeRequest(ctx, urls, opts, copts, secWebSocketKey) + if err != nil { + return nil, resp, err + } + respBody := resp.Body + resp.Body = nil + defer func() { + if err != nil { + // We read a bit of the body for easier debugging. + r := io.LimitReader(respBody, 1024) + + timer := time.AfterFunc(time.Second*3, func() { + respBody.Close() + }) + defer timer.Stop() + + b, _ := ioutil.ReadAll(r) + respBody.Close() + resp.Body = ioutil.NopCloser(bytes.NewReader(b)) + } + }() + + copts, err = verifyServerResponse(opts, copts, secWebSocketKey, resp) + if err != nil { + return nil, resp, err + } + + rwc, ok := respBody.(io.ReadWriteCloser) + if !ok { + return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", respBody) + } + + return newConn(connConfig{ + subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"), + rwc: rwc, + client: true, + copts: copts, + flateThreshold: opts.CompressionThreshold, + br: getBufioReader(rwc), + bw: getBufioWriter(rwc), + }), resp, nil +} + +func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts *compressionOptions, secWebSocketKey string) (*http.Response, error) { + if opts.HTTPClient.Timeout > 0 { + return nil, errors.New("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67") + } + + u, err := url.Parse(urls) + if err != nil { + return nil, fmt.Errorf("failed to parse url: %w", err) + } + + switch u.Scheme { + case "ws": + u.Scheme = "http" + case "wss": + u.Scheme = "https" + case "http", "https": + default: + return nil, fmt.Errorf("unexpected url scheme: %q", u.Scheme) + } + + req, _ := http.NewRequestWithContext(ctx, "GET", u.String(), nil) + req.Header = opts.HTTPHeader.Clone() + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Sec-WebSocket-Version", "13") + req.Header.Set("Sec-WebSocket-Key", secWebSocketKey) + if len(opts.Subprotocols) > 0 { + req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) + } + if copts != nil { + copts.setHeader(req.Header) + } + + resp, err := opts.HTTPClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send handshake request: %w", err) + } + return resp, nil +} + +func secWebSocketKey(rr io.Reader) (string, error) { + if rr == nil { + rr = rand.Reader + } + b := make([]byte, 16) + _, err := io.ReadFull(rr, b) + if err != nil { + return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err) + } + return base64.StdEncoding.EncodeToString(b), nil +} + +func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) { + if resp.StatusCode != http.StatusSwitchingProtocols { + return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode) + } + + if !headerContainsToken(resp.Header, "Connection", "Upgrade") { + return nil, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection")) + } + + if !headerContainsToken(resp.Header, "Upgrade", "WebSocket") { + return nil, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade")) + } + + if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(secWebSocketKey) { + return nil, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q", + resp.Header.Get("Sec-WebSocket-Accept"), + secWebSocketKey, + ) + } + + err := verifySubprotocol(opts.Subprotocols, resp) + if err != nil { + return nil, err + } + + return verifyServerExtensions(copts, resp.Header) +} + +func verifySubprotocol(subprotos []string, resp *http.Response) error { + proto := resp.Header.Get("Sec-WebSocket-Protocol") + if proto == "" { + return nil + } + + for _, sp2 := range subprotos { + if strings.EqualFold(sp2, proto) { + return nil + } + } + + return fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto) +} + +func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compressionOptions, error) { + exts := websocketExtensions(h) + if len(exts) == 0 { + return nil, nil + } + + ext := exts[0] + if ext.name != "permessage-deflate" || len(exts) > 1 || copts == nil { + return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:]) + } + + copts = &*copts + + for _, p := range ext.params { + switch p { + case "client_no_context_takeover": + copts.clientNoContextTakeover = true + continue + case "server_no_context_takeover": + copts.serverNoContextTakeover = true + continue + } + + return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p) + } + + return copts, nil +} + +var bufioReaderPool sync.Pool + +func getBufioReader(r io.Reader) *bufio.Reader { + br, ok := bufioReaderPool.Get().(*bufio.Reader) + if !ok { + return bufio.NewReader(r) + } + br.Reset(r) + return br +} + +func putBufioReader(br *bufio.Reader) { + bufioReaderPool.Put(br) +} + +var bufioWriterPool sync.Pool + +func getBufioWriter(w io.Writer) *bufio.Writer { + bw, ok := bufioWriterPool.Get().(*bufio.Writer) + if !ok { + return bufio.NewWriter(w) + } + bw.Reset(w) + return bw +} + +func putBufioWriter(bw *bufio.Writer) { + bufioWriterPool.Put(bw) +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/doc.go b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/doc.go new file mode 100644 index 000000000000..efa920e3b61e --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/doc.go @@ -0,0 +1,32 @@ +// +build !js + +// Package websocket implements the RFC 6455 WebSocket protocol. +// +// https://tools.ietf.org/html/rfc6455 +// +// Use Dial to dial a WebSocket server. +// +// Use Accept to accept a WebSocket client. +// +// Conn represents the resulting WebSocket connection. +// +// The examples are the best way to understand how to correctly use the library. +// +// The wsjson and wspb subpackages contain helpers for JSON and protobuf messages. +// +// More documentation at https://nhooyr.io/websocket. +// +// Wasm +// +// The client side supports compiling to Wasm. +// It wraps the WebSocket browser API. +// +// See https://developer.mozilla.org/en-US/docs/Web/API/WebSocket +// +// Some important caveats to be aware of: +// +// - Accept always errors out +// - Conn.Ping is no-op +// - HTTPClient, HTTPHeader and CompressionMode in DialOptions are no-op +// - *http.Response from Dial is &http.Response{} with a 101 status code on success +package websocket // import "nhooyr.io/websocket" diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/frame.go b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/frame.go new file mode 100644 index 000000000000..2a036f944ac9 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/frame.go @@ -0,0 +1,294 @@ +package websocket + +import ( + "bufio" + "encoding/binary" + "fmt" + "io" + "math" + "math/bits" + + "nhooyr.io/websocket/internal/errd" +) + +// opcode represents a WebSocket opcode. +type opcode int + +// https://tools.ietf.org/html/rfc6455#section-11.8. +const ( + opContinuation opcode = iota + opText + opBinary + // 3 - 7 are reserved for further non-control frames. + _ + _ + _ + _ + _ + opClose + opPing + opPong + // 11-16 are reserved for further control frames. +) + +// header represents a WebSocket frame header. +// See https://tools.ietf.org/html/rfc6455#section-5.2. +type header struct { + fin bool + rsv1 bool + rsv2 bool + rsv3 bool + opcode opcode + + payloadLength int64 + + masked bool + maskKey uint32 +} + +// readFrameHeader reads a header from the reader. +// See https://tools.ietf.org/html/rfc6455#section-5.2. +func readFrameHeader(r *bufio.Reader, readBuf []byte) (h header, err error) { + defer errd.Wrap(&err, "failed to read frame header") + + b, err := r.ReadByte() + if err != nil { + return header{}, err + } + + h.fin = b&(1<<7) != 0 + h.rsv1 = b&(1<<6) != 0 + h.rsv2 = b&(1<<5) != 0 + h.rsv3 = b&(1<<4) != 0 + + h.opcode = opcode(b & 0xf) + + b, err = r.ReadByte() + if err != nil { + return header{}, err + } + + h.masked = b&(1<<7) != 0 + + payloadLength := b &^ (1 << 7) + switch { + case payloadLength < 126: + h.payloadLength = int64(payloadLength) + case payloadLength == 126: + _, err = io.ReadFull(r, readBuf[:2]) + h.payloadLength = int64(binary.BigEndian.Uint16(readBuf)) + case payloadLength == 127: + _, err = io.ReadFull(r, readBuf) + h.payloadLength = int64(binary.BigEndian.Uint64(readBuf)) + } + if err != nil { + return header{}, err + } + + if h.payloadLength < 0 { + return header{}, fmt.Errorf("received negative payload length: %v", h.payloadLength) + } + + if h.masked { + _, err = io.ReadFull(r, readBuf[:4]) + if err != nil { + return header{}, err + } + h.maskKey = binary.LittleEndian.Uint32(readBuf) + } + + return h, nil +} + +// maxControlPayload is the maximum length of a control frame payload. +// See https://tools.ietf.org/html/rfc6455#section-5.5. +const maxControlPayload = 125 + +// writeFrameHeader writes the bytes of the header to w. +// See https://tools.ietf.org/html/rfc6455#section-5.2 +func writeFrameHeader(h header, w *bufio.Writer, buf []byte) (err error) { + defer errd.Wrap(&err, "failed to write frame header") + + var b byte + if h.fin { + b |= 1 << 7 + } + if h.rsv1 { + b |= 1 << 6 + } + if h.rsv2 { + b |= 1 << 5 + } + if h.rsv3 { + b |= 1 << 4 + } + + b |= byte(h.opcode) + + err = w.WriteByte(b) + if err != nil { + return err + } + + lengthByte := byte(0) + if h.masked { + lengthByte |= 1 << 7 + } + + switch { + case h.payloadLength > math.MaxUint16: + lengthByte |= 127 + case h.payloadLength > 125: + lengthByte |= 126 + case h.payloadLength >= 0: + lengthByte |= byte(h.payloadLength) + } + err = w.WriteByte(lengthByte) + if err != nil { + return err + } + + switch { + case h.payloadLength > math.MaxUint16: + binary.BigEndian.PutUint64(buf, uint64(h.payloadLength)) + _, err = w.Write(buf) + case h.payloadLength > 125: + binary.BigEndian.PutUint16(buf, uint16(h.payloadLength)) + _, err = w.Write(buf[:2]) + } + if err != nil { + return err + } + + if h.masked { + binary.LittleEndian.PutUint32(buf, h.maskKey) + _, err = w.Write(buf[:4]) + if err != nil { + return err + } + } + + return nil +} + +// mask applies the WebSocket masking algorithm to p +// with the given key. +// See https://tools.ietf.org/html/rfc6455#section-5.3 +// +// The returned value is the correctly rotated key to +// to continue to mask/unmask the message. +// +// It is optimized for LittleEndian and expects the key +// to be in little endian. +// +// See https://github.com/golang/go/issues/31586 +func mask(key uint32, b []byte) uint32 { + if len(b) >= 8 { + key64 := uint64(key)<<32 | uint64(key) + + // At some point in the future we can clean these unrolled loops up. + // See https://github.com/golang/go/issues/31586#issuecomment-487436401 + + // Then we xor until b is less than 128 bytes. + for len(b) >= 128 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + v = binary.LittleEndian.Uint64(b[8:16]) + binary.LittleEndian.PutUint64(b[8:16], v^key64) + v = binary.LittleEndian.Uint64(b[16:24]) + binary.LittleEndian.PutUint64(b[16:24], v^key64) + v = binary.LittleEndian.Uint64(b[24:32]) + binary.LittleEndian.PutUint64(b[24:32], v^key64) + v = binary.LittleEndian.Uint64(b[32:40]) + binary.LittleEndian.PutUint64(b[32:40], v^key64) + v = binary.LittleEndian.Uint64(b[40:48]) + binary.LittleEndian.PutUint64(b[40:48], v^key64) + v = binary.LittleEndian.Uint64(b[48:56]) + binary.LittleEndian.PutUint64(b[48:56], v^key64) + v = binary.LittleEndian.Uint64(b[56:64]) + binary.LittleEndian.PutUint64(b[56:64], v^key64) + v = binary.LittleEndian.Uint64(b[64:72]) + binary.LittleEndian.PutUint64(b[64:72], v^key64) + v = binary.LittleEndian.Uint64(b[72:80]) + binary.LittleEndian.PutUint64(b[72:80], v^key64) + v = binary.LittleEndian.Uint64(b[80:88]) + binary.LittleEndian.PutUint64(b[80:88], v^key64) + v = binary.LittleEndian.Uint64(b[88:96]) + binary.LittleEndian.PutUint64(b[88:96], v^key64) + v = binary.LittleEndian.Uint64(b[96:104]) + binary.LittleEndian.PutUint64(b[96:104], v^key64) + v = binary.LittleEndian.Uint64(b[104:112]) + binary.LittleEndian.PutUint64(b[104:112], v^key64) + v = binary.LittleEndian.Uint64(b[112:120]) + binary.LittleEndian.PutUint64(b[112:120], v^key64) + v = binary.LittleEndian.Uint64(b[120:128]) + binary.LittleEndian.PutUint64(b[120:128], v^key64) + b = b[128:] + } + + // Then we xor until b is less than 64 bytes. + for len(b) >= 64 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + v = binary.LittleEndian.Uint64(b[8:16]) + binary.LittleEndian.PutUint64(b[8:16], v^key64) + v = binary.LittleEndian.Uint64(b[16:24]) + binary.LittleEndian.PutUint64(b[16:24], v^key64) + v = binary.LittleEndian.Uint64(b[24:32]) + binary.LittleEndian.PutUint64(b[24:32], v^key64) + v = binary.LittleEndian.Uint64(b[32:40]) + binary.LittleEndian.PutUint64(b[32:40], v^key64) + v = binary.LittleEndian.Uint64(b[40:48]) + binary.LittleEndian.PutUint64(b[40:48], v^key64) + v = binary.LittleEndian.Uint64(b[48:56]) + binary.LittleEndian.PutUint64(b[48:56], v^key64) + v = binary.LittleEndian.Uint64(b[56:64]) + binary.LittleEndian.PutUint64(b[56:64], v^key64) + b = b[64:] + } + + // Then we xor until b is less than 32 bytes. + for len(b) >= 32 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + v = binary.LittleEndian.Uint64(b[8:16]) + binary.LittleEndian.PutUint64(b[8:16], v^key64) + v = binary.LittleEndian.Uint64(b[16:24]) + binary.LittleEndian.PutUint64(b[16:24], v^key64) + v = binary.LittleEndian.Uint64(b[24:32]) + binary.LittleEndian.PutUint64(b[24:32], v^key64) + b = b[32:] + } + + // Then we xor until b is less than 16 bytes. + for len(b) >= 16 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + v = binary.LittleEndian.Uint64(b[8:16]) + binary.LittleEndian.PutUint64(b[8:16], v^key64) + b = b[16:] + } + + // Then we xor until b is less than 8 bytes. + for len(b) >= 8 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + b = b[8:] + } + } + + // Then we xor until b is less than 4 bytes. + for len(b) >= 4 { + v := binary.LittleEndian.Uint32(b) + binary.LittleEndian.PutUint32(b, v^key) + b = b[4:] + } + + // xor remaining bytes. + for i := range b { + b[i] ^= byte(key) + key = bits.RotateLeft32(key, -8) + } + + return key +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/go.mod b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/go.mod new file mode 100644 index 000000000000..60377823cba0 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/go.mod @@ -0,0 +1,14 @@ +module nhooyr.io/websocket + +go 1.13 + +require ( + github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee // indirect + github.com/gobwas/pool v0.2.0 // indirect + github.com/gobwas/ws v1.0.2 + github.com/golang/protobuf v1.3.5 + github.com/google/go-cmp v0.4.0 + github.com/gorilla/websocket v1.4.1 + github.com/klauspost/compress v1.10.3 + golang.org/x/time v0.0.0-20191024005414-555d28b269f0 +) diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/internal/bpool/bpool.go b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/internal/bpool/bpool.go new file mode 100644 index 000000000000..aa826fba2b1c --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/internal/bpool/bpool.go @@ -0,0 +1,24 @@ +package bpool + +import ( + "bytes" + "sync" +) + +var bpool sync.Pool + +// Get returns a buffer from the pool or creates a new one if +// the pool is empty. +func Get() *bytes.Buffer { + b := bpool.Get() + if b == nil { + return &bytes.Buffer{} + } + return b.(*bytes.Buffer) +} + +// Put returns a buffer into the pool. +func Put(b *bytes.Buffer) { + b.Reset() + bpool.Put(b) +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/internal/errd/wrap.go b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/internal/errd/wrap.go new file mode 100644 index 000000000000..6e779131af8b --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/internal/errd/wrap.go @@ -0,0 +1,14 @@ +package errd + +import ( + "fmt" +) + +// Wrap wraps err with fmt.Errorf if err is non nil. +// Intended for use with defer and a named error return. +// Inspired by https://github.com/golang/go/issues/32676. +func Wrap(err *error, f string, v ...interface{}) { + if *err != nil { + *err = fmt.Errorf(f+": %w", append(v, *err)...) + } +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/internal/wsjs/wsjs_js.go b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/internal/wsjs/wsjs_js.go new file mode 100644 index 000000000000..26ffb45625b3 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/internal/wsjs/wsjs_js.go @@ -0,0 +1,170 @@ +// +build js + +// Package wsjs implements typed access to the browser javascript WebSocket API. +// +// https://developer.mozilla.org/en-US/docs/Web/API/WebSocket +package wsjs + +import ( + "syscall/js" +) + +func handleJSError(err *error, onErr func()) { + r := recover() + + if jsErr, ok := r.(js.Error); ok { + *err = jsErr + + if onErr != nil { + onErr() + } + return + } + + if r != nil { + panic(r) + } +} + +// New is a wrapper around the javascript WebSocket constructor. +func New(url string, protocols []string) (c WebSocket, err error) { + defer handleJSError(&err, func() { + c = WebSocket{} + }) + + jsProtocols := make([]interface{}, len(protocols)) + for i, p := range protocols { + jsProtocols[i] = p + } + + c = WebSocket{ + v: js.Global().Get("WebSocket").New(url, jsProtocols), + } + + c.setBinaryType("arraybuffer") + + return c, nil +} + +// WebSocket is a wrapper around a javascript WebSocket object. +type WebSocket struct { + v js.Value +} + +func (c WebSocket) setBinaryType(typ string) { + c.v.Set("binaryType", string(typ)) +} + +func (c WebSocket) addEventListener(eventType string, fn func(e js.Value)) func() { + f := js.FuncOf(func(this js.Value, args []js.Value) interface{} { + fn(args[0]) + return nil + }) + c.v.Call("addEventListener", eventType, f) + + return func() { + c.v.Call("removeEventListener", eventType, f) + f.Release() + } +} + +// CloseEvent is the type passed to a WebSocket close handler. +type CloseEvent struct { + Code uint16 + Reason string + WasClean bool +} + +// OnClose registers a function to be called when the WebSocket is closed. +func (c WebSocket) OnClose(fn func(CloseEvent)) (remove func()) { + return c.addEventListener("close", func(e js.Value) { + ce := CloseEvent{ + Code: uint16(e.Get("code").Int()), + Reason: e.Get("reason").String(), + WasClean: e.Get("wasClean").Bool(), + } + fn(ce) + }) +} + +// OnError registers a function to be called when there is an error +// with the WebSocket. +func (c WebSocket) OnError(fn func(e js.Value)) (remove func()) { + return c.addEventListener("error", fn) +} + +// MessageEvent is the type passed to a message handler. +type MessageEvent struct { + // string or []byte. + Data interface{} + + // There are more fields to the interface but we don't use them. + // See https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent +} + +// OnMessage registers a function to be called when the WebSocket receives a message. +func (c WebSocket) OnMessage(fn func(m MessageEvent)) (remove func()) { + return c.addEventListener("message", func(e js.Value) { + var data interface{} + + arrayBuffer := e.Get("data") + if arrayBuffer.Type() == js.TypeString { + data = arrayBuffer.String() + } else { + data = extractArrayBuffer(arrayBuffer) + } + + me := MessageEvent{ + Data: data, + } + fn(me) + + return + }) +} + +// Subprotocol returns the WebSocket subprotocol in use. +func (c WebSocket) Subprotocol() string { + return c.v.Get("protocol").String() +} + +// OnOpen registers a function to be called when the WebSocket is opened. +func (c WebSocket) OnOpen(fn func(e js.Value)) (remove func()) { + return c.addEventListener("open", fn) +} + +// Close closes the WebSocket with the given code and reason. +func (c WebSocket) Close(code int, reason string) (err error) { + defer handleJSError(&err, nil) + c.v.Call("close", code, reason) + return err +} + +// SendText sends the given string as a text message +// on the WebSocket. +func (c WebSocket) SendText(v string) (err error) { + defer handleJSError(&err, nil) + c.v.Call("send", v) + return err +} + +// SendBytes sends the given message as a binary message +// on the WebSocket. +func (c WebSocket) SendBytes(v []byte) (err error) { + defer handleJSError(&err, nil) + c.v.Call("send", uint8Array(v)) + return err +} + +func extractArrayBuffer(arrayBuffer js.Value) []byte { + uint8Array := js.Global().Get("Uint8Array").New(arrayBuffer) + dst := make([]byte, uint8Array.Length()) + js.CopyBytesToGo(dst, uint8Array) + return dst +} + +func uint8Array(src []byte) js.Value { + uint8Array := js.Global().Get("Uint8Array").New(len(src)) + js.CopyBytesToJS(uint8Array, src) + return uint8Array +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/internal/xsync/go.go b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/internal/xsync/go.go new file mode 100644 index 000000000000..7a61f27fa2ae --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/internal/xsync/go.go @@ -0,0 +1,25 @@ +package xsync + +import ( + "fmt" +) + +// Go allows running a function in another goroutine +// and waiting for its error. +func Go(fn func() error) <-chan error { + errs := make(chan error, 1) + go func() { + defer func() { + r := recover() + if r != nil { + select { + case errs <- fmt.Errorf("panic in go fn: %v", r): + default: + } + } + }() + errs <- fn() + }() + + return errs +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/internal/xsync/int64.go b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/internal/xsync/int64.go new file mode 100644 index 000000000000..a0c402041568 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/internal/xsync/int64.go @@ -0,0 +1,23 @@ +package xsync + +import ( + "sync/atomic" +) + +// Int64 represents an atomic int64. +type Int64 struct { + // We do not use atomic.Load/StoreInt64 since it does not + // work on 32 bit computers but we need 64 bit integers. + i atomic.Value +} + +// Load loads the int64. +func (v *Int64) Load() int64 { + i, _ := v.i.Load().(int64) + return i +} + +// Store stores the int64. +func (v *Int64) Store(i int64) { + v.i.Store(i) +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/netconn.go b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/netconn.go new file mode 100644 index 000000000000..64aadf0b998e --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/netconn.go @@ -0,0 +1,166 @@ +package websocket + +import ( + "context" + "fmt" + "io" + "math" + "net" + "sync" + "time" +) + +// NetConn converts a *websocket.Conn into a net.Conn. +// +// It's for tunneling arbitrary protocols over WebSockets. +// Few users of the library will need this but it's tricky to implement +// correctly and so provided in the library. +// See https://github.com/nhooyr/websocket/issues/100. +// +// Every Write to the net.Conn will correspond to a message write of +// the given type on *websocket.Conn. +// +// The passed ctx bounds the lifetime of the net.Conn. If cancelled, +// all reads and writes on the net.Conn will be cancelled. +// +// If a message is read that is not of the correct type, the connection +// will be closed with StatusUnsupportedData and an error will be returned. +// +// Close will close the *websocket.Conn with StatusNormalClosure. +// +// When a deadline is hit, the connection will be closed. This is +// different from most net.Conn implementations where only the +// reading/writing goroutines are interrupted but the connection is kept alive. +// +// The Addr methods will return a mock net.Addr that returns "websocket" for Network +// and "websocket/unknown-addr" for String. +// +// A received StatusNormalClosure or StatusGoingAway close frame will be translated to +// io.EOF when reading. +func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn { + nc := &netConn{ + c: c, + msgType: msgType, + } + + var cancel context.CancelFunc + nc.writeContext, cancel = context.WithCancel(ctx) + nc.writeTimer = time.AfterFunc(math.MaxInt64, cancel) + if !nc.writeTimer.Stop() { + <-nc.writeTimer.C + } + + nc.readContext, cancel = context.WithCancel(ctx) + nc.readTimer = time.AfterFunc(math.MaxInt64, cancel) + if !nc.readTimer.Stop() { + <-nc.readTimer.C + } + + return nc +} + +type netConn struct { + c *Conn + msgType MessageType + + writeTimer *time.Timer + writeContext context.Context + + readTimer *time.Timer + readContext context.Context + + readMu sync.Mutex + eofed bool + reader io.Reader +} + +var _ net.Conn = &netConn{} + +func (c *netConn) Close() error { + return c.c.Close(StatusNormalClosure, "") +} + +func (c *netConn) Write(p []byte) (int, error) { + err := c.c.Write(c.writeContext, c.msgType, p) + if err != nil { + return 0, err + } + return len(p), nil +} + +func (c *netConn) Read(p []byte) (int, error) { + c.readMu.Lock() + defer c.readMu.Unlock() + + if c.eofed { + return 0, io.EOF + } + + if c.reader == nil { + typ, r, err := c.c.Reader(c.readContext) + if err != nil { + switch CloseStatus(err) { + case StatusNormalClosure, StatusGoingAway: + c.eofed = true + return 0, io.EOF + } + return 0, err + } + if typ != c.msgType { + err := fmt.Errorf("unexpected frame type read (expected %v): %v", c.msgType, typ) + c.c.Close(StatusUnsupportedData, err.Error()) + return 0, err + } + c.reader = r + } + + n, err := c.reader.Read(p) + if err == io.EOF { + c.reader = nil + err = nil + } + return n, err +} + +type websocketAddr struct { +} + +func (a websocketAddr) Network() string { + return "websocket" +} + +func (a websocketAddr) String() string { + return "websocket/unknown-addr" +} + +func (c *netConn) RemoteAddr() net.Addr { + return websocketAddr{} +} + +func (c *netConn) LocalAddr() net.Addr { + return websocketAddr{} +} + +func (c *netConn) SetDeadline(t time.Time) error { + c.SetWriteDeadline(t) + c.SetReadDeadline(t) + return nil +} + +func (c *netConn) SetWriteDeadline(t time.Time) error { + if t.IsZero() { + c.writeTimer.Stop() + } else { + c.writeTimer.Reset(t.Sub(time.Now())) + } + return nil +} + +func (c *netConn) SetReadDeadline(t time.Time) error { + if t.IsZero() { + c.readTimer.Stop() + } else { + c.readTimer.Reset(t.Sub(time.Now())) + } + return nil +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/read.go b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/read.go new file mode 100644 index 000000000000..afd08cc7cdeb --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/read.go @@ -0,0 +1,471 @@ +// +build !js + +package websocket + +import ( + "bufio" + "context" + "errors" + "fmt" + "io" + "io/ioutil" + "strings" + "time" + + "nhooyr.io/websocket/internal/errd" + "nhooyr.io/websocket/internal/xsync" +) + +// Reader reads from the connection until until there is a WebSocket +// data message to be read. It will handle ping, pong and close frames as appropriate. +// +// It returns the type of the message and an io.Reader to read it. +// The passed context will also bound the reader. +// Ensure you read to EOF otherwise the connection will hang. +// +// Call CloseRead if you do not expect any data messages from the peer. +// +// Only one Reader may be open at a time. +func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { + return c.reader(ctx) +} + +// Read is a convenience method around Reader to read a single message +// from the connection. +func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { + typ, r, err := c.Reader(ctx) + if err != nil { + return 0, nil, err + } + + b, err := ioutil.ReadAll(r) + return typ, b, err +} + +// CloseRead starts a goroutine to read from the connection until it is closed +// or a data message is received. +// +// Once CloseRead is called you cannot read any messages from the connection. +// The returned context will be cancelled when the connection is closed. +// +// If a data message is received, the connection will be closed with StatusPolicyViolation. +// +// Call CloseRead when you do not expect to read any more messages. +// Since it actively reads from the connection, it will ensure that ping, pong and close +// frames are responded to. This means c.Ping and c.Close will still work as expected. +func (c *Conn) CloseRead(ctx context.Context) context.Context { + ctx, cancel := context.WithCancel(ctx) + go func() { + defer cancel() + c.Reader(ctx) + c.Close(StatusPolicyViolation, "unexpected data message") + }() + return ctx +} + +// SetReadLimit sets the max number of bytes to read for a single message. +// It applies to the Reader and Read methods. +// +// By default, the connection has a message read limit of 32768 bytes. +// +// When the limit is hit, the connection will be closed with StatusMessageTooBig. +func (c *Conn) SetReadLimit(n int64) { + // We add read one more byte than the limit in case + // there is a fin frame that needs to be read. + c.msgReader.limitReader.limit.Store(n + 1) +} + +const defaultReadLimit = 32768 + +func newMsgReader(c *Conn) *msgReader { + mr := &msgReader{ + c: c, + fin: true, + } + mr.readFunc = mr.read + + mr.limitReader = newLimitReader(c, mr.readFunc, defaultReadLimit+1) + return mr +} + +func (mr *msgReader) resetFlate() { + if mr.flateContextTakeover() { + mr.dict.init(32768) + } + if mr.flateBufio == nil { + mr.flateBufio = getBufioReader(mr.readFunc) + } + + mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf) + mr.limitReader.r = mr.flateReader + mr.flateTail.Reset(deflateMessageTail) +} + +func (mr *msgReader) putFlateReader() { + if mr.flateReader != nil { + putFlateReader(mr.flateReader) + mr.flateReader = nil + } +} + +func (mr *msgReader) close() { + mr.c.readMu.forceLock() + mr.putFlateReader() + mr.dict.close() + if mr.flateBufio != nil { + putBufioReader(mr.flateBufio) + } + + if mr.c.client { + putBufioReader(mr.c.br) + mr.c.br = nil + } +} + +func (mr *msgReader) flateContextTakeover() bool { + if mr.c.client { + return !mr.c.copts.serverNoContextTakeover + } + return !mr.c.copts.clientNoContextTakeover +} + +func (c *Conn) readRSV1Illegal(h header) bool { + // If compression is disabled, rsv1 is illegal. + if !c.flate() { + return true + } + // rsv1 is only allowed on data frames beginning messages. + if h.opcode != opText && h.opcode != opBinary { + return true + } + return false +} + +func (c *Conn) readLoop(ctx context.Context) (header, error) { + for { + h, err := c.readFrameHeader(ctx) + if err != nil { + return header{}, err + } + + if h.rsv1 && c.readRSV1Illegal(h) || h.rsv2 || h.rsv3 { + err := fmt.Errorf("received header with unexpected rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3) + c.writeError(StatusProtocolError, err) + return header{}, err + } + + if !c.client && !h.masked { + return header{}, errors.New("received unmasked frame from client") + } + + switch h.opcode { + case opClose, opPing, opPong: + err = c.handleControl(ctx, h) + if err != nil { + // Pass through CloseErrors when receiving a close frame. + if h.opcode == opClose && CloseStatus(err) != -1 { + return header{}, err + } + return header{}, fmt.Errorf("failed to handle control frame %v: %w", h.opcode, err) + } + case opContinuation, opText, opBinary: + return h, nil + default: + err := fmt.Errorf("received unknown opcode %v", h.opcode) + c.writeError(StatusProtocolError, err) + return header{}, err + } + } +} + +func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { + select { + case <-c.closed: + return header{}, c.closeErr + case c.readTimeout <- ctx: + } + + h, err := readFrameHeader(c.br, c.readHeaderBuf[:]) + if err != nil { + select { + case <-c.closed: + return header{}, c.closeErr + case <-ctx.Done(): + return header{}, ctx.Err() + default: + c.close(err) + return header{}, err + } + } + + select { + case <-c.closed: + return header{}, c.closeErr + case c.readTimeout <- context.Background(): + } + + return h, nil +} + +func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { + select { + case <-c.closed: + return 0, c.closeErr + case c.readTimeout <- ctx: + } + + n, err := io.ReadFull(c.br, p) + if err != nil { + select { + case <-c.closed: + return n, c.closeErr + case <-ctx.Done(): + return n, ctx.Err() + default: + err = fmt.Errorf("failed to read frame payload: %w", err) + c.close(err) + return n, err + } + } + + select { + case <-c.closed: + return n, c.closeErr + case c.readTimeout <- context.Background(): + } + + return n, err +} + +func (c *Conn) handleControl(ctx context.Context, h header) (err error) { + if h.payloadLength < 0 || h.payloadLength > maxControlPayload { + err := fmt.Errorf("received control frame payload with invalid length: %d", h.payloadLength) + c.writeError(StatusProtocolError, err) + return err + } + + if !h.fin { + err := errors.New("received fragmented control frame") + c.writeError(StatusProtocolError, err) + return err + } + + ctx, cancel := context.WithTimeout(ctx, time.Second*5) + defer cancel() + + b := c.readControlBuf[:h.payloadLength] + _, err = c.readFramePayload(ctx, b) + if err != nil { + return err + } + + if h.masked { + mask(h.maskKey, b) + } + + switch h.opcode { + case opPing: + return c.writeControl(ctx, opPong, b) + case opPong: + c.activePingsMu.Lock() + pong, ok := c.activePings[string(b)] + c.activePingsMu.Unlock() + if ok { + close(pong) + } + return nil + } + + defer func() { + c.readCloseFrameErr = err + }() + + ce, err := parseClosePayload(b) + if err != nil { + err = fmt.Errorf("received invalid close payload: %w", err) + c.writeError(StatusProtocolError, err) + return err + } + + err = fmt.Errorf("received close frame: %w", ce) + c.setCloseErr(err) + c.writeClose(ce.Code, ce.Reason) + c.close(err) + return err +} + +func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) { + defer errd.Wrap(&err, "failed to get reader") + + err = c.readMu.lock(ctx) + if err != nil { + return 0, nil, err + } + defer c.readMu.unlock() + + if !c.msgReader.fin { + err = errors.New("previous message not read to completion") + c.close(fmt.Errorf("failed to get reader: %w", err)) + return 0, nil, err + } + + h, err := c.readLoop(ctx) + if err != nil { + return 0, nil, err + } + + if h.opcode == opContinuation { + err := errors.New("received continuation frame without text or binary frame") + c.writeError(StatusProtocolError, err) + return 0, nil, err + } + + c.msgReader.reset(ctx, h) + + return MessageType(h.opcode), c.msgReader, nil +} + +type msgReader struct { + c *Conn + + ctx context.Context + flate bool + flateReader io.Reader + flateBufio *bufio.Reader + flateTail strings.Reader + limitReader *limitReader + dict slidingWindow + + fin bool + payloadLength int64 + maskKey uint32 + + // readerFunc(mr.Read) to avoid continuous allocations. + readFunc readerFunc +} + +func (mr *msgReader) reset(ctx context.Context, h header) { + mr.ctx = ctx + mr.flate = h.rsv1 + mr.limitReader.reset(mr.readFunc) + + if mr.flate { + mr.resetFlate() + } + + mr.setFrame(h) +} + +func (mr *msgReader) setFrame(h header) { + mr.fin = h.fin + mr.payloadLength = h.payloadLength + mr.maskKey = h.maskKey +} + +func (mr *msgReader) Read(p []byte) (n int, err error) { + err = mr.c.readMu.lock(mr.ctx) + if err != nil { + return 0, fmt.Errorf("failed to read: %w", err) + } + defer mr.c.readMu.unlock() + + n, err = mr.limitReader.Read(p) + if mr.flate && mr.flateContextTakeover() { + p = p[:n] + mr.dict.write(p) + } + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate { + mr.putFlateReader() + return n, io.EOF + } + if err != nil { + err = fmt.Errorf("failed to read: %w", err) + mr.c.close(err) + } + return n, err +} + +func (mr *msgReader) read(p []byte) (int, error) { + for { + if mr.payloadLength == 0 { + if mr.fin { + if mr.flate { + return mr.flateTail.Read(p) + } + return 0, io.EOF + } + + h, err := mr.c.readLoop(mr.ctx) + if err != nil { + return 0, err + } + if h.opcode != opContinuation { + err := errors.New("received new data message without finishing the previous message") + mr.c.writeError(StatusProtocolError, err) + return 0, err + } + mr.setFrame(h) + + continue + } + + if int64(len(p)) > mr.payloadLength { + p = p[:mr.payloadLength] + } + + n, err := mr.c.readFramePayload(mr.ctx, p) + if err != nil { + return n, err + } + + mr.payloadLength -= int64(n) + + if !mr.c.client { + mr.maskKey = mask(mr.maskKey, p) + } + + return n, nil + } +} + +type limitReader struct { + c *Conn + r io.Reader + limit xsync.Int64 + n int64 +} + +func newLimitReader(c *Conn, r io.Reader, limit int64) *limitReader { + lr := &limitReader{ + c: c, + } + lr.limit.Store(limit) + lr.reset(r) + return lr +} + +func (lr *limitReader) reset(r io.Reader) { + lr.n = lr.limit.Load() + lr.r = r +} + +func (lr *limitReader) Read(p []byte) (int, error) { + if lr.n <= 0 { + err := fmt.Errorf("read limited at %v bytes", lr.limit.Load()) + lr.c.writeError(StatusMessageTooBig, err) + return 0, err + } + + if int64(len(p)) > lr.n { + p = p[:lr.n] + } + n, err := lr.r.Read(p) + lr.n -= int64(n) + return n, err +} + +type readerFunc func(p []byte) (int, error) + +func (f readerFunc) Read(p []byte) (int, error) { + return f(p) +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/stringer.go b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/stringer.go new file mode 100644 index 000000000000..5a66ba290762 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/stringer.go @@ -0,0 +1,91 @@ +// Code generated by "stringer -type=opcode,MessageType,StatusCode -output=stringer.go"; DO NOT EDIT. + +package websocket + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[opContinuation-0] + _ = x[opText-1] + _ = x[opBinary-2] + _ = x[opClose-8] + _ = x[opPing-9] + _ = x[opPong-10] +} + +const ( + _opcode_name_0 = "opContinuationopTextopBinary" + _opcode_name_1 = "opCloseopPingopPong" +) + +var ( + _opcode_index_0 = [...]uint8{0, 14, 20, 28} + _opcode_index_1 = [...]uint8{0, 7, 13, 19} +) + +func (i opcode) String() string { + switch { + case 0 <= i && i <= 2: + return _opcode_name_0[_opcode_index_0[i]:_opcode_index_0[i+1]] + case 8 <= i && i <= 10: + i -= 8 + return _opcode_name_1[_opcode_index_1[i]:_opcode_index_1[i+1]] + default: + return "opcode(" + strconv.FormatInt(int64(i), 10) + ")" + } +} +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[MessageText-1] + _ = x[MessageBinary-2] +} + +const _MessageType_name = "MessageTextMessageBinary" + +var _MessageType_index = [...]uint8{0, 11, 24} + +func (i MessageType) String() string { + i -= 1 + if i < 0 || i >= MessageType(len(_MessageType_index)-1) { + return "MessageType(" + strconv.FormatInt(int64(i+1), 10) + ")" + } + return _MessageType_name[_MessageType_index[i]:_MessageType_index[i+1]] +} +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[StatusNormalClosure-1000] + _ = x[StatusGoingAway-1001] + _ = x[StatusProtocolError-1002] + _ = x[StatusUnsupportedData-1003] + _ = x[statusReserved-1004] + _ = x[StatusNoStatusRcvd-1005] + _ = x[StatusAbnormalClosure-1006] + _ = x[StatusInvalidFramePayloadData-1007] + _ = x[StatusPolicyViolation-1008] + _ = x[StatusMessageTooBig-1009] + _ = x[StatusMandatoryExtension-1010] + _ = x[StatusInternalError-1011] + _ = x[StatusServiceRestart-1012] + _ = x[StatusTryAgainLater-1013] + _ = x[StatusBadGateway-1014] + _ = x[StatusTLSHandshake-1015] +} + +const _StatusCode_name = "StatusNormalClosureStatusGoingAwayStatusProtocolErrorStatusUnsupportedDatastatusReservedStatusNoStatusRcvdStatusAbnormalClosureStatusInvalidFramePayloadDataStatusPolicyViolationStatusMessageTooBigStatusMandatoryExtensionStatusInternalErrorStatusServiceRestartStatusTryAgainLaterStatusBadGatewayStatusTLSHandshake" + +var _StatusCode_index = [...]uint16{0, 19, 34, 53, 74, 88, 106, 127, 156, 177, 196, 220, 239, 259, 278, 294, 312} + +func (i StatusCode) String() string { + i -= 1000 + if i < 0 || i >= StatusCode(len(_StatusCode_index)-1) { + return "StatusCode(" + strconv.FormatInt(int64(i+1000), 10) + ")" + } + return _StatusCode_name[_StatusCode_index[i]:_StatusCode_index[i+1]] +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/stub.go b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/stub.go deleted file mode 100644 index 7bf2a208daca..000000000000 --- a/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/stub.go +++ /dev/null @@ -1,76 +0,0 @@ -// Code generated by depstubber. DO NOT EDIT. -// This is a simple stub for nhooyr.io/websocket, strictly for use in testing. - -// See the LICENSE file for information about the licensing of the original library. -// Source: nhooyr.io/websocket (exports: ; functions: Dial) - -// Package websocket is a stub of nhooyr.io/websocket, generated by depstubber. -package websocket - -import ( - context "context" - io "io" - http "net/http" -) - -type CompressionMode int - -type Conn struct{} - -func (_ *Conn) Close(_ StatusCode, _ string) error { - return nil -} - -func (_ *Conn) CloseRead(_ context.Context) context.Context { - return nil -} - -func (_ *Conn) Ping(_ context.Context) error { - return nil -} - -func (_ *Conn) Read(_ context.Context) (MessageType, []byte, error) { - return 0, nil, nil -} - -func (_ *Conn) Reader(_ context.Context) (MessageType, io.Reader, error) { - return 0, nil, nil -} - -func (_ *Conn) SetReadLimit(_ int64) {} - -func (_ *Conn) Subprotocol() string { - return "" -} - -func (_ *Conn) Write(_ context.Context, _ MessageType, _ []byte) error { - return nil -} - -func (_ *Conn) Writer(_ context.Context, _ MessageType) (io.WriteCloser, error) { - return nil, nil -} - -func Dial(_ context.Context, _ string, _ *DialOptions) (*Conn, *http.Response, error) { - return nil, nil, nil -} - -type DialOptions struct { - HTTPClient *http.Client - HTTPHeader http.Header - Subprotocols []string - CompressionMode CompressionMode - CompressionThreshold int -} - -type MessageType int - -func (_ MessageType) String() string { - return "" -} - -type StatusCode int - -func (_ StatusCode) String() string { - return "" -} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/write.go b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/write.go new file mode 100644 index 000000000000..60a4fba06448 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/write.go @@ -0,0 +1,386 @@ +// +build !js + +package websocket + +import ( + "bufio" + "context" + "crypto/rand" + "encoding/binary" + "errors" + "fmt" + "io" + "time" + + "github.com/klauspost/compress/flate" + + "nhooyr.io/websocket/internal/errd" +) + +// Writer returns a writer bounded by the context that will write +// a WebSocket message of type dataType to the connection. +// +// You must close the writer once you have written the entire message. +// +// Only one writer can be open at a time, multiple calls will block until the previous writer +// is closed. +func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { + w, err := c.writer(ctx, typ) + if err != nil { + return nil, fmt.Errorf("failed to get writer: %w", err) + } + return w, nil +} + +// Write writes a message to the connection. +// +// See the Writer method if you want to stream a message. +// +// If compression is disabled or the threshold is not met, then it +// will write the message in a single frame. +func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { + _, err := c.write(ctx, typ, p) + if err != nil { + return fmt.Errorf("failed to write msg: %w", err) + } + return nil +} + +type msgWriter struct { + mw *msgWriterState + closed bool +} + +func (mw *msgWriter) Write(p []byte) (int, error) { + if mw.closed { + return 0, errors.New("cannot use closed writer") + } + return mw.mw.Write(p) +} + +func (mw *msgWriter) Close() error { + if mw.closed { + return errors.New("cannot use closed writer") + } + mw.closed = true + return mw.mw.Close() +} + +type msgWriterState struct { + c *Conn + + mu *mu + writeMu *mu + + ctx context.Context + opcode opcode + flate bool + + trimWriter *trimLastFourBytesWriter + dict slidingWindow +} + +func newMsgWriterState(c *Conn) *msgWriterState { + mw := &msgWriterState{ + c: c, + mu: newMu(c), + writeMu: newMu(c), + } + return mw +} + +func (mw *msgWriterState) ensureFlate() { + if mw.trimWriter == nil { + mw.trimWriter = &trimLastFourBytesWriter{ + w: writerFunc(mw.write), + } + } + + mw.dict.init(8192) + mw.flate = true +} + +func (mw *msgWriterState) flateContextTakeover() bool { + if mw.c.client { + return !mw.c.copts.clientNoContextTakeover + } + return !mw.c.copts.serverNoContextTakeover +} + +func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { + err := c.msgWriterState.reset(ctx, typ) + if err != nil { + return nil, err + } + return &msgWriter{ + mw: c.msgWriterState, + closed: false, + }, nil +} + +func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) { + mw, err := c.writer(ctx, typ) + if err != nil { + return 0, err + } + + if !c.flate() { + defer c.msgWriterState.mu.unlock() + return c.writeFrame(ctx, true, false, c.msgWriterState.opcode, p) + } + + n, err := mw.Write(p) + if err != nil { + return n, err + } + + err = mw.Close() + return n, err +} + +func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error { + err := mw.mu.lock(ctx) + if err != nil { + return err + } + + mw.ctx = ctx + mw.opcode = opcode(typ) + mw.flate = false + + mw.trimWriter.reset() + + return nil +} + +// Write writes the given bytes to the WebSocket connection. +func (mw *msgWriterState) Write(p []byte) (_ int, err error) { + err = mw.writeMu.lock(mw.ctx) + if err != nil { + return 0, fmt.Errorf("failed to write: %w", err) + } + defer mw.writeMu.unlock() + + defer func() { + if err != nil { + err = fmt.Errorf("failed to write: %w", err) + mw.c.close(err) + } + }() + + if mw.c.flate() { + // Only enables flate if the length crosses the + // threshold on the first frame + if mw.opcode != opContinuation && len(p) >= mw.c.flateThreshold { + mw.ensureFlate() + } + } + + if mw.flate { + err = flate.StatelessDeflate(mw.trimWriter, p, false, mw.dict.buf) + if err != nil { + return 0, err + } + mw.dict.write(p) + return len(p), nil + } + + return mw.write(p) +} + +func (mw *msgWriterState) write(p []byte) (int, error) { + n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p) + if err != nil { + return n, fmt.Errorf("failed to write data frame: %w", err) + } + mw.opcode = opContinuation + return n, nil +} + +// Close flushes the frame to the connection. +func (mw *msgWriterState) Close() (err error) { + defer errd.Wrap(&err, "failed to close writer") + + err = mw.writeMu.lock(mw.ctx) + if err != nil { + return err + } + defer mw.writeMu.unlock() + + _, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil) + if err != nil { + return fmt.Errorf("failed to write fin frame: %w", err) + } + + if mw.flate && !mw.flateContextTakeover() { + mw.dict.close() + } + mw.mu.unlock() + return nil +} + +func (mw *msgWriterState) close() { + if mw.c.client { + mw.c.writeFrameMu.forceLock() + putBufioWriter(mw.c.bw) + } + + mw.writeMu.forceLock() + mw.dict.close() +} + +func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error { + ctx, cancel := context.WithTimeout(ctx, time.Second*5) + defer cancel() + + _, err := c.writeFrame(ctx, true, false, opcode, p) + if err != nil { + return fmt.Errorf("failed to write control frame %v: %w", opcode, err) + } + return nil +} + +// frame handles all writes to the connection. +func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) { + err = c.writeFrameMu.lock(ctx) + if err != nil { + return 0, err + } + defer func() { + // We leave it locked when writing the close frame to avoid + // any other goroutine writing any other frame. + if opcode != opClose { + c.writeFrameMu.unlock() + } + }() + + select { + case <-c.closed: + return 0, c.closeErr + case c.writeTimeout <- ctx: + } + + defer func() { + if err != nil { + select { + case <-c.closed: + err = c.closeErr + case <-ctx.Done(): + err = ctx.Err() + } + c.close(err) + err = fmt.Errorf("failed to write frame: %w", err) + } + }() + + c.writeHeader.fin = fin + c.writeHeader.opcode = opcode + c.writeHeader.payloadLength = int64(len(p)) + + if c.client { + c.writeHeader.masked = true + _, err = io.ReadFull(rand.Reader, c.writeHeaderBuf[:4]) + if err != nil { + return 0, fmt.Errorf("failed to generate masking key: %w", err) + } + c.writeHeader.maskKey = binary.LittleEndian.Uint32(c.writeHeaderBuf[:]) + } + + c.writeHeader.rsv1 = false + if flate && (opcode == opText || opcode == opBinary) { + c.writeHeader.rsv1 = true + } + + err = writeFrameHeader(c.writeHeader, c.bw, c.writeHeaderBuf[:]) + if err != nil { + return 0, err + } + + n, err := c.writeFramePayload(p) + if err != nil { + return n, err + } + + if c.writeHeader.fin { + err = c.bw.Flush() + if err != nil { + return n, fmt.Errorf("failed to flush: %w", err) + } + } + + select { + case <-c.closed: + return n, c.closeErr + case c.writeTimeout <- context.Background(): + } + + return n, nil +} + +func (c *Conn) writeFramePayload(p []byte) (n int, err error) { + defer errd.Wrap(&err, "failed to write frame payload") + + if !c.writeHeader.masked { + return c.bw.Write(p) + } + + maskKey := c.writeHeader.maskKey + for len(p) > 0 { + // If the buffer is full, we need to flush. + if c.bw.Available() == 0 { + err = c.bw.Flush() + if err != nil { + return n, err + } + } + + // Start of next write in the buffer. + i := c.bw.Buffered() + + j := len(p) + if j > c.bw.Available() { + j = c.bw.Available() + } + + _, err := c.bw.Write(p[:j]) + if err != nil { + return n, err + } + + maskKey = mask(maskKey, c.writeBuf[i:c.bw.Buffered()]) + + p = p[j:] + n += j + } + + return n, nil +} + +type writerFunc func(p []byte) (int, error) + +func (f writerFunc) Write(p []byte) (int, error) { + return f(p) +} + +// extractBufioWriterBuf grabs the []byte backing a *bufio.Writer +// and returns it. +func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte { + var writeBuf []byte + bw.Reset(writerFunc(func(p2 []byte) (int, error) { + writeBuf = p2[:cap(p2)] + return len(p2), nil + })) + + bw.WriteByte(0) + bw.Flush() + + bw.Reset(w) + + return writeBuf +} + +func (c *Conn) writeError(code StatusCode, err error) { + c.setCloseErr(err) + c.writeClose(code, err.Error()) + c.close(nil) +} diff --git a/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/ws_js.go b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/ws_js.go new file mode 100644 index 000000000000..b87e32cdafb2 --- /dev/null +++ b/go/ql/test/query-tests/Security/CWE-918/vendor/nhooyr.io/websocket/ws_js.go @@ -0,0 +1,379 @@ +package websocket // import "nhooyr.io/websocket" + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net/http" + "reflect" + "runtime" + "strings" + "sync" + "syscall/js" + + "nhooyr.io/websocket/internal/bpool" + "nhooyr.io/websocket/internal/wsjs" + "nhooyr.io/websocket/internal/xsync" +) + +// Conn provides a wrapper around the browser WebSocket API. +type Conn struct { + ws wsjs.WebSocket + + // read limit for a message in bytes. + msgReadLimit xsync.Int64 + + closingMu sync.Mutex + isReadClosed xsync.Int64 + closeOnce sync.Once + closed chan struct{} + closeErrOnce sync.Once + closeErr error + closeWasClean bool + + releaseOnClose func() + releaseOnMessage func() + + readSignal chan struct{} + readBufMu sync.Mutex + readBuf []wsjs.MessageEvent +} + +func (c *Conn) close(err error, wasClean bool) { + c.closeOnce.Do(func() { + runtime.SetFinalizer(c, nil) + + if !wasClean { + err = fmt.Errorf("unclean connection close: %w", err) + } + c.setCloseErr(err) + c.closeWasClean = wasClean + close(c.closed) + }) +} + +func (c *Conn) init() { + c.closed = make(chan struct{}) + c.readSignal = make(chan struct{}, 1) + + c.msgReadLimit.Store(32768) + + c.releaseOnClose = c.ws.OnClose(func(e wsjs.CloseEvent) { + err := CloseError{ + Code: StatusCode(e.Code), + Reason: e.Reason, + } + // We do not know if we sent or received this close as + // its possible the browser triggered it without us + // explicitly sending it. + c.close(err, e.WasClean) + + c.releaseOnClose() + c.releaseOnMessage() + }) + + c.releaseOnMessage = c.ws.OnMessage(func(e wsjs.MessageEvent) { + c.readBufMu.Lock() + defer c.readBufMu.Unlock() + + c.readBuf = append(c.readBuf, e) + + // Lets the read goroutine know there is definitely something in readBuf. + select { + case c.readSignal <- struct{}{}: + default: + } + }) + + runtime.SetFinalizer(c, func(c *Conn) { + c.setCloseErr(errors.New("connection garbage collected")) + c.closeWithInternal() + }) +} + +func (c *Conn) closeWithInternal() { + c.Close(StatusInternalError, "something went wrong") +} + +// Read attempts to read a message from the connection. +// The maximum time spent waiting is bounded by the context. +func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { + if c.isReadClosed.Load() == 1 { + return 0, nil, errors.New("WebSocket connection read closed") + } + + typ, p, err := c.read(ctx) + if err != nil { + return 0, nil, fmt.Errorf("failed to read: %w", err) + } + if int64(len(p)) > c.msgReadLimit.Load() { + err := fmt.Errorf("read limited at %v bytes", c.msgReadLimit.Load()) + c.Close(StatusMessageTooBig, err.Error()) + return 0, nil, err + } + return typ, p, nil +} + +func (c *Conn) read(ctx context.Context) (MessageType, []byte, error) { + select { + case <-ctx.Done(): + c.Close(StatusPolicyViolation, "read timed out") + return 0, nil, ctx.Err() + case <-c.readSignal: + case <-c.closed: + return 0, nil, c.closeErr + } + + c.readBufMu.Lock() + defer c.readBufMu.Unlock() + + me := c.readBuf[0] + // We copy the messages forward and decrease the size + // of the slice to avoid reallocating. + copy(c.readBuf, c.readBuf[1:]) + c.readBuf = c.readBuf[:len(c.readBuf)-1] + + if len(c.readBuf) > 0 { + // Next time we read, we'll grab the message. + select { + case c.readSignal <- struct{}{}: + default: + } + } + + switch p := me.Data.(type) { + case string: + return MessageText, []byte(p), nil + case []byte: + return MessageBinary, p, nil + default: + panic("websocket: unexpected data type from wsjs OnMessage: " + reflect.TypeOf(me.Data).String()) + } +} + +// Ping is mocked out for Wasm. +func (c *Conn) Ping(ctx context.Context) error { + return nil +} + +// Write writes a message of the given type to the connection. +// Always non blocking. +func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { + err := c.write(ctx, typ, p) + if err != nil { + // Have to ensure the WebSocket is closed after a write error + // to match the Go API. It can only error if the message type + // is unexpected or the passed bytes contain invalid UTF-8 for + // MessageText. + err := fmt.Errorf("failed to write: %w", err) + c.setCloseErr(err) + c.closeWithInternal() + return err + } + return nil +} + +func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error { + if c.isClosed() { + return c.closeErr + } + switch typ { + case MessageBinary: + return c.ws.SendBytes(p) + case MessageText: + return c.ws.SendText(string(p)) + default: + return fmt.Errorf("unexpected message type: %v", typ) + } +} + +// Close closes the WebSocket with the given code and reason. +// It will wait until the peer responds with a close frame +// or the connection is closed. +// It thus performs the full WebSocket close handshake. +func (c *Conn) Close(code StatusCode, reason string) error { + err := c.exportedClose(code, reason) + if err != nil { + return fmt.Errorf("failed to close WebSocket: %w", err) + } + return nil +} + +func (c *Conn) exportedClose(code StatusCode, reason string) error { + c.closingMu.Lock() + defer c.closingMu.Unlock() + + ce := fmt.Errorf("sent close: %w", CloseError{ + Code: code, + Reason: reason, + }) + + if c.isClosed() { + return fmt.Errorf("tried to close with %q but connection already closed: %w", ce, c.closeErr) + } + + c.setCloseErr(ce) + err := c.ws.Close(int(code), reason) + if err != nil { + return err + } + + <-c.closed + if !c.closeWasClean { + return c.closeErr + } + return nil +} + +// Subprotocol returns the negotiated subprotocol. +// An empty string means the default protocol. +func (c *Conn) Subprotocol() string { + return c.ws.Subprotocol() +} + +// DialOptions represents the options available to pass to Dial. +type DialOptions struct { + // Subprotocols lists the subprotocols to negotiate with the server. + Subprotocols []string +} + +// Dial creates a new WebSocket connection to the given url with the given options. +// The passed context bounds the maximum time spent waiting for the connection to open. +// The returned *http.Response is always nil or a mock. It's only in the signature +// to match the core API. +func Dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) { + c, resp, err := dial(ctx, url, opts) + if err != nil { + return nil, nil, fmt.Errorf("failed to WebSocket dial %q: %w", url, err) + } + return c, resp, nil +} + +func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) { + if opts == nil { + opts = &DialOptions{} + } + + url = strings.Replace(url, "http://", "ws://", 1) + url = strings.Replace(url, "https://", "wss://", 1) + + ws, err := wsjs.New(url, opts.Subprotocols) + if err != nil { + return nil, nil, err + } + + c := &Conn{ + ws: ws, + } + c.init() + + opench := make(chan struct{}) + releaseOpen := ws.OnOpen(func(e js.Value) { + close(opench) + }) + defer releaseOpen() + + select { + case <-ctx.Done(): + c.Close(StatusPolicyViolation, "dial timed out") + return nil, nil, ctx.Err() + case <-opench: + return c, &http.Response{ + StatusCode: http.StatusSwitchingProtocols, + }, nil + case <-c.closed: + return nil, nil, c.closeErr + } +} + +// Reader attempts to read a message from the connection. +// The maximum time spent waiting is bounded by the context. +func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { + typ, p, err := c.Read(ctx) + if err != nil { + return 0, nil, err + } + return typ, bytes.NewReader(p), nil +} + +// Writer returns a writer to write a WebSocket data message to the connection. +// It buffers the entire message in memory and then sends it when the writer +// is closed. +func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { + return writer{ + c: c, + ctx: ctx, + typ: typ, + b: bpool.Get(), + }, nil +} + +type writer struct { + closed bool + + c *Conn + ctx context.Context + typ MessageType + + b *bytes.Buffer +} + +func (w writer) Write(p []byte) (int, error) { + if w.closed { + return 0, errors.New("cannot write to closed writer") + } + n, err := w.b.Write(p) + if err != nil { + return n, fmt.Errorf("failed to write message: %w", err) + } + return n, nil +} + +func (w writer) Close() error { + if w.closed { + return errors.New("cannot close closed writer") + } + w.closed = true + defer bpool.Put(w.b) + + err := w.c.Write(w.ctx, w.typ, w.b.Bytes()) + if err != nil { + return fmt.Errorf("failed to close writer: %w", err) + } + return nil +} + +// CloseRead implements *Conn.CloseRead for wasm. +func (c *Conn) CloseRead(ctx context.Context) context.Context { + c.isReadClosed.Store(1) + + ctx, cancel := context.WithCancel(ctx) + go func() { + defer cancel() + c.read(ctx) + c.Close(StatusPolicyViolation, "unexpected data message") + }() + return ctx +} + +// SetReadLimit implements *Conn.SetReadLimit for wasm. +func (c *Conn) SetReadLimit(n int64) { + c.msgReadLimit.Store(n) +} + +func (c *Conn) setCloseErr(err error) { + c.closeErrOnce.Do(func() { + c.closeErr = fmt.Errorf("WebSocket closed: %w", err) + }) +} + +func (c *Conn) isClosed() bool { + select { + case <-c.closed: + return true + default: + return false + } +}