Created
September 24, 2017 19:31
-
-
Save khafatech/169e665826c050e1252c2295cbef15a1 to your computer and use it in GitHub Desktop.
Simple dns server in go
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
// go build dns.go && ./dns` | |
// dig +qr +tries=1 +time=1 @localhost -p 9000 foo.com | |
import ( | |
"bytes" | |
"encoding/binary" | |
"encoding/hex" | |
"errors" | |
"fmt" | |
"io" | |
"log" | |
"net" | |
) | |
const ( | |
// in bytes | |
HeaderSize = 12 | |
) | |
// These represent the | |
type DNSHeader struct { | |
Id uint16 | |
Flags uint16 | |
Queries uint16 | |
Answers uint16 | |
Auths uint16 | |
Additional uint16 | |
} | |
type DNSQuery struct { | |
Name []byte | |
NameRaw []byte // with length info | |
Type uint16 | |
Class uint16 | |
} | |
type DNSResourceRecord struct { | |
Name []byte | |
Type uint16 | |
Class uint16 | |
TTL uint32 | |
Rdatalen uint16 | |
Rdata []byte | |
} | |
// A, AAAA, MX, etc. | |
type RType uint16 | |
type DNSRequest struct { | |
Id uint16 | |
Type uint16 | |
Name []byte | |
NameRaw []byte | |
} | |
func newARecord(name []byte, ip net.IP) DNSResourceRecord { | |
return DNSResourceRecord{ | |
Name: name, | |
Type: 1, // A | |
Class: 1, | |
TTL: 0, | |
Rdatalen: 4, | |
Rdata: []byte(ip), | |
} | |
} | |
func decodeRequest(request_bytes []byte) (DNSRequest, error) { | |
var request DNSRequest | |
header, err := decodeHeader(request_bytes) | |
if err != nil { | |
fmt.Println("failed decoding header:", err) | |
return request, err | |
} | |
// FIXME - check counts from header | |
if len(request_bytes) <= HeaderSize { | |
fmt.Println("only header") | |
return request, err | |
} | |
if header.Queries > 0 { | |
query, err := decodeQuery(request_bytes[HeaderSize:]) | |
if err != nil { | |
fmt.Println("error decoding query:", err) | |
return request, err | |
} | |
return DNSRequest{ | |
Id: header.Id, | |
Name: query.Name, | |
NameRaw: query.NameRaw, | |
Type: query.Type}, nil | |
} | |
return request, fmt.Errorf("No queries in request") | |
} | |
func decodeHeader(b []byte) (DNSHeader, error) { | |
var header DNSHeader | |
buf := bytes.NewReader(b) | |
err := binary.Read(buf, binary.BigEndian, &header) | |
if err != nil { | |
return header, err | |
} | |
fmt.Printf("Header: %#v", header) | |
fmt.Printf("request id: %d\n", header.Id) | |
return header, nil | |
} | |
func decodeQuery(b []byte) (DNSQuery, error) { | |
var query DNSQuery | |
fmt.Printf("Query:\n%s", hex.Dump(b)) | |
buf := bytes.NewReader(b) | |
query.Name = parseRequestName(buf) | |
i := bytes.IndexByte(b, 0) | |
query.NameRaw = b[:i+1] | |
fmt.Printf("name: '%s' len(query.NameRaw): %d\n", query.Name, len(query.NameRaw)) | |
err := binary.Read(buf, binary.BigEndian, &query.Type) | |
if err != nil { | |
return query, err | |
} | |
err = binary.Read(buf, binary.BigEndian, &query.Class) | |
if err != nil { | |
return query, err | |
} | |
fmt.Printf("type: %d, class: %d\n", query.Type, query.Class) | |
return query, nil | |
} | |
func parseRequestNameSimple(b []byte) ([]byte, error) { | |
if len(b) == 0 { | |
return nil, errors.New("Name must not be empty") | |
} | |
count := b[0] | |
if count == 0 { | |
return []byte{}, nil | |
} | |
// + 2 to account for zero at the end | |
if int(count) > len(b)+2 { | |
return nil, errors.New(fmt.Sprintf("Wrong count: %d", count)) | |
} | |
this := b[1 : count+1] | |
rest, err := parseRequestNameSimple(b[count+1:]) | |
var result []byte | |
if len(rest) > 0 { | |
result = append(append(this, "."...), rest...) | |
} else { | |
result = this | |
} | |
return result, err | |
} | |
func parseRequestName(reader *bytes.Reader) []byte { | |
var name []byte = make([]byte, 512) | |
debReader := &blockReader{r: reader} | |
// FIXME - handle errors | |
io.ReadFull(debReader, name) | |
fmt.Println("len name:", len(name)) | |
fmt.Println("len string(name):", len(string(name))) | |
nameStr := fmt.Sprintf("%s", string(name)) | |
fmt.Println("len nameStr:", len(nameStr)) | |
// remove findal dot | |
// see TruncateAtFinalSlash() in https://blog.golang.org/slices | |
i := bytes.LastIndex(name, []byte(".")) | |
if i >= 0 { | |
name = name[0:i] | |
} | |
return name | |
} | |
type blockReader struct { | |
r *bytes.Reader | |
slice []byte | |
tmp [256]byte | |
} | |
// from https://blog.golang.org/gif-decoder-exercise-in-go-interfaces | |
func (b *blockReader) Read(p []byte) (int, error) { | |
if len(p) == 0 { | |
return 0, nil | |
} | |
if len(b.slice) == 0 { | |
blockLen, err := b.r.ReadByte() | |
if err != nil { | |
return 0, err | |
} | |
if blockLen == 0 { | |
return 0, io.EOF | |
} | |
b.slice = b.tmp[0:blockLen] | |
if _, err = io.ReadFull(b.r, b.slice); err != nil { | |
return 0, err | |
} | |
b.slice = append(b.slice, '.') | |
} | |
n := copy(p, b.slice) | |
b.slice = b.slice[n:] | |
return n, nil | |
} | |
func serializeResponse(header DNSHeader, rr DNSResourceRecord) []byte { | |
buf := new(bytes.Buffer) | |
err := binary.Write(buf, binary.BigEndian, header) | |
if err != nil { | |
fmt.Println("error serializing header:", err) | |
} | |
buf.Write(rr.Bytes()) | |
return buf.Bytes() | |
} | |
func (rr *DNSResourceRecord) Bytes() []byte { | |
buf := new(bytes.Buffer) | |
buf.Write(rr.Name) | |
twobytes := make([]byte, 2) | |
fourbytes := make([]byte, 4) | |
binary.BigEndian.PutUint16(twobytes, rr.Type) | |
buf.Write(twobytes) | |
binary.BigEndian.PutUint16(twobytes, rr.Class) | |
buf.Write(twobytes) | |
binary.BigEndian.PutUint32(fourbytes, rr.TTL) | |
buf.Write(fourbytes) | |
binary.BigEndian.PutUint16(twobytes, rr.Rdatalen) | |
buf.Write(twobytes) | |
buf.Write(rr.Rdata) | |
return buf.Bytes() | |
} | |
func handleRequest(pc net.PacketConn, request DNSRequest, clientAddr net.Addr) { | |
hosts := map[string]string{ | |
"foo.com": "10.0.0.1", | |
"reddit.com": "10.0.0.2", | |
"any.thing.io": "192.178.0.3", | |
} | |
ipStr := hosts[string(request.Name)] | |
if ipStr != "" { | |
fmt.Println("found ip:", ipStr) | |
responseRR := newARecord(request.NameRaw, net.ParseIP(ipStr).To4()) | |
fmt.Printf("response: %#v\n", responseRR) | |
responseHeader := DNSHeader{ | |
Id: request.Id, | |
Flags: 0x8000, | |
Queries: 0, | |
Answers: 1, | |
Auths: 0, | |
Additional: 0, | |
} | |
responseBytes := serializeResponse(responseHeader, responseRR) | |
fmt.Printf("Response:\n%s", hex.Dump(responseBytes)) | |
pc.WriteTo(responseBytes, clientAddr) | |
} | |
} | |
func main() { | |
port := 9000 | |
pc, err := net.ListenPacket("udp", fmt.Sprintf("localhost:%d", port)) | |
if err != nil { | |
log.Fatal(err) | |
} else { | |
log.Println("listening on port", port) | |
} | |
// what does this do? | |
defer pc.Close() | |
buffer := make([]byte, 1024) | |
for { | |
n, addr, err := pc.ReadFrom(buffer) | |
if err != nil { | |
log.Fatal(err) | |
} | |
request_bytes := buffer[:n] | |
fmt.Printf("request length %d\n", len(request_bytes)) | |
fmt.Printf("Request:\n%s", hex.Dump(request_bytes)) | |
request, err := decodeRequest(request_bytes) | |
handleRequest(pc, request, addr) | |
// input := string(request_bytes) | |
// pc.WriteTo([]byte("Hello there "+strings.ToUpper(input)), addr) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment