-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #10 from climate-machine/histograms
Histograms.jl: Implement Wasserstein-1 distance
- Loading branch information
Showing
10 changed files
with
152 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
|
||
const RPAD = 25 | ||
|
||
function name(name::AbstractString) | ||
return rpad(name * ":", RPAD) | ||
end | ||
|
||
function warn(name::AbstractString) | ||
return rpad("WARNING (" * name * "):", RPAD) | ||
end | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
module Histograms | ||
""" | ||
This module is mostly a convenient wrapper of Python functions (numpy, scipy). | ||
Functions in this module: | ||
- W1 (2 methods) | ||
""" | ||
|
||
import PyCall | ||
include("ConvenienceFunctions.jl") | ||
|
||
scsta = PyCall.pyimport("scipy.stats") | ||
|
||
################################################################################ | ||
# distance functions ########################################################### | ||
################################################################################ | ||
""" | ||
Compute the Wasserstein-1 distance between two distributions from their samples | ||
Parameters: | ||
- u_samples: array-like; samples from the 1st distribution | ||
- v_samples: array-like; samples from the 2nd distribution | ||
- normalize: boolean; whether to normalize the distance by 1/(max-min) | ||
Returns: | ||
- w1_uv: number; the Wasserstein-1 distance | ||
""" | ||
function W1(u_samples::AbstractVector, v_samples::AbstractVector; | ||
normalize = true) | ||
L = maximum([u_samples; v_samples]) - minimum([u_samples; v_samples]) | ||
return if !normalize | ||
scsta.wasserstein_distance(u_samples, v_samples) | ||
else | ||
scsta.wasserstein_distance(u_samples, v_samples) / L | ||
end | ||
end | ||
|
||
""" | ||
Compute the pairwise Wasserstein-1 distances between two sets of distributions | ||
from their samples | ||
Parameters: | ||
- U_samples: matrix-like; samples from distributions (u1, u2, ...) | ||
- V_samples: matrix-like; samples from distributions (v1, v2, ...) | ||
- normalize: boolean; whether to normalize the distances by 1/(max-min) | ||
`U_samples` and `V_samples` should have samples in the 2nd dimension (along | ||
rows) and have the same 1st dimension (same number of rows). If not, the minimum | ||
of the two (minimum number of rows) will be taken. | ||
`normalize` induces *pairwise* normalization, i.e. it max's and min's are | ||
computed for each pair (u_j, v_j) individually. | ||
Returns: | ||
- w1_UV: array-like; the pairwise Wasserstein-1 distances: | ||
w1(u1, v1) | ||
w1(u2, v2) | ||
... | ||
w1(u_K, v_K) | ||
""" | ||
function W1(U_samples::AbstractMatrix, V_samples::AbstractMatrix; | ||
normalize = true) | ||
if size(U_samples, 1) != size(V_samples, 1) | ||
println(warn("W1"), "sizes of U_samples & V_samples don't match; ", | ||
"will use the minimum of the two") | ||
end | ||
K = min(size(U_samples, 1), size(V_samples, 1)) | ||
w1_UV = zeros(K) | ||
U_sorted = sort(U_samples[1:K, 1:end], dims = 2) | ||
V_sorted = sort(V_samples[1:K, 1:end], dims = 2) | ||
for k in 1:K | ||
w1_UV[k] = W1(U_sorted[k, 1:end], V_sorted[k, 1:end]; normalize = normalize) | ||
end | ||
return w1_UV | ||
end | ||
|
||
end # module | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
using Test | ||
|
||
include("../../src/ConvenienceFunctions.jl") | ||
|
||
################################################################################ | ||
# unit testing ################################################################# | ||
################################################################################ | ||
@testset "unit testing" begin | ||
@test isdefined(Main, :RPAD) | ||
@test length(name("a")) == RPAD | ||
@test length(name("a" ^ RPAD)) == (RPAD + 1) | ||
@test length(warn("a")) == RPAD | ||
@test length(warn("a" ^ RPAD)) == (RPAD + 11) | ||
@test isa(name("a"), String) | ||
@test isa(warn("a"), String) | ||
end | ||
println("") | ||
|
||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
using Test | ||
import NPZ | ||
|
||
include("../../src/Histograms.jl") | ||
const Hgm = Histograms | ||
|
||
const data_dir = joinpath(@__DIR__, "data") | ||
const x1_bal = NPZ.npzread(joinpath(data_dir, "x1_bal.npy")) | ||
const x1_dns = NPZ.npzread(joinpath(data_dir, "x1_dns.npy")) | ||
const x1_onl = NPZ.npzread(joinpath(data_dir, "x1_onl.npy")) | ||
const w1_dns_bal = 0.03755967829782972 | ||
const w1_dns_onl = 0.004489688974663949 | ||
const w1_bal_onl = 0.037079734072606625 | ||
const w1_dns_bal_unnorm = 0.8190688772401341 | ||
|
||
################################################################################ | ||
# unit testing ################################################################# | ||
################################################################################ | ||
@testset "unit testing" begin | ||
arr1 = [1, 1, 1, 2, 3, 4, 4, 4] | ||
arr2 = [1, 1, 2, 2, 3, 3, 4, 4, 4] | ||
@test Hgm.W1(arr1, arr2, normalize = false) == 0.25 | ||
@test Hgm.W1(arr2, arr1, normalize = false) == 0.25 | ||
@test Hgm.W1(arr1, arr2) == Hgm.W1(arr2, arr1) | ||
|
||
@test isapprox(Hgm.W1(x1_dns, x1_bal), w1_dns_bal) | ||
@test isapprox(Hgm.W1(x1_dns, x1_onl), w1_dns_onl) | ||
@test isapprox(Hgm.W1(x1_bal, x1_onl), w1_bal_onl) | ||
@test isapprox(Hgm.W1(x1_dns, x1_bal, normalize = false), w1_dns_bal_unnorm) | ||
|
||
@test size(Hgm.W1(rand(3,100), rand(3,100))) == (3,) | ||
@test size(Hgm.W1(rand(9,100), rand(3,100))) == (3,) | ||
end | ||
println("") | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters