Skip to content

Commit

Permalink
matrix vector compatability
Browse files Browse the repository at this point in the history
  • Loading branch information
FlyingWorkshop committed Mar 10, 2024
1 parent ca07e21 commit 7e2ab1b
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions src/epca.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,5 +102,13 @@ function CompressedBeliefMDPs.compress(epca::EPCA, X; verbose=false, maxiter::In
end
CompressedBeliefMDPs.compress(epca::EPCA, X::Vector; verbose=false, maxiter::Integer=50) = vec(compress(epca, X'; verbose=verbose, maxiter=maxiter))

CompressedBeliefMDPs.decompress(epca::EPCA, A) = epca.g(A * epca.V)
CompressedBeliefMDPs.decompress(epca::EPCA, A::Vector) = vec(epca.g(A' * epca.V))
function CompressedBeliefMDPs.decompress(epca::EPCA, A)
if ndims(A) == 1
return vec(epca.g((A' * epca.V)))
else
return epca.g(A * epca.V)
end
end

# CompressedBeliefMDPs.decompress(epca::EPCA, A) = epca.g((ndims(A) == 1 ? A' : A) * epca.V)
# CompressedBeliefMDPs.decompress(epca::EPCA, A::Vector) = vec(epca.g(A' * epca.V))

0 comments on commit 7e2ab1b

Please sign in to comment.