Skip to content

Instantly share code, notes, and snippets.

@bcwaldon
Created April 30, 2015 01:53
Show Gist options
  • Save bcwaldon/42ec473c093311c4bf15 to your computer and use it in GitHub Desktop.
Save bcwaldon/42ec473c093311c4bf15 to your computer and use it in GitHub Desktop.
TCP proxy testing
package main
import (
"fmt"
"io"
"net"
"os"
"strings"
"sync"
)
var gwg sync.WaitGroup
func maybeAddNewline(s string) string {
if !strings.HasSuffix(s, "\n") {
s = s + "\n"
}
return s
}
func stderr(format string, args ...interface{}) {
fmt.Fprintf(os.Stderr, maybeAddNewline(format), args...)
}
func main() {
srv, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 59001})
if err != nil {
panic(err)
}
go func() {
for {
conn, err := srv.AcceptTCP()
if err != nil {
panic(err)
}
go io.Copy(conn, conn)
}
}()
proxy, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 59000})
if err != nil {
panic(err)
}
go func() {
for {
cliConn, err := proxy.AcceptTCP()
if err != nil {
panic(err)
}
srvConn, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 59001})
if err != nil {
panic(err)
}
gwg.Add(1)
go func() {
proxyTCP(cliConn, srvConn)
gwg.Done()
}()
}
}()
cli, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 59000})
if err != nil {
panic(err)
}
n, err := cli.Write([]byte("FOO"))
if err != nil {
panic(err)
}
stderr("client wrote %d bytes", n)
cli.Close()
data := make([]byte, 3)
_, err = cli.Read(data)
if err != nil {
stderr("client failed to read: %v", err)
}
stderr("client got data: %s", data)
gwg.Wait()
stderr("global wait group done")
}
func proxyTCP(src, dst *net.TCPConn) {
stderr("connecting %s <-> %s", src.RemoteAddr(), dst.RemoteAddr())
var wg sync.WaitGroup
wg.Add(2)
go copyBytes(src, dst, &wg)
go copyBytes(dst, src, &wg)
wg.Wait()
stderr("closing %s <-> %s", src.RemoteAddr(), dst.RemoteAddr())
}
func copyBytes(dst, src *net.TCPConn, wg *sync.WaitGroup) {
defer wg.Done()
n, err := io.Copy(dst, src)
if err != nil {
stderr("i/o error: %v", err)
}
stderr("copied %d bytes %s-> %s", n, src.RemoteAddr(), dst.RemoteAddr())
//dst.Close()
//src.Close()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment