summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJonathan Turner <jt@jtnet.co.uk>2017-11-29 10:58:51 -0800
committerJonathan Turner <jt@jtnet.co.uk>2017-11-29 10:58:51 -0800
commita58124b8ce973a500cfdade1eecda03031775eb7 (patch)
tree72857ecd3f623a176d4c32653544dd2ef8e8e722
parent051f7eaa7d61ef52f7da333d142c883be331db84 (diff)
update for override ns
-rw-r--r--srv.go35
1 files changed, 32 insertions, 3 deletions
diff --git a/srv.go b/srv.go
index 5c25d07..94b0537 100644
--- a/srv.go
+++ b/srv.go
@@ -1,8 +1,11 @@
package dnsutils
import (
+ "context"
+ "fmt"
"math/rand"
"net"
+ "os"
"sort"
)
@@ -31,13 +34,39 @@ func (s SRVRecords) Less(i, j int) bool {
// // Do something such as dial this SRV. If fails move on the the next
// i += 1
// }
+//
+// To override the operating system's name server set the DNSUTILS_OVERRIDE_NS environment variable.
+// The override should also include the port. For example 127.0.0.1:53
func OrderedSRV(service, proto, name string) (int, map[int]*net.SRV, error) {
- _, addrs, err := net.LookupSRV(service, proto, name)
+ var addrs []*net.SRV
+ var err error
+ ns := os.Getenv("DNSUTILS_OVERRIDE_NS")
+ if ns != "" {
+ res := net.Resolver{Dial: overrideNS}
+ _, addrs, err = res.LookupSRV(context.Background(), service, proto, name)
+ } else {
+ _, addrs, err = net.LookupSRV(service, proto, name)
+ }
if err != nil {
return 0, make(map[int]*net.SRV), err
}
- index, os := orderSRV(addrs)
- return index, os, nil
+ index, osrv := orderSRV(addrs)
+ return index, osrv, nil
+}
+
+func overrideNS(ctx context.Context, network, address string) (conn net.Conn, err error) {
+ // Ignore the address provided and override with an environment variable if it is defined
+ ns := os.Getenv("DNSUTILS_OVERRIDE_NS")
+ if ns != "" {
+ address = ns
+ }
+ if network == "tcp" || network == "udp" {
+ var d net.Dialer
+ conn, err = d.DialContext(ctx, network, address)
+ return
+ }
+ err = fmt.Errorf("unsupported network protocol %s for DNS lookup", network)
+ return
}
func orderSRV(addrs []*net.SRV) (int, map[int]*net.SRV) {