Skip to content

Instantly share code, notes, and snippets.

@venkatd
Last active November 20, 2024 17:18
Show Gist options
  • Save venkatd/18f1cbe15fa53c21e63f65e2771f44c2 to your computer and use it in GitHub Desktop.
Save venkatd/18f1cbe15fa53c21e63f65e2771f44c2 to your computer and use it in GitHub Desktop.
Minimal S3 lib based on Req - MIT LICENSE or do whatever you want with it plus I copiedbig chunks from ReqS3/ExAws anyway
defmodule Ex.S3.Auth do
def put_aws_sigv4(request) do
if s3_options = request.options[:s3] do
s3_options =
s3_options
|> Keyword.put_new(:region, "us-east-1")
|> Keyword.put_new(:datetime, DateTime.utc_now())
# aws_credentials returns this key so let's ignore it
|> Keyword.drop([:credential_provider, :bucket, :endpoint])
|> Keyword.put(:service, :s3)
Req.Request.validate_options(s3_options, [
:access_key_id,
:secret_access_key,
:token,
:service,
:region,
:datetime,
# for req_s3
:expires
])
unless s3_options[:access_key_id] do
raise ArgumentError, "missing :access_key_id in :s3 option"
end
unless s3_options[:secret_access_key] do
raise ArgumentError, "missing :secret_access_key in :s3 option"
end
{body, options} =
case request.body do
nil ->
{"", []}
iodata when is_binary(iodata) or is_list(iodata) ->
{iodata, []}
_enumerable ->
if Req.Request.get_header(request, "content-length") == [] do
raise "content-length header must be explicitly set when streaming request body"
end
{"", [body_digest: "UNSIGNED-PAYLOAD"]}
end
request = Req.Request.put_new_header(request, "host", request.url.host)
headers = for {name, values} <- request.headers, value <- values, do: {name, value}
headers =
aws_sigv4_headers(
s3_options ++
[
method: request.method,
url: to_string(request.url),
headers: headers,
body: body
] ++ options
)
request
|> Req.merge(headers: headers)
|> Req.Request.append_response_steps(s3_decode: &decode_body/1)
else
request
end
end
defp decode_body({request, response}) do
if request.method in [:get, :head] and
request.options[:decode_body] != false and
request.options[:raw] != true and
match?(["application/xml" <> _], response.headers["content-type"]) do
response = update_in(response.body, &Ex.S3.XML.parse_s3/1)
{request, response}
else
{request, response}
end
end
@doc """
Create AWS Signature v4.
https://docs.aws.amazon.com/AmazonS3/latest/API/sig-v4-header-based-auth.html
"""
def aws_sigv4_headers(options) do
{access_key_id, options} = Keyword.pop!(options, :access_key_id)
{secret_access_key, options} = Keyword.pop!(options, :secret_access_key)
{security_token, options} = Keyword.pop(options, :token)
{region, options} = Keyword.pop!(options, :region)
{service, options} = Keyword.pop!(options, :service)
{datetime, options} = Keyword.pop!(options, :datetime)
{method, options} = Keyword.pop!(options, :method)
{url, options} = Keyword.pop!(options, :url)
{headers, options} = Keyword.pop!(options, :headers)
{body, options} = Keyword.pop!(options, :body)
Keyword.validate!(options, [:body_digest])
datetime = DateTime.truncate(datetime, :second)
datetime_string = DateTime.to_iso8601(datetime, :basic)
date_string = Date.to_iso8601(datetime, :basic)
url = URI.parse(url)
body_digest = options[:body_digest] || hex(sha256(body))
service = to_string(service)
method = method |> Atom.to_string() |> String.upcase()
aws_headers = [
{"x-amz-content-sha256", body_digest},
{"x-amz-date", datetime_string}
]
aws_headers =
if security_token do
aws_headers ++ [{"x-amz-security-token", security_token}]
else
aws_headers
end
canonical_headers = headers ++ aws_headers
## canonical_headers needs to be sorted for canonical_request construction
canonical_headers = Enum.sort(canonical_headers)
signed_headers =
Enum.map_intersperse(
Enum.sort(canonical_headers),
";",
&String.downcase(elem(&1, 0), :ascii)
)
canonical_headers =
Enum.map_intersperse(canonical_headers, "\n", fn {name, value} -> [name, ":", value] end)
canonical_query = canonical_query(url.query)
canonical_request =
[
method,
"\n",
url.path,
"\n",
canonical_query,
"\n",
canonical_headers,
"\n",
"\n",
signed_headers,
"\n",
body_digest
]
|> IO.iodata_to_binary()
string_to_sign =
[
"AWS4-HMAC-SHA256",
"\n",
datetime_string,
"\n",
"#{date_string}/#{region}/#{service}/aws4_request",
"\n",
hex(sha256(canonical_request))
]
|> IO.iodata_to_binary()
signature =
aws_sigv4(
string_to_sign,
date_string,
region,
service,
secret_access_key
)
credential = "#{access_key_id}/#{date_string}/#{region}/#{service}/aws4_request"
authorization =
"AWS4-HMAC-SHA256 Credential=#{credential},SignedHeaders=#{signed_headers},Signature=#{signature}"
[{"authorization", authorization}] ++ aws_headers ++ headers
end
defp canonical_query(nil),
do: ""
defp canonical_query(query) do
for item <- String.split(query, "&", trim: true) do
case String.split(item, "=") do
[name, value] -> [name, "=", value]
[name] -> [name, "="]
end
end
|> Enum.sort()
|> Enum.intersperse("&")
end
@doc """
Create AWS Signature v4 URL.
https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-query-string-auth.html
"""
def aws_sigv4_url(options) do
{access_key_id, options} = Keyword.pop!(options, :access_key_id)
{secret_access_key, options} = Keyword.pop!(options, :secret_access_key)
{region, options} = Keyword.pop!(options, :region)
{service, options} = Keyword.pop!(options, :service)
{datetime, options} = Keyword.pop!(options, :datetime)
{method, options} = Keyword.pop!(options, :method)
{url, options} = Keyword.pop!(options, :url)
{expires, options} = Keyword.pop(options, :expires, 86400)
{headers, options} = Keyword.pop(options, :headers, [])
{query_params, options} = Keyword.pop(options, :query_params, [])
[] = options
datetime = DateTime.truncate(datetime, :second)
datetime_string = DateTime.to_iso8601(datetime, :basic)
date_string = Date.to_iso8601(datetime, :basic)
url = URI.parse(url)
service = to_string(service)
combined_headers =
[host_header(url) | headers]
|> Enum.sort(&compare_pair/2)
signed_headers =
Enum.map_intersperse(
Enum.sort(combined_headers),
";",
&String.downcase(elem(&1, 0), :ascii)
)
path = url.path
method = method |> Atom.to_string() |> String.upcase()
canonical_headers =
Enum.map_intersperse(combined_headers, "\n", fn {name, value} -> [name, ":", value] end)
aws_query_params =
[
{"X-Amz-Algorithm", "AWS4-HMAC-SHA256"},
{"X-Amz-Credential", "#{access_key_id}/#{date_string}/#{region}/#{service}/aws4_request"},
{"X-Amz-Date", datetime_string},
{"X-Amz-Expires", to_string(expires)},
{"X-Amz-SignedHeaders", signed_headers}
]
|> Enum.sort(&compare_pair/2)
# Extract existing query parameters
custom_query_params =
Enum.sort(
Enum.to_list(URI.decode_query(url.query || "")) ++ query_params,
&compare_pair/2
)
# Build canonical query string for signing
query_to_sign =
(custom_query_params ++ aws_query_params)
|> Enum.sort(&compare_pair/2)
|> encode_query_params()
canonical_request =
[
method,
"\n",
path,
"\n",
query_to_sign,
"\n",
canonical_headers,
"\n",
"\n",
signed_headers,
"\n",
"UNSIGNED-PAYLOAD"
]
|> IO.iodata_to_binary()
string_to_sign =
[
"AWS4-HMAC-SHA256",
"\n",
datetime_string,
"\n",
"#{date_string}/#{region}/#{service}/aws4_request",
"\n",
hex(sha256(canonical_request))
]
|> IO.iodata_to_binary()
signature =
aws_sigv4(
string_to_sign,
date_string,
region,
service,
secret_access_key
)
final_query_string =
case custom_query_params do
[] ->
encode_query_params(aws_query_params)
_ ->
encode_query_params(custom_query_params) <> "&" <> encode_query_params(aws_query_params)
end
# Append the signature to the URL's query string
put_in(url.query, final_query_string <> "&X-Amz-Signature=#{signature}")
end
# when it's http or https, the host implied when signing
defp host_header(%URI{host: host, port: port}) when port in [80, 443] do
{"host", host}
end
defp host_header(%URI{host: host, port: port}) do
{"host", "#{host}:#{port}"}
end
def aws_sigv4(
string_to_sign,
date_string,
region,
service,
secret_access_key
) do
signature =
["AWS4", secret_access_key]
|> hmac(date_string)
|> hmac(region)
|> hmac(service)
|> hmac("aws4_request")
|> hmac(string_to_sign)
|> hex()
signature
end
defp hex(data) do
Base.encode16(data, case: :lower)
end
defp sha256(data) do
:crypto.hash(:sha256, data)
end
defp hmac(key, data) do
:crypto.mac(:hmac, :sha256, key, data)
end
defp encode_query_params(params) do
Enum.map_join(params, "&", &pair/1)
end
defp compare_pair({key, value1}, {key, value2}), do: value1 < value2
defp compare_pair({key_1, _}, {key_2, _}), do: key_1 < key_2
defp pair({k, v}) do
URI.encode_www_form(Kernel.to_string(k)) <> "=" <> aws_encode_www_form(Kernel.to_string(v))
end
# is basically the same as URI.encode_www_form
# but doesn't use %20 instead of "+"
defp aws_encode_www_form(str) when is_binary(str) do
import Bitwise
for <<c <- str>>, into: "" do
case URI.char_unreserved?(c) do
true -> <<c>>
false -> "%" <> char_hex(bsr(c, 4)) <> char_hex(band(c, 15))
end
end
end
defp char_hex(n) when n <= 9, do: <<n + ?0>>
defp char_hex(n), do: <<n + ?A - 10>>
end
defmodule Ex.S3 do
defstruct [:client, :bucket]
alias Ex.S3
require Logger
def new(url) do
s3_opts = opts_from_url(url)
bucket = Keyword.fetch!(s3_opts, :bucket)
endpoint = Keyword.fetch!(s3_opts, :endpoint)
client =
Req.new(base_url: "#{endpoint}/#{bucket}")
|> Ex.Req.Log.attach()
|> Req.merge(log: false)
|> Req.Request.register_options([:s3])
|> Req.Request.append_request_steps(s3: &Ex.S3.Auth.put_aws_sigv4/1)
|> Req.Request.merge_options(s3: s3_opts)
%Ex.S3{client: client, bucket: bucket}
end
defp opts_from_url(s3_url) do
%URI{
scheme: "s3",
userinfo: userinfo,
host: host,
port: port,
path: "/" <> bucket,
query: query,
fragment: nil
} = URI.parse(s3_url)
opts = URI.decode_query(query || "")
ssl = opts["ssl"] != "false"
endpoint_protocol = if ssl, do: "https", else: "http"
extra_opts =
case opts do
%{"region" => region} -> [region: region]
_ -> []
end
[access_key_id, secret_access_key] = String.split(userinfo, ":", parts: 2)
[
bucket: bucket,
access_key_id: access_key_id,
secret_access_key: secret_access_key,
endpoint: "#{endpoint_protocol}://#{host}:#{port}"
] ++ extra_opts
end
def get(%S3{client: client}, key) do
encoded_key = encode_key(key)
Req.get!(client,
url: "/#{encoded_key}",
decode_body: false,
receive_timeout: 60 * 1000,
pool_timeout: 60 * 1000
)
end
def head(%S3{client: client}, key) do
encoded_key = encode_key(key)
Req.head!(client,
url: "/#{encoded_key}",
headers: [{"content-length", "0"}]
)
end
def put(%S3{client: client}, key, body, opts \\ []) do
encoded_key = encode_key(key)
Req.put!(client,
url: "/#{encoded_key}",
body: body,
receive_timeout: 60 * 1000,
pool_timeout: 60 * 1000,
connect_options: [timeout: 60 * 1000],
headers: Keyword.get(opts, :headers, []),
retry: fn
_req, %Req.TransportError{} ->
Logger.warning("S3 put request timed out retrying #{key}")
true
_, _ ->
false
end
)
end
def stream_list(%S3{} = s3, prefix) do
Stream.resource(
fn -> nil end,
fn
:done ->
{:halt, nil}
continuation_token ->
case list(s3, prefix, continuation_token: continuation_token) do
# last page
{:ok, items, nil} -> {items, :done}
{:ok, items, next_continuation_token} -> {items, next_continuation_token}
end
end,
fn _ -> :ok end
)
end
def list(%S3{client: client}, prefix, opts \\ []) do
pagination_params =
case opts[:continuation_token] do
nil -> []
token -> [{"continuation-token", token}]
end
Req.get!(client,
url: "/",
params:
[
{"list-type", 2},
# this is the limit anyway
{"max-keys", 1000},
{"prefix", prefix}
] ++ pagination_params
)
|> case do
%Req.Response{
status: 200,
body: %{"ListBucketResult" => %{"Contents" => items, "NextContinuationToken" => token}}
} ->
{:ok, items, token}
%Req.Response{status: 200, body: %{"ListBucketResult" => %{"KeyCount" => "0"}}} ->
{:ok, [], nil}
%Req.Response{status: 200, body: %{"ListBucketResult" => %{"Contents" => items}}} ->
{:ok, items, nil}
end
end
defp encode_key(key) do
URI.encode(key, &(&1 == ?/ or URI.char_unreserved?(&1)))
end
def presign_url(%S3{client: client}, method, key, opts \\ [])
when method in [:get, :put] do
encoded_key = encode_key(key)
s3_opts = Req.Request.get_option(client, :s3)
endpoint = Keyword.fetch!(s3_opts, :endpoint)
bucket = Keyword.fetch!(s3_opts, :bucket)
sign_opts =
[
method: method,
service: :s3,
datetime: DateTime.utc_now(),
access_key_id: Keyword.fetch!(s3_opts, :access_key_id),
secret_access_key: Keyword.fetch!(s3_opts, :secret_access_key),
region: Keyword.fetch!(s3_opts, :region),
url: "#{endpoint}/#{bucket}/#{encoded_key}"
] ++ opts
Ex.S3.Auth.aws_sigv4_url(sign_opts)
|> URI.to_string()
end
end
defmodule Ex.S3.XML do
@moduledoc false
# Straight copy from ReqS3
if System.otp_release() < "25" do
# xmerl_sax_parser :disallow_entities requires OTP 25+
raise "req_s3 requires OTP 25+"
end
@list_fields [
{"ListBucketResult", "Contents"},
{"ListVersionsResult", "Version"}
]
@list_fields_skip [
{"ListAllMyBucketsResult", "Buckets", "Bucket"}
]
@doc """
Parses S3 XML into maps, lists, and strings.
This is a best effort parser, trying to return the most convenient representation. This is
tricky because collections can be represented in multiple ways:
<ListAllMyBucketsResult>
<Buckets>
<Bucket><Name>bucket1</Name></Bucket>
<Bucket><Name>bucket2</Name></Bucket>
</Buckets>
</ListAllMyBucketResult>
<ListBucketResult>
<Name>bucket1</Name>
<Contents><Key>key1</Key></Contents>
<Contents><Key>key2</Key></Contents>
</ListBucketResult>
We handle `ListBucketResult/Contents`, `ListVersionsResult/Version`,
`ListAllMyBucketsResult/Buckets/Bucket` (and possibly others in the future) in a particular
way and have a best effort fallback.
"""
def parse_s3(xml) do
parse(xml, {nil, []}, fn
{:start_element, name, _attributes}, {root, stack} ->
{root || name, [{name, nil} | stack]}
# Collect e.g. <ListBucketResults><Contents>...</Contents> into a "Contents" _list_.
{:end_element, name}, {root, [{name, val}, {parent_name, parent_val} | stack]}
when {root, name} in @list_fields ->
parent_val = Map.update(parent_val || %{}, name, [val], &(&1 ++ [val]))
{root, [{parent_name, parent_val} | stack]}
# Collect e.g. <ListAllMyBucketsResult><Buckets><Bucket>...</Bucket> into a "Buckets" _list_
# skipping "Bucket".
{:end_element, name}, {root, [{name, val}, {parent_name, parent_val} | stack]}
when {root, parent_name, name} in @list_fields_skip ->
parent_val = (parent_val || []) ++ [val]
{root, [{parent_name, parent_val} | stack]}
{:end_element, name}, {root, stack} ->
case stack do
# Best effort: by default simply put name/value into parent map. If the parent
# map already contains name, turn parent[name] into a list and keep appending.
# The obvious caveat is we'd only turn parent[name] into a list on the second element,
# hence if XML contained just one element for what is semantically a list, it will be
# represented as a map, not a list with single map element. As we discover these,
# let's update @list_fields and @list_fields_skip.
[{^name, val}, {parent_name, parent_val} | stack] ->
parent_val = Map.update(parent_val || %{}, name, val, &(List.wrap(&1) ++ [val]))
{root, [{parent_name, parent_val} | stack]}
[{name, val}] ->
{root, %{name => val}}
other ->
raise """
unexpected :end_element state:
#{inspect(other, pretty: true)}
"""
end
{:characters, string}, {root, [{name, _} | stack]} ->
{root, [{name, string} | stack]}
other, {root, stack} ->
raise """
unexpected event:
#{inspect(other, pretty: true)}
root: #{root}
stack:
#{inspect(stack, pretty: true)}
"""
end)
|> elem(1)
end
def parse(xml, state, fun) do
{:ok, %{state: state}, _leftover} =
:xmerl_sax_parser.stream(
xml,
[
:disallow_entities,
event_fun: &process/3,
event_state: %{
state: state,
fun: fun
},
external_entities: :none,
fail_undeclared_ref: false
]
)
state
end
# https://www.erlang.org/doc/apps/xmerl/xmerl_sax_parser.html#t:event/0
defp process(event, loc, state)
defp process({:startElement, _uri, name, _qualified_name, attributes}, _loc, state) do
attributes =
for attribute <- attributes do
{_, _, name, value} = attribute
{List.to_string(name), List.to_string(value)}
end
%{state | state: state.fun.({:start_element, List.to_string(name), attributes}, state.state)}
end
defp process({:endElement, _uri, name, _qualified_name}, _loc, state) do
%{state | state: state.fun.({:end_element, List.to_string(name)}, state.state)}
end
defp process({:characters, charlist}, _loc, state) do
%{state | state: state.fun.({:characters, List.to_string(charlist)}, state.state)}
end
defp process(_event, _loc, state) do
state
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment