From abf03b1e10360c1fe7763e32228dc57d65367dbe Mon Sep 17 00:00:00 2001 From: Greg Linton Date: Wed, 11 May 2016 18:16:04 -0600 Subject: [PATCH 1/6] Refactor to allow multiple records per domain and more fully utilize dns library --- .gitignore | 29 +--- .travis.yml | 14 +- Boxfile | 10 -- README.md | 266 +++++++++++++++++------------- api/README.md | 68 ++++++++ api/api.go | 178 +++++++------------- api/api_test.go | 260 +++++++++++++++++++++++++++++ api/records.go | 108 ++++++++++++ build.sh | 31 ++++ cache/cache.go | 117 +++++++++++++ cache/cache_test.go | 103 ++++++++++++ cache/scribble.go | 111 +++++++++++++ cache/scribble_test.go | 91 ++++++++++ caches/caches.go | 246 --------------------------- caches/caches_test.go | 138 ---------------- caches/map_cacher.go | 85 ---------- caches/map_cacher_test.go | 145 ---------------- caches/mock_caches/mock_caches.go | 101 ------------ caches/postgresql_cacher.go | 158 ------------------ caches/postgresql_cacher_test.go | 150 ----------------- caches/redis_cacher.go | 147 ----------------- caches/redis_cacher_test.go | 148 ----------------- caches/scribble_cacher.go | 116 ------------- caches/scribble_cacher_test.go | 149 ----------------- commands/README.md | 121 ++++++++++++++ commands/add.go | 77 ++++----- commands/commands.go | 165 +++++++++--------- commands/commands_test.go | 216 ++++++++++++++++++++++++ commands/delete.go | 38 +++++ commands/get.go | 37 +++++ commands/list.go | 52 ++---- commands/remove.go | 61 ------- commands/reset.go | 55 ++++++ commands/show.go | 61 ------- commands/update.go | 77 ++++----- config/config.go | 117 ++++++++++--- core/common/common.go | 66 ++++++++ core/shaman.go | 192 +++++++++++++++++++++ core/shaman_test.go | 121 ++++++++++++++ main.go | 126 +++++++++++++- server/dns.go | 63 +++++++ server/server.go | 98 ----------- version.go | 3 + 43 files changed, 2381 insertions(+), 2334 deletions(-) delete mode 100644 Boxfile create mode 100644 api/README.md create mode 100644 api/api_test.go create mode 100644 api/records.go create mode 100755 build.sh create mode 100644 cache/cache.go create mode 100644 cache/cache_test.go create mode 100644 cache/scribble.go create mode 100644 cache/scribble_test.go delete mode 100644 caches/caches.go delete mode 100644 caches/caches_test.go delete mode 100644 caches/map_cacher.go delete mode 100644 caches/map_cacher_test.go delete mode 100644 caches/mock_caches/mock_caches.go delete mode 100644 caches/postgresql_cacher.go delete mode 100644 caches/postgresql_cacher_test.go delete mode 100644 caches/redis_cacher.go delete mode 100644 caches/redis_cacher_test.go delete mode 100644 caches/scribble_cacher.go delete mode 100644 caches/scribble_cacher_test.go create mode 100644 commands/README.md create mode 100644 commands/commands_test.go create mode 100644 commands/delete.go create mode 100644 commands/get.go delete mode 100644 commands/remove.go create mode 100644 commands/reset.go delete mode 100644 commands/show.go create mode 100644 core/common/common.go create mode 100644 core/shaman.go create mode 100644 core/shaman_test.go create mode 100644 server/dns.go delete mode 100644 server/server.go create mode 100644 version.go diff --git a/.gitignore b/.gitignore index 53d2029..430c7d1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,27 +1,4 @@ -# Compiled Object files, Static and Dynamic libs (Shared Objects) -*.o -*.a -*.so - -# Folders -_obj -_test - -# Architecture specific extensions/prefixes -*.[568vq] -[568vq].out - -*.cgo1.go -*.cgo2.c -_cgo_defun.c -_cgo_gotypes.go -_cgo_export.* - -_testmain.go - -*.exe -*.test -*.prof - shaman -cli/cli \ No newline at end of file +dns +*.cover +config.json diff --git a/.travis.yml b/.travis.yml index 28fcff0..6e49533 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,13 +1,5 @@ language: go +go: 1.6 -go: - - 1.5 - -go_import_path: github.com/nanopack/shaman - -services: - - postgres - - redis-server - -before_script: - - psql -c 'create database travis_ci_test;' -U postgres \ No newline at end of file +install: go get -t -v . +script: go test -v ./... diff --git a/Boxfile b/Boxfile deleted file mode 100644 index a0f70e8..0000000 --- a/Boxfile +++ /dev/null @@ -1,10 +0,0 @@ -build: - engine: go -web1: - ports: - - 53:8053 - - 443:8443 -postgresql1: - name: l2 -redis1: - name: l1 \ No newline at end of file diff --git a/README.md b/README.md index 2c76cbc..13eb73b 100644 --- a/README.md +++ b/README.md @@ -1,142 +1,172 @@ -[![shaman logo](http://nano-assets.gopagoda.io/readme-headers/shaman.png)](http://nanobox.io/open-source#shaman) +[![shaman logo](http://nano-assets.gopagoda.io/readme-headers/shaman.png)](http://nanobox.io/open-source#shaman) [![Build Status](https://travis-ci.org/nanopack/shaman.svg)](https://travis-ci.org/nanopack/shaman) +[![GoDoc](https://godoc.org/github.com/nanopack/shaman?status.svg)](https://godoc.org/github.com/nanopack/shaman) -# shaman +# Shaman -Small, lightweight, api-driven dns server. +Small, clusterable, lightweight, api-driven dns server. -## Status +## Quickstart: +```sh +# Start shaman with defaults (requires admin privileges (port 53)) +shaman -s -Working +# register a new domain +shaman add -d nanopack.io -A 127.0.0.1 -## Todo -- Logging -- Tests -- Read configuration from file - -## Server -``` -Usage: - [flags] - [command] +# perform dns lookup +dig @localhost nanopack.io +short +# 127.0.0.1 -Available Commands: - add Add entry into shaman database - remove Remove entry from shaman database - show Show entry in shaman database - update Update entry in shaman database - list List entries in shaman database - -Flags: - -c, --api-crt="": Path to SSL crt for API access - -H, --api-host="127.0.0.1": Listen address for the API - -k, --api-key="": Path to SSL key for API access - -p, --api-key-password="": Password for SSL key - -P, --api-port="8443": Listen address for the API - -t, --api-token="": Token for API Access - -d, --domain=".": Parent domain for requests - -h, --help[=false]: help for - -O, --host="127.0.0.1": Listen address for DNS requests - -i, --insecure[=false]: Disable tls key checking - -1, --l1-connect="map://127.0.0.1/": Connection string for the l1 cache - -e, --l1-expires=120: TTL for the L1 Cache (0 = never expire) - -2, --l2-connect="map://127.0.0.1/": Connection string for the l2 cache - -E, --l2-expires=0: TTL for the L2 Cache (0 = never expire) - -l, --log-file="": Log file (blank = log to console) - -L, --log-level="INFO": Log level to use - -o, --port="8053": Listen port for DNS requests - -s, --server[=false]: Run in server mode - -T, --ttl=60: Default TTL for DNS records - -Use " [command] --help" for more information about a command. +# Congratulations! ``` -### L1 and L2 connection strings - -#### In-Memory Map Cacher -This is the default cacher. If the connection string doesn't match any of the other's, it will use this one. - -#### Postgresql Cacher -The connection string looks like `postgres://user@host/database` and more [docs here](https://godoc.org/github.com/lib/pq). This string gets passed into the sql driver without modification. - -#### Redis Cacher -The connection string looks like `redis://user:password@host:port/`. The user is not really used, but only there if there is a password on the redis-server. - -#### Scribble Cacher -The connection string looks like `scribble://localhost/path/to/data/store`. Scribble only cares about the path part of the URI to determine where it should place the files. - -### Commands - -#### add -`add [Record Type] [Domain] [Value]` - -#### remove -`remove [Record Type] [Domain]` -#### show -`show [Record Type] [Domain]` +## Usage: -#### update -`update [Record Type] [Domain] [Value]` +### As a CLI +Simply run `shaman ` -#### list -`list` +`shaman` or `shaman -h` will show usage and a list of commands: -## API -The API is a web based API. The API uses TLS and a token for security and authentication. +``` +shaman - api driven dns server -### API token -The API requires a token to be passed for authentication. This token is set when the server is started. The token is passed in the header as `X-NANOBOX-TOKEN`. +Usage: + shaman [flags] + shaman [command] -#### Add -POST to `/records/[record type]/[domain]` -A `value` must be posted. Currently it has to be past as a query string rather than part of the post body like `/records/[record type]/[domain]?value=[value]`. This is an issue that should be fixed. +Available Commands: + add Add a domain to shaman + delete Remove a domain from shaman + list List all domains in shaman + get Get records for a domain + update Update records for a domain + reset Reset all domains in shaman -#### Remove -DELETE to `/records/[record type]/[domain]` +Flags: + -C, --api-crt string Path to SSL crt for API access + -k, --api-key string Path to SSL key for API access + -p, --api-key-password string Password for SSL key + -H, --api-listen string Listen address for the API (ip:port) (default "127.0.0.1:1632") + -c, --config-file string Configuration file to load + -O, --dns-listen string Listen address for DNS requests (ip:port) (default "127.0.0.1:53") + -d, --domain string Parent domain for requests (default ".") + -i, --insecure Disable tls key checking (client) and listen on http (api) + -2, --l2-connect string Connection string for the l2 cache (default "scribble:///var/db/shaman") + -l, --log-level string Log level to output [fatal|error|info|debug|trace] (default "INFO") + -s, --server Run in server mode + -t, --token string Token for API Access (default "secret") + -T, --ttl int Default TTL for DNS records (default 60) + -v, --version Print version info and exit + +Use "shaman [command] --help" for more information about a command. +``` -#### Show -GET to `/records/[record type]/[domain]` +For usage examples, see [api](api/README.md) and/or [cli](commands/README.md) readme + +### As a Server +To start shaman as a server run: +`shaman --server` +An optional config file can also be passed on startup: +`shaman -c config.json` + +>config.json +>```json +{ + "api-crt": "", + "api-key": "", + "api-key-password": "", + "api-listen": "127.0.0.1:1632", + "token": "secret", + "insecure": false, + "l2-connect": "scribble:///var/db/shaman", + "ttl": 60, + "domain": ".", + "dns-listen": "127.0.0.1:53", + "log-level": "info", + "server": true +} +``` -#### Update -PUT to `/records/[record type]/[domain]` -A `value` must be put. Currently it has to be past as a query string rather than part of the put body like `/records/[record type]/[domain]?value=[value]`. This is an issue that should be fixed. +## API: + +| Route | Description | Payload | Output | +| --- | --- | --- | --- | +| **POST** /records | Adds the domain and full record | json domain object | json domain object | +| **PUT** /records | Update all domains and records (replaces all) | json array of domain objects | json array of domain objects | +| **GET** /records | Returns a list of domains we have records for | nil | string array of domains | +| **PUT** /records/{domain} | Update domain's records (replaces all) | json domain object | json domain object | +| **GET** /records/{domain} | Returns the records for that domain | nil | json domain object | +| **DELETE** /records/{domain} | Delete a domain | nil | success message | + +For examples, see [the api's readme](api/README.md) + +## Data types: +### Domain (Resource): +json: +```json +{ + "domain": "nanopack.io.", + "records": [ + { + "ttl": 60, + "class": "IN", + "type": "A", + "address": "127.0.0.1" + }, + { + "ttl": 60, + "class": "IN", + "type": "A", + "address": "127.0.0.2" + } + ] +} +``` -#### List -GET to `/records` +Fields: +- **domain**: Domain name to resolve +- **records**: Array of address records + - **ttl**: Seconds a client should cache for + - **class**: Record class + - **type**: Record type + - A - Address record + - CNAME - Canonical name record + - MX - Mail exchange record + - [Many more](https://en.wikipedia.org/wiki/List_of_DNS_record_types) - may or may not work as is + - **address**: Address domain resolves to + - note: Special rules apply in some cases. E.g. MX records require a number "10 mail.google.com" + +### Error: +json: +```json +{ + "err": "exit status 2: unexpected argument" +} +``` -### Notes +Fields: + - **err**: Error message -#### Using nslookup to test -The port can be set with `set port=8053` and the server with `server 127.0.0.1` -``` -$ nslookup -> set port=8053 -> server 127.0.0.1 -Default server: 127.0.0.1 -Address: 127.0.0.1#8053 -> test.com -Server: 127.0.0.1 -Address: 127.0.0.1#8053 - -Non-authoritative answer: -*** Can't find test.com: No answer -> exit +### Message: +json: +```json +{ + "msg": "Success" +} ``` -#### Overview +Fields: + - **msg**: Success message -``` -+------------+ +----------+ +-----------------+ -| +-----> +-----> | -| API Server | | | | Short-Term (L1) | -| <-----+ Caching <-----+ | -+------------+ | And | +-----------------+ - | Database | -+------------+ | Manager | +-----------------+ -| +-----> +-----> | -| DNS Server | | | | Long-Term (L2) | -| <-----+ <-----+ | -+------------+ +----------+ +-----------------+ -``` +## Todo +- tests for server/dns +- start server insecure +- atomic local cache updates +- export in hosts file format + +## Changelog +- v0.0.2 (May 11, 2016) + - Refactor to allow multiple records per domain and more fully utilize dns library -[![shaman logo](http://nano-assets.gopagoda.io/open-src/nanobox-open-src.png)](http://nanobox.io/open-source) +[![oss logo](http://nano-assets.gopagoda.io/open-src/nanobox-open-src.png)](http://nanobox.io/open-source) diff --git a/api/README.md b/api/README.md new file mode 100644 index 0000000..e2ad045 --- /dev/null +++ b/api/README.md @@ -0,0 +1,68 @@ +[![shaman logo](http://nano-assets.gopagoda.io/readme-headers/shaman.png)](http://nanobox.io/open-source#shaman) +[![Build Status](https://travis-ci.org/nanopack/shaman.svg)](https://travis-ci.org/nanopack/shaman) + +# Shaman + +Small, lightweight, api-driven dns server. + +## Routes: + +| Route | Description | Payload | Output | +| --- | --- | --- | --- | +| **POST** /records | Adds the domain and full record | json domain object | json domain object | +| **PUT** /records | Update all domains and records (replaces all) | json array of domain objects | json array of domain objects | +| **GET** /records | Returns a list of domains we have records for | nil | string array of domains | +| **PUT** /records/{id} | Update domain's records (replaces all) | json domain object | json domain object | +| **GET** /records/{id} | Returns the records for that domain | nil | json domain object | +| **DELETE** /records/{id} | Delete a domain | nil | success message | + +## Usage Example: + +#### add domain +```sh +$ curl -k -H "X-AUTH-TOKEN: secret" https://localhost:1632/records -d \ + '{"domain":"nanopack.io","records":[{"ttl":60,"class":"IN","type":"A","address":"127.0.0.2"}]}' +# {"domain":"nanopack.io.","records":[{"ttl":60,"class":"IN","type":"A","address":"127.0.0.2"}]} +``` + +#### list domains +```sh +$ curl -k -H "X-AUTH-TOKEN: secret" https://localhost:1632/records +# ["nanopack.io"] +``` +or add `?full=true` for the full records +```sh +$ curl -k -H "X-AUTH-TOKEN: secret" https://localhost:1632/records?full=true +# [{"domain":"nanopack.io.","records":[{"ttl":60,"class":"IN","type":"A","address":"127.0.0.2"}]}] +``` + +#### update domains +```sh +$ curl -k -H "X-AUTH-TOKEN: secret" https://localhost:1632/records -d \ + '[{"domain":"nanobox.io","records":[{"address":"127.0.0.1"}]}]' \ + -X PUT +# [{"domain":"nanobox.io.","records":[{"ttl":60,"class":"IN","type":"A","address":"127.0.0.1"}]}] +``` + +#### update domain +```sh +$ curl -k -H "X-AUTH-TOKEN: secret" https://localhost:1632/records/nanobox.io -d \ + '{"domain":"nanobox.io","records":[{"address":"127.0.0.2"}]}' \ + -X PUT +# {"domain":"nanobox.io.","records":[{"ttl":60,"class":"IN","type":"A","address":"127.0.0.2"}]} +``` + +#### delete domain +```sh +$ curl -k -H "X-AUTH-TOKEN: secret" https://localhost:1632/records/nanobox.io \ + -X DELETE +# {"msg":"success"} +``` + +#### get domain +```sh +$ curl -k -H "X-AUTH-TOKEN: secret" https://localhost:1632/records/nanobox.io +# {"err":"failed to find record for domain - 'nanobox.io'"} +``` + +[![oss logo](http://nano-assets.gopagoda.io/open-src/nanobox-open-src.png)](http://nanobox.io/open-source) diff --git a/api/api.go b/api/api.go index 9e5dc7e..36d2ecc 100644 --- a/api/api.go +++ b/api/api.go @@ -1,29 +1,37 @@ +// Package "api" provides a restful interface to manage entries in the DNS database. package api -// This is a restful interface to manage entries in the DNS database - -// TODO: -// - parse data to build record to add/update -// - add logging -// - test - import ( "crypto/tls" "encoding/json" + "errors" "fmt" + "io/ioutil" "net/http" "github.com/gorilla/pat" - "github.com/miekg/dns" nanoauth "github.com/nanobox-io/golang-nanoauth" - "github.com/nanopack/shaman/caches" "github.com/nanopack/shaman/config" ) -var auth nanoauth.Auth +type ( + apiError struct { + ErrorString string `json:"err"` + } + apiMsg struct { + MsgString string `json:"msg"` + } +) + +var ( + auth nanoauth.Auth + badJson = errors.New("Bad JSON syntax received in body") + bodyReadFail = errors.New("Body Read Failed") +) -func StartApi() error { +// Start starts shaman's http api +func Start() error { var cert *tls.Certificate var err error if config.ApiCrt == "" { @@ -35,139 +43,69 @@ func StartApi() error { return err } auth.Certificate = cert - auth.Header = "X-NANOBOX-TOKEN" + auth.Header = "X-AUTH-TOKEN" + + config.Log.Info("Shaman listening on https://%v", config.ApiListen) - return auth.ListenAndServeTLS(fmt.Sprintf("%s:%s", config.ApiHost, config.ApiPort), config.ApiToken, routes()) + // todo: handle config.Insecure + + return fmt.Errorf("API stopped - %v", auth.ListenAndServeTLS(config.ApiListen, config.ApiToken, routes())) } func routes() *pat.Router { router := pat.New() - router.Get("/records/{rtype}/{domain}", handleRequest(getRecord)) - router.Post("/records/{rtype}/{domain}", handleRequest(addRecord)) - router.Put("/records/{rtype}/{domain}", handleRequest(updateRecord)) - router.Delete("/records/{rtype}/{domain}", handleRequest(deleteRecord)) - router.Get("/records", handleRequest(listRecords)) - return router -} -func handleRequest(fn func(http.ResponseWriter, *http.Request)) http.HandlerFunc { - return func(rw http.ResponseWriter, req *http.Request) { - fn(rw, req) - } + router.Delete("/records/{domain}", deleteRecord) // delete resource + router.Put("/records/{domain}", updateRecord) // reset resource's records + router.Get("/records/{domain}", getRecord) // return resource's records + + router.Post("/records", createRecord) // add a resource + router.Get("/records", listRecords) // return all domains + router.Put("/records", updateAnswers) // reset all resources + + return router } -func writeBody(v interface{}, rw http.ResponseWriter, status int) error { +func writeBody(rw http.ResponseWriter, req *http.Request, v interface{}, status int) error { b, err := json.Marshal(v) if err != nil { return err } + // print the error only if there is one + var msg map[string]string + json.Unmarshal(b, &msg) + + var errMsg string + if msg["error"] != "" { + errMsg = msg["error"] + } + + config.Log.Debug("%s %d %s %s %s", req.RemoteAddr, status, req.Method, req.RequestURI, errMsg) + rw.Header().Set("Content-Type", "application/json") rw.WriteHeader(status) - rw.Write(b) + rw.Write(append(b, byte('\n'))) return nil } -func getRecord(rw http.ResponseWriter, req *http.Request) { - rtype := dns.StringToType[req.URL.Query().Get(":rtype")] - domain := req.URL.Query().Get(":domain") - dns.IsDomainName(domain) - key := caches.Key(domain, rtype) - findReturn := make(chan caches.FindReturn) - findOp := caches.FindOp{Key: key, Resp: findReturn} - caches.FindOps <- findOp - findRet := <-findReturn - err := findRet.Err - record := findRet.Value - if err != nil { - writeBody(err, rw, http.StatusInternalServerError) - return - } - if record == "" { - writeBody(nil, rw, http.StatusNotFound) - return - } - rr, err := dns.NewRR(record) - if err != nil { - writeBody(err, rw, http.StatusInternalServerError) - return - } - err = writeBody(rr, rw, http.StatusOK) - if err != nil { - // log error - } -} +// parseBody parses the json body into v +func parseBody(req *http.Request, v interface{}) error { -func addRecord(rw http.ResponseWriter, req *http.Request) { - rtype := req.URL.Query().Get(":rtype") - domain := req.URL.Query().Get(":domain") - value := req.FormValue("value") - ttl := config.TTL - key := caches.Key(domain, dns.StringToType[rtype]) - rrString := fmt.Sprintf("%s %d IN %s %s", domain, ttl, rtype, value) - rr, err := dns.NewRR(rrString) + // read the body + b, err := ioutil.ReadAll(req.Body) if err != nil { - writeBody(err, rw, http.StatusInternalServerError) - return + config.Log.Error(err.Error()) + return bodyReadFail } - resp := make(chan error) - addOp := caches.AddOp{Key: key, Value: rr.String(), Resp: resp} - caches.AddOps <- addOp - err = <-resp - if err != nil { - writeBody(err, rw, http.StatusInternalServerError) - return - } - writeBody(rr, rw, http.StatusOK) -} + defer req.Body.Close() -func updateRecord(rw http.ResponseWriter, req *http.Request) { - rtype := req.URL.Query().Get(":rtype") - domain := req.URL.Query().Get(":domain") - key := caches.Key(domain, dns.StringToType[rtype]) - ttl := config.TTL - value := req.FormValue("value") - rrString := fmt.Sprintf("%s %d IN %s %s", domain, ttl, rtype, value) - rr, err := dns.NewRR(rrString) + // parse body and store in v + err = json.Unmarshal(b, v) if err != nil { - writeBody(err, rw, http.StatusInternalServerError) - return + return badJson } - resp := make(chan error) - updateOp := caches.UpdateOp{Key: key, Value: rr.String(), Resp: resp} - caches.UpdateOps <- updateOp - err = <-resp - if err != nil { - writeBody(err, rw, http.StatusInternalServerError) - return - } - writeBody(rr, rw, http.StatusOK) -} -func deleteRecord(rw http.ResponseWriter, req *http.Request) { - rtype := req.URL.Query().Get(":rtype") - domain := req.URL.Query().Get(":domain") - key := caches.Key(domain, dns.StringToType[rtype]) - resp := make(chan error) - removeOp := caches.RemoveOp{Key: key, Resp: resp} - caches.RemoveOps <- removeOp - err := <-resp - if err != nil { - writeBody(err, rw, http.StatusInternalServerError) - return - } - writeBody(nil, rw, http.StatusOK) -} - -func listRecords(rw http.ResponseWriter, req *http.Request) { - resp := make(chan caches.ListReturn) - listOp := caches.ListOp{Resp: resp} - caches.ListOps <- listOp - listReturn := <-resp - if listReturn.Err != nil { - writeBody(listReturn.Err, rw, http.StatusInternalServerError) - return - } - writeBody(listReturn.Values, rw, http.StatusOK) + return nil } diff --git a/api/api_test.go b/api/api_test.go new file mode 100644 index 0000000..3eb45f3 --- /dev/null +++ b/api/api_test.go @@ -0,0 +1,260 @@ +package api_test + +import ( + "bytes" + "crypto/tls" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "os" + "strings" + "testing" + "time" + + "github.com/jcelliott/lumber" + + "github.com/nanopack/shaman/api" + "github.com/nanopack/shaman/config" + shaman "github.com/nanopack/shaman/core/common" +) + +var ( + testResource1 = `{"domain":"google.com","records":[{"type":"A","address":"127.0.0.1"}]}` + testResource2 = `{"domain":"google.com","records":[{"type":"A","address":"127.0.0.2"}]}` + badResource = `{"domain":"google.com","records":[{"type":1,"address":"127.0.0.3"}]}` + testResource3 = `{"domain":"foogle.com","records":[{"type":"A","address":"127.0.0.4"}]}` +) + +func TestMain(m *testing.M) { + // manually configure + initialize() + + // start api + go api.Start() + <-time.After(3 * time.Second) + rtn := m.Run() + + os.Exit(rtn) +} + +// test put records +func TestPutRecords(t *testing.T) { + // good request test + resp, _, err := rest("PUT", "/records", fmt.Sprintf("[%v]", testResource1)) + if err != nil { + t.Error(err) + } + + var resources []shaman.Resource + json.Unmarshal(resp, &resources) + + if len(resources) != 1 { + t.Errorf("%q doesn't match expected out", resources) + } + + if len(resources) == 1 && + len(resources[0].Records) == 1 && + resources[0].Records[0].Address != "127.0.0.1" { + t.Errorf("%q doesn't match expected out", resources) + } + + // bad request test + resp, _, err = rest("PUT", "/records", testResource1) + if err != nil { + t.Error(err) + } + + if !strings.Contains(string(resp), "Bad JSON syntax received in body") { + t.Errorf("%q doesn't match expected out", resp) + } + + // clear records + rest("PUT", "/records", "[]") +} + +// todo: "tests should be able to run independent" `go test -v ./api -run TestGet` +// test get records +func TestGetRecords(t *testing.T) { + body, _, err := rest("GET", "/records", "") + if err != nil { + t.Error(err) + } + if string(body) != "[]\n" { + t.Errorf("%q doesn't match expected out", body) + } + body, _, err = rest("GET", "/records?full=true", "") + if err != nil { + t.Error(err) + } + if string(body) != "[]\n" { + t.Errorf("%q doesn't match expected out", body) + } +} + +// test post records +func TestPostRecord(t *testing.T) { + // good request test + resp, _, err := rest("POST", "/records", testResource1) + if err != nil { + t.Error(err) + } + + var resource shaman.Resource + json.Unmarshal(resp, &resource) + + if resource.Domain != "google.com." { + t.Errorf("%q doesn't match expected out", resource) + } + + // bad request test + resp, _, err = rest("POST", "/records", badResource) + if err != nil { + t.Error(err) + } + + if !strings.Contains(string(resp), "Bad JSON syntax received in body") { + t.Errorf("%q doesn't match expected out", resp) + } +} + +// test get resource +func TestGetRecord(t *testing.T) { + // good request test + resp, _, err := rest("GET", "/records/google.com", "") + if err != nil { + t.Error(err) + } + + var resource shaman.Resource + json.Unmarshal(resp, &resource) + + if resource.Domain != "google.com." { + t.Errorf("%q doesn't match expected out", resource) + } + + // bad request test + resp, _, err = rest("GET", "/records/not-real.com", "") + if err != nil { + t.Error(err) + } + + if !strings.Contains(string(resp), "failed to find record for domain - 'not-real.com'") { + t.Errorf("%q doesn't match expected out", resp) + } +} + +// test put records +func TestPutRecord(t *testing.T) { + // good request test - create(201) + resp, code, err := rest("PUT", "/records/foogle.com", testResource3) + if err != nil { + t.Error(err) + } + if code != 201 { + t.Error("Failed to meet rfc2616 spec, expecting 201") + } + + var resource shaman.Resource + json.Unmarshal(resp, &resource) + + if len(resource.Records) == 1 && + resource.Records[0].Address != "127.0.0.4" { + t.Errorf("%q doesn't match expected out", resource) + } + + // good request test - update + resp, _, err = rest("PUT", "/records/foogle.com", testResource2) + if err != nil { + t.Error(err) + } + + // verify old resource is gone + resp, _, err = rest("GET", "/records/foogle.com", "") + if err != nil { + t.Error(err) + } + + if !strings.Contains(string(resp), "failed to find record for domain - 'foogle.com'") { + t.Errorf("%q doesn't match expected out", resp) + } + + // bad request test + resp, _, err = rest("PUT", "/records/not-real.com", badResource) + if err != nil { + t.Error(err) + } + + if !strings.Contains(string(resp), "Bad JSON syntax received in body") { + t.Errorf("%q doesn't match expected out", resp) + } +} + +// test delete resource +func TestDeleteRecord(t *testing.T) { + // good request test + resp, _, err := rest("DELETE", "/records/google.com", "") + if err != nil { + t.Error(err) + } + + if !strings.Contains(string(resp), "{\"msg\":\"success\"}") { + t.Errorf("%q doesn't match expected out", resp) + } + + // verify gone + resp, code, err := rest("GET", "/records/google.com", "") + if err != nil { + t.Error(err) + } + + if code != 404 { + t.Errorf("%q doesn't match expected out", code) + } + + // bad request test + resp, _, err = rest("DELETE", "/records/not-real.com", "") + if err != nil { + t.Error(err) + } + + if !strings.Contains(string(resp), "{\"msg\":\"success\"}") { + t.Errorf("%q doesn't match expected out", resp) + } +} + +//////////////////////////////////////////////////////////////////////////////// +// PRIVS +//////////////////////////////////////////////////////////////////////////////// +// hit api and return response body +func rest(method, route, data string) ([]byte, int, error) { + body := bytes.NewBuffer([]byte(data)) + http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + + uri := fmt.Sprintf("https://%s%s", config.ApiListen, route) + + req, _ := http.NewRequest(method, uri, body) + req.Header.Add("X-AUTH-TOKEN", config.ApiToken) + + res, err := http.DefaultClient.Do(req) + if err != nil { + return nil, 500, fmt.Errorf("Unable to %v %v - %v", method, route, err) + } + defer res.Body.Close() + + if res.StatusCode == 401 { + return nil, res.StatusCode, fmt.Errorf("401 Unauthorized. Please specify api token (-t 'token')") + } + + b, err := ioutil.ReadAll(res.Body) + + return b, res.StatusCode, err +} + +// manually configure and start internals +func initialize() { + config.L2Connect = "none://" + config.ApiListen = "127.0.0.1:1633" + config.Log = lumber.NewConsoleLogger(lumber.LvlInt("FATAL")) + config.LogLevel = "FATAL" +} diff --git a/api/records.go b/api/records.go new file mode 100644 index 0000000..e4ffcce --- /dev/null +++ b/api/records.go @@ -0,0 +1,108 @@ +package api + +import ( + "fmt" + "net/http" + + "github.com/nanopack/shaman/core" + sham "github.com/nanopack/shaman/core/common" +) + +func createRecord(rw http.ResponseWriter, req *http.Request) { + var resource sham.Resource + err := parseBody(req, &resource) + if err != nil { + writeBody(rw, req, apiError{err.Error()}, http.StatusBadRequest) + return + } + + err = shaman.AddRecord(&resource) + if err != nil { + writeBody(rw, req, apiError{err.Error()}, http.StatusInternalServerError) + return + } + + writeBody(rw, req, resource, http.StatusOK) +} + +func listRecords(rw http.ResponseWriter, req *http.Request) { + if req.URL.Query().Get("full") == "true" { + writeBody(rw, req, shaman.ListRecords(), http.StatusOK) + return + } + + writeBody(rw, req, shaman.ListDomains(), http.StatusOK) +} + +func updateAnswers(rw http.ResponseWriter, req *http.Request) { + resources := make([]sham.Resource, 0) + err := parseBody(req, &resources) + if err != nil { + writeBody(rw, req, apiError{err.Error()}, http.StatusBadRequest) + return + } + + err = shaman.ResetRecords(&resources) + if err != nil { + writeBody(rw, req, apiError{err.Error()}, http.StatusInternalServerError) + return + } + + writeBody(rw, req, resources, http.StatusOK) +} + +func updateRecord(rw http.ResponseWriter, req *http.Request) { + var resource sham.Resource + err := parseBody(req, &resource) + if err != nil { + writeBody(rw, req, apiError{err.Error()}, http.StatusBadRequest) + return + } + + domain := req.URL.Query().Get(":domain") + + if !shaman.Exists(domain) { + // create resource if not exist + err = shaman.AddRecord(&resource) + if err != nil { + writeBody(rw, req, apiError{err.Error()}, http.StatusInternalServerError) + return + } + + // "MUST reply 201"(https://www.w3.org/Protocols/rfc2616/rfc2616-sec9.html) + writeBody(rw, req, resource, http.StatusCreated) + return + } + + err = shaman.UpdateRecord(domain, &resource) + if err != nil { + writeBody(rw, req, apiError{err.Error()}, http.StatusInternalServerError) + return + } + + writeBody(rw, req, resource, http.StatusOK) +} + +func getRecord(rw http.ResponseWriter, req *http.Request) { + domain := req.URL.Query().Get(":domain") + + resource, err := shaman.GetRecord(domain) + if err != nil { + writeBody(rw, req, apiError{fmt.Sprintf("failed to find record for domain - '%v'", domain)}, http.StatusNotFound) + return + } + + writeBody(rw, req, resource, http.StatusOK) +} + +func deleteRecord(rw http.ResponseWriter, req *http.Request) { + domain := req.URL.Query().Get(":domain") + + err := shaman.DeleteRecord(domain) + if err != nil { + writeBody(rw, req, apiError{err.Error()}, http.StatusInternalServerError) + return + } + + writeBody(rw, req, apiMsg{"success"}, http.StatusOK) +} diff --git a/build.sh b/build.sh new file mode 100755 index 0000000..4fb0dc0 --- /dev/null +++ b/build.sh @@ -0,0 +1,31 @@ +#!/usr/bin/env bash +set -e + +# try and use the correct MD5 lib (depending on user OS darwin/linux) +MD5=$(which md5 || which md5sum) + +# build shaman +echo "Building SHAMAN and uploading it to 's3://tools.nanopack.io/shaman'" +gox -osarch "linux/amd64" -output="./build/{{.OS}}/{{.Arch}}/shaman" + +# look through each os/arch/file and generate an md5 for each +echo "Generating md5s..." +for os in $(ls ./build); do + for arch in $(ls ./build/${os}); do + for file in $(ls ./build/${os}/${arch}); do + cat "./build/${os}/${arch}/${file}" | ${MD5} | awk '{print $1}' >> "./build/${os}/${arch}/${file}.md5" + done + done +done + +# upload to AWS S3 +echo "Uploading builds to S3..." +aws s3 sync ./build/ s3://tools.nanopack.io/shaman --grants read=uri=http://acs.amazonaws.com/groups/global/AllUsers --region us-east-1 + +# +echo "Cleaning up..." + +# remove build +[ -e "./build" ] && \ + echo "Removing build files..." && \ + rm -rf "./build" diff --git a/cache/cache.go b/cache/cache.go new file mode 100644 index 0000000..5ec4266 --- /dev/null +++ b/cache/cache.go @@ -0,0 +1,117 @@ +// Package "cache" provides a pluggable backend for persistant record storage. +package cache + +import ( + "errors" + "fmt" + "net/url" + + "github.com/nanopack/shaman/config" + shaman "github.com/nanopack/shaman/core/common" +) + +var ( + storage cacher + noRecordError = errors.New("No Record Found") +) + +// The cacher interface is what all the backends [will] implement +type cacher interface { + initialize() error + addRecord(resource *shaman.Resource) error + getRecord(domain string) (*shaman.Resource, error) + updateRecord(domain string, resource *shaman.Resource) error + deleteRecord(domain string) error + resetRecords(resources *[]shaman.Resource) error + listRecords() ([]shaman.Resource, error) +} + +// Set default cacher and initialize it +func Initialize() error { + u, err := url.Parse(config.L2Connect) + if err != nil { + return fmt.Errorf("Failed to parse 'l2-connect' - %v", err) + } + + switch u.Scheme { + case "scribble": + storage = &scribbleDb{} + case "none": + storage = nil + default: + storage = &scribbleDb{} + } + + if storage != nil { + err = storage.initialize() + if err != nil { + storage = nil + err = fmt.Errorf("Failed to initialize cache, turning off - %v", err) + } + } + + return err +} + +// AddRecord adds a record to the persistant cache +func AddRecord(resource *shaman.Resource) error { + if storage == nil { + return nil + } + resource.Validate() + return storage.addRecord(resource) +} + +// GetRecord gets a record to the persistant cache +func GetRecord(domain string) (*shaman.Resource, error) { + if storage == nil { + return nil, nil + } + + shaman.SanitizeDomain(&domain) + return storage.getRecord(domain) +} + +// UpdateRecord updates a record in the persistant cache +func UpdateRecord(domain string, resource *shaman.Resource) error { + if storage == nil { + return nil + } + shaman.SanitizeDomain(&domain) + resource.Validate() + return storage.updateRecord(domain, resource) +} + +// DeleteRecord removes a record from the persistant cache +func DeleteRecord(domain string) error { + if storage == nil { + return nil + } + shaman.SanitizeDomain(&domain) + return storage.deleteRecord(domain) +} + +// ResetRecords replaces all records in the persistant cache +func ResetRecords(resources *[]shaman.Resource) error { + if storage == nil { + return nil + } + for i := range *resources { + (*resources)[i].Validate() + } + + return storage.resetRecords(resources) +} + +// ListRecords lists all records in the persistant cache +func ListRecords() ([]shaman.Resource, error) { + if storage == nil { + return make([]shaman.Resource, 0), nil + } + return storage.listRecords() +} + +// Exists returns whether the default cacher exists +func Exists() bool { + return storage != nil +} diff --git a/cache/cache_test.go b/cache/cache_test.go new file mode 100644 index 0000000..1efa882 --- /dev/null +++ b/cache/cache_test.go @@ -0,0 +1,103 @@ +package cache_test + +import ( + "os" + "testing" + + "github.com/jcelliott/lumber" + + "github.com/nanopack/shaman/cache" + "github.com/nanopack/shaman/config" + shaman "github.com/nanopack/shaman/core/common" +) + +var ( + nanopack = shaman.Resource{Domain: "nanopack.io.", Records: []shaman.Record{shaman.Record{Address: "127.0.0.1"}}} + nanobox = shaman.Resource{Domain: "nanobox.io.", Records: []shaman.Record{shaman.Record{Address: "127.0.0.2"}}} + nanoBoth = []shaman.Resource{nanopack, nanobox} +) + +func TestMain(m *testing.M) { + // manually configure + config.Log = lumber.NewConsoleLogger(lumber.LvlInt("FATAL")) + + // run tests + rtn := m.Run() + + os.Exit(rtn) +} + +// test nil cache init +func TestNoneInitialize(t *testing.T) { + config.L2Connect = "none://" + err := cache.Initialize() + if err != nil { + t.Errorf("Failed to initalize none cacher - %v", err) + } +} + +// test nil cache addRecord +func TestNoneAddRecord(t *testing.T) { + noneReset() + err := cache.AddRecord(&shaman.Resource{}) + if err != nil { + t.Errorf("Failed to add record to none cacher - %v", err) + } +} + +// test nil cache getRecord +func TestNoneGetRecord(t *testing.T) { + noneReset() + _, err := cache.GetRecord("nanopack.io") + if err != nil { + t.Errorf("Failed to get record from none cacher - %v", err) + } +} + +// test nil cache updateRecord +func TestNoneUpdateRecord(t *testing.T) { + noneReset() + err := cache.UpdateRecord("nanopack.io", &shaman.Resource{}) + if err != nil { + t.Errorf("Failed to update record in none cacher - %v", err) + } +} + +// test nil cache deleteRecord +func TestNoneDeleteRecord(t *testing.T) { + noneReset() + err := cache.DeleteRecord("nanopack.io") + if err != nil { + t.Errorf("Failed to delete record from none cacher - %v", err) + } +} + +// test nil cache resetRecords +func TestNoneResetRecords(t *testing.T) { + noneReset() + err := cache.ResetRecords(&[]shaman.Resource{}) + if err != nil { + t.Errorf("Failed to reset records in none cacher - %v", err) + } +} + +// test nil cache listRecords +func TestNoneListRecords(t *testing.T) { + noneReset() + _, err := cache.ListRecords() + if err != nil { + t.Errorf("Failed to list records in none cacher - %v", err) + } +} + +func TestNoneExists(t *testing.T) { + noneReset() + if cache.Exists() { + t.Error("Cache exits but shouldn't") + } +} + +func noneReset() { + config.L2Connect = "none://" + cache.Initialize() +} diff --git a/cache/scribble.go b/cache/scribble.go new file mode 100644 index 0000000..12f90d4 --- /dev/null +++ b/cache/scribble.go @@ -0,0 +1,111 @@ +package cache + +import ( + "encoding/json" + "fmt" + "net/url" + "strings" + + "github.com/nanobox-io/golang-scribble" + + "github.com/nanopack/shaman/config" + shaman "github.com/nanopack/shaman/core/common" +) + +type scribbleDb struct { + db *scribble.Driver +} + +func (self *scribbleDb) initialize() error { + u, err := url.Parse(config.L2Connect) + if err != nil { + return fmt.Errorf("Failed to parse 'l2-connect' - %v", err) + } + dir := u.Path + if dir == "" || dir == "/" { + config.Log.Debug("Invalid directory, using default '/var/db/shaman'") + dir = "/var/db/shaman" + } + db, err := scribble.New(dir, nil) + if err != nil { + config.Log.Fatal("Failed to create db") + return fmt.Errorf("Failed to create new db at '%v' - %v", dir, err) + } + + self.db = db + return nil +} + +func (self scribbleDb) addRecord(resource *shaman.Resource) error { + err := self.db.Write("hosts", resource.Domain, *resource) + if err != nil { + err = fmt.Errorf("Failed to save record - %v", err) + } + return err +} + +func (self scribbleDb) getRecord(domain string) (*shaman.Resource, error) { + resource := shaman.Resource{} + err := self.db.Read("hosts", domain, &resource) + if err != nil { + if strings.Contains(err.Error(), "no such file or directory") { + err = noRecordError + } + return nil, err + } + return &resource, nil +} + +func (self scribbleDb) updateRecord(domain string, resource *shaman.Resource) error { + if domain != resource.Domain { + err := self.deleteRecord(domain) + if err != nil { + return fmt.Errorf("Failed to clear current record - %v", err) + } + } + + return self.addRecord(resource) +} + +func (self scribbleDb) deleteRecord(domain string) error { + err := self.db.Delete("hosts", domain) + if err != nil { + if strings.Contains(err.Error(), "Unable to find") { + err = nil + } else { + err = fmt.Errorf("Failed to delete record - %v", err) + } + } + return err +} + +func (self scribbleDb) resetRecords(resources *[]shaman.Resource) (err error) { + self.db.Delete("hosts", "") + for i := range *resources { + err = self.db.Write("hosts", (*resources)[i].Domain, (*resources)[i]) + if err != nil { + err = fmt.Errorf("Failed to save records - %v", err) + } + } + return err +} + +func (self scribbleDb) listRecords() ([]shaman.Resource, error) { + resources := make([]shaman.Resource, 0) + values, err := self.db.ReadAll("hosts") + if err != nil { + if strings.Contains(err.Error(), "no such file or directory") { + // if error is about a missing db, return empty array + return resources, nil + } + return nil, err + } + for i := range values { + var resource shaman.Resource + if err = json.Unmarshal([]byte(values[i]), &resource); err != nil { + return nil, fmt.Errorf("Bad JSON syntax found in stored body") + } + resources = append(resources, resource) + } + return resources, nil +} diff --git a/cache/scribble_test.go b/cache/scribble_test.go new file mode 100644 index 0000000..b901fb7 --- /dev/null +++ b/cache/scribble_test.go @@ -0,0 +1,91 @@ +package cache_test + +import ( + "os" + "testing" + + "github.com/nanopack/shaman/cache" + "github.com/nanopack/shaman/config" +) + +// test scribble cache init +func TestScribbleInitialize(t *testing.T) { + config.L2Connect = "/tmp/shamanCache" // default + err := cache.Initialize() + config.L2Connect = "!@#$%^&*()" // unparse-able + err2 := cache.Initialize() + config.L2Connect = "scribble:///roots/file" // unable to init? (test no sudo) + err3 := cache.Initialize() + config.L2Connect = "scribble:///" // defaulting to "/var/db" + err4 := cache.Initialize() + if err != nil || err2 == nil || err3 == nil || err4 != nil { + t.Errorf("Failed to initalize scribble cacher - %v%v%v%v", err, err2, err3, err4) + } +} + +// test scribble cache addRecord +func TestScribbleAddRecord(t *testing.T) { + scribbleReset() + err := cache.AddRecord(&nanopack) + if err != nil { + t.Errorf("Failed to add record to scribble cacher - %v", err) + } +} + +// test scribble cache getRecord +func TestScribbleGetRecord(t *testing.T) { + scribbleReset() + cache.AddRecord(&nanopack) + _, err := cache.GetRecord("nanobox.io") + _, err2 := cache.GetRecord("nanopack.io") + if err == nil || err2 != nil { + t.Errorf("Failed to get record from scribble cacher - %v%v", err, err2) + } +} + +// test scribble cache updateRecord +func TestScribbleUpdateRecord(t *testing.T) { + scribbleReset() + err := cache.UpdateRecord("nanobox.io", &nanopack) + err2 := cache.UpdateRecord("nanopack.io", &nanopack) + if err != nil || err2 != nil { + t.Errorf("Failed to update record in scribble cacher - %v%v", err, err2) + } +} + +// test scribble cache deleteRecord +func TestScribbleDeleteRecord(t *testing.T) { + scribbleReset() + err := cache.DeleteRecord("nanobox.io") + cache.AddRecord(&nanopack) + err2 := cache.DeleteRecord("nanopack.io") + if err != nil || err2 != nil { + t.Errorf("Failed to delete record from scribble cacher - %v%v", err, err2) + } +} + +// test scribble cache resetRecords +func TestScribbleResetRecords(t *testing.T) { + scribbleReset() + err := cache.ResetRecords(&nanoBoth) + if err != nil { + t.Errorf("Failed to reset records in scribble cacher - %v", err) + } +} + +// test scribble cache listRecords +func TestScribbleListRecords(t *testing.T) { + scribbleReset() + _, err := cache.ListRecords() + cache.ResetRecords(&nanoBoth) + _, err2 := cache.ListRecords() + if err != nil || err2 != nil { + t.Errorf("Failed to list records in scribble cacher - %v%v", err, err2) + } +} + +func scribbleReset() { + os.RemoveAll("/tmp/shamanCache") + config.L2Connect = "scribble:///tmp/shamanCache" + cache.Initialize() +} diff --git a/caches/caches.go b/caches/caches.go deleted file mode 100644 index f30d48e..0000000 --- a/caches/caches.go +++ /dev/null @@ -1,246 +0,0 @@ -package caches - -// General layered caching for the shaman dns server -// L1 is a short-term quick response lookup for entries -// L2 is a long-term storage for entries -// L1 and L2 can be configured to use different backend caches or databases - -// TODO: -// - implement caching backends -// - add logging -// - test - -import ( - "fmt" - "net/url" - - "github.com/nanopack/shaman/config" -) - -type Cacher interface { - InitializeDatabase() error - ClearDatabase() error - GetRecord(string) (string, error) - SetRecord(string, string) error - ReviseRecord(string, string) error - DeleteRecord(string) error - ListRecords() ([]string, error) -} - -type CacheEntry struct { - Expires int64 - Value string -} - -type FindReturn struct { - Err error - Value string -} - -type ListReturn struct { - Err error - Values []string -} - -type AddOp struct { - Key string - Value string - Resp chan error -} - -type UpdateOp struct { - Key string - Value string - Resp chan error -} - -type RemoveOp struct { - Key string - Resp chan error -} - -type FindOp struct { - Key string - Resp chan FindReturn -} - -type ListOp struct { - Resp chan ListReturn -} - -var ( - AddOps = make(chan AddOp) - RemoveOps = make(chan RemoveOp) - UpdateOps = make(chan UpdateOp) - FindOps = make(chan FindOp) - ListOps = make(chan ListOp) - L1 Cacher - L2 Cacher -) - -func StartCache() error { - for { - select { - case addOp := <-AddOps: - addOp.Resp <- addRecord(addOp.Key, addOp.Value) - case removeOp := <-RemoveOps: - removeOp.Resp <- removeRecord(removeOp.Key) - case updateOp := <-UpdateOps: - updateOp.Resp <- updateRecord(updateOp.Key, updateOp.Value) - case findOp := <-FindOps: - value, err := findRecord(findOp.Key) - findOp.Resp <- FindReturn{Err: err, Value: value} - case listOp := <-ListOps: - values, err := listRecords() - listOp.Resp <- ListReturn{Err: err, Values: values} - } - } - return nil -} - -// Determine the backend cache to initialize based off of the connection string -// Pass the connection string and TTL into the backend constructor -func initializeCacher(connection string, expires int) (Cacher, error) { - u, err := url.Parse(connection) - if err != nil { - - } - var cacher Cacher - switch u.Scheme { - case "redis": - cacher, err = NewRedisCacher(connection, expires) - case "postgres": - cacher, err = NewPostgresCacher(connection, expires) - case "scribble": - cacher, err = NewScribbleCacher(connection, expires) - default: - cacher, err = NewMapCacher(connection, expires) - } - if err != nil { - return nil, err - } - err = cacher.InitializeDatabase() - if err != nil { - return nil, err - } - return cacher, nil -} - -// Create L1 and L2 from the config -func InitCache() { - config.Log.Info("Initializing caches") - var err error - L1, err = initializeCacher(config.L1Connect, config.L1Expires) - if err != nil { - config.Log.Error("Error with L1: %s", err) - } - L2, err = initializeCacher(config.L2Connect, config.L2Expires) - if err != nil { - config.Log.Error("Error with L2: %s", err) - } -} - -// Create a lookup key based off of the domain and type of record -func Key(domain string, rtype uint16) string { - return fmt.Sprintf("%d-%s", rtype, domain) -} - -// Add record into the caches. First insert into the long term, -// then try the short term. -func addRecord(key string, value string) error { - config.Log.Info("Adding key: %s, value: %s", key, value) - if L2 != nil { - err := L2.SetRecord(key, value) - if err != nil { - config.Log.Error("Error adding to L2: %s", err) - return err - } - } - if L1 != nil { - err := L1.SetRecord(key, value) - if err != nil { - config.Log.Error("Error adding to L1: %s", err) - return err - } - } - return nil -} - -// Remove record from the long term storage, then remove from short term. -func removeRecord(key string) error { - config.Log.Info("Removing key: %s", key) - if L2 != nil { - err := L2.DeleteRecord(key) - if err != nil { - config.Log.Error("Error removing from L2: %s", err) - return err - } - } - if L1 != nil { - err := L1.DeleteRecord(key) - if err != nil { - config.Log.Error("Error removing from L1: %s", err) - return err - } - } - return nil -} - -// Update the long term storage, then update the short term storage. -func updateRecord(key string, value string) error { - config.Log.Info("Updating key: %s, value: %s", key, value) - if L2 != nil { - err := L2.ReviseRecord(key, value) - if err != nil { - config.Log.Error("Error updating L2: %s", err) - return err - } - } - if L1 != nil { - err := L1.ReviseRecord(key, value) - if err != nil { - config.Log.Error("Error updating L1: %s", err) - return err - } - } - return nil -} - -// Look for the record in the short term, if it isn't there, check the -// long term, and put it in the short term. -func findRecord(key string) (string, error) { - config.Log.Info("Finding key: %s", key) - var record string - var err error - if L1 != nil { - record, err = L1.GetRecord(key) - if err != nil { - config.Log.Error("Error finding L1: %s", err) - return record, err - } - } - if record != "" { - return record, nil - } - if L2 != nil { - record, err = L2.GetRecord(key) - if err != nil { - config.Log.Error("Error finding L2: %s", err) - } - if record != "" { - L1.SetRecord(key, record) - return record, err - } - } - return "", nil -} - -func listRecords() ([]string, error) { - if L2 != nil { - return L2.ListRecords() - } - if L1 != nil { - return L1.ListRecords() - } - return []string{}, nil -} diff --git a/caches/caches_test.go b/caches/caches_test.go deleted file mode 100644 index 725f218..0000000 --- a/caches/caches_test.go +++ /dev/null @@ -1,138 +0,0 @@ -package caches_test - -import ( - "os" - "testing" - - "github.com/golang/mock/gomock" - "github.com/jcelliott/lumber" - - "github.com/nanopack/shaman/caches" - "github.com/nanopack/shaman/caches/mock_caches" - "github.com/nanopack/shaman/config" -) - -func TestMain(m *testing.M) { - config.Log = lumber.NewConsoleLogger(lumber.ERROR) - if testing.Verbose() { - config.Log = lumber.NewConsoleLogger(lumber.DEBUG) - } - os.Exit(m.Run()) -} - -func initializeCaches(t *testing.T) (*mock_caches.MockCacher, *mock_caches.MockCacher) { - caches.InitCache() - ctrl1 := gomock.NewController(t) - defer ctrl1.Finish() - ctrl2 := gomock.NewController(t) - defer ctrl2.Finish() - l1 := mock_caches.NewMockCacher(ctrl1) - l2 := mock_caches.NewMockCacher(ctrl2) - caches.L1 = l1 - caches.L2 = l2 - go caches.StartCache() - return l1, l2 -} - -func TestFindRecordL1(t *testing.T) { - l1, _ := initializeCaches(t) - gomock.InOrder( - l1.EXPECT().GetRecord("1-key").Return("found", nil), - ) - findReturn := make(chan caches.FindReturn) - findOp := caches.FindOp{Key: "1-key", Resp: findReturn} - caches.FindOps <- findOp - findRet := <-findReturn - err := findRet.Err - if err != nil { - t.Errorf("Error: %s", err) - } - record := findRet.Value - if record != "found" { - t.Errorf("bad result from L1: %s", record) - } -} - -func TestFindRecordL2(t *testing.T) { - l1, l2 := initializeCaches(t) - gomock.InOrder( - l1.EXPECT().GetRecord("1-key").Return("", nil), - l2.EXPECT().GetRecord("1-key").Return("found", nil), - l1.EXPECT().SetRecord("1-key", "found").Return(nil), - ) - findReturn := make(chan caches.FindReturn) - findOp := caches.FindOp{Key: "1-key", Resp: findReturn} - caches.FindOps <- findOp - findRet := <-findReturn - err := findRet.Err - if err != nil { - t.Errorf("Error: %s", err) - } - record := findRet.Value - if record != "found" { - t.Errorf("bad result from L1: %s", record) - } -} - -func TestAddRecord(t *testing.T) { - l1, l2 := initializeCaches(t) - gomock.InOrder( - l2.EXPECT().SetRecord("1-key", "found").Return(nil), - l1.EXPECT().SetRecord("1-key", "found").Return(nil), - ) - resp := make(chan error) - addOp := caches.AddOp{Key: "1-key", Value: "found", Resp: resp} - caches.AddOps <- addOp - err := <-resp - if err != nil { - t.Errorf("Error: %s", err) - } -} - -func TestUpdateRecord(t *testing.T) { - l1, l2 := initializeCaches(t) - gomock.InOrder( - l2.EXPECT().ReviseRecord("1-key", "found").Return(nil), - l1.EXPECT().ReviseRecord("1-key", "found").Return(nil), - ) - resp := make(chan error) - updateOp := caches.UpdateOp{Key: "1-key", Value: "found", Resp: resp} - caches.UpdateOps <- updateOp - err := <-resp - if err != nil { - t.Errorf("Error: %s", err) - } -} - -func TestRemoveRecord(t *testing.T) { - l1, l2 := initializeCaches(t) - gomock.InOrder( - l2.EXPECT().DeleteRecord("1-key").Return(nil), - l1.EXPECT().DeleteRecord("1-key").Return(nil), - ) - resp := make(chan error) - removeOp := caches.RemoveOp{Key: "1-key", Resp: resp} - caches.RemoveOps <- removeOp - err := <-resp - if err != nil { - t.Errorf("Error: %s", err) - } -} - -func TestListRecords(t *testing.T) { - _, l2 := initializeCaches(t) - gomock.InOrder( - l2.EXPECT().ListRecords().Return([]string{"found"}, nil), - ) - resp := make(chan caches.ListReturn) - listOp := caches.ListOp{Resp: resp} - caches.ListOps <- listOp - listReturn := <-resp - err := listReturn.Err - if err != nil { - t.Errorf("Error: %s", err) - } - if len(listReturn.Values) != 1 && listReturn.Values[0] != "found" { - t.Errorf("Bad return: %s", listReturn.Values) - } -} diff --git a/caches/map_cacher.go b/caches/map_cacher.go deleted file mode 100644 index a104808..0000000 --- a/caches/map_cacher.go +++ /dev/null @@ -1,85 +0,0 @@ -package caches - -// Simple cache that stores data in a simple go map. -// map doesn't automatically evict expired data, this will need to -// check to ensure data isn't already expired. - -// TODO: -// - add logging -// - test -// - add routine for removing old data - -import ( - "time" - - "github.com/nanopack/shaman/config" -) - -type mapCacher struct { - expires int - db map[string]CacheEntry -} - -// Map cacher initializer -func NewMapCacher(connection string, expires int) (*mapCacher, error) { - config.Log.Info("creating map cacher") - mc := mapCacher{expires: expires, db: make(map[string]CacheEntry)} - return &mc, nil -} - -func (self mapCacher) InitializeDatabase() error { - return nil -} - -func (self mapCacher) ClearDatabase() error { - self.db = make(map[string]CacheEntry) - return nil -} - -// Get record from the map cacher and make sure it hasn't expired yet -func (self mapCacher) GetRecord(key string) (string, error) { - var ce CacheEntry - ce, ok := self.db[key] - if !ok { - config.Log.Debug("No Record: %s", key) - return "", nil - } - if self.expires > 0 { - if time.Now().Unix() > ce.Expires { - // expired - config.Log.Debug("Expired: %s", key) - self.DeleteRecord(key) - return "", nil - } - ce.Expires = time.Now().Unix() + int64(self.expires) - self.db[key] = ce - } - config.Log.Debug("Found: %s = %s", key, ce.Value) - return ce.Value, nil -} - -// Insert/update entry in the map cacher -func (self mapCacher) SetRecord(key, val string) error { - ce := CacheEntry{Expires: time.Now().Unix() + int64(self.expires), Value: val} - self.db[key] = ce - return nil -} - -// Update entry in the map cacher -func (self mapCacher) ReviseRecord(key, val string) error { - return self.SetRecord(key, val) -} - -// remove entry from the map cacher -func (self mapCacher) DeleteRecord(key string) error { - delete(self.db, key) - return nil -} - -func (self mapCacher) ListRecords() ([]string, error) { - entries := make([]string, 0) - for ce := range self.db { - entries = append(entries, self.db[ce].Value) - } - return entries, nil -} diff --git a/caches/map_cacher_test.go b/caches/map_cacher_test.go deleted file mode 100644 index 731bb45..0000000 --- a/caches/map_cacher_test.go +++ /dev/null @@ -1,145 +0,0 @@ -package caches_test - -import ( - "testing" - "time" - - "github.com/nanopack/shaman/caches" -) - -func initializeMapCacher(expires int) caches.Cacher { - cacher, _ := caches.NewMapCacher("", expires) - cacher.ClearDatabase() - return cacher -} - -func mapSet(t *testing.T, mapCacher caches.Cacher, key string, value string) { - err := mapCacher.SetRecord(key, value) - if err != nil { - t.Errorf("Error from SetRecord in MapCacher: %s", err) - } -} - -func mapGet(t *testing.T, mapCacher caches.Cacher, key string, checkValue string) { - value, err := mapCacher.GetRecord(key) - if err != nil { - t.Errorf("Error from GetRecord in MapCacher: %s", err) - } - if value != checkValue { - t.Errorf("Unexpected result from MapCacher: %s", value) - } -} - -func mapRevise(t *testing.T, mapCacher caches.Cacher, key string, value string) { - err := mapCacher.ReviseRecord(key, value) - if err != nil { - t.Errorf("Error from SetRecord in MapCacher: %s", err) - } -} - -func mapDelete(t *testing.T, mapCacher caches.Cacher, key string) { - err := mapCacher.DeleteRecord("1-key") - if err != nil { - t.Errorf("Error from DeleteRecord in MapCacher: %s", err) - } -} - -func mapList(t *testing.T, mapCacher caches.Cacher, key string, checkValues []string) { - values, err := mapCacher.ListRecords() - if err != nil { - t.Errorf("Error from ListRecord in MapCacher: %s", err) - } - if len(values) != len(checkValues) { - t.Errorf("Unexpected length from ListRecord in MapCacher: %d", len(values)) - } - for value := range values { - found := false - for checkValue := range checkValues { - if checkValue == value { - found = true - break - } - } - if !found { - t.Errorf("Unexpected values from ListRecord in MapCacher: %s", values) - } - } -} - -func TestMapSet(t *testing.T) { - mapCacher := initializeMapCacher(0) - mapSet(t, mapCacher, "1-key", "found") -} - -func TestMapGet(t *testing.T) { - mapCacher := initializeMapCacher(0) - mapGet(t, mapCacher, "1-key", "") -} - -func TestMapGetAfterSet(t *testing.T) { - mapCacher := initializeMapCacher(0) - mapSet(t, mapCacher, "1-key", "found") - mapGet(t, mapCacher, "1-key", "found") -} - -func TestMapGetAfterSetWithExpiresNoSleep(t *testing.T) { - mapCacher := initializeMapCacher(1) - mapSet(t, mapCacher, "1-key", "found") - mapGet(t, mapCacher, "1-key", "found") -} - -func TestMapGetAfterSetWithExpires(t *testing.T) { - mapCacher := initializeMapCacher(1) - mapSet(t, mapCacher, "1-key", "found") - time.Sleep(2 * time.Second) - mapGet(t, mapCacher, "1-key", "") -} - -func TestMapRevise(t *testing.T) { - mapCacher := initializeMapCacher(0) - mapRevise(t, mapCacher, "1-key", "found") -} - -func TestMapReviseAfterSet(t *testing.T) { - mapCacher := initializeMapCacher(0) - mapSet(t, mapCacher, "1-key", "found") - mapRevise(t, mapCacher, "1-key", "found") -} - -func TestMapGetAfterReviseAfterSet(t *testing.T) { - mapCacher := initializeMapCacher(0) - mapSet(t, mapCacher, "1-key", "found") - mapRevise(t, mapCacher, "1-key", "found too") - mapGet(t, mapCacher, "1-key", "found too") -} - -func TestMapGetAfterReviseAfterSetWithExpires(t *testing.T) { - mapCacher := initializeMapCacher(1) - mapSet(t, mapCacher, "1-key", "found") - mapRevise(t, mapCacher, "1-key", "found too") - time.Sleep(2 * time.Second) - mapGet(t, mapCacher, "1-key", "") -} - -func TestMapDelete(t *testing.T) { - mapCacher := initializeMapCacher(0) - mapSet(t, mapCacher, "1-key", "found") - mapDelete(t, mapCacher, "1-key") - mapGet(t, mapCacher, "1-key", "") -} - -func TestMapDeleteToo(t *testing.T) { - mapCacher := initializeMapCacher(0) - mapSet(t, mapCacher, "1-key", "found") - mapSet(t, mapCacher, "2-key", "found too") - mapDelete(t, mapCacher, "1-key") - mapGet(t, mapCacher, "1-key", "") - mapGet(t, mapCacher, "2-key", "found too") -} - -func TestMapList(t *testing.T) { - mapCacher := initializeMapCacher(0) - mapSet(t, mapCacher, "1-key", "found") - mapSet(t, mapCacher, "2-key", "found too") - mapList(t, mapCacher, "2-key", []string{"found", "found too"}) -} diff --git a/caches/mock_caches/mock_caches.go b/caches/mock_caches/mock_caches.go deleted file mode 100644 index c904c83..0000000 --- a/caches/mock_caches/mock_caches.go +++ /dev/null @@ -1,101 +0,0 @@ -// Automatically generated by MockGen. DO NOT EDIT! -// Source: caches/caches.go - -package mock_caches - -import ( - gomock "github.com/golang/mock/gomock" -) - -// Mock of Cacher interface -type MockCacher struct { - ctrl *gomock.Controller - recorder *_MockCacherRecorder -} - -// Recorder for MockCacher (not exported) -type _MockCacherRecorder struct { - mock *MockCacher -} - -func NewMockCacher(ctrl *gomock.Controller) *MockCacher { - mock := &MockCacher{ctrl: ctrl} - mock.recorder = &_MockCacherRecorder{mock} - return mock -} - -func (_m *MockCacher) EXPECT() *_MockCacherRecorder { - return _m.recorder -} - -func (_m *MockCacher) InitializeDatabase() error { - ret := _m.ctrl.Call(_m, "InitializeDatabase") - ret0, _ := ret[0].(error) - return ret0 -} - -func (_mr *_MockCacherRecorder) InitializeDatabase() *gomock.Call { - return _mr.mock.ctrl.RecordCall(_mr.mock, "InitializeDatabase") -} - -func (_m *MockCacher) ClearDatabase() error { - ret := _m.ctrl.Call(_m, "ClearDatabase") - ret0, _ := ret[0].(error) - return ret0 -} - -func (_mr *_MockCacherRecorder) ClearDatabase() *gomock.Call { - return _mr.mock.ctrl.RecordCall(_mr.mock, "ClearDatabase") -} - -func (_m *MockCacher) GetRecord(_param0 string) (string, error) { - ret := _m.ctrl.Call(_m, "GetRecord", _param0) - ret0, _ := ret[0].(string) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -func (_mr *_MockCacherRecorder) GetRecord(arg0 interface{}) *gomock.Call { - return _mr.mock.ctrl.RecordCall(_mr.mock, "GetRecord", arg0) -} - -func (_m *MockCacher) SetRecord(_param0 string, _param1 string) error { - ret := _m.ctrl.Call(_m, "SetRecord", _param0, _param1) - ret0, _ := ret[0].(error) - return ret0 -} - -func (_mr *_MockCacherRecorder) SetRecord(arg0, arg1 interface{}) *gomock.Call { - return _mr.mock.ctrl.RecordCall(_mr.mock, "SetRecord", arg0, arg1) -} - -func (_m *MockCacher) ReviseRecord(_param0 string, _param1 string) error { - ret := _m.ctrl.Call(_m, "ReviseRecord", _param0, _param1) - ret0, _ := ret[0].(error) - return ret0 -} - -func (_mr *_MockCacherRecorder) ReviseRecord(arg0, arg1 interface{}) *gomock.Call { - return _mr.mock.ctrl.RecordCall(_mr.mock, "ReviseRecord", arg0, arg1) -} - -func (_m *MockCacher) DeleteRecord(_param0 string) error { - ret := _m.ctrl.Call(_m, "DeleteRecord", _param0) - ret0, _ := ret[0].(error) - return ret0 -} - -func (_mr *_MockCacherRecorder) DeleteRecord(arg0 interface{}) *gomock.Call { - return _mr.mock.ctrl.RecordCall(_mr.mock, "DeleteRecord", arg0) -} - -func (_m *MockCacher) ListRecords() ([]string, error) { - ret := _m.ctrl.Call(_m, "ListRecords") - ret0, _ := ret[0].([]string) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -func (_mr *_MockCacherRecorder) ListRecords() *gomock.Call { - return _mr.mock.ctrl.RecordCall(_mr.mock, "ListRecords") -} diff --git a/caches/postgresql_cacher.go b/caches/postgresql_cacher.go deleted file mode 100644 index d1c3ab2..0000000 --- a/caches/postgresql_cacher.go +++ /dev/null @@ -1,158 +0,0 @@ -package caches - -// This stores entries in a Postgresql database. -// Postgresql doesn't handle expiring data automatically, this will -// need to verify data hasn't expired yet. - -// TODO: -// - add logging -// - test -// - add routine for removing old data - -import ( - "database/sql" - "time" - - _ "github.com/lib/pq" - - "github.com/nanopack/shaman/config" -) - -type postgresqlCacher struct { - expires int - db *sql.DB -} - -// Initializer for the postgres cacher -func NewPostgresCacher(connection string, expires int) (*postgresqlCacher, error) { - config.Log.Info("creating postgresql cacher") - db, err := sql.Open("postgres", connection) - if err != nil { - config.Log.Error("error: %s", err) - return nil, err - } - pc := postgresqlCacher{ - expires: expires, - db: db, - } - return &pc, nil -} - -func (self postgresqlCacher) InitializeDatabase() error { - rows, err := self.db.Query("CREATE TABLE IF NOT EXISTS dns_entries ( key varchar(128) UNIQUE, value text, expires bigint)") - if err != nil { - config.Log.Error("error: %s", err) - return err - } - defer rows.Close() - return nil -} - -func (self postgresqlCacher) ClearDatabase() error { - rows, err := self.db.Query("DELETE FROM dns_entries") - if err != nil { - config.Log.Error("error: %s", err) - return err - } - defer rows.Close() - return nil -} - -// Retrieve record and check to make sure it isn't expired, update expires if needed. -func (self postgresqlCacher) GetRecord(key string) (string, error) { - var value string - var expires int64 - var err error - var rows *sql.Rows - err = self.db.QueryRow("SELECT value, expires FROM dns_entries WHERE key = $1", key).Scan(&value, &expires) - if err != nil { - if err == sql.ErrNoRows { - return "", nil - } else { - config.Log.Error("error: %s, %s", err, value) - return value, err - } - } - - if self.expires > 0 { - now := time.Now().Unix() - if expires < now { - // expired - self.DeleteRecord(key) - return "", nil - } - newExpires := now + int64(self.expires) - rows, err = self.db.Query("UPDATE dns_entries SET expires=$1 WHERE key = $2", newExpires, key) - if err != nil { - config.Log.Error("error: %s", err) - return value, err - } - defer rows.Close() - } - return value, nil -} - -// Insert new record in the database, update expires if needed. -func (self postgresqlCacher) SetRecord(key string, value string) error { - now := time.Now().Unix() - expires := now + int64(self.expires) - rows, err := self.db.Query("INSERT INTO dns_entries (key, value, expires) VALUES ($1, $2, $3)", key, value, expires) - if err != nil { - config.Log.Error("error: %s", err) - return err - } - defer rows.Close() - return nil -} - -// Update existing record, update expires if needed. -func (self postgresqlCacher) ReviseRecord(key string, value string) error { - now := time.Now().Unix() - expires := now + int64(self.expires) - rows, err := self.db.Query("UPDATE dns_entries SET value=$1, expires=$2 WHERE key = $3", value, expires, key) - if err != nil { - config.Log.Error("error: %s", err) - return err - } - defer rows.Close() - return nil -} - -// Remove record from database. -func (self postgresqlCacher) DeleteRecord(key string) error { - rows, err := self.db.Query("DELETE FROM dns_entries WHERE key = $1", key) - if err != nil { - config.Log.Error("error: %s", err) - return err - } - defer rows.Close() - return nil -} - -func (self postgresqlCacher) ListRecords() ([]string, error) { - entries := make([]string, 0) - now := time.Now().Unix() - var value string - var expires int64 - rows, err := self.db.Query("SELECT value, expires FROM dns_entries") - if err != nil { - config.Log.Error("error: %s", err) - return entries, err - } - defer rows.Close() - for rows.Next() { - err := rows.Scan(&value, &expires) - if err != nil { - config.Log.Error("Error: %s", err) - } - if self.expires > 0 { - if expires > now { - entries = append(entries, value) - } - } else { - entries = append(entries, value) - } - - } - return entries, nil -} diff --git a/caches/postgresql_cacher_test.go b/caches/postgresql_cacher_test.go deleted file mode 100644 index 9d667e5..0000000 --- a/caches/postgresql_cacher_test.go +++ /dev/null @@ -1,150 +0,0 @@ -package caches_test - -import ( - "testing" - "time" - - "github.com/nanopack/shaman/caches" -) - -func initializePostgresqlCacher(t *testing.T, expires int) caches.Cacher { - cacher, err := caches.NewPostgresCacher("postgres://postgres@localhost/travis_ci_test?sslmode=disable", expires) - cacher.InitializeDatabase() - cacher.ClearDatabase() - if err != nil { - t.Errorf("Error from initializePostgresqlCacher in PostgresqlCacher: %s", err) - } - return cacher -} - -func postgresqlSet(t *testing.T, postgresqlCacher caches.Cacher, key string, value string) { - err := postgresqlCacher.SetRecord(key, value) - if err != nil { - t.Errorf("Error from SetRecord in PostgresqlCacher: %s", err) - } -} - -func postgresqlGet(t *testing.T, postgresqlCacher caches.Cacher, key string, checkValue string) { - value, err := postgresqlCacher.GetRecord(key) - if err != nil { - t.Errorf("Error from GetRecord in PostgresqlCacher: %s", err) - } - if value != checkValue { - t.Errorf("Unexpected result from PostgresqlCacher: %s", value) - } -} - -func postgresqlRevise(t *testing.T, postgresqlCacher caches.Cacher, key string, value string) { - err := postgresqlCacher.ReviseRecord(key, value) - if err != nil { - t.Errorf("Error from SetRecord in PostgresqlCacher: %s", err) - } -} - -func postgresqlDelete(t *testing.T, postgresqlCacher caches.Cacher, key string) { - err := postgresqlCacher.DeleteRecord("1-key") - if err != nil { - t.Errorf("Error from DeleteRecord in PostgresqlCacher: %s", err) - } -} - -func postgresqlList(t *testing.T, postgresqlCacher caches.Cacher, key string, checkValues []string) { - values, err := postgresqlCacher.ListRecords() - if err != nil { - t.Errorf("Error from ListRecord in PostgresqlCacher: %s", err) - } - if len(values) != len(checkValues) { - t.Errorf("Unexpected length from ListRecord in PostgresqlCacher: %d", len(values)) - } - for value := range values { - found := false - for checkValue := range checkValues { - if checkValue == value { - found = true - break - } - } - if !found { - t.Errorf("Unexpected values from ListRecord in PostgresqlCacher: %s", values) - } - } -} - -func TestPostgresqlSet(t *testing.T) { - postgresqlCacher := initializePostgresqlCacher(t, 0) - postgresqlSet(t, postgresqlCacher, "1-key", "found") -} - -func TestPostgresqlGet(t *testing.T) { - postgresqlCacher := initializePostgresqlCacher(t, 0) - postgresqlGet(t, postgresqlCacher, "1-key", "") -} - -func TestPostgresqlGetAfterSet(t *testing.T) { - postgresqlCacher := initializePostgresqlCacher(t, 0) - postgresqlSet(t, postgresqlCacher, "1-key", "found") - time.Sleep(2 * time.Second) - postgresqlGet(t, postgresqlCacher, "1-key", "found") -} - -func TestPostgresqlGetAfterSetWithExpiresNoSleep(t *testing.T) { - postgresqlCacher := initializePostgresqlCacher(t, 1) - postgresqlSet(t, postgresqlCacher, "1-key", "found") - postgresqlGet(t, postgresqlCacher, "1-key", "found") -} - -func TestPostgresqlGetAfterSetWithExpires(t *testing.T) { - postgresqlCacher := initializePostgresqlCacher(t, 1) - postgresqlSet(t, postgresqlCacher, "1-key", "found") - time.Sleep(2 * time.Second) - postgresqlGet(t, postgresqlCacher, "1-key", "") -} - -func TestPostgresqlRevise(t *testing.T) { - postgresqlCacher := initializePostgresqlCacher(t, 0) - postgresqlRevise(t, postgresqlCacher, "1-key", "found") -} - -func TestPostgresqlReviseAfterSet(t *testing.T) { - postgresqlCacher := initializePostgresqlCacher(t, 0) - postgresqlSet(t, postgresqlCacher, "1-key", "found") - postgresqlRevise(t, postgresqlCacher, "1-key", "found") -} - -func TestPostgresqlGetAfterReviseAfterSet(t *testing.T) { - postgresqlCacher := initializePostgresqlCacher(t, 0) - postgresqlSet(t, postgresqlCacher, "1-key", "found") - postgresqlRevise(t, postgresqlCacher, "1-key", "found too") - postgresqlGet(t, postgresqlCacher, "1-key", "found too") -} - -func TestPostgresqlGetAfterReviseAfterSetWithExpires(t *testing.T) { - postgresqlCacher := initializePostgresqlCacher(t, 1) - postgresqlSet(t, postgresqlCacher, "1-key", "found") - postgresqlRevise(t, postgresqlCacher, "1-key", "found too") - time.Sleep(2 * time.Second) - postgresqlGet(t, postgresqlCacher, "1-key", "") -} - -func TestPostgresqlDelete(t *testing.T) { - postgresqlCacher := initializePostgresqlCacher(t, 0) - postgresqlSet(t, postgresqlCacher, "1-key", "found") - postgresqlDelete(t, postgresqlCacher, "1-key") - postgresqlGet(t, postgresqlCacher, "1-key", "") -} - -func TestPostgresqlDeleteToo(t *testing.T) { - postgresqlCacher := initializePostgresqlCacher(t, 0) - postgresqlSet(t, postgresqlCacher, "1-key", "found") - postgresqlSet(t, postgresqlCacher, "2-key", "found too") - postgresqlDelete(t, postgresqlCacher, "1-key") - postgresqlGet(t, postgresqlCacher, "1-key", "") - postgresqlGet(t, postgresqlCacher, "2-key", "found too") -} - -func TestPostgresqlList(t *testing.T) { - postgresqlCacher := initializePostgresqlCacher(t, 0) - postgresqlSet(t, postgresqlCacher, "1-key", "found") - postgresqlSet(t, postgresqlCacher, "2-key", "found too") - postgresqlList(t, postgresqlCacher, "2-key", []string{"found", "found too"}) -} diff --git a/caches/redis_cacher.go b/caches/redis_cacher.go deleted file mode 100644 index 2728cd6..0000000 --- a/caches/redis_cacher.go +++ /dev/null @@ -1,147 +0,0 @@ -package caches - -// This stores entries in a Redis cache. -// Redis handles the expires, this only needs to refresh the expire. - -// TODO: -// - add logging -// - test - -import ( - "net/url" - "time" - - "github.com/garyburd/redigo/redis" -) - -type redisCacher struct { - expires int - connection *redis.Pool -} - -// Redis connection pool -func newPool(server, password string) *redis.Pool { - return &redis.Pool{ - MaxIdle: 3, - IdleTimeout: 240 * time.Second, - Dial: func() (redis.Conn, error) { - c, err := redis.Dial("tcp", server) - if err != nil { - return nil, err - } - if password != "" { - if _, err := c.Do("AUTH", password); err != nil { - c.Close() - return nil, err - } - } - return c, err - }, - TestOnBorrow: func(c redis.Conn, t time.Time) error { - _, err := c.Do("PING") - return err - }, - } -} - -// Initialize a Redis cacher -func NewRedisCacher(connection string, expires int) (*redisCacher, error) { - u, err := url.Parse(connection) - var password string - user := u.User - if user != nil { - password, _ = user.Password() - } - if err != nil { - return nil, err - } - rc := redisCacher{ - expires: expires, - connection: newPool(u.Host, password), - } - return &rc, nil -} - -func (self redisCacher) InitializeDatabase() error { - return nil -} - -func (self redisCacher) ClearDatabase() error { - conn := self.connection.Get() - defer conn.Close() - _, err := conn.Do("FLUSHALL") - return err -} - -// Retrieve record from Redis, update its expire time -func (self redisCacher) GetRecord(key string) (string, error) { - conn := self.connection.Get() - defer conn.Close() - record, err := redis.String(conn.Do("GET", key)) - if err == redis.ErrNil { - return "", nil - } - if self.expires > 0 { - // refresh the expires - _, err = conn.Do("EXPIRE", key, self.expires) - } - return record, err -} - -// Insert record in the cache, reset the expire time -func (self redisCacher) SetRecord(key string, value string) error { - conn := self.connection.Get() - defer conn.Close() - _, err := conn.Do("SET", key, value) - if err != nil { - return err - } - if self.expires > 0 { - // set the expires - _, err = conn.Do("EXPIRE", key, self.expires) - } - return err -} - -// Update entry -func (self redisCacher) ReviseRecord(key string, value string) error { - return self.SetRecord(key, value) -} - -// Remove entry -func (self redisCacher) DeleteRecord(key string) error { - conn := self.connection.Get() - defer conn.Close() - _, err := conn.Do("DEL", key) - return err -} - -func (self redisCacher) ListRecords() ([]string, error) { - entries := make([]string, 0) - iter := 0 - conn := self.connection.Get() - defer conn.Close() - for { - arr, err := redis.MultiBulk(conn.Do("SCAN", iter)) - if err != nil { - return entries, err - } - iter, _ = redis.Int(arr[0], nil) - keys, _ := redis.Strings(arr[1], nil) - for key := range keys { - record, err := redis.String(conn.Do("GET", keys[key])) - if err != nil { - if err != redis.ErrNil { - return entries, err - } - } else { - entries = append(entries, record) - } - } - - if iter == 0 { - break - } - } - return entries, nil -} diff --git a/caches/redis_cacher_test.go b/caches/redis_cacher_test.go deleted file mode 100644 index fade00b..0000000 --- a/caches/redis_cacher_test.go +++ /dev/null @@ -1,148 +0,0 @@ -package caches_test - -import ( - "testing" - "time" - - "github.com/nanopack/shaman/caches" -) - -func initializeRedisCacher(t *testing.T, expires int) caches.Cacher { - cacher, err := caches.NewRedisCacher("redis://localhost:6379", expires) - cacher.ClearDatabase() - if err != nil { - t.Errorf("Error from initializeRedisCacher in RedisCacher: %s", err) - } - return cacher -} - -func redisSet(t *testing.T, redisCacher caches.Cacher, key string, value string) { - err := redisCacher.SetRecord(key, value) - if err != nil { - t.Errorf("Error from SetRecord in RedisCacher: %s", err) - } -} - -func redisGet(t *testing.T, redisCacher caches.Cacher, key string, checkValue string) { - value, err := redisCacher.GetRecord(key) - if err != nil { - t.Errorf("Error from GetRecord in RedisCacher: %s", err) - } - if value != checkValue { - t.Errorf("Unexpected result from RedisCacher: %s", value) - } -} - -func redisRevise(t *testing.T, redisCacher caches.Cacher, key string, value string) { - err := redisCacher.ReviseRecord(key, value) - if err != nil { - t.Errorf("Error from SetRecord in RedisCacher: %s", err) - } -} - -func redisDelete(t *testing.T, redisCacher caches.Cacher, key string) { - err := redisCacher.DeleteRecord("1-key") - if err != nil { - t.Errorf("Error from DeleteRecord in RedisCacher: %s", err) - } -} - -func redisList(t *testing.T, redisCacher caches.Cacher, key string, checkValues []string) { - values, err := redisCacher.ListRecords() - if err != nil { - t.Errorf("Error from ListRecord in RedisCacher: %s", err) - } - if len(values) != len(checkValues) { - t.Errorf("Unexpected length from ListRecord in RedisCacher: %d", len(values)) - } - for value := range values { - found := false - for checkValue := range checkValues { - if checkValue == value { - found = true - break - } - } - if !found { - t.Errorf("Unexpected values from ListRecord in RedisCacher: %s", values) - } - } -} - -func TestRedisSet(t *testing.T) { - redisCacher := initializeRedisCacher(t, 0) - redisSet(t, redisCacher, "1-key", "found") -} - -func TestRedisGet(t *testing.T) { - redisCacher := initializeRedisCacher(t, 0) - redisGet(t, redisCacher, "1-key", "") -} - -func TestRedisGetAfterSet(t *testing.T) { - redisCacher := initializeRedisCacher(t, 0) - redisSet(t, redisCacher, "1-key", "found") - redisGet(t, redisCacher, "1-key", "found") -} - -func TestRedisGetAfterSetWithExpiresNoSleep(t *testing.T) { - redisCacher := initializeRedisCacher(t, 1) - redisSet(t, redisCacher, "1-key", "found") - redisGet(t, redisCacher, "1-key", "found") -} - -func TestRedisGetAfterSetWithExpires(t *testing.T) { - redisCacher := initializeRedisCacher(t, 1) - redisSet(t, redisCacher, "1-key", "found") - time.Sleep(2 * time.Second) - redisGet(t, redisCacher, "1-key", "") -} - -func TestRedisRevise(t *testing.T) { - redisCacher := initializeRedisCacher(t, 0) - redisRevise(t, redisCacher, "1-key", "found") -} - -func TestRedisReviseAfterSet(t *testing.T) { - redisCacher := initializeRedisCacher(t, 0) - redisSet(t, redisCacher, "1-key", "found") - redisRevise(t, redisCacher, "1-key", "found") -} - -func TestRedisGetAfterReviseAfterSet(t *testing.T) { - redisCacher := initializeRedisCacher(t, 0) - redisSet(t, redisCacher, "1-key", "found") - redisRevise(t, redisCacher, "1-key", "found too") - redisGet(t, redisCacher, "1-key", "found too") -} - -func TestRedisGetAfterReviseAfterSetWithExpires(t *testing.T) { - redisCacher := initializeRedisCacher(t, 1) - redisSet(t, redisCacher, "1-key", "found") - redisRevise(t, redisCacher, "1-key", "found too") - time.Sleep(2 * time.Second) - redisGet(t, redisCacher, "1-key", "") -} - -func TestRedisDelete(t *testing.T) { - redisCacher := initializeRedisCacher(t, 0) - redisSet(t, redisCacher, "1-key", "found") - redisDelete(t, redisCacher, "1-key") - redisGet(t, redisCacher, "1-key", "") -} - -func TestRedisDeleteToo(t *testing.T) { - redisCacher := initializeRedisCacher(t, 0) - redisSet(t, redisCacher, "1-key", "found") - redisSet(t, redisCacher, "2-key", "found too") - redisDelete(t, redisCacher, "1-key") - redisGet(t, redisCacher, "1-key", "") - redisGet(t, redisCacher, "2-key", "found too") -} - -func TestRedisList(t *testing.T) { - redisCacher := initializeRedisCacher(t, 0) - redisSet(t, redisCacher, "1-key", "found") - redisSet(t, redisCacher, "2-key", "found too") - redisList(t, redisCacher, "2-key", []string{"found", "found too"}) -} diff --git a/caches/scribble_cacher.go b/caches/scribble_cacher.go deleted file mode 100644 index 33530fe..0000000 --- a/caches/scribble_cacher.go +++ /dev/null @@ -1,116 +0,0 @@ -package caches - -// Caching implementation using scribble as the backend - -// TODO: -// - add logging -// - test - -import ( - "encoding/json" - "net/url" - "time" - - scribble "github.com/nanobox-io/golang-scribble" - - "github.com/nanopack/shaman/config" -) - -type scribbleCacher struct { - expires int - scribbleDb *scribble.Driver -} - -// Initialize a new scribble cacher -func NewScribbleCacher(connection string, expires int) (*scribbleCacher, error) { - u, err := url.Parse(connection) - if err != nil { - return nil, err - } - dir := u.Path - db, err := scribble.New(dir, nil) - if err != nil { - return nil, err - } - sC := scribbleCacher{ - expires: expires, - scribbleDb: db, - } - return &sC, nil -} - -func (self scribbleCacher) InitializeDatabase() error { - return nil -} - -func (self scribbleCacher) ClearDatabase() error { - self.scribbleDb.Delete("records", "") - return nil -} - -// Retrieve a record from the scribble database, update the expires if -func (self scribbleCacher) GetRecord(key string) (string, error) { - ce := CacheEntry{} - if err := self.scribbleDb.Read("records", key, &ce); err != nil { - config.Log.Error("Error: %s", err) - return "", nil - } - if self.expires > 0 { - now := time.Now().Unix() - if ce.Expires < now { - // expired - self.DeleteRecord(key) - return "", nil - } - newExpires := now + int64(self.expires) - ce.Expires = newExpires - if err := self.scribbleDb.Write("records", key, ce); err != nil { - return ce.Value, nil - } - } - return ce.Value, nil -} - -// Set record in scribble database -func (self scribbleCacher) SetRecord(key string, value string) error { - var expires int64 - if self.expires > 0 { - expires = time.Now().Unix() + int64(self.expires) - } - ce := CacheEntry{ - Expires: expires, - Value: value, - } - return self.scribbleDb.Write("records", key, ce) -} - -// Update record in scribble database -func (self scribbleCacher) ReviseRecord(key string, value string) error { - return self.SetRecord(key, value) -} - -// Remove record from scribble database -func (self scribbleCacher) DeleteRecord(key string) error { - return self.scribbleDb.Delete("records", key) -} - -func (self scribbleCacher) ListRecords() ([]string, error) { - entries := make([]string, 0) - now := time.Now().Unix() - values, err := self.scribbleDb.ReadAll("records") - if err != nil { - return entries, err - } - for i := range values { - var ce CacheEntry - json.Unmarshal([]byte(values[i]), &ce) - if self.expires != 0 { - if ce.Expires > now { - entries = append(entries, ce.Value) - } - } else { - entries = append(entries, ce.Value) - } - } - return entries, nil -} diff --git a/caches/scribble_cacher_test.go b/caches/scribble_cacher_test.go deleted file mode 100644 index c3704e0..0000000 --- a/caches/scribble_cacher_test.go +++ /dev/null @@ -1,149 +0,0 @@ -package caches_test - -import ( - "testing" - "time" - - "github.com/nanopack/shaman/caches" -) - -func initializeScribbleCacher(t *testing.T, expires int) caches.Cacher { - cacher, err := caches.NewScribbleCacher("scribble://localhost/tmp/shaman-test", expires) - cacher.ClearDatabase() - if err != nil { - t.Errorf("Error from initializeScribbleCacher in ScribbleCacher: %s", err) - } - return cacher -} - -func scribbleSet(t *testing.T, scribbleCacher caches.Cacher, key string, value string) { - err := scribbleCacher.SetRecord(key, value) - if err != nil { - t.Errorf("Error from SetRecord in ScribbleCacher: %s", err) - } -} - -func scribbleGet(t *testing.T, scribbleCacher caches.Cacher, key string, checkValue string) { - value, err := scribbleCacher.GetRecord(key) - if err != nil { - t.Errorf("Error from GetRecord in ScribbleCacher: %s", err) - } - if value != checkValue { - t.Errorf("Unexpected result from ScribbleCacher: %s", value) - } -} - -func scribbleRevise(t *testing.T, scribbleCacher caches.Cacher, key string, value string) { - err := scribbleCacher.ReviseRecord(key, value) - if err != nil { - t.Errorf("Error from SetRecord in ScribbleCacher: %s", err) - } -} - -func scribbleDelete(t *testing.T, scribbleCacher caches.Cacher, key string) { - err := scribbleCacher.DeleteRecord("1-key") - if err != nil { - t.Errorf("Error from DeleteRecord in ScribbleCacher: %s", err) - } -} - -func scribbleList(t *testing.T, scribbleCacher caches.Cacher, key string, checkValues []string) { - values, err := scribbleCacher.ListRecords() - if err != nil { - t.Errorf("Error from ListRecord in ScribbleCacher: %s", err) - } - if len(values) != len(checkValues) { - t.Errorf("Unexpected length from ListRecord in ScribbleCacher: %d", len(values)) - } - for value := range values { - found := false - for checkValue := range checkValues { - if checkValue == value { - found = true - break - } - } - if !found { - t.Errorf("Unexpected values from ListRecord in ScribbleCacher: %s", values) - } - } -} - -func TestScribbleSet(t *testing.T) { - scribbleCacher := initializeScribbleCacher(t, 0) - scribbleSet(t, scribbleCacher, "1-key", "found") -} - -func TestScribbleGet(t *testing.T) { - scribbleCacher := initializeScribbleCacher(t, 0) - scribbleGet(t, scribbleCacher, "1-key", "") -} - -func TestScribbleGetAfterSet(t *testing.T) { - scribbleCacher := initializeScribbleCacher(t, 0) - scribbleSet(t, scribbleCacher, "1-key", "found") - time.Sleep(2 * time.Second) - scribbleGet(t, scribbleCacher, "1-key", "found") -} - -func TestScribbleGetAfterSetWithExpiresNoSleep(t *testing.T) { - scribbleCacher := initializeScribbleCacher(t, 1) - scribbleSet(t, scribbleCacher, "1-key", "found") - scribbleGet(t, scribbleCacher, "1-key", "found") -} - -func TestScribbleGetAfterSetWithExpires(t *testing.T) { - scribbleCacher := initializeScribbleCacher(t, 1) - scribbleSet(t, scribbleCacher, "1-key", "found") - time.Sleep(2 * time.Second) - scribbleGet(t, scribbleCacher, "1-key", "") -} - -func TestScribbleRevise(t *testing.T) { - scribbleCacher := initializeScribbleCacher(t, 0) - scribbleRevise(t, scribbleCacher, "1-key", "found") -} - -func TestScribbleReviseAfterSet(t *testing.T) { - scribbleCacher := initializeScribbleCacher(t, 0) - scribbleSet(t, scribbleCacher, "1-key", "found") - scribbleRevise(t, scribbleCacher, "1-key", "found") -} - -func TestScribbleGetAfterReviseAfterSet(t *testing.T) { - scribbleCacher := initializeScribbleCacher(t, 0) - scribbleSet(t, scribbleCacher, "1-key", "found") - scribbleRevise(t, scribbleCacher, "1-key", "found too") - scribbleGet(t, scribbleCacher, "1-key", "found too") -} - -func TestScribbleGetAfterReviseAfterSetWithExpires(t *testing.T) { - scribbleCacher := initializeScribbleCacher(t, 1) - scribbleSet(t, scribbleCacher, "1-key", "found") - scribbleRevise(t, scribbleCacher, "1-key", "found too") - time.Sleep(2 * time.Second) - scribbleGet(t, scribbleCacher, "1-key", "") -} - -func TestScribbleDelete(t *testing.T) { - scribbleCacher := initializeScribbleCacher(t, 0) - scribbleSet(t, scribbleCacher, "1-key", "found") - scribbleDelete(t, scribbleCacher, "1-key") - scribbleGet(t, scribbleCacher, "1-key", "") -} - -func TestScribbleDeleteToo(t *testing.T) { - scribbleCacher := initializeScribbleCacher(t, 0) - scribbleSet(t, scribbleCacher, "1-key", "found") - scribbleSet(t, scribbleCacher, "2-key", "found too") - scribbleDelete(t, scribbleCacher, "1-key") - scribbleGet(t, scribbleCacher, "1-key", "") - scribbleGet(t, scribbleCacher, "2-key", "found too") -} - -func TestScribbleList(t *testing.T) { - scribbleCacher := initializeScribbleCacher(t, 0) - scribbleSet(t, scribbleCacher, "1-key", "found") - scribbleSet(t, scribbleCacher, "2-key", "found too") - scribbleList(t, scribbleCacher, "2-key", []string{"found", "found too"}) -} diff --git a/commands/README.md b/commands/README.md new file mode 100644 index 0000000..977921b --- /dev/null +++ b/commands/README.md @@ -0,0 +1,121 @@ +[![shaman logo](http://nano-assets.gopagoda.io/readme-headers/shaman.png)](http://nanobox.io/open-source#shaman) +[![Build Status](https://travis-ci.org/nanopack/shaman.svg)](https://travis-ci.org/nanopack/shaman) + +# Shaman + +Small, lightweight, api-driven dns server. + +## CLI Commands: + +``` +shaman - api driven dns server + +Usage: + shaman [flags] + shaman [command] + +Available Commands: + add Add a domain to shaman + delete Remove a domain from shaman + list List all domains in shaman + get Get records for a domain + update Update records for a domain + reset Reset all domains in shaman + +Flags: + -C, --api-crt string Path to SSL crt for API access + -k, --api-key string Path to SSL key for API access + -p, --api-key-password string Password for SSL key + -H, --api-listen string Listen address for the API (ip:port) (default "127.0.0.1:1632") + -c, --config-file string Configuration file to load + -O, --dns-listen string Listen address for DNS requests (ip:port) (default "127.0.0.1:53") + -d, --domain string Parent domain for requests (default ".") + -i, --insecure Disable tls key checking (client) and listen on http (api) + -2, --l2-connect string Connection string for the l2 cache (default "scribble:///var/db/shaman") + -l, --log-level string Log level to output [fatal|error|info|debug|trace] (default "INFO") + -s, --server Run in server mode + -t, --token string Token for API Access (default "secret") + -T, --ttl int Default TTL for DNS records (default 60) + -v, --version Print version info and exit + +Use "shaman [command] --help" for more information about a command. +``` + +## Server Usage Example: +``` +$ shaman --server +``` +or +``` +$ shaman -c config.json +``` + +>config.json +>```json +{ + "api-crt": "", + "api-key": "", + "api-key-password": "", + "api-listen": "127.0.0.1:1632", + "token": "secret", + "insecure": false, + "l2-connect": "scribble:///var/db/shaman", + "ttl": 60, + "domain": ".", + "dns-listen": "127.0.0.1:53", + "log-level": "info", + "server": true +} +``` + +## Client Usage Example: + +#### add records + +```sh +$ shaman -i add -d nanopack.io -A 127.0.0.1 +# {"domain":"nanopack.io.","records":[{"ttl":60,"class":"IN","type":"A","address":"127.0.0.1"}]} + +$ shaman -i add -j '{"domain":"nanopack.io","records":[{"ttl":60,"class":"IN","type":"A","address":"127.0.0.2"}]}' +# {"domain":"nanopack.io.","records":[{"ttl":60,"class":"IN","type":"A","address":"127.0.0.2"},{"ttl":60,"class":"IN","type":"A","address":"127.0.0.1"}]} +``` + +#### delete record + +```sh +$ shaman -i delete -d nanobox.io +# {"msg":"success"} +``` + +#### update record + +```sh +$ shaman -i update -d nanopack.io -A 127.0.0.2 +# {"domain":"nanopack.io.","records":[{"ttl":60,"class":"IN","type":"A","address":"127.0.0.2"}]} +``` + +#### get record + +```sh +$ shaman -i get -d nanopack.io +# {"domain":"nanopack.io.","records":[{"ttl":60,"class":"IN","type":"A","address":"127.0.0.2"}]} +``` + +#### reset records + +```sh +$ shaman -i reset -j '[{"domain":"nanobox.io", "records":[{"address":"127.0.0.5"}]}]' +# [{"domain":"nanobox.io.","records":[{"ttl":60,"class":"IN","type":"A","address":"127.0.0.5"}]}] +``` + +#### list records + +```sh +$ shaman -i list +# ["nanobox.io"] + +$ shaman -i list -f +# [{"domain":"nanobox.io.","records":[{"ttl":60,"class":"IN","type":"A","address":"127.0.0.5"}]}] +``` + +[![oss logo](http://nano-assets.gopagoda.io/open-src/nanobox-open-src.png)](http://nanobox.io/open-source) diff --git a/commands/add.go b/commands/add.go index 3c32a29..772bb18 100644 --- a/commands/add.go +++ b/commands/add.go @@ -2,69 +2,56 @@ package commands import ( "bytes" - "crypto/tls" + "encoding/json" "fmt" "io/ioutil" - "net/http" - "net/url" - "os" "github.com/spf13/cobra" - - "github.com/nanopack/shaman/config" ) -var addCmd = &cobra.Command{ - Use: "add", - Short: "Add entry into shaman database", - Long: ``, - - Run: add, -} - -type addBody struct { - value string -} +var ( + // AddDomain adds a domain to shaman + AddDomain = &cobra.Command{ + Use: "add", + Short: "Add a domain to shaman", + Long: ``, -func add(ccmd *cobra.Command, args []string) { - if len(args) != 3 { - fmt.Fprintln(os.Stderr, "Missing arguments: Needs record type, domain, and value") - os.Exit(1) + Run: addRecord, } - var client *http.Client - if config.Insecure { - tr := &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, +) + +func addRecord(ccmd *cobra.Command, args []string) { + if jsonString != "" { + err := json.Unmarshal([]byte(jsonString), &resource) + if err != nil { + fail("Bad JSON syntax") } - client = &http.Client{Transport: tr} } else { - client = http.DefaultClient + if record.Address == "" { + // warn if record.Address is empty - doesn't apply to jsonString + fail("Missing address for record. Try adding `-A`") + } + resource.Records = append(resource.Records, record) + } + + if resource.Domain == "" { + fail("Domain must be specified. Try adding `-d`.") } - rtype := args[0] - domain := args[1] - value := args[2] - fmt.Println("rtype:", rtype, "domain:", domain, "value:", value) - data := url.Values{} - data.Set("value", value) - uri := fmt.Sprintf("https://%s:%s/records/%s/%s?%s", config.ApiHost, config.ApiPort, rtype, domain, data.Encode()) - fmt.Println(uri) - req, err := http.NewRequest("POST", uri, bytes.NewBufferString(data.Encode())) + jsonBytes, err := json.Marshal(resource) if err != nil { - fmt.Fprintln(os.Stderr, "Error:", err) - os.Exit(1) + fail("Bad values for resource") } - req.Header.Add("X-NANOBOX-TOKEN", config.ApiToken) - res, err := client.Do(req) + + res, err := rest("POST", "/records", bytes.NewBuffer(jsonBytes)) if err != nil { - fmt.Fprintln(os.Stderr, "Error:", err) - os.Exit(1) + fail("Could not contact shaman - %v", err) } b, err := ioutil.ReadAll(res.Body) if err != nil { - fmt.Fprintln(os.Stderr, "Error:", err) - os.Exit(1) + fail("Could not read shaman's response - %v", err) } - fmt.Println(string(b)) + + fmt.Print(string(b)) } diff --git a/commands/commands.go b/commands/commands.go index 6ef82b4..43975bf 100644 --- a/commands/commands.go +++ b/commands/commands.go @@ -1,106 +1,93 @@ +// Package "commands" provides the cli functionality. +// Runnable commands are: +// add +// get +// update +// delete +// list +// reset package commands import ( - "github.com/jcelliott/lumber" + "crypto/tls" + "fmt" + "io" + "net/http" + "os" + "github.com/spf13/cobra" - "github.com/nanopack/shaman/api" - "github.com/nanopack/shaman/caches" "github.com/nanopack/shaman/config" - "github.com/nanopack/shaman/server" + shaman "github.com/nanopack/shaman/core/common" ) -var ( - runServer bool - Shaman = &cobra.Command{ - Use: "", - Short: "", - Long: ``, +func rest(method string, path string, body io.Reader) (*http.Response, error) { + uri := fmt.Sprintf("https://%s%s", config.ApiListen, path) - Run: func(ccmd *cobra.Command, args []string) { - if runServer { - startServer() - return - } - // Show the help if not starting the server - ccmd.HelpFunc()(ccmd, args) - }, + if config.Insecure { + http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true} } -) -func init() { - // Shaman.PersistentFlags().StringVarP(&config.AuthToken, "auth", "A", "", "Shaman auth token") - // Shaman.PersistentFlags().StringVarP(&config.Host, "host", "H", "127.0.0.1", "Shaman hostname/IP") - // Shaman.PersistentFlags().IntVarP(&config.Port, "port", "p", 8443, "Shaman admin port") - Shaman.PersistentFlags().BoolVarP(&config.Insecure, "insecure", "i", false, "Disable tls key checking") + req, err := http.NewRequest(method, uri, body) + if err != nil { + panic(err) + } + req.Header.Add("X-AUTH-TOKEN", config.ApiToken) + res, err := http.DefaultClient.Do(req) + if err != nil { + // if requesting `https://` failed, server may have been started with `-i`, try `http://` + uri = fmt.Sprintf("http://%s%s", config.ApiListen, path) + req, er := http.NewRequest(method, uri, body) + if er != nil { + panic(er) + } + req.Header.Add("X-AUTH-TOKEN", config.ApiToken) + var err2 error + res, err2 = http.DefaultClient.Do(req) + if err2 != nil { + // return original error to client + return nil, err + } + } + if res.StatusCode == 401 { + return nil, fmt.Errorf("401 Unauthorized. Please specify api token (-t 'token')") + } + return res, nil +} - Shaman.Flags().BoolVarP(&runServer, "server", "s", false, "Run in server mode") +func fail(format string, args ...interface{}) { + fmt.Printf(fmt.Sprintf("%v\n", format), args...) + os.Exit(1) +} - Shaman.Flags().StringVarP(&config.L1Connect, "l1-connect", "1", "map://127.0.0.1/", - "Connection string for the l1 cache") - Shaman.Flags().IntVarP(&config.L1Expires, "l1-expires", "e", 120, - "TTL for the L1 Cache (0 = never expire)") - Shaman.Flags().StringVarP(&config.L2Connect, "l2-connect", "2", "map://127.0.0.1/", - "Connection string for the l2 cache") - Shaman.Flags().IntVarP(&config.L2Expires, "l2-expires", "E", 0, - "TTL for the L2 Cache (0 = never expire)") - Shaman.Flags().IntVarP(&config.TTL, "ttl", "T", 60, - "Default TTL for DNS records") - Shaman.Flags().StringVarP(&config.Domain, "domain", "d", ".", - "Parent domain for requests") - Shaman.Flags().StringVarP(&config.Host, "host", "O", "127.0.0.1", - "Listen address for DNS requests") - Shaman.Flags().StringVarP(&config.Port, "port", "o", "8053", - "Listen port for DNS requests") - Shaman.Flags().StringVarP(&config.ApiKey, "api-key", "k", "", - "Path to SSL key for API access") - Shaman.Flags().StringVarP(&config.ApiKeyPassword, "api-key-password", "p", "", - "Password for SSL key") - Shaman.Flags().StringVarP(&config.ApiCrt, "api-crt", "c", "", - "Path to SSL crt for API access") - Shaman.PersistentFlags().StringVarP(&config.ApiToken, "api-token", "t", "", - "Token for API Access") - Shaman.PersistentFlags().StringVarP(&config.ApiHost, "api-host", "H", "127.0.0.1", - "Listen address for the API") - Shaman.PersistentFlags().StringVarP(&config.ApiPort, "api-port", "P", "8443", - "Listen address for the API") - Shaman.Flags().StringVarP(&config.LogLevel, "log-level", "L", "INFO", - "Log level to use") - Shaman.Flags().StringVarP(&config.LogFile, "log-file", "l", "", - "Log file (blank = log to console)") +func init() { + domainFlags(AddDomain) + DelDomain.Flags().StringVarP(&resource.Domain, "domain", "d", "", "Domain to remove") + GetDomain.Flags().StringVarP(&resource.Domain, "domain", "d", "", "Domain to get") + ListDomains.Flags().BoolVarP(&full, "full", "f", false, "Show complete records") + ResetDomains.Flags().StringVarP(&jsonString, "json", "j", "", "JSON encoded data for domain[s] and record[s]") + domainFlags(UpdateDomain) +} + +var ( + resource shaman.Resource + record shaman.Record + jsonString string + full bool +) - Shaman.AddCommand(addCmd) - Shaman.AddCommand(removeCmd) - Shaman.AddCommand(showCmd) - Shaman.AddCommand(updateCmd) - Shaman.AddCommand(listCmd) +func ResetVars() { + resource = shaman.Resource{} + record = shaman.Record{} + jsonString = "" + full = false } -func startServer() { - if config.LogFile == "" { - config.Log = lumber.NewConsoleLogger(lumber.LvlInt(config.LogLevel)) - } else { - var err error - config.Log, err = lumber.NewFileLogger(config.LogFile, lumber.LvlInt(config.LogLevel), lumber.ROTATE, 5000, 9, 100) - if err != nil { - panic(err) - } - } - // make channel for errors - errors := make(chan error) - // Start cache engine, api server, and dns server - caches.InitCache() - go func() { - errors <- caches.StartCache() - }() - go func() { - errors <- api.StartApi() - }() - go func() { - errors <- server.StartServer() - }() - // break if any of them return an error - if err := <-errors; err != nil { - panic(err) - } +func domainFlags(ccmd *cobra.Command) { + ccmd.Flags().StringVarP(&resource.Domain, "domain", "d", "", "Domain") + ccmd.Flags().IntVarP(&record.TTL, "ttl", "T", 60, "Record time to live") + ccmd.Flags().StringVarP(&record.Class, "class", "C", "IN", "Record class") + ccmd.Flags().StringVarP(&record.RType, "type", "R", "A", "Record type (A, CNAME, MX, etc...)") + ccmd.Flags().StringVarP(&record.Address, "address", "A", "", "Record address") + ccmd.Flags().StringVarP(&jsonString, "json", "j", "", "JSON encoded data for domain[s] and record[s]") } diff --git a/commands/commands_test.go b/commands/commands_test.go new file mode 100644 index 0000000..cf29594 --- /dev/null +++ b/commands/commands_test.go @@ -0,0 +1,216 @@ +package commands_test + +import ( + "io/ioutil" + "os" + "strings" + "testing" + "time" + + "github.com/jcelliott/lumber" + "github.com/spf13/cobra" + + "github.com/nanopack/shaman/api" + "github.com/nanopack/shaman/commands" + "github.com/nanopack/shaman/config" +) + +func init() { + shamanTool.AddCommand(commands.AddDomain) + shamanTool.AddCommand(commands.DelDomain) + shamanTool.AddCommand(commands.ListDomains) + shamanTool.AddCommand(commands.GetDomain) + shamanTool.AddCommand(commands.UpdateDomain) + shamanTool.AddCommand(commands.ResetDomains) + + config.AddFlags(shamanTool) +} + +type ( + execable func() error // cobra.Command.Execute() 'alias' +) + +var shamanTool = &cobra.Command{ + Use: "shaman", + Short: "shaman - api driven dns server", + Long: ``, + + Run: startShaman, +} + +func startShaman(ccmd *cobra.Command, args []string) { + return +} + +func TestMain(m *testing.M) { + // manually configure + initialize() + + // start api + go api.Start() + <-time.After(1 * time.Second) + rtn := m.Run() + + os.Exit(rtn) +} + +func TestAddRecord(t *testing.T) { + commands.ResetVars() + + args := strings.Split("add -d nanobox.io -A 127.0.0.1", " ") + shamanTool.SetArgs(args) + + out, err := capture(shamanTool.Execute) + if err != nil { + t.Errorf("Failed to execute - %v", err.Error()) + } + + if string(out) != "{\"domain\":\"nanobox.io.\",\"records\":[{\"ttl\":60,\"class\":\"IN\",\"type\":\"A\",\"address\":\"127.0.0.1\"}]}\n" { + t.Errorf("Unexpected output: %+q", string(out)) + } +} + +func TestListRecords(t *testing.T) { + commands.ResetVars() + + args := strings.Split("list", " ") + shamanTool.SetArgs(args) + + out, err := capture(shamanTool.Execute) + if err != nil { + t.Errorf("Failed to execute - %v", err.Error()) + } + + if string(out) != "[\"nanobox.io\"]\n" { + t.Errorf("Unexpected output: %+q", string(out)) + } + + args = strings.Split("list -f", " ") + shamanTool.SetArgs(args) + + out, err = capture(shamanTool.Execute) + if err != nil { + t.Errorf("Failed to execute - %v", err.Error()) + } + + if string(out) != "[{\"domain\":\"nanobox.io.\",\"records\":[{\"ttl\":60,\"class\":\"IN\",\"type\":\"A\",\"address\":\"127.0.0.1\"}]}]\n" { + t.Errorf("Unexpected output: %+q", string(out)) + } +} + +func TestResetRecords(t *testing.T) { + commands.ResetVars() + + args := strings.Split("reset -j [{\"domain\":\"nanopack.io\"}]", " ") + shamanTool.SetArgs(args) + + out, err := capture(shamanTool.Execute) + if err != nil { + t.Errorf("Failed to execute - %v", err.Error()) + } + + if string(out) != "[{\"domain\":\"nanopack.io.\",\"records\":null}]\n" { + t.Errorf("Unexpected output: %+q", string(out)) + } + + args = strings.Split("list", " ") + shamanTool.SetArgs(args) + + out, err = capture(shamanTool.Execute) + if err != nil { + t.Errorf("Failed to execute - %v", err.Error()) + } + + if string(out) != "[\"nanopack.io\"]\n" { + t.Errorf("Unexpected output: %+q", string(out)) + } +} + +func TestUpdateRecord(t *testing.T) { + commands.ResetVars() + + args := strings.Split("update -d nanopack.io -A 127.0.0.5", " ") + shamanTool.SetArgs(args) + + out, err := capture(shamanTool.Execute) + if err != nil { + t.Errorf("Failed to execute - %v", err.Error()) + } + + if string(out) != "{\"domain\":\"nanopack.io.\",\"records\":[{\"ttl\":60,\"class\":\"IN\",\"type\":\"A\",\"address\":\"127.0.0.5\"}]}\n" { + t.Errorf("Unexpected output: %+q", string(out)) + } + + args = strings.Split("list", " ") + shamanTool.SetArgs(args) + + out, err = capture(shamanTool.Execute) + if err != nil { + t.Errorf("Failed to execute - %v", err.Error()) + } + + if string(out) != "[\"nanopack.io\"]\n" { + t.Errorf("Unexpected output: %+q", string(out)) + } +} + +func TestGetRecord(t *testing.T) { + commands.ResetVars() + + args := strings.Split("get -d nanopack.io", " ") + shamanTool.SetArgs(args) + + out, err := capture(shamanTool.Execute) + if err != nil { + t.Errorf("Failed to execute - %v", err.Error()) + } + + if string(out) != "{\"domain\":\"nanopack.io.\",\"records\":[{\"ttl\":60,\"class\":\"IN\",\"type\":\"A\",\"address\":\"127.0.0.5\"}]}\n" { + t.Errorf("Unexpected output: %+q", string(out)) + } +} + +func TestDeleteRecord(t *testing.T) { + commands.ResetVars() + + args := strings.Split("delete -d nanopack.io", " ") + shamanTool.SetArgs(args) + + out, err := capture(shamanTool.Execute) + if err != nil { + t.Errorf("Failed to execute - %v", err.Error()) + } + + if string(out) != "{\"msg\":\"success\"}\n" { + t.Errorf("Unexpected output: %+q", string(out)) + } +} + +/////////////////////////////////////////////////// +// PRIVS +/////////////////////////////////////////////////// + +// function to capture output of cli +func capture(fn execable) ([]byte, error) { + oldStdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + err := fn() + os.Stdout = oldStdout + w.Close() // do not defer after os.Pipe() + if err != nil { + return nil, err + } + + return ioutil.ReadAll(r) +} + +// manually configure and start internals +func initialize() { + config.Insecure = true + config.L2Connect = "none://" + config.ApiListen = "127.0.0.1:1634" + config.Log = lumber.NewConsoleLogger(lumber.LvlInt("FATAL")) + config.LogLevel = "FATAL" +} diff --git a/commands/delete.go b/commands/delete.go new file mode 100644 index 0000000..463438d --- /dev/null +++ b/commands/delete.go @@ -0,0 +1,38 @@ +package commands + +import ( + "fmt" + "io/ioutil" + + "github.com/spf13/cobra" +) + +var ( + // DelDomain removes a domain from shaman + DelDomain = &cobra.Command{ + Use: "delete", + Short: "Remove a domain from shaman", + Long: ``, + + Run: delRecord, + } +) + +func delRecord(ccmd *cobra.Command, args []string) { + if resource.Domain == "" { + fail("Domain must be specified. Try adding `-d`.") + } + + res, err := rest("DELETE", fmt.Sprintf("/records/%v", resource.Domain), nil) + if err != nil { + fail("Could not contact shaman - %v", err) + } + + // parse response + b, err := ioutil.ReadAll(res.Body) + if err != nil { + fail("Could not read shaman's response - %v", err) + } + + fmt.Print(string(b)) +} diff --git a/commands/get.go b/commands/get.go new file mode 100644 index 0000000..fd0997d --- /dev/null +++ b/commands/get.go @@ -0,0 +1,37 @@ +package commands + +import ( + "fmt" + "io/ioutil" + + "github.com/spf13/cobra" +) + +var ( + // GetDomain gets records for a domain + GetDomain = &cobra.Command{ + Use: "get", + Short: "Get records for a domain", + Long: ``, + + Run: getResource, + } +) + +func getResource(ccmd *cobra.Command, args []string) { + if resource.Domain == "" { + fail("Domain must be specified. Try adding `-d`.") + } + + res, err := rest("GET", fmt.Sprintf("/records/%v", resource.Domain), nil) + if err != nil { + fail("Could not contact shaman - %v", err) + } + + b, err := ioutil.ReadAll(res.Body) + if err != nil { + fail("Could not read shaman's response - %v", err) + } + + fmt.Print(string(b)) +} diff --git a/commands/list.go b/commands/list.go index 593a46b..f000b41 100644 --- a/commands/list.go +++ b/commands/list.go @@ -1,54 +1,38 @@ package commands import ( - "crypto/tls" "fmt" "io/ioutil" - "net/http" - "os" "github.com/spf13/cobra" - - "github.com/nanopack/shaman/config" ) -var listCmd = &cobra.Command{ - Use: "list", - Short: "List entries in shaman database", - Long: ``, - - Run: list, -} +var ( + // ListDomains lists all domains in shaman + ListDomains = &cobra.Command{ + Use: "list", + Short: "List all domains in shaman", + Long: ``, -func list(ccmd *cobra.Command, args []string) { - var client *http.Client - if config.Insecure { - tr := &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - } - client = &http.Client{Transport: tr} - } else { - client = http.DefaultClient + Run: listRecords, } +) - uri := fmt.Sprintf("https://%s:%s/records", config.ApiHost, config.ApiPort) - fmt.Println(uri) - req, err := http.NewRequest("GET", uri, nil) - if err != nil { - fmt.Fprintln(os.Stderr, "Error:", err) - os.Exit(1) +func listRecords(ccmd *cobra.Command, args []string) { + var query string + if full { + query = "?full=true" } - req.Header.Add("X-NANOBOX-TOKEN", config.ApiToken) - res, err := client.Do(req) + + res, err := rest("GET", fmt.Sprintf("/records%v", query), nil) if err != nil { - fmt.Fprintln(os.Stderr, "Error:", err) - os.Exit(1) + fail("Could not contact shaman - %v", err) } b, err := ioutil.ReadAll(res.Body) if err != nil { - fmt.Fprintln(os.Stderr, "Error:", err) - os.Exit(1) + fail("Could not read shaman's response - %v", err) } - fmt.Println(string(b)) + + fmt.Print(string(b)) } diff --git a/commands/remove.go b/commands/remove.go deleted file mode 100644 index a3e6599..0000000 --- a/commands/remove.go +++ /dev/null @@ -1,61 +0,0 @@ -package commands - -import ( - "crypto/tls" - "fmt" - "io/ioutil" - "net/http" - "os" - - "github.com/spf13/cobra" - - "github.com/nanopack/shaman/config" -) - -var removeCmd = &cobra.Command{ - Use: "remove", - Short: "Remove entry from shaman database", - Long: ``, - - Run: remove, -} - -func remove(ccmd *cobra.Command, args []string) { - if len(args) != 2 { - fmt.Fprintln(os.Stderr, "Missing arguments: Needs record type, domain") - os.Exit(1) - } - var client *http.Client - if config.Insecure { - tr := &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - } - client = &http.Client{Transport: tr} - } else { - client = http.DefaultClient - } - rtype := args[0] - domain := args[1] - fmt.Println("rtype:", rtype, "domain:", domain) - - uri := fmt.Sprintf("https://%s:%s/records/%s/%s", config.ApiHost, config.ApiPort, rtype, domain) - fmt.Println(uri) - req, err := http.NewRequest("DELETE", uri, nil) - if err != nil { - fmt.Fprintln(os.Stderr, "Error:", err) - os.Exit(1) - } - req.Header.Add("X-NANOBOX-TOKEN", config.ApiToken) - res, err := client.Do(req) - if err != nil { - fmt.Fprintln(os.Stderr, "Error:", err) - os.Exit(1) - } - - b, err := ioutil.ReadAll(res.Body) - if err != nil { - fmt.Fprintln(os.Stderr, "Error:", err) - os.Exit(1) - } - fmt.Println(string(b)) -} diff --git a/commands/reset.go b/commands/reset.go new file mode 100644 index 0000000..15b2144 --- /dev/null +++ b/commands/reset.go @@ -0,0 +1,55 @@ +package commands + +import ( + "bytes" + "encoding/json" + "fmt" + "io/ioutil" + + "github.com/spf13/cobra" + + shaman "github.com/nanopack/shaman/core/common" +) + +var ( + // ResetDomains resets all domains in shaman + ResetDomains = &cobra.Command{ + Use: "reset", + Short: "Reset all domains in shaman", + Long: ``, + + Run: resetRecords, + } +) + +func resetRecords(ccmd *cobra.Command, args []string) { + if jsonString == "" { + fail("Must pass json string. Try adding `-j`.") + } + + resources := make([]shaman.Resource, 0) + + err := json.Unmarshal([]byte(jsonString), &resources) + if err != nil { + fail("Bad JSON syntax") + } + + // validate valid values + jsonBytes, err := json.Marshal(resources) + if err != nil { + fail("Bad values for resource") + } + + res, err := rest("PUT", "/records", bytes.NewBuffer(jsonBytes)) + if err != nil { + fail("Could not contact shaman - %v", err) + } + + // parse response + b, err := ioutil.ReadAll(res.Body) + if err != nil { + fail("Could not read shaman's response - %v", err) + } + + fmt.Print(string(b)) +} diff --git a/commands/show.go b/commands/show.go deleted file mode 100644 index 00cf35e..0000000 --- a/commands/show.go +++ /dev/null @@ -1,61 +0,0 @@ -package commands - -import ( - "crypto/tls" - "fmt" - "io/ioutil" - "net/http" - "os" - - "github.com/spf13/cobra" - - "github.com/nanopack/shaman/config" -) - -var showCmd = &cobra.Command{ - Use: "show", - Short: "Show entry in shaman database", - Long: ``, - - Run: show, -} - -func show(ccmd *cobra.Command, args []string) { - if len(args) != 2 { - fmt.Fprintln(os.Stderr, "Missing arguments: Needs record type and domain") - os.Exit(1) - } - var client *http.Client - if config.Insecure { - tr := &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - } - client = &http.Client{Transport: tr} - } else { - client = http.DefaultClient - } - rtype := args[0] - domain := args[1] - fmt.Println("rtype:", rtype, "domain:", domain) - - uri := fmt.Sprintf("https://%s:%s/records/%s/%s", config.ApiHost, config.ApiPort, rtype, domain) - fmt.Println(uri) - req, err := http.NewRequest("GET", uri, nil) - if err != nil { - fmt.Fprintln(os.Stderr, "Error:", err) - os.Exit(1) - } - req.Header.Add("X-NANOBOX-TOKEN", config.ApiToken) - res, err := client.Do(req) - if err != nil { - fmt.Fprintln(os.Stderr, "Error:", err) - os.Exit(1) - } - - b, err := ioutil.ReadAll(res.Body) - if err != nil { - fmt.Fprintln(os.Stderr, "Error:", err) - os.Exit(1) - } - fmt.Println(string(b)) -} diff --git a/commands/update.go b/commands/update.go index 355a518..fe758b9 100644 --- a/commands/update.go +++ b/commands/update.go @@ -2,69 +2,54 @@ package commands import ( "bytes" - "crypto/tls" + "encoding/json" "fmt" "io/ioutil" - "net/http" - "net/url" - "os" "github.com/spf13/cobra" - - "github.com/nanopack/shaman/config" ) -var updateCmd = &cobra.Command{ - Use: "update", - Short: "Update entry in shaman database", - Long: ``, - - Run: update, -} - -type updateBody struct { - value string -} +var ( + // UpdateDomain updates records for a domain + UpdateDomain = &cobra.Command{ + Use: "update", + Short: "Update records for a domain", + Long: ``, -func update(ccmd *cobra.Command, args []string) { - if len(args) != 3 { - fmt.Fprintln(os.Stderr, "Missing arguments: Needs record type, domain, and value") - os.Exit(1) + Run: updateRecord, } - var client *http.Client - if config.Insecure { - tr := &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, +) + +func updateRecord(ccmd *cobra.Command, args []string) { + if jsonString != "" { + err := json.Unmarshal([]byte(jsonString), &resource) + if err != nil { + fail("Bad JSON syntax") } - client = &http.Client{Transport: tr} - } else { - client = http.DefaultClient } - rtype := args[0] - domain := args[1] - value := args[2] - fmt.Println("rtype:", rtype, "domain:", domain, "value:", value) - data := url.Values{} - data.Set("value", value) - uri := fmt.Sprintf("https://%s:%s/records/%s/%s?%s", config.ApiHost, config.ApiPort, rtype, domain, data.Encode()) - fmt.Println(uri) - req, err := http.NewRequest("PUT", uri, bytes.NewBufferString(data.Encode())) + if resource.Domain == "" { + fail("Domain must be specified. Try adding `-d`.") + } + + resource.Records = append(resource.Records, record) + + // validate valid values + jsonBytes, err := json.Marshal(resource) if err != nil { - fmt.Fprintln(os.Stderr, "Error:", err) - os.Exit(1) + fail("Bad values for resource") } - req.Header.Add("X-NANOBOX-TOKEN", config.ApiToken) - res, err := client.Do(req) + + res, err := rest("PUT", fmt.Sprintf("/records/%v", resource.Domain), bytes.NewBuffer(jsonBytes)) if err != nil { - fmt.Fprintln(os.Stderr, "Error:", err) - os.Exit(1) + fail("Could not contact shaman - %v", err) } + // parse response b, err := ioutil.ReadAll(res.Body) if err != nil { - fmt.Fprintln(os.Stderr, "Error:", err) - os.Exit(1) + fail("Could not read shaman's response - %v", err) } - fmt.Println(string(b)) + + fmt.Print(string(b)) } diff --git a/config/config.go b/config/config.go index ead5718..8499b28 100644 --- a/config/config.go +++ b/config/config.go @@ -1,31 +1,102 @@ +// Package "config" is a central location for configuration options. It also contains +// config file parsing logic. package config -// TODO: -// - read command line arguments for options -// - read config file for options -// - test - import ( + "fmt" + "path/filepath" + "github.com/jcelliott/lumber" + "github.com/spf13/cobra" + "github.com/spf13/viper" ) var ( - Insecure bool - L1Connect string - L2Connect string - L1Expires int - L2Expires int - Domain string - TTL int - Host string - Port string - ApiKey string - ApiKeyPassword string - ApiCrt string - ApiToken string - ApiHost string - ApiPort string - LogLevel string - LogFile string - Log lumber.Logger + ApiCrt = "" // Path to SSL crt for API access + ApiKey = "" // Path to SSL key for API access + ApiKeyPassword = "" // Password for SSL key + ApiListen = "127.0.0.1:1632" // Listen address for the API (ip:port) + ApiToken = "secret" // Token for API Access + Insecure = false // Disable tls key checking (client) and listen on http (server) + L2Connect = "scribble:///var/db/shaman" // Connection string for the l2 cache + TTL int = 60 // Default TTL for DNS records + Domain = "." // Parent domain for requests + DnsListen = "127.0.0.1:53" // Listen address for DNS requests (ip:port) + + LogLevel = "INFO" // Log level to output [fatal|error|info|debug|trace] + Server = false // Run in server mode + ConfigFile = "" // Configuration file to load + Version = false // Print version info and exit + + Log lumber.Logger // Central logger for shaman ) + +// AddFlags adds the available cli flags +func AddFlags(cmd *cobra.Command) { + // api + cmd.Flags().StringVarP(&ApiCrt, "api-crt", "C", ApiCrt, "Path to SSL crt for API access") + cmd.Flags().StringVarP(&ApiKey, "api-key", "k", ApiKey, "Path to SSL key for API access") + cmd.Flags().StringVarP(&ApiKeyPassword, "api-key-password", "p", ApiKeyPassword, "Password for SSL key") + cmd.PersistentFlags().StringVarP(&ApiListen, "api-listen", "H", ApiListen, "Listen address for the API (ip:port)") + cmd.PersistentFlags().StringVarP(&ApiToken, "token", "t", ApiToken, "Token for API Access") + cmd.PersistentFlags().BoolVarP(&Insecure, "insecure", "i", Insecure, "Disable tls key checking (client) and listen on http (api)") + + // dns + cmd.Flags().StringVarP(&L2Connect, "l2-connect", "2", L2Connect, "Connection string for the l2 cache") + cmd.Flags().IntVarP(&TTL, "ttl", "T", TTL, "Default TTL for DNS records") + cmd.Flags().StringVarP(&Domain, "domain", "d", Domain, "Parent domain for requests") + cmd.Flags().StringVarP(&DnsListen, "dns-listen", "O", DnsListen, "Listen address for DNS requests (ip:port)") + + // core + cmd.Flags().StringVarP(&LogLevel, "log-level", "l", LogLevel, "Log level to output [fatal|error|info|debug|trace]") + cmd.Flags().BoolVarP(&Server, "server", "s", Server, "Run in server mode") + cmd.PersistentFlags().StringVarP(&ConfigFile, "config-file", "c", ConfigFile, "Configuration file to load") + + cmd.Flags().BoolVarP(&Version, "version", "v", Version, "Print version info and exit") +} + +// LoadConfigFile reads the specified config file +func LoadConfigFile() error { + if ConfigFile == "" { + return nil + } + + // Set defaults to whatever might be there already + viper.SetDefault("api-crt", ApiCrt) + viper.SetDefault("api-key", ApiKey) + viper.SetDefault("api-key-password", ApiKeyPassword) + viper.SetDefault("api-listen", ApiListen) + viper.SetDefault("token", ApiToken) + viper.SetDefault("insecure", Insecure) + viper.SetDefault("l2-connect", L2Connect) + viper.SetDefault("ttl", TTL) + viper.SetDefault("domain", Domain) + viper.SetDefault("dns-listen", DnsListen) + viper.SetDefault("log-level", LogLevel) + viper.SetDefault("server", Server) + + filename := filepath.Base(ConfigFile) + viper.SetConfigName(filename[:len(filename)-len(filepath.Ext(filename))]) + viper.AddConfigPath(filepath.Dir(ConfigFile)) + + err := viper.ReadInConfig() + if err != nil { + return fmt.Errorf("Failed to read config file - %v", err) + } + + // Set values. Config file will override commandline + ApiCrt = viper.GetString("api-crt") + ApiKey = viper.GetString("api-key") + ApiKeyPassword = viper.GetString("api-key-password") + ApiListen = viper.GetString("api-listen") + ApiToken = viper.GetString("token") + Insecure = viper.GetBool("insecure") + L2Connect = viper.GetString("l2-connect") + TTL = viper.GetInt("ttl") + Domain = viper.GetString("domain") + DnsListen = viper.GetString("dns-listen") + LogLevel = viper.GetString("log-level") + Server = viper.GetBool("server") + + return nil +} diff --git a/core/common/common.go b/core/common/common.go new file mode 100644 index 0000000..be4634f --- /dev/null +++ b/core/common/common.go @@ -0,0 +1,66 @@ +// Package "common" contains common structs used in shaman +package common + +import ( + "fmt" + + "github.com/nanopack/shaman/config" +) + +// Resource contains the domain name and a slice of its records +type Resource struct { + Domain string `json:"domain"` // google.com + Records []Record `json:"records"` // dns records +} + +// Record contains dns information +type Record struct { + TTL int `json:"ttl"` // seconds record may be cached (300) + Class string `json:"class"` // protocol family (IN) + RType string `json:"type"` // dns record type (A) + Address string `json:"address"` // address domain resolves to (216.58.217.46) +} + +// StringSlice returns a slice of strings with dns info, each ready for dns.NewRR +func (self Resource) StringSlice() []string { + var records []string + for i := range self.Records { + records = append(records, fmt.Sprintf("%s %d %s %s %s\n", self.Domain, + self.Records[i].TTL, self.Records[i].Class, + self.Records[i].RType, self.Records[i].Address)) + } + return records +} + +// SanitizeDomain ensures the domain ends with a `.` +func SanitizeDomain(domain *string) { + t := []byte(*domain) + if t[len(t)-1] != '.' { + *domain = string(append(t, '.')) + } +} + +// UnsanitizeDomain ensures the domain ends with a `.` +func UnsanitizeDomain(domain *string) { + t := []byte(*domain) + if t[len(t)-1] == '.' { + *domain = string(t[:len(t)-1]) + } +} + +// Validate ensures record values are set +func (self *Resource) Validate() { + SanitizeDomain(&self.Domain) + + for i := range self.Records { + if self.Records[i].Class == "" { + self.Records[i].Class = "IN" + } + if self.Records[i].TTL == 0 { + self.Records[i].TTL = config.TTL + } + if self.Records[i].RType == "" { + self.Records[i].RType = "A" + } + } +} diff --git a/core/shaman.go b/core/shaman.go new file mode 100644 index 0000000..2f63c22 --- /dev/null +++ b/core/shaman.go @@ -0,0 +1,192 @@ +// Package "shaman" contains the logic to add/remove DNS entries. +package shaman + +// todo: atomic C.U.D. + +import ( + "fmt" + + "github.com/nanopack/shaman/cache" + "github.com/nanopack/shaman/config" + sham "github.com/nanopack/shaman/core/common" +) + +var Answers map[string]sham.Resource + +func init() { + Answers = make(map[string]sham.Resource, 0) +} + +// GetRecord returns a resource for the specified domain +func GetRecord(domain string) (sham.Resource, error) { + sham.SanitizeDomain(&domain) + + resource, ok := Answers[domain] + // if domain not cached in memory... + if !ok { + // fetch from cache + record, err := cache.GetRecord(domain) + if record == nil { + return resource, fmt.Errorf("Failed to find domain - %v", err) + } + // update local cache + config.Log.Debug("Cache differs from local, updating...") + Answers[domain] = *record + } + + return Answers[domain], nil +} + +// ListDomains returns a list of all known domains +func ListDomains() []string { + domains := make([]string, 0) + + for _, record := range ListRecords() { + sham.UnsanitizeDomain(&record.Domain) + domains = append(domains, record.Domain) + } + + return domains +} + +// ListRecords returns all known domains +func ListRecords() []sham.Resource { + if cache.Exists() { + // get from cache + stored, _ := cache.ListRecords() + if len(Answers) != len(stored) { + config.Log.Debug("Cache differs from local, updating...") + ResetRecords(&stored, true) + } + } + + resources := make([]sham.Resource, 0) + for _, v := range Answers { + resources = append(resources, v) + } + + return resources +} + +// DeleteRecord deletes the resource(domain) +func DeleteRecord(domain string) error { + sham.SanitizeDomain(&domain) + + // update cache + config.Log.Trace("Removing record from persistent cache...") + err := cache.DeleteRecord(domain) + if err != nil { + return err + } + + // todo: atomic + delete(Answers, domain) + + // otherwise, be idempotent and report it was deleted... + return nil +} + +// AddRecord adds a record to a resource(domain) +func AddRecord(resource *sham.Resource) error { + resource.Validate() + domain := resource.Domain + + // todo: atomic + _, ok := Answers[domain] + if ok { + config.Log.Trace("Domain is in local cache") + // if we have the domain registered... + for k := range Answers[domain].Records { + for j := range resource.Records { + // check if the record exists... + if resource.Records[j].RType == Answers[domain].Records[k].RType && + resource.Records[j].Address == Answers[domain].Records[k].Address && + resource.Records[j].Class == Answers[domain].Records[k].Class { + // if so, skip... + config.Log.Trace("Record exists in local cache, skipping") + goto next + } + } + // otherwise, add the record + config.Log.Trace("Record not in local cache, adding") + resource.Records = append(resource.Records, Answers[domain].Records[k]) + next: + } + } + + // store in cache + config.Log.Trace("Saving record to persistent cache...") + err := cache.AddRecord(resource) + if err != nil { + return err + } + + // add the resource to the list of knowns + Answers[domain] = *resource + + return nil +} + +// returns whether or not that domain exists +func Exists(domain string) bool { + sham.SanitizeDomain(&domain) + _, ok := Answers[domain] + return ok +} + +// UpdateRecord updates a record to a resource(domain) +func UpdateRecord(domain string, resource *sham.Resource) error { + resource.Validate() + sham.SanitizeDomain(&domain) + + // in case of some update to domain name... + if domain != resource.Domain { + // delete old domain + err := DeleteRecord(domain) + if err != nil { + return fmt.Errorf("Failed to clean up old domain - %v", err) + } + } + + // store in cache + config.Log.Trace("Updating record in persistent cache...") + err := cache.UpdateRecord(domain, resource) + if err != nil { + return err + } + + // set new resource to domain + // todo: atomic + Answers[resource.Domain] = *resource + + return nil +} + +// ResetRecords resets all answers. If any nocache has any values, caching is skipped +func ResetRecords(resources *[]sham.Resource, nocache ...bool) error { + for i := range *resources { + (*resources)[i].Validate() + } + + // new map to clear current answers + answers := make(map[string]sham.Resource) + + for i := range *resources { + answers[(*resources)[i].Domain] = (*resources)[i] + } + + if len(nocache) == 0 { + // store in cache + config.Log.Trace("Resetting records in persistent cache...") + err := cache.ResetRecords(resources) + if err != nil { + return err + } + } + + // reset the answers + // todo: atomic + Answers = answers + + return nil +} diff --git a/core/shaman_test.go b/core/shaman_test.go new file mode 100644 index 0000000..b1e5c17 --- /dev/null +++ b/core/shaman_test.go @@ -0,0 +1,121 @@ +package shaman_test + +import ( + "fmt" + "os" + "testing" + + "github.com/jcelliott/lumber" + + "github.com/nanopack/shaman/config" + "github.com/nanopack/shaman/core" + sham "github.com/nanopack/shaman/core/common" +) + +var ( + nanopack = sham.Resource{Domain: "nanopack.io.", Records: []sham.Record{sham.Record{Address: "127.0.0.1"}}} + nanopack2 = sham.Resource{Domain: "nanopack.io.", Records: []sham.Record{sham.Record{Address: "127.0.0.3"}}} + nanobox = sham.Resource{Domain: "nanobox.io.", Records: []sham.Record{sham.Record{Address: "127.0.0.2"}}} + nanoBoth = []sham.Resource{nanopack, nanobox} +) + +func TestMain(m *testing.M) { + shamanClear() + // manually configure + config.Log = lumber.NewConsoleLogger(lumber.LvlInt("FATAL")) + + // run tests + rtn := m.Run() + + os.Exit(rtn) +} + +func TestAddRecord(t *testing.T) { + shamanClear() + err := shaman.AddRecord(&nanopack) + err = shaman.AddRecord(&nanopack) + err2 := shaman.AddRecord(&nanopack2) + if err != nil || err2 != nil { + t.Errorf("Failed to add record - %v%v", err, err2) + } +} + +func TestGetRecord(t *testing.T) { + shamanClear() + _, err := shaman.GetRecord("nanopack.io") + shaman.AddRecord(&nanopack) + _, err2 := shaman.GetRecord("nanopack.io") + if err == nil || err2 != nil { + // t.Errorf("Failed to get record - %v%v", err, "hi") + t.Errorf("Failed to get record - %v%v", err, err2) + } +} + +func TestUpdateRecord(t *testing.T) { + shamanClear() + err := shaman.UpdateRecord("nanopack.io", &nanopack) + err2 := shaman.UpdateRecord("nanobox.io", &nanopack) + if err != nil || err2 != nil { + t.Errorf("Failed to update record - %v%v", err, err2) + } +} + +func TestDeleteRecord(t *testing.T) { + shamanClear() + err := shaman.DeleteRecord("nanobox.io") + shaman.AddRecord(&nanopack) + err2 := shaman.DeleteRecord("nanopack.io") + if err != nil || err2 != nil { + t.Errorf("Failed to delete record - %v%v", err, err2) + } +} + +func TestResetRecords(t *testing.T) { + shamanClear() + err := shaman.ResetRecords(&nanoBoth) + err2 := shaman.ResetRecords(&nanoBoth, true) + if err != nil || err2 != nil { + t.Errorf("Failed to reset records - %v%v", err, err2) + } +} + +func TestListDomains(t *testing.T) { + shamanClear() + domains := shaman.ListDomains() + if fmt.Sprint(domains) != "[]" { + t.Errorf("Failed to list domains - %+q", domains) + } + shaman.ResetRecords(&nanoBoth) + domains = shaman.ListDomains() + if len(domains) != 2 { + t.Errorf("Failed to list domains - %+q", domains) + } +} + +func TestListRecords(t *testing.T) { + shamanClear() + resources := shaman.ListRecords() + if fmt.Sprint(resources) != "[]" { + t.Errorf("Failed to list records - %+q", resources) + } + shaman.ResetRecords(&nanoBoth) + resources = shaman.ListRecords() + if len(resources) == 2 && (resources[0].Domain != "nanopack.io." && resources[0].Domain != "nanobox.io.") { + t.Errorf("Failed to list records - %+q", resources) + } +} + +func TestExists(t *testing.T) { + shamanClear() + if shaman.Exists("nanopack.io") { + t.Errorf("Failed to list records") + } + shaman.AddRecord(&nanopack) + if !shaman.Exists("nanopack.io") { + t.Errorf("Failed to list records") + } +} + +func shamanClear() { + shaman.Answers = make(map[string]sham.Resource, 0) +} diff --git a/main.go b/main.go index e03c975..140d340 100644 --- a/main.go +++ b/main.go @@ -1,18 +1,126 @@ +// Shaman is a small, clusterable, lightweight, api-driven dns server. +// +// Usage +// +// To start shaman as a server, simply run (with administrator privileges): +// +// shaman -s +// +// For more specific usage information, refer to the help doc `shaman -h`: +// Usage: +// shaman [flags] +// shaman [command] +// +// Available Commands: +// add Add a domain to shaman +// delete Remove a domain from shaman +// list List all domains in shaman +// get Get records for a domain +// update Update records for a domain +// reset Reset all domains in shaman +// +// Flags: +// -C, --api-crt string Path to SSL crt for API access +// -k, --api-key string Path to SSL key for API access +// -p, --api-key-password string Password for SSL key +// -H, --api-listen string Listen address for the API (ip:port) (default "127.0.0.1:1632") +// -c, --config-file string Configuration file to load +// -O, --dns-listen string Listen address for DNS requests (ip:port) (default "127.0.0.1:53") +// -d, --domain string Parent domain for requests (default ".") +// -i, --insecure Disable tls key checking (client) and listen on http (api) +// -2, --l2-connect string Connection string for the l2 cache (default "scribble:///var/db/shaman") +// -l, --log-level string Log level to output [fatal|error|info|debug|trace] (default "INFO") +// -s, --server Run in server mode +// -t, --token string Token for API Access (default "secret") +// -T, --ttl int Default TTL for DNS records (default 60) +// -v, --version Print version info and exit +// package main -// Main entry point into the shaman program. This starts up the API, caching, -// and DNS servers in their own routines. +import ( + "fmt" + "os" -// TODO: -// - handle signals -// - add logging -// - test + "github.com/jcelliott/lumber" + "github.com/spf13/cobra" -import ( + "github.com/nanopack/shaman/api" + "github.com/nanopack/shaman/cache" "github.com/nanopack/shaman/commands" + "github.com/nanopack/shaman/config" + "github.com/nanopack/shaman/server" ) -// main entry point +var ( + // shaman provides the shaman cli/server functionality + shamanTool = &cobra.Command{ + Use: "shaman", + Short: "shaman - api driven dns server", + Long: ``, + PersistentPreRun: readConfig, + PreRun: preFlight, + Run: startShaman, + } +) + +// add supported cli commands/flags +func init() { + shamanTool.AddCommand(commands.AddDomain) + shamanTool.AddCommand(commands.DelDomain) + shamanTool.AddCommand(commands.ListDomains) + shamanTool.AddCommand(commands.GetDomain) + shamanTool.AddCommand(commands.UpdateDomain) + shamanTool.AddCommand(commands.ResetDomains) + + config.AddFlags(shamanTool) +} + func main() { - commands.Shaman.Execute() + shamanTool.Execute() +} + +func readConfig(ccmd *cobra.Command, args []string) { + if err := config.LoadConfigFile(); err != nil { + fmt.Println(err) + os.Exit(1) + } +} + +func preFlight(ccmd *cobra.Command, args []string) { + if config.Version { + fmt.Printf("shaman %s\n", VERSION) + os.Exit(0) + } + + if !config.Server { + ccmd.HelpFunc()(ccmd, args) + os.Exit(0) + } +} + +func startShaman(ccmd *cobra.Command, args []string) { + config.Log = lumber.NewConsoleLogger(lumber.LvlInt(config.LogLevel)) + + // initialize cache + err := cache.Initialize() + if err != nil { + config.Log.Fatal(err.Error()) + os.Exit(1) + } + + // make channel for errors + errors := make(chan error) + + go func() { + errors <- api.Start() + }() + go func() { + errors <- server.Start() + }() + + // break if any of them return an error (blocks exit) + if err := <-errors; err != nil { + config.Log.Fatal(err.Error()) + os.Exit(1) + } } diff --git a/server/dns.go b/server/dns.go new file mode 100644 index 0000000..b0019f0 --- /dev/null +++ b/server/dns.go @@ -0,0 +1,63 @@ +// Package "server" contains logic to handle DNS requests. +package server + +import ( + "fmt" + + "github.com/miekg/dns" + + "github.com/nanopack/shaman/config" + "github.com/nanopack/shaman/core" +) + +// Start starts the DNS listener +func Start() error { + dns.HandleFunc(".", handlerFunc) + udpListener := &dns.Server{Addr: config.DnsListen, Net: "udp"} + return fmt.Errorf("DNS listener stopped - %v", udpListener.ListenAndServe()) +} + +// handlerFunc receives requests, looks up the result and returns what is found. +func handlerFunc(res dns.ResponseWriter, req *dns.Msg) { + message := new(dns.Msg) + switch req.Opcode { + case dns.OpcodeQuery: + message.SetReply(req) + message.Compress = false + message.Answer = make([]dns.RR, 0) + + for _, question := range message.Question { + answers := answerQuestion(question) + for i := range answers { + message.Answer = append(message.Answer, answers[i]) + } + } + if len(message.Answer) == 0 { + message.Rcode = dns.RcodeNameError + } + default: + message = message.SetRcode(req, dns.RcodeNotImplemented) + } + res.WriteMsg(message) +} + +// answerQuestion returns resource record answers for the domain in question +func answerQuestion(question dns.Question) []dns.RR { + answers := make([]dns.RR, 0) + r, _ := shaman.GetRecord(question.Name) + records := r.StringSlice() + // fmt.Printf("Records received - %+q\n", records) + for _, record := range records { + entry, err := dns.NewRR(record) + if err != nil { + config.Log.Trace("Failed to create RR from record - %v", err) + continue + } + entry.Header().Name = question.Name + if entry.Header().Rrtype == question.Qtype || question.Qtype == dns.TypeANY { + answers = append(answers, entry) + } + } + + return answers +} diff --git a/server/server.go b/server/server.go deleted file mode 100644 index 492984f..0000000 --- a/server/server.go +++ /dev/null @@ -1,98 +0,0 @@ -package server - -// This has the handler for the DNS server. - -// TODO: -// - add logging -// - test - -import ( - "fmt" - "strings" - - "github.com/miekg/dns" - - "github.com/nanopack/shaman/caches" - "github.com/nanopack/shaman/config" -) - -func stripSubdomain(name string) string { - names := strings.SplitN(name, ".", 2) - if len(names) == 2 { - return names[1] - } else { - return "" - } -} - -func answerQuestion(question dns.Question) []dns.RR { - answers := make([]dns.RR, 0) - name := question.Name - for { - findReturn := make(chan caches.FindReturn) - var findOp caches.FindOp - var key string - if name != question.Name { - key = caches.Key("*."+name, question.Qtype) - } else { - key = caches.Key(name, question.Qtype) - } - findOp = caches.FindOp{Key: key, Resp: findReturn} - caches.FindOps <- findOp - findRet := <-findReturn - err := findRet.Err - record := findRet.Value - if err != nil { - config.Log.Error("error: %s", err) - continue - } - if record != "" { - entry, err := dns.NewRR(record) - if err != nil { - config.Log.Error("error: %s", err) - continue - } - entry.Header().Name = question.Name - answers = append(answers, entry) - } - if len(answers) > 0 { - break - } - name = stripSubdomain(name) - if len(name) == 0 { - break - } - } - return answers -} - -// This receives requests, looks up the result and returns what is found. -func handlerFunc(res dns.ResponseWriter, req *dns.Msg) { - message := new(dns.Msg) - switch req.Opcode { - case dns.OpcodeQuery: - message.SetReply(req) - message.Compress = false - message.Answer = make([]dns.RR, 0) - - for _, question := range message.Question { - answers := answerQuestion(question) - for i := range answers { - message.Answer = append(message.Answer, answers[i]) - } - } - if len(message.Answer) == 0 { - message.Rcode = dns.RcodeNameError - } - default: - message = message.SetRcode(req, dns.RcodeNotImplemented) - } - res.WriteMsg(message) -} - -// This starts the DNS listener -func StartServer() error { - dns.HandleFunc(config.Domain, handlerFunc) - udpListener := &dns.Server{Addr: fmt.Sprintf("%s:%s", config.Host, config.Port), Net: "udp"} - return udpListener.ListenAndServe() -} diff --git a/version.go b/version.go new file mode 100644 index 0000000..b8776c5 --- /dev/null +++ b/version.go @@ -0,0 +1,3 @@ +package main + +const VERSION = "0.0.2" From 97d816990708f72c547a447f122ad57d8568b200 Mon Sep 17 00:00:00 2001 From: Greg Linton Date: Thu, 12 May 2016 11:20:40 -0600 Subject: [PATCH 2/6] Add tests for dns server implementation Previous commit was a squashed commit of the following: commit b7260d959f9c85b2cde8bdfc755b5522df3bebdb Author: Greg Linton Date: Wed May 11 18:01:10 2016 -0600 Add README and travis.yml commit 01804599d720b777e962c20e9c9e317b13599f56 Author: Greg Linton Date: Wed May 11 17:21:02 2016 -0600 Add cli README and fix bug commit 5b98f7ff4f2d4c2bf01c0333606c238cc9bbabfb Author: Greg Linton Date: Wed May 11 16:37:00 2016 -0600 Add cache and core tests commit f45029fde5f4e0ac6ad1c583d3d1056d692275b1 Author: Greg Linton Date: Wed May 11 12:07:29 2016 -0600 Layer caching back in commit 219be4ce476d0eefe1433899d1690cfcadba41e6 Author: Greg Linton Date: Tue May 10 17:11:04 2016 -0600 Always read the config file commit c7bb52cacb5af6eea354815244dae04a19326727 Author: Greg Linton Date: Tue May 10 16:58:42 2016 -0600 Cli finished and tested commit affe30377a2050059e0936428eb517f34ab43fa5 Author: Greg Linton Date: Tue May 10 11:23:08 2016 -0600 Api to finalized and tested state commit cd5cde4453cee7e2649101857aa9be8a916ba2a3 Author: Greg Linton Date: Mon May 9 16:10:23 2016 -0600 Finish api and functions, minor cleaning, test it all manually commit 62a187bbaa7348dbb0b9b489b139ada5c7cefce2 Author: Greg Linton Date: Fri May 6 17:53:54 2016 -0600 Initial port of shaman to new API and simpler dns response --- api/api_test.go | 2 +- cache/scribble_test.go | 6 +-- commands/commands_test.go | 2 +- core/common/common.go | 4 +- server/dns_test.go | 95 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 102 insertions(+), 7 deletions(-) create mode 100644 server/dns_test.go diff --git a/api/api_test.go b/api/api_test.go index 3eb45f3..daa19c7 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -32,7 +32,7 @@ func TestMain(m *testing.M) { // start api go api.Start() - <-time.After(3 * time.Second) + <-time.After(time.Second) rtn := m.Run() os.Exit(rtn) diff --git a/cache/scribble_test.go b/cache/scribble_test.go index b901fb7..1dead1e 100644 --- a/cache/scribble_test.go +++ b/cache/scribble_test.go @@ -17,9 +17,9 @@ func TestScribbleInitialize(t *testing.T) { config.L2Connect = "scribble:///roots/file" // unable to init? (test no sudo) err3 := cache.Initialize() config.L2Connect = "scribble:///" // defaulting to "/var/db" - err4 := cache.Initialize() - if err != nil || err2 == nil || err3 == nil || err4 != nil { - t.Errorf("Failed to initalize scribble cacher - %v%v%v%v", err, err2, err3, err4) + cache.Initialize() + if err != nil || err2 == nil || err3 == nil { + t.Errorf("Failed to initalize scribble cacher - %v%v%v", err, err2, err3) } } diff --git a/commands/commands_test.go b/commands/commands_test.go index cf29594..08a7f8a 100644 --- a/commands/commands_test.go +++ b/commands/commands_test.go @@ -48,7 +48,7 @@ func TestMain(m *testing.M) { // start api go api.Start() - <-time.After(1 * time.Second) + <-time.After(time.Second) rtn := m.Run() os.Exit(rtn) diff --git a/core/common/common.go b/core/common/common.go index be4634f..d914b1a 100644 --- a/core/common/common.go +++ b/core/common/common.go @@ -35,7 +35,7 @@ func (self Resource) StringSlice() []string { // SanitizeDomain ensures the domain ends with a `.` func SanitizeDomain(domain *string) { t := []byte(*domain) - if t[len(t)-1] != '.' { + if len(t) > 0 && t[len(t)-1] != '.' { *domain = string(append(t, '.')) } } @@ -43,7 +43,7 @@ func SanitizeDomain(domain *string) { // UnsanitizeDomain ensures the domain ends with a `.` func UnsanitizeDomain(domain *string) { t := []byte(*domain) - if t[len(t)-1] == '.' { + if len(t) > 0 && t[len(t)-1] == '.' { *domain = string(t[:len(t)-1]) } } diff --git a/server/dns_test.go b/server/dns_test.go new file mode 100644 index 0000000..5a72efb --- /dev/null +++ b/server/dns_test.go @@ -0,0 +1,95 @@ +package server_test + +import ( + "fmt" + "os" + "testing" + "time" + + "github.com/jcelliott/lumber" + "github.com/miekg/dns" + + "github.com/nanopack/shaman/config" + "github.com/nanopack/shaman/core" + "github.com/nanopack/shaman/server" + sham "github.com/nanopack/shaman/core/common" +) + +var nanopack = sham.Resource{Domain: "nanopack.io.", Records: []sham.Record{sham.Record{Address: "127.0.0.1"}}} + +func TestMain(m *testing.M) { + // manually configure + config.DnsListen = "127.0.0.1:8053" + config.Log = lumber.NewConsoleLogger(lumber.LvlInt("FATAL")) + + // start dns server + go server.Start() + <-time.After(time.Second) + + // run tests + rtn := m.Run() + + os.Exit(rtn) +} + +func TestDNS(t *testing.T) { + err := shaman.AddRecord(&nanopack) + if err != nil { + t.Errorf("Failed to add record - %v", err) + t.FailNow() + } + + r, err := ResolveIt("nanopack.io", dns.TypeA) + if err != nil { + t.Errorf("Failed to get record - %v", err) + } + if len(r.Answer) == 0 { + t.Error("No record found") + } + if len(r.Answer) > 0 && r.Answer[0].String() != "nanopack.io.\t60\tIN\tA\t127.0.0.1" { + t.Errorf("Response doesn't match expected - %+q", r.Answer[0].String()) + } + + r, err = ResolveIt("nanobox.io", dns.TypeA) + if err != nil { + t.Errorf("Failed to get record - %v", err) + } + if len(r.Answer) != 0 { + t.Error("Found non-existant record") + } + + r, err = ResolveIt("nanopack.io", dns.TypeMX, true) + if err != nil { + t.Errorf("Failed to get record - %v", err) + } + if len(r.Answer) != 0 { + t.Error("Found non-existant record") + } +} + + +func ResolveIt(domain string, rType uint16, badop ...bool) (*dns.Msg, error) { + // root domain if not already + root(&domain) + m := new(dns.Msg) + m.SetQuestion(domain, rType) + + if len(badop) > 0 { + m.Opcode = dns.OpcodeStatus + } + + // ask the dns server + r, err := dns.Exchange(m, config.DnsListen) + if err != nil { + return nil, fmt.Errorf("Failed to exchange - %v", err) + } + + return r, nil +} + +func root(domain *string) { + t := []byte(*domain) + if len(t) > 0 && t[len(t)-1] != '.' { + *domain = string(append(t, '.')) + } +} From 2e61202e01a5b244599fd16af14e53366ad0f484 Mon Sep 17 00:00:00 2001 From: Greg Linton Date: Thu, 12 May 2016 11:53:10 -0600 Subject: [PATCH 3/6] Allow server to start insecure --- README.md | 46 ++++++++++++++++++++- api/api.go | 10 +++-- server/dns.go | 1 + server/dns_test.go | 101 ++++++++++++++++++++++----------------------- version.go | 2 +- 5 files changed, 103 insertions(+), 57 deletions(-) diff --git a/README.md b/README.md index 13eb73b..7845558 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,7 @@ Small, clusterable, lightweight, api-driven dns server. + ## Quickstart: ```sh # Start shaman with defaults (requires admin privileges (port 53)) @@ -15,12 +16,15 @@ shaman -s shaman add -d nanopack.io -A 127.0.0.1 # perform dns lookup +# OR `nslookup -port=53 nanopack.io 127.0.0.1` dig @localhost nanopack.io +short # 127.0.0.1 + # Congratulations! ``` + ## Usage: ### As a CLI @@ -88,6 +92,20 @@ An optional config file can also be passed on startup: } ``` +#### L2 connection strings + +##### Scribble Cacher +The connection string looks like `scribble://localhost/path/to/data/store`. + + + + ## API: | Route | Description | Payload | Output | @@ -99,8 +117,28 @@ An optional config file can also be passed on startup: | **GET** /records/{domain} | Returns the records for that domain | nil | json domain object | | **DELETE** /records/{domain} | Delete a domain | nil | success message | +**note:** The API requires a token to be passed for authentication by default and is configurable at server start (`--token`). The token is passed in as a custom header: `X-AUTH-TOKEN`. + For examples, see [the api's readme](api/README.md) + +## Overview + +```sh ++------------+ +----------+ +-----------------+ +| +-----> +-----> | +| API Server | | | | Short-Term | +| <-----+ Caching <-----+ (in-memory) | ++------------+ | And | +-----------------+ + | Database | ++------------+ | Manager | +-----------------+ +| +-----> +-----> | +| DNS Server | | | | Long-Term (L2) | +| <-----+ <-----+ | ++------------+ +----------+ +-----------------+ +``` + + ## Data types: ### Domain (Resource): json: @@ -159,14 +197,18 @@ json: Fields: - **msg**: Success message + ## Todo -- tests for server/dns -- start server insecure - atomic local cache updates - export in hosts file format + ## Changelog - v0.0.2 (May 11, 2016) - Refactor to allow multiple records per domain and more fully utilize dns library +- v0.0.3 (May 12, 2016) + - Tests for DNS server + - Start Server Insecure + [![oss logo](http://nano-assets.gopagoda.io/open-src/nanobox-open-src.png)](http://nanobox.io/open-source) diff --git a/api/api.go b/api/api.go index 36d2ecc..6a20c19 100644 --- a/api/api.go +++ b/api/api.go @@ -32,6 +32,12 @@ var ( // Start starts shaman's http api func Start() error { + // handle config.Insecure + if config.Insecure { + config.Log.Info("Shaman listening at http://%s...", config.ApiListen) + return fmt.Errorf("API stopped - %v", http.ListenAndServe(config.ApiListen, routes())) + } + var cert *tls.Certificate var err error if config.ApiCrt == "" { @@ -45,9 +51,7 @@ func Start() error { auth.Certificate = cert auth.Header = "X-AUTH-TOKEN" - config.Log.Info("Shaman listening on https://%v", config.ApiListen) - - // todo: handle config.Insecure + config.Log.Info("Shaman listening at https://%v", config.ApiListen) return fmt.Errorf("API stopped - %v", auth.ListenAndServeTLS(config.ApiListen, config.ApiToken, routes())) } diff --git a/server/dns.go b/server/dns.go index b0019f0..7d51aac 100644 --- a/server/dns.go +++ b/server/dns.go @@ -14,6 +14,7 @@ import ( func Start() error { dns.HandleFunc(".", handlerFunc) udpListener := &dns.Server{Addr: config.DnsListen, Net: "udp"} + config.Log.Info("DNS listening at udp://%v", config.DnsListen) return fmt.Errorf("DNS listener stopped - %v", udpListener.ListenAndServe()) } diff --git a/server/dns_test.go b/server/dns_test.go index 5a72efb..57b280c 100644 --- a/server/dns_test.go +++ b/server/dns_test.go @@ -11,20 +11,20 @@ import ( "github.com/nanopack/shaman/config" "github.com/nanopack/shaman/core" - "github.com/nanopack/shaman/server" sham "github.com/nanopack/shaman/core/common" + "github.com/nanopack/shaman/server" ) -var nanopack = sham.Resource{Domain: "nanopack.io.", Records: []sham.Record{sham.Record{Address: "127.0.0.1"}}} +var nanopack = sham.Resource{Domain: "nanopack.io.", Records: []sham.Record{sham.Record{Address: "127.0.0.1"}}} func TestMain(m *testing.M) { // manually configure - config.DnsListen = "127.0.0.1:8053" + config.DnsListen = "127.0.0.1:8053" config.Log = lumber.NewConsoleLogger(lumber.LvlInt("FATAL")) - // start dns server - go server.Start() - <-time.After(time.Second) + // start dns server + go server.Start() + <-time.After(time.Second) // run tests rtn := m.Run() @@ -36,59 +36,58 @@ func TestDNS(t *testing.T) { err := shaman.AddRecord(&nanopack) if err != nil { t.Errorf("Failed to add record - %v", err) - t.FailNow() + t.FailNow() } - r, err := ResolveIt("nanopack.io", dns.TypeA) - if err != nil { - t.Errorf("Failed to get record - %v", err) - } - if len(r.Answer) == 0 { - t.Error("No record found") - } - if len(r.Answer) > 0 && r.Answer[0].String() != "nanopack.io.\t60\tIN\tA\t127.0.0.1" { - t.Errorf("Response doesn't match expected - %+q", r.Answer[0].String()) - } - - r, err = ResolveIt("nanobox.io", dns.TypeA) - if err != nil { - t.Errorf("Failed to get record - %v", err) - } - if len(r.Answer) != 0 { - t.Error("Found non-existant record") - } - - r, err = ResolveIt("nanopack.io", dns.TypeMX, true) - if err != nil { - t.Errorf("Failed to get record - %v", err) - } - if len(r.Answer) != 0 { - t.Error("Found non-existant record") - } -} + r, err := ResolveIt("nanopack.io", dns.TypeA) + if err != nil { + t.Errorf("Failed to get record - %v", err) + } + if len(r.Answer) == 0 { + t.Error("No record found") + } + if len(r.Answer) > 0 && r.Answer[0].String() != "nanopack.io.\t60\tIN\tA\t127.0.0.1" { + t.Errorf("Response doesn't match expected - %+q", r.Answer[0].String()) + } + r, err = ResolveIt("nanobox.io", dns.TypeA) + if err != nil { + t.Errorf("Failed to get record - %v", err) + } + if len(r.Answer) != 0 { + t.Error("Found non-existant record") + } + + r, err = ResolveIt("nanopack.io", dns.TypeMX, true) + if err != nil { + t.Errorf("Failed to get record - %v", err) + } + if len(r.Answer) != 0 { + t.Error("Found non-existant record") + } +} func ResolveIt(domain string, rType uint16, badop ...bool) (*dns.Msg, error) { - // root domain if not already - root(&domain) - m := new(dns.Msg) - m.SetQuestion(domain, rType) - - if len(badop) > 0 { - m.Opcode = dns.OpcodeStatus - } - - // ask the dns server - r, err := dns.Exchange(m, config.DnsListen) - if err != nil { - return nil, fmt.Errorf("Failed to exchange - %v", err) - } - - return r, nil + // root domain if not already + root(&domain) + m := new(dns.Msg) + m.SetQuestion(domain, rType) + + if len(badop) > 0 { + m.Opcode = dns.OpcodeStatus + } + + // ask the dns server + r, err := dns.Exchange(m, config.DnsListen) + if err != nil { + return nil, fmt.Errorf("Failed to exchange - %v", err) + } + + return r, nil } func root(domain *string) { - t := []byte(*domain) + t := []byte(*domain) if len(t) > 0 && t[len(t)-1] != '.' { *domain = string(append(t, '.')) } diff --git a/version.go b/version.go index b8776c5..d2a9433 100644 --- a/version.go +++ b/version.go @@ -1,3 +1,3 @@ package main -const VERSION = "0.0.2" +const VERSION = "0.0.3" From 88f706c01556d3894962d2f1db534486e83ae021 Mon Sep 17 00:00:00 2001 From: Greg Linton Date: Thu, 12 May 2016 14:17:17 -0600 Subject: [PATCH 4/6] Add test for main --- cache/cache.go | 3 +- main.go | 35 +++++++++++---------- main_test.go | 82 ++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 103 insertions(+), 17 deletions(-) create mode 100644 main_test.go diff --git a/cache/cache.go b/cache/cache.go index 5ec4266..1e13c49 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -46,7 +46,8 @@ func Initialize() error { err = storage.initialize() if err != nil { storage = nil - err = fmt.Errorf("Failed to initialize cache, turning off - %v", err) + config.Log.Info("Failed to initialize cache, turning off - %v", err) + err = nil } } diff --git a/main.go b/main.go index 140d340..72aae0a 100644 --- a/main.go +++ b/main.go @@ -39,7 +39,6 @@ package main import ( "fmt" - "os" "github.com/jcelliott/lumber" "github.com/spf13/cobra" @@ -54,12 +53,14 @@ import ( var ( // shaman provides the shaman cli/server functionality shamanTool = &cobra.Command{ - Use: "shaman", - Short: "shaman - api driven dns server", - Long: ``, - PersistentPreRun: readConfig, - PreRun: preFlight, - Run: startShaman, + Use: "shaman", + Short: "shaman - api driven dns server", + Long: ``, + PersistentPreRunE: readConfig, + PreRunE: preFlight, + RunE: startShaman, + SilenceErrors: true, + SilenceUsage: true, } ) @@ -79,33 +80,35 @@ func main() { shamanTool.Execute() } -func readConfig(ccmd *cobra.Command, args []string) { +func readConfig(ccmd *cobra.Command, args []string) error { if err := config.LoadConfigFile(); err != nil { - fmt.Println(err) - os.Exit(1) + fmt.Printf("Error: %v\n", err) + return err } + return nil } -func preFlight(ccmd *cobra.Command, args []string) { +func preFlight(ccmd *cobra.Command, args []string) error { if config.Version { fmt.Printf("shaman %s\n", VERSION) - os.Exit(0) + return fmt.Errorf("") } if !config.Server { ccmd.HelpFunc()(ccmd, args) - os.Exit(0) + return fmt.Errorf("") } + return nil } -func startShaman(ccmd *cobra.Command, args []string) { +func startShaman(ccmd *cobra.Command, args []string) error { config.Log = lumber.NewConsoleLogger(lumber.LvlInt(config.LogLevel)) // initialize cache err := cache.Initialize() if err != nil { config.Log.Fatal(err.Error()) - os.Exit(1) + return err } // make channel for errors @@ -121,6 +124,6 @@ func startShaman(ccmd *cobra.Command, args []string) { // break if any of them return an error (blocks exit) if err := <-errors; err != nil { config.Log.Fatal(err.Error()) - os.Exit(1) } + return err } diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..c86eaab --- /dev/null +++ b/main_test.go @@ -0,0 +1,82 @@ +package main + +import ( + "io" + "os" + "strings" + "testing" + "time" + + "github.com/nanopack/shaman/config" +) + +var discard io.Writer = devNull(0) + +// dummy writer type +type devNull int + +// dummy write method +func (devNull) Write(p []byte) (int, error) { + return len(p), nil +} + +func TestMain(m *testing.M) { + config.LogLevel = "fatal" + args := strings.Split("-O 127.0.0.1:8053 -2 none:// -s", " ") + shamanTool.SetArgs(args) + shamanTool.SetOutput(discard) + + go main() + <-time.After(time.Second) + + // run tests + rtn := m.Run() + + os.Exit(rtn) +} + +func TestShowHelp(t *testing.T) { + config.Server = false + shamanTool.SetArgs([]string{""}) + + shamanTool.Execute() +} + +func TestBadConfig(t *testing.T) { + args := strings.Split("-c /tmp/nowaythisexists list", " ") + shamanTool.SetArgs(args) + + shamanTool.Execute() + config.ConfigFile = "" +} + +func TestShowVersion(t *testing.T) { + args := strings.Split("-v", " ") + shamanTool.SetArgs(args) + + shamanTool.Execute() + config.Version = false +} + +func TestBadCache(t *testing.T) { + config.L2Connect = "!@#$%^&" + args := strings.Split("-s", " ") + shamanTool.SetArgs(args) + + shamanTool.Execute() + config.L2Connect = "none://" +} + +func TestBadDNSListen(t *testing.T) { + config.L2Connect = "none://" + config.DnsListen = "127.0.0.1:53" + args := strings.Split("-s", " ") + shamanTool.SetArgs(args) + + go shamanTool.Execute() + <-time.After(time.Second) + + // port already in use, will fail here + shamanTool.Execute() + config.DnsListen = "127.0.0.1:8053" +} From 0a9a6b29b5caf16bccc222779cb34e163c6916c8 Mon Sep 17 00:00:00 2001 From: Greg Linton Date: Thu, 12 May 2016 14:26:29 -0600 Subject: [PATCH 5/6] Minor update to tests --- cache/scribble_test.go | 2 +- main_test.go | 19 +++++++------------ 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/cache/scribble_test.go b/cache/scribble_test.go index 1dead1e..e553da6 100644 --- a/cache/scribble_test.go +++ b/cache/scribble_test.go @@ -18,7 +18,7 @@ func TestScribbleInitialize(t *testing.T) { err3 := cache.Initialize() config.L2Connect = "scribble:///" // defaulting to "/var/db" cache.Initialize() - if err != nil || err2 == nil || err3 == nil { + if err != nil || err2 == nil || err3 != nil { t.Errorf("Failed to initalize scribble cacher - %v%v%v", err, err2, err3) } } diff --git a/main_test.go b/main_test.go index c86eaab..da3a593 100644 --- a/main_test.go +++ b/main_test.go @@ -1,7 +1,7 @@ package main import ( - "io" + "bytes" "os" "strings" "testing" @@ -10,22 +10,17 @@ import ( "github.com/nanopack/shaman/config" ) -var discard io.Writer = devNull(0) - -// dummy writer type -type devNull int - -// dummy write method -func (devNull) Write(p []byte) (int, error) { - return len(p), nil -} - func TestMain(m *testing.M) { + // manually configure config.LogLevel = "fatal" + discard := &bytes.Buffer{} + shamanTool.SetOutput(discard) + + // set args for shaman args := strings.Split("-O 127.0.0.1:8053 -2 none:// -s", " ") shamanTool.SetArgs(args) - shamanTool.SetOutput(discard) + // run shaman server go main() <-time.After(time.Second) From 67d5d59dd8e09d71853010cb4f9c62bb9a89284a Mon Sep 17 00:00:00 2001 From: Greg Linton Date: Fri, 13 May 2016 10:05:32 -0600 Subject: [PATCH 6/6] Search upper level domains if domain not found --- README.md | 1 - core/common/common.go | 2 +- server/dns.go | 46 ++++++++++++++++++++++++++++++++++++++----- server/dns_test.go | 2 +- 4 files changed, 43 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 7845558..12ef52d 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,6 @@ shaman add -d nanopack.io -A 127.0.0.1 dig @localhost nanopack.io +short # 127.0.0.1 - # Congratulations! ``` diff --git a/core/common/common.go b/core/common/common.go index d914b1a..1882f96 100644 --- a/core/common/common.go +++ b/core/common/common.go @@ -40,7 +40,7 @@ func SanitizeDomain(domain *string) { } } -// UnsanitizeDomain ensures the domain ends with a `.` +// UnsanitizeDomain ensures the domain does not end with a `.` func UnsanitizeDomain(domain *string) { t := []byte(*domain) if len(t) > 0 && t[len(t)-1] == '.' { diff --git a/server/dns.go b/server/dns.go index 7d51aac..1b738b4 100644 --- a/server/dns.go +++ b/server/dns.go @@ -3,6 +3,7 @@ package server import ( "fmt" + "strings" "github.com/miekg/dns" @@ -45,13 +46,18 @@ func handlerFunc(res dns.ResponseWriter, req *dns.Msg) { // answerQuestion returns resource record answers for the domain in question func answerQuestion(question dns.Question) []dns.RR { answers := make([]dns.RR, 0) - r, _ := shaman.GetRecord(question.Name) - records := r.StringSlice() - // fmt.Printf("Records received - %+q\n", records) - for _, record := range records { + + // get the resource (check memory, cache, and (todo:) upstream) + r, err := shaman.GetRecord(question.Name) + if err != nil { + config.Log.Trace("Failed to get records for '%s' - %v", question.Name, err) + } + + // validate the records and append correct type to answers[] + for _, record := range r.StringSlice() { entry, err := dns.NewRR(record) if err != nil { - config.Log.Trace("Failed to create RR from record - %v", err) + config.Log.Debug("Failed to create RR from record - %v", err) continue } entry.Header().Name = question.Name @@ -60,5 +66,35 @@ func answerQuestion(question dns.Question) []dns.RR { } } + // todo: should `shaman.GetRecord` be wildcard aware (*.domain.com) or is this ok + // recursively resolve if no records found + if len(answers) == 0 { + question.Name = stripSubdomain(question.Name) + if len(question.Name) > 0 { + config.Log.Trace("Checking again with '%v'", question.Name) + return answerQuestion(question) + } + } + return answers } + +// stripSubdomain strips off the subbest domain, returning the domain (won't return TLD) +func stripSubdomain(name string) string { + words := 3 // assume rooted domain (end with '.') + // handle edge case of unrooted domain + t := []byte(name) + if len(t) > 0 && t[len(t)-1] != '.' { + words = 2 + } + + config.Log.Trace("Stripping subdomain from '%v'", name) + names := strings.Split(name, ".") + + // prevent searching for just 'com.' (["domain", "com", ""]) + if len(names) > words { + return strings.Join(names[1:], ".") + } else { + return "" + } +} diff --git a/server/dns_test.go b/server/dns_test.go index 57b280c..c011a33 100644 --- a/server/dns_test.go +++ b/server/dns_test.go @@ -50,7 +50,7 @@ func TestDNS(t *testing.T) { t.Errorf("Response doesn't match expected - %+q", r.Answer[0].String()) } - r, err = ResolveIt("nanobox.io", dns.TypeA) + r, err = ResolveIt("a.b.nanobox.io", dns.TypeA) if err != nil { t.Errorf("Failed to get record - %v", err) }