Skip to content

Commit

Permalink
feat: periodically check auth_query creds (#464)
Browse files Browse the repository at this point in the history
  • Loading branch information
abc3 authored Oct 25, 2024
1 parent b8da64c commit de98cd6
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 3 deletions.
4 changes: 4 additions & 0 deletions lib/supavisor.ex
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ defmodule Supavisor do
db_host: db_host,
db_port: db_port,
db_database: db_database,
auth_query: auth_query,
default_parameter_status: ps,
ip_version: ip_ver,
default_pool_size: def_pool_size,
Expand All @@ -345,6 +346,7 @@ defmodule Supavisor do
db_user: db_user,
db_password: db_pass,
pool_size: pool_size,
db_user_alias: alias,
# mode_type: mode_type,
max_clients: max_clients
}
Expand All @@ -363,6 +365,8 @@ defmodule Supavisor do
sni_hostname: if(sni_hostname != nil, do: to_charlist(sni_hostname)),
port: db_port,
user: db_user,
alias: alias,
auth_query: auth_query,
database: if(db_name != nil, do: db_name, else: db_database),
password: fn -> db_pass end,
application_name: "Supavisor",
Expand Down
5 changes: 4 additions & 1 deletion lib/supavisor/application.ex
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,10 @@ defmodule Supavisor.Application do
{Registry, keys: :unique, name: Supavisor.Registry.ManagerTables},
{Registry, keys: :unique, name: Supavisor.Registry.PoolPids},
{Registry, keys: :duplicate, name: Supavisor.Registry.TenantSups},
{Registry, keys: :duplicate, name: Supavisor.Registry.TenantClients},
{Registry,
keys: :duplicate,
name: Supavisor.Registry.TenantClients,
partitions: System.schedulers_online()},
{Cluster.Supervisor, [topologies, [name: Supavisor.ClusterSupervisor]]},
Supavisor.Repo,
# Start the Telemetry supervisor
Expand Down
1 change: 0 additions & 1 deletion lib/supavisor/client_handler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,6 @@ defmodule Supavisor.ClientHandler do
host: to_charlist(info.tenant.db_host),
sni_hostname:
if(info.tenant.sni_hostname != nil, do: to_charlist(info.tenant.sni_hostname)),
ip_version: Helpers.ip_version(info.tenant.ip_version, info.tenant.db_host),
port: info.tenant.db_port,
user: user,
password: info.user.db_password,
Expand Down
1 change: 1 addition & 0 deletions lib/supavisor/helpers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ defmodule Supavisor.Helpers do
"""
@spec detect_ip_version(String.t()) :: :inet | :inet6
def detect_ip_version(host) when is_binary(host) do
Logger.info("Detecting IP version for #{host}")
host = String.to_charlist(host)

case :inet.gethostbyname(host) do
Expand Down
116 changes: 116 additions & 0 deletions lib/supavisor/secret_checker.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
defmodule Supavisor.SecretChecker do
@moduledoc false

use GenServer
require Logger

alias Supavisor.Helpers

@interval :timer.seconds(15)

def start_link(args) do
name = {:via, Registry, {Supavisor.Registry.Tenants, {:secret_checker, args.id}}}

GenServer.start_link(__MODULE__, args, name: name)
end

def init(args) do
Logger.debug("SecretChecker: Starting secret checker")
tenant = Supavisor.tenant(args.id)

%{auth: auth, user: user} = Enum.find(args.replicas, fn e -> e.replica_type == :write end)

state = %{
tenant: tenant,
auth: auth,
user: user,
key: {:secrets, tenant, user},
ttl: args[:ttl] || :timer.hours(24),
conn: nil,
check_ref: check()
}

Logger.metadata(project: tenant, user: user)
{:ok, state, {:continue, :init_conn}}
end

def handle_continue(:init_conn, %{auth: auth} = state) do
ssl_opts =
if auth.upstream_ssl and auth.upstream_verify == "peer" do
[
{:verify, :verify_peer},
{:cacerts, [Helpers.upstream_cert(auth.upstream_tls_ca)]},
{:server_name_indication, auth.host},
{:customize_hostname_check, [{:match_fun, fn _, _ -> true end}]}
]
end

{:ok, conn} =
Postgrex.start_link(
hostname: auth.host,
port: auth.port,
database: auth.database,
password: auth.password.(),
username: auth.user,
parameters: [application_name: "Supavisor auth_query"],
ssl: auth.upstream_ssl,
socket_options: [
auth.ip_version
],
queue_target: 1_000,
queue_interval: 5_000,
ssl_opts: ssl_opts || []
)

# kill the postgrex connection if the current process exits unexpectedly
Process.link(conn)
{:noreply, %{state | conn: conn}}
end

def handle_info(:check, state) do
Logger.debug("Checking secrets")
check_secrets(state)
Logger.debug("Secrets checked")
{:noreply, %{state | check_ref: check()}}
end

def handle_info(msg, state) do
Logger.error("Unexpected message: #{inspect(msg)}")
{:noreply, state}
end

def terminate(_, state) do
:gen_statem.stop(state.conn)
:ok
end

def check(interval \\ @interval),
do: Process.send_after(self(), :check, interval)

def check_secrets(%{auth: auth, user: user, conn: conn} = state) do
case Helpers.get_user_secret(conn, auth.auth_query, user) do
{:ok, secret} ->
method = if secret.digest == :md5, do: :auth_query_md5, else: :auth_query
secrets = Map.put(secret, :alias, auth.alias)

update_cache =
case Cachex.get(Supavisor.Cache, state.key) do
{:ok, {:cached, {_, {old_method, old_secrets}}}} ->
method != old_method or secrets != old_secrets.()

other ->
Logger.error("Failed to get cache: #{inspect(other)}")
true
end

if update_cache do
Logger.info("Secrets changed or not present, updating cache")
value = {:ok, {method, fn -> secrets end}}
Cachex.put(Supavisor.Cache, state.key, {:cached, value}, expire: :timer.hours(24))
end

other ->
Logger.error("Failed to get secret: #{inspect(other)}")
end
end
end
3 changes: 2 additions & 1 deletion lib/supavisor/tenant_supervisor.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ defmodule Supavisor.TenantSupervisor do

require Logger
alias Supavisor.Manager
alias Supavisor.SecretChecker

def start_link(%{replicas: [%{mode: mode} = single]} = args)
when mode in [:transaction, :session] do
Expand Down Expand Up @@ -33,7 +34,7 @@ defmodule Supavisor.TenantSupervisor do
}
end)

children = [{Manager, args} | pools]
children = [{Manager, args}, {SecretChecker, args} | pools]

{{type, tenant}, user, mode, db_name, search_path} = args.id
map_id = %{user: user, mode: mode, type: type, db_name: db_name, search_path: search_path}
Expand Down

0 comments on commit de98cd6

Please sign in to comment.