Last active
April 21, 2020 02:00
-
-
Save korc/66fe38ad3f285c691648af3d91044cef to your computer and use it in GitHub Desktop.
S3 proxy test
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"bytes" | |
"crypto/tls" | |
"crypto/x509" | |
"flag" | |
"fmt" | |
"io" | |
"io/ioutil" | |
"log" | |
"net/http" | |
"net/http/httptrace" | |
"net/url" | |
"os" | |
"time" | |
"github.com/aws/aws-sdk-go/aws/credentials" | |
"github.com/aws/aws-sdk-go/aws/signer/v4" | |
) | |
const ( | |
envKey = "AWS_ACCESS_KEY_ID" | |
envSecret = "AWS_SECRET_ACCESS_KEY" | |
preSignExpireDuration = time.Minute * 5 | |
amzDateFormat = "20060102T150405Z" | |
) | |
func main() { | |
listenAddr := flag.String("listen", ":8080", "Listen address") | |
remoteAddr := flag.String("remote", "", "Remote URL") | |
awsId := flag.String("key", os.Getenv(envKey), "override $"+envKey) | |
awsSecret := flag.String("secret", os.Getenv(envSecret), "override $"+envSecret) | |
caFile := flag.String("cafile", "", "CA bundle") | |
region := flag.String("region", "us-east-1", "Default region") | |
service := flag.String("service", "s3", "default service") | |
fixedSignTimeFlag := flag.String("sign-time", "", fmt.Sprintf( | |
"Fixed signing time in UTC / "+amzDateFormat+" format (now: %s)", time.Now().UTC().Format(amzDateFormat))) | |
flag.Parse() | |
if *remoteAddr == "" { | |
log.Fatal("Need -remote <url> option") | |
} | |
if *awsId == "" || *awsSecret == "" { | |
log.Fatal("Need to set -key or $" + envKey + ", and -secret or $" + envSecret) | |
} | |
if err := os.Setenv(envKey, *awsId); err != nil { | |
log.Fatal("Cannot set environment: ", err) | |
} | |
if err := os.Setenv(envSecret, *awsSecret); err != nil { | |
log.Fatal("Cannot set environment: ", err) | |
} | |
remoteUrl, err := url.Parse(*remoteAddr) | |
if err != nil { | |
log.Fatal("Cannot parse remote URL: ", err) | |
} | |
var fixedSignTime time.Time | |
if *fixedSignTimeFlag != "" { | |
var err error | |
if fixedSignTime, err = time.Parse(amzDateFormat, *fixedSignTimeFlag); err != nil { | |
log.Fatalf("Cannot parse time %#v as "+amzDateFormat, *fixedSignTimeFlag) | |
} | |
fixedSignTime = fixedSignTime.UTC() | |
} | |
signer := v4.NewSigner(credentials.NewEnvCredentials()) | |
tr := http.DefaultTransport.(*http.Transport) | |
if *caFile != "" { | |
caData, err := ioutil.ReadFile(*caFile) | |
if err != nil { | |
log.Fatal("Cannot read CA certs: ", err) | |
} | |
if tr.TLSClientConfig == nil { | |
tr.TLSClientConfig = &tls.Config{} | |
} | |
tr.TLSClientConfig.RootCAs = x509.NewCertPool() | |
if !tr.TLSClientConfig.RootCAs.AppendCertsFromPEM(caData) { | |
log.Fatalf("Could not add CA cert data from %#v", *caFile) | |
} | |
} | |
reportError := func(w http.ResponseWriter, statusCode int, message string, err error) { | |
log.Printf("Error: %s: %s", message, err) | |
w.WriteHeader(statusCode) | |
if _, err := w.Write([]byte(message)); err != nil { | |
log.Printf("Could not send error to client: %s", err) | |
} | |
} | |
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { | |
log.Printf("Got request: %s", r.RequestURI) | |
r.RequestURI = "" | |
var bodyReader io.ReadSeeker | |
if r.Body != nil { | |
if b, ok := r.Body.(io.ReadSeeker); ok { | |
bodyReader = b | |
} else { | |
data, err := ioutil.ReadAll(r.Body) | |
if err != nil { | |
reportError(w, http.StatusBadRequest, "Could not read body", err) | |
return | |
} | |
bodyReader = bytes.NewReader(data) | |
} | |
} | |
r.URL.Scheme = remoteUrl.Scheme | |
r.URL.Host = remoteUrl.Host | |
r.Host = remoteUrl.Host | |
var signTime time.Time | |
if fixedSignTime.IsZero() { | |
signTime = time.Now().UTC() | |
} else { | |
signTime = fixedSignTime | |
} | |
if timeHeader := r.Header.Get("X-Amz-Date"); timeHeader != "" { | |
signTime, err = time.Parse(amzDateFormat, timeHeader) | |
if err != nil { | |
reportError(w, http.StatusBadRequest, "Bad time format", err) | |
return | |
} | |
} | |
pr := r.Clone(r.Context()) | |
preSignHdr, err := signer.Presign(pr, nil, *service, *region, preSignExpireDuration, signTime) | |
if err != nil { | |
log.Printf("Cannot do presign: %s", err) | |
} else { | |
log.Printf("Pre-Signed header/request: %#v, %s", preSignHdr, pr.URL) | |
} | |
signHdr, err := signer.Sign(r, bodyReader, *service, *region, signTime) | |
if err != nil { | |
reportError(w, http.StatusInternalServerError, "Could not sign", err) | |
return | |
} | |
log.Printf("Signed headers: %#v", signHdr) | |
buf := bytes.NewBufferString("") | |
if err := r.Write(buf); err != nil { | |
log.Printf("Could not write to buffer: %s", err) | |
} | |
if bodyReader != nil { | |
if _, err := bodyReader.Seek(0, io.SeekStart); err != nil { | |
log.Printf("Could not seek body: %s", err) | |
} | |
r.Body = ioutil.NopCloser(bodyReader) | |
} | |
log.Printf("Sending request to %s:\n%s\n", r.URL.Host, buf.String()) | |
resp, err := http.DefaultClient.Do(r.Clone(httptrace.WithClientTrace(r.Context(), &httptrace.ClientTrace{ | |
WroteRequest: func(info httptrace.WroteRequestInfo) { | |
log.Printf("Wrote request (error: %v)", info.Err) | |
}, | |
DNSDone: func(info httptrace.DNSDoneInfo) { | |
log.Printf("DNS done: %s", info.Addrs) | |
}, | |
WroteHeaderField: func(key string, value []string) { | |
for _, v := range value { | |
log.Printf("Wrote header %s: %s", key, v) | |
} | |
}, | |
}))) | |
if err != nil { | |
reportError(w, http.StatusBadGateway, "Could not connect to remote", err) | |
return | |
} | |
log.Printf("Response status %d %s", resp.StatusCode, resp.Status) | |
respBody, err := ioutil.ReadAll(resp.Body) | |
if err != nil { | |
reportError(w, http.StatusInternalServerError, "Cannot read response body", err) | |
return | |
} | |
wHdr := w.Header() | |
for hdr := range resp.Header { | |
wHdr[hdr] = resp.Header[hdr] | |
log.Printf("Response header %s: %s", hdr, resp.Header[hdr]) | |
} | |
w.WriteHeader(resp.StatusCode) | |
log.Printf("Response body: %#v", string(respBody)) | |
if _, err := w.Write(respBody); err != nil { | |
log.Printf("Could not send response body to client: %s", err) | |
return | |
} | |
}) | |
if err := http.ListenAndServe(*listenAddr, nil); err != nil { | |
log.Fatal("Could not listen: ", err) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment