-
-
Save 3manuek/b889d43678d12f950aee to your computer and use it in GitHub Desktop.
simple proxy
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"fmt" | |
"io" | |
"io/ioutil" | |
"net" | |
"net/http" | |
"sync" | |
"time" | |
) | |
const TIMEOUT = 10 * time.Second | |
func handleReq(w http.ResponseWriter, r *http.Request) { | |
if r.Method != "CONNECT" { | |
proxyReq(w, r) | |
} else { | |
tunnelReq(w, r) | |
} | |
} | |
func tunnelReq(w http.ResponseWriter, r *http.Request) { | |
fmt.Printf("recv tunnel req: %+v\n", r.URL.String()) | |
hj, ok := w.(http.Hijacker) | |
if !ok { | |
http.Error(w, "webserver doesn't support hijacking", http.StatusInternalServerError) | |
return | |
} | |
conn, bufrw, err := hj.Hijack() | |
if err != nil { | |
fmt.Println("hijack err:", err.Error()) | |
http.Error(w, err.Error(), http.StatusInternalServerError) | |
return | |
} | |
defer conn.Close() | |
connDest, err := net.DialTimeout("tcp", r.URL.Host, TIMEOUT) | |
if err != nil { | |
fmt.Println("dial err:", err.Error()) | |
bufrw.WriteString("HTTP/1.1 500 Internal Server Error\r\n\r\n") | |
return | |
} | |
defer connDest.Close() | |
bufrw.WriteString("HTTP/1.1 200 Connection Established\r\n\r\n") | |
bufrw.Flush() | |
var wg sync.WaitGroup | |
wg.Add(2) | |
go func() { | |
_, err := io.Copy(connDest, conn) | |
if err != nil { | |
fmt.Println("copy err:", err) | |
} | |
conn.Close() | |
connDest.Close() | |
fmt.Println("src -> dest close") | |
wg.Done() | |
}() | |
go func() { | |
_, err := io.Copy(conn, connDest) | |
if err != nil { | |
fmt.Println("copy err:", err) | |
} | |
conn.Close() | |
connDest.Close() | |
fmt.Println("dest -> src close") | |
wg.Done() | |
}() | |
wg.Wait() | |
fmt.Println("disconnect tunnel") | |
} | |
func proxyReq(w http.ResponseWriter, r *http.Request) { | |
fmt.Printf("recv proxy req: %+v\n", r.URL.String()) | |
client := &http.Client{Timeout: TIMEOUT} | |
req, e := http.NewRequest(r.Method, r.URL.String(), r.Body) | |
if e != nil { | |
fmt.Println("create request err: ", e) | |
return | |
} | |
req.Header = r.Header | |
resp, e := client.Do(req) | |
if e != nil { | |
fmt.Println("do client err: ", e) | |
return | |
} | |
for k, v := range resp.Header { | |
w.Header()[k] = v | |
} | |
w.WriteHeader(resp.StatusCode) | |
defer resp.Body.Close() | |
body, e := ioutil.ReadAll(resp.Body) | |
if e != nil { | |
fmt.Println("read body err:", e) | |
return | |
} | |
_, e = w.Write(body) | |
if e != nil { | |
fmt.Println("write body err:", e) | |
return | |
} | |
fmt.Println("end proxy req") | |
} | |
func main() { | |
s := &http.Server{Addr: ":12345", Handler: http.HandlerFunc(handleReq), ReadTimeout: TIMEOUT, WriteTimeout: TIMEOUT} | |
err := s.ListenAndServe() | |
if err != nil { | |
fmt.Println("ListenAndServe: ", err) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment