Skip to content

Commit

Permalink
Merge pull request #10 from climate-machine/histograms
Browse files Browse the repository at this point in the history
Histograms.jl: Implement Wasserstein-1 distance
  • Loading branch information
dburov190 authored Sep 10, 2019
2 parents 6e63890 + e3bf4a2 commit e622818
Show file tree
Hide file tree
Showing 10 changed files with 152 additions and 13 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
ScikitLearn = "3646fa90-6ef7-5e7e-9f22-8aca16db6324"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Expand Down
12 changes: 12 additions & 0 deletions src/ConvenienceFunctions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@

const RPAD = 25

function name(name::AbstractString)
return rpad(name * ":", RPAD)
end

function warn(name::AbstractString)
return rpad("WARNING (" * name * "):", RPAD)
end


15 changes: 2 additions & 13 deletions src/GPR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ using Parameters # lets you have defaults for fields
using EllipsisNotation # adds '..' to refer to the rest of array
import ScikitLearn
import StatsBase
include("ConvenienceFunctions.jl")

const sklearn = ScikitLearn

sklearn.@sk_import gaussian_process : GaussianProcessRegressor
Expand Down Expand Up @@ -324,19 +326,6 @@ function plot_fit(gprw::Wrap, plt; plot_95 = false, label = nothing)
end
end

################################################################################
# convenience functions ########################################################
################################################################################
const RPAD = 25

function name(name::AbstractString)
return rpad(name * ":", RPAD)
end

function warn(name::AbstractString)
return rpad("WARNING (" * name * "):", RPAD)
end

end # module


80 changes: 80 additions & 0 deletions src/Histograms.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
module Histograms
"""
This module is mostly a convenient wrapper of Python functions (numpy, scipy).
Functions in this module:
- W1 (2 methods)
"""

import PyCall
include("ConvenienceFunctions.jl")

scsta = PyCall.pyimport("scipy.stats")

################################################################################
# distance functions ###########################################################
################################################################################
"""
Compute the Wasserstein-1 distance between two distributions from their samples
Parameters:
- u_samples: array-like; samples from the 1st distribution
- v_samples: array-like; samples from the 2nd distribution
- normalize: boolean; whether to normalize the distance by 1/(max-min)
Returns:
- w1_uv: number; the Wasserstein-1 distance
"""
function W1(u_samples::AbstractVector, v_samples::AbstractVector;
normalize = true)
L = maximum([u_samples; v_samples]) - minimum([u_samples; v_samples])
return if !normalize
scsta.wasserstein_distance(u_samples, v_samples)
else
scsta.wasserstein_distance(u_samples, v_samples) / L
end
end

"""
Compute the pairwise Wasserstein-1 distances between two sets of distributions
from their samples
Parameters:
- U_samples: matrix-like; samples from distributions (u1, u2, ...)
- V_samples: matrix-like; samples from distributions (v1, v2, ...)
- normalize: boolean; whether to normalize the distances by 1/(max-min)
`U_samples` and `V_samples` should have samples in the 2nd dimension (along
rows) and have the same 1st dimension (same number of rows). If not, the minimum
of the two (minimum number of rows) will be taken.
`normalize` induces *pairwise* normalization, i.e. it max's and min's are
computed for each pair (u_j, v_j) individually.
Returns:
- w1_UV: array-like; the pairwise Wasserstein-1 distances:
w1(u1, v1)
w1(u2, v2)
...
w1(u_K, v_K)
"""
function W1(U_samples::AbstractMatrix, V_samples::AbstractMatrix;
normalize = true)
if size(U_samples, 1) != size(V_samples, 1)
println(warn("W1"), "sizes of U_samples & V_samples don't match; ",
"will use the minimum of the two")
end
K = min(size(U_samples, 1), size(V_samples, 1))
w1_UV = zeros(K)
U_sorted = sort(U_samples[1:K, 1:end], dims = 2)
V_sorted = sort(V_samples[1:K, 1:end], dims = 2)
for k in 1:K
w1_UV[k] = W1(U_sorted[k, 1:end], V_sorted[k, 1:end]; normalize = normalize)
end
return w1_UV
end

end # module


19 changes: 19 additions & 0 deletions test/ConvenienceFunctions/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using Test

include("../../src/ConvenienceFunctions.jl")

################################################################################
# unit testing #################################################################
################################################################################
@testset "unit testing" begin
@test isdefined(Main, :RPAD)
@test length(name("a")) == RPAD
@test length(name("a" ^ RPAD)) == (RPAD + 1)
@test length(warn("a")) == RPAD
@test length(warn("a" ^ RPAD)) == (RPAD + 11)
@test isa(name("a"), String)
@test isa(warn("a"), String)
end
println("")


Binary file added test/Histograms/data/x1_bal.npy
Binary file not shown.
Binary file added test/Histograms/data/x1_dns.npy
Binary file not shown.
Binary file added test/Histograms/data/x1_onl.npy
Binary file not shown.
36 changes: 36 additions & 0 deletions test/Histograms/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
using Test
import NPZ

include("../../src/Histograms.jl")
const Hgm = Histograms

const data_dir = joinpath(@__DIR__, "data")
const x1_bal = NPZ.npzread(joinpath(data_dir, "x1_bal.npy"))
const x1_dns = NPZ.npzread(joinpath(data_dir, "x1_dns.npy"))
const x1_onl = NPZ.npzread(joinpath(data_dir, "x1_onl.npy"))
const w1_dns_bal = 0.03755967829782972
const w1_dns_onl = 0.004489688974663949
const w1_bal_onl = 0.037079734072606625
const w1_dns_bal_unnorm = 0.8190688772401341

################################################################################
# unit testing #################################################################
################################################################################
@testset "unit testing" begin
arr1 = [1, 1, 1, 2, 3, 4, 4, 4]
arr2 = [1, 1, 2, 2, 3, 3, 4, 4, 4]
@test Hgm.W1(arr1, arr2, normalize = false) == 0.25
@test Hgm.W1(arr2, arr1, normalize = false) == 0.25
@test Hgm.W1(arr1, arr2) == Hgm.W1(arr2, arr1)

@test isapprox(Hgm.W1(x1_dns, x1_bal), w1_dns_bal)
@test isapprox(Hgm.W1(x1_dns, x1_onl), w1_dns_onl)
@test isapprox(Hgm.W1(x1_bal, x1_onl), w1_bal_onl)
@test isapprox(Hgm.W1(x1_dns, x1_bal, normalize = false), w1_dns_bal_unnorm)

@test size(Hgm.W1(rand(3,100), rand(3,100))) == (3,)
@test size(Hgm.W1(rand(9,100), rand(3,100))) == (3,)
end
println("")


2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ include("neki.jl")

for submodule in ["L96m",
"GPR",
"Histograms",
"ConvenienceFunctions",
]

println("Starting tests for $submodule")
Expand Down

0 comments on commit e622818

Please sign in to comment.