-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: allow alphabets over larger alphabets
- Loading branch information
Showing
10 changed files
with
798 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
// ----------------------------------------------------------------------------------------------------- | ||
// Copyright (c) 2006-2023, Knut Reinert & Freie Universität Berlin | ||
// Copyright (c) 2016-2023, Knut Reinert & MPI für molekulare Genetik | ||
// This file may be used, modified and/or redistributed under the terms of the 3-clause BSD-License | ||
// shipped with this file. | ||
// ----------------------------------------------------------------------------------------------------- | ||
#pragma once | ||
|
||
#include "CSA_32.h" | ||
#include "occtable/concepts.h" | ||
#include "utils.h" | ||
|
||
#include <algorithm> | ||
|
||
namespace fmindex_collection { | ||
|
||
template <OccTable_32 Table, typename TCSA = CSA_32> | ||
struct BiFMIndex_32 { | ||
static size_t constexpr Sigma = Table::Sigma; | ||
|
||
using TTable = Table; | ||
|
||
Table occ; | ||
Table occRev; | ||
TCSA csa; | ||
|
||
//private: | ||
BiFMIndex_32(std::span<uint32_t const> bwt, std::vector<uint32_t> const& bwtRev, TCSA _csa) | ||
: occ{bwt} | ||
, occRev{bwtRev} | ||
, csa{std::move(_csa)} | ||
{ | ||
assert(bwt.size() == bwtRev.size()); | ||
assert(occ.size() == occRev.size()); | ||
if (bwt.size() != bwtRev.size()) { | ||
throw std::runtime_error("bwt don't have the same size: " + std::to_string(bwt.size()) + " " + std::to_string(bwtRev.size())); | ||
} | ||
if (occ.size() != occRev.size()) { | ||
throw std::runtime_error("occ don't have the same size: " + std::to_string(occ.size()) + " " + std::to_string(occRev.size())); | ||
} | ||
// compute last row | ||
auto ct = std::array<uint64_t, Sigma>{}; | ||
for (auto v : bwt) { | ||
ct[v] += 1; | ||
} | ||
for (size_t i{1}; i < ct.size(); ++i) { | ||
ct[i] = ct[i-1] + ct[i]; | ||
} | ||
// check last row is correct | ||
for (size_t sym{0}; sym < Sigma; ++sym) { | ||
if (occ.rank(occ.size(), sym) != ct[sym]) { | ||
auto e = std::string{"Wrong rank for the last entry."} | ||
+ " Got different values for forward index." | ||
+ " sym: " + std::to_string(sym) | ||
+ " got: " + std::to_string(occ.rank(occ.size(), sym)) | ||
+ " expected: " + std::to_string(ct[sym]); | ||
throw std::runtime_error(e); | ||
} | ||
if (occRev.rank(occRev.size(), sym) != ct[sym]) { | ||
auto e = std::string{"Wrong rank for the last entry."} | ||
+ " Got different values for reverse index." | ||
+ " sym: " + std::to_string(sym) | ||
+ " got: " + std::to_string(occRev.rank(occRev.size(), sym)) | ||
+ " expected: " + std::to_string(ct[sym]); | ||
throw std::runtime_error(e); | ||
} | ||
} | ||
if constexpr (requires(Table t) {{ t.hasValue(size_t{}) }; }) { | ||
for (size_t i{0}; i < occ.size(); ++i) { | ||
if (csa.value(i).has_value()) { | ||
occ.setValue(i); | ||
} | ||
} | ||
} | ||
} | ||
|
||
public: | ||
/**!\brief Creates a BiFMIndex with a specified sampling rate | ||
* | ||
* \param _input a list of sequences | ||
* \param samplingRate rate of the sampling | ||
*/ | ||
BiFMIndex_32(Sequences_32 auto const& _input, size_t samplingRate, size_t threadNbr) | ||
: occ{cereal_tag{}} | ||
, occRev{cereal_tag{}} | ||
, csa{cereal_tag{}} | ||
{ | ||
auto [totalSize, inputText, inputSizes] = createSequences_32(_input, samplingRate); | ||
|
||
// create BurrowsWheelerTransform and CompressedSuffixArray | ||
auto [bwt, csa] = [&, &inputText=inputText, &inputSizes=inputSizes] () { | ||
auto sa = createSA_32(inputText, threadNbr); | ||
auto bwt = createBWT_32(inputText, sa); | ||
auto csa = TCSA(std::move(sa), samplingRate, inputSizes); | ||
return std::make_tuple(std::move(bwt), std::move(csa)); | ||
}(); | ||
|
||
// create BurrowsWheelerTransform on reversed text | ||
auto bwtRev = [&, &inputText=inputText]() { | ||
std::ranges::reverse(inputText); | ||
auto saRev = createSA_32(inputText, threadNbr); | ||
auto bwtRev = createBWT_32(inputText, saRev); | ||
return bwtRev; | ||
}(); | ||
|
||
decltype(inputText){}.swap(inputText); // inputText memory can be deleted | ||
|
||
*this = BiFMIndex_32{bwt, bwtRev, std::move(csa)}; | ||
} | ||
|
||
|
||
/*!\brief Specific c'tor for serialization use | ||
*/ | ||
BiFMIndex_32(cereal_tag) | ||
: occ{cereal_tag{}} | ||
, occRev{cereal_tag{}} | ||
, csa{cereal_tag{}} | ||
{} | ||
|
||
size_t memoryUsage() const requires OccTableMemoryUsage<Table> { | ||
return occ.memoryUsage() + occRev.memoryUsage() + csa.memoryUsage(); | ||
} | ||
|
||
size_t size() const { | ||
return occ.size(); | ||
} | ||
|
||
auto locate(size_t idx) const -> std::tuple<size_t, size_t> { | ||
if constexpr (requires(Table t) {{ t.hasValue(size_t{}) }; }) { | ||
bool v = occ.hasValue(idx); | ||
uint64_t steps{}; | ||
while(!v) { | ||
idx = occ.rank_symbol(idx); | ||
steps += 1; | ||
v = occ.hasValue(idx); | ||
} | ||
auto [chr, pos] = csa.value(idx); | ||
return {chr, pos+steps}; | ||
|
||
} else { | ||
auto opt = csa.value(idx); | ||
uint64_t steps{}; | ||
while(!opt) { | ||
if constexpr (requires(Table t) { { t.rank_symbol(size_t{}) }; }) { | ||
idx = occ.rank_symbol(idx); | ||
} else { | ||
idx = occ.rank(idx, occ.symbol(idx)); | ||
} | ||
steps += 1; | ||
opt = csa.value(idx); | ||
} | ||
auto [chr, pos] = *opt; | ||
return {chr, pos+steps}; | ||
} | ||
} | ||
|
||
auto locate(size_t idx, size_t maxSteps) const -> std::optional<std::tuple<size_t, size_t>> { | ||
auto opt = csa.value(idx); | ||
uint64_t steps{}; | ||
for (;!opt and maxSteps > 0; --maxSteps) { | ||
idx = occ.rank(idx, occ.symbol(idx)); | ||
steps += 1; | ||
opt = csa.value(idx); | ||
} | ||
if (opt) { | ||
std::get<1>(*opt) += steps; | ||
} | ||
return opt; | ||
} | ||
|
||
|
||
auto single_locate_step(size_t idx) const -> std::optional<std::tuple<size_t, size_t>> { | ||
return csa.value(idx); | ||
} | ||
|
||
|
||
template <typename Archive> | ||
void serialize(Archive& ar) { | ||
ar(occ, occRev, csa); | ||
} | ||
}; | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
// ----------------------------------------------------------------------------------------------------- | ||
// Copyright (c) 2006-2023, Knut Reinert & Freie Universität Berlin | ||
// Copyright (c) 2016-2023, Knut Reinert & MPI für molekulare Genetik | ||
// This file may be used, modified and/or redistributed under the terms of the 3-clause BSD-License | ||
// shipped with this file. | ||
// ----------------------------------------------------------------------------------------------------- | ||
#pragma once | ||
|
||
#include "BitStack.h" | ||
#include "Bitvector.h" | ||
#include "BitvectorCompact.h" | ||
#include "cereal_tag.h" | ||
|
||
#include <algorithm> | ||
#include <cmath> | ||
#include <numeric> | ||
#include <optional> | ||
#include <tuple> | ||
|
||
|
||
namespace fmindex_collection { | ||
|
||
struct CSA_32 { | ||
std::vector<uint32_t> ssa; | ||
BitvectorCompact bv; | ||
size_t samplingRate; // distance between two samples (inside one sequence) | ||
size_t bitsForPosition; // bits reserved for position | ||
size_t bitPositionMask; | ||
|
||
|
||
CSA_32(std::vector<uint32_t> _ssa, BitStack const& bitstack, size_t _samplingRate, size_t _bitsForPosition) | ||
: ssa{std::move(_ssa)} | ||
, bv{bitstack.size, [&](size_t idx) { | ||
return bitstack.value(idx); | ||
}} | ||
, samplingRate{_samplingRate} | ||
, bitsForPosition{_bitsForPosition} | ||
, bitPositionMask{(1ull<<bitsForPosition)-1} | ||
{} | ||
CSA_32(CSA_32 const&) = delete; | ||
CSA_32(CSA_32&& _other) noexcept = default; | ||
|
||
CSA_32(cereal_tag) | ||
: bv {cereal_tag{}} | ||
{} | ||
|
||
CSA_32(std::span<int32_t const> sa, size_t _samplingRate, std::span<std::tuple<size_t, size_t> const> _inputSizes, bool reverse=false) | ||
: samplingRate{_samplingRate} | ||
{ | ||
size_t bitsForSeqId = std::max(size_t{1}, size_t(std::ceil(std::log2(_inputSizes.size())))); | ||
assert(bitsForSeqId < 64); | ||
|
||
bitsForPosition = 64 - bitsForSeqId; | ||
bitPositionMask = (1ull<<bitsForPosition)-1; | ||
|
||
// Generate accumulated input | ||
auto accInputSizes = std::vector<uint32_t>{}; | ||
accInputSizes.reserve(_inputSizes.size()+1); | ||
accInputSizes.emplace_back(0); | ||
for (size_t i{0}; i < _inputSizes.size(); ++i) { | ||
auto [len, delCt] = _inputSizes[i]; | ||
accInputSizes.emplace_back(accInputSizes.back() + len + delCt); | ||
} | ||
|
||
// Annotate text with labels, naming the correct sequence id | ||
auto labels = std::vector<uint32_t>{}; | ||
labels.reserve(sa.size() / samplingRate); | ||
|
||
for (size_t i{0}, subjId{0}; i < sa.size(); i += samplingRate) { | ||
while (i >= accInputSizes[subjId]) { | ||
subjId += 1; | ||
} | ||
labels.emplace_back(subjId-1); | ||
} | ||
|
||
// Construct sampled suffix array | ||
auto ssa = std::vector<uint32_t>{}; | ||
ssa.reserve(sa.size() / _samplingRate); | ||
for (size_t i{0}; i < sa.size(); ++i) { | ||
bool sample = (sa[i] % samplingRate) == 0; | ||
if (sample) { | ||
auto subjId = labels[sa[i] / samplingRate]; | ||
auto subjPos = sa[i] - accInputSizes[subjId]; | ||
if (reverse) { | ||
auto [len, delCt] = _inputSizes[subjId]; | ||
if (subjPos < len) { | ||
subjPos = len - subjPos; | ||
} else { | ||
subjPos = len+1; | ||
} | ||
} | ||
ssa.emplace_back(subjPos | (subjId << bitsForPosition)); | ||
} | ||
} | ||
this->ssa = std::move(ssa); | ||
this->bv = BitvectorCompact{sa.size(), [&](size_t idx) { | ||
return (sa[idx] % samplingRate) == 0; | ||
}}; | ||
} | ||
|
||
|
||
auto operator=(CSA_32 const&) -> CSA_32& = delete; | ||
auto operator=(CSA_32&& _other) noexcept -> CSA_32& = default; | ||
|
||
size_t memoryUsage() const { | ||
return sizeof(ssa) + ssa.size() * sizeof(ssa.back()) | ||
+ bv.memoryUsage(); | ||
} | ||
|
||
auto value(size_t idx) const -> std::optional<std::tuple<uint64_t, uint64_t>> { | ||
if (!bv.value(idx)) { | ||
return std::nullopt; | ||
} | ||
auto v = ssa[bv.rank(idx)]; | ||
auto chr = v >> bitsForPosition; | ||
auto pos = v & bitPositionMask; | ||
|
||
return std::make_tuple(chr, pos); | ||
} | ||
|
||
template <typename Archive> | ||
void serialize(Archive& ar) { | ||
ar(ssa, bv, samplingRate, bitsForPosition, bitPositionMask); | ||
} | ||
}; | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.