Skip to content

Commit

Permalink
Add new matcher structure
Browse files Browse the repository at this point in the history
  • Loading branch information
iefode committed Jul 10, 2023
1 parent 9de54bc commit b541181
Show file tree
Hide file tree
Showing 10 changed files with 43 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#pragma once

#include "cache/cache.hpp"

#include "cache/meta/input_info.hpp"
#include "matchers/subgraph/subgraph.hpp"
#include "matchers/subgraph/fused_names.hpp"
#include "matchers/subgraph/repeat_pattern.hpp"
Expand Down Expand Up @@ -42,6 +44,9 @@ class GraphCache final : public virtual ICache {
};
m_manager.set_matchers(matchers);
}

void update_cache(const std::shared_ptr<ov::Model>& model, const std::string& model_path,
const std::map<std::string, InputInfo>& input_info, size_t model_op_cnt);
};

} // namespace subgraph_dumper
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ struct InputInfo {
}
};

using ExtractedPattern = std::pair<std::shared_ptr<ov::Model>, std::map<std::string, InputInfo>>;

} // namespace subgraph_dumper
} // namespace tools
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ namespace subgraph_dumper {
class BaseMatcher {
public:
using Ptr = std::shared_ptr<BaseMatcher>;
using ExtractedPattern = std::pair<std::shared_ptr<ov::Model>, std::map<std::string, InputInfo>>;

BaseMatcher() = default;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class MatchersManager {
bool match(const std::shared_ptr<ov::Model> &model,
const std::shared_ptr<ov::Model> &ref_model);

std::list<BaseMatcher::ExtractedPattern> run_extractors(const std::shared_ptr<ov::Model> &model);
std::list<ExtractedPattern> run_extractors(const std::shared_ptr<ov::Model> &model);

void set_matchers(const MatchersMap& matchers = {}) { m_matchers = matchers; }
const MatchersMap& get_matchers() { return m_matchers; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace subgraph_dumper {

class FusedNamesMatcher : public SubgraphMatcher {
public:
std::list<BaseMatcher::ExtractedPattern> extract(const std::shared_ptr<ov::Model> &model) override;
std::list<ExtractedPattern> extract(const std::shared_ptr<ov::Model> &model) override;
};

} // namespace subgraph_dumper
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace subgraph_dumper {

class RepeatPatternMatcher : public SubgraphMatcher {
public:
std::list<BaseMatcher::ExtractedPattern> extract(const std::shared_ptr<ov::Model> &model) override;
std::list<ExtractedPattern> extract(const std::shared_ptr<ov::Model> &model) override;
};

} // namespace subgraph_dumper
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,35 +22,38 @@ void GraphCache::update_cache(const std::shared_ptr<ov::Model>& model, const std
if (extracted_patterns.empty()) {
return;
}
for (const auto& cached_pattern : m_graph_cache) {
while (!extracted_patterns.empty()) {
auto it = extracted_patterns.begin();
while (it != extracted_patterns.end()) {
if (m_manager.match(cached_pattern.first, it->first)) {
break;
}
++it;
}
if (it == extracted_patterns.end()) {
continue;
}
auto cached_model_size = cached_pattern.first->get_graph_size();
auto pattern_model_size = it->first->get_graph_size();
if (pattern_model_size < cached_model_size) {
auto meta = cached_pattern.second;
meta.update(model_meta_data, it->second, model_total_op);
m_graph_cache.erase(cached_pattern.first);
m_graph_cache.insert({it->first, meta});
} else {
m_graph_cache[cached_pattern.first].update(model_meta_data, it->second, model_total_op);
}
extracted_patterns.erase(it);
update_cache(it->first, model_meta_data, it->second, model_total_op);
extracted_patterns.pop_front();
}
return;
}

for (const auto& extracted_pattern : extracted_patterns) {
auto meta = MetaInfo(model_meta_data, extracted_pattern.second, model_total_op);
m_graph_cache.insert({model, meta});
void GraphCache::update_cache(const std::shared_ptr<ov::Model>& extracted_model, const std::string& model_path,
const std::map<std::string, InputInfo>& input_info, size_t model_op_cnt) {
std::shared_ptr<ov::Model> model_to_update = nullptr;
for (const auto& cached_model : m_graph_cache) {
if (m_manager.match(cached_model.first, extracted_model)) {
model_to_update = cached_model.first;
break;
}
}
if (model_to_update == nullptr) {
auto meta = MetaInfo(model_path, input_info, model_op_cnt);
m_graph_cache.insert({ extracted_model, meta });
return;
}
auto cached_model_size = model_to_update->get_graph_size();
auto pattern_model_size = extracted_model->get_graph_size();
if (pattern_model_size < cached_model_size) {
auto meta = m_graph_cache[model_to_update];
meta.update(model_path, input_info, model_op_cnt);
m_graph_cache.erase(model_to_update);
m_graph_cache.insert({extracted_model, meta});
} else {
m_graph_cache[model_to_update].update(model_path, input_info, model_op_cnt);
}
return;
}

void GraphCache::serialize_cache() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ bool MatchersManager::match(const std::shared_ptr<ov::Model> &model,
return false;
}

std::list<BaseMatcher::ExtractedPattern>
std::list<ExtractedPattern>
MatchersManager::run_extractors(const std::shared_ptr<ov::Model> &model) {
std::list<BaseMatcher::ExtractedPattern> result;
std::list<ExtractedPattern> result;
for (const auto &it : m_matchers) {
auto extracted_patterns = it.second->extract(model);
result.insert(result.end(), extracted_patterns.begin(), extracted_patterns.end());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@

using namespace ov::tools::subgraph_dumper;

std::list<BaseMatcher::ExtractedPattern>
std::list<ExtractedPattern>
FusedNamesMatcher::extract(const std::shared_ptr<ov::Model> &model) {
std::list<BaseMatcher::ExtractedPattern> matched_patterns;
std::list<ExtractedPattern> matched_patterns;
auto core = ov::test::utils::PluginCache::get().core();
auto compiled_model = core->compile_model(model);
bool is_graph_started = false;
std::map<std::string, std::shared_ptr<ov::Node>> model_map;
std::unordered_set<std::string> compiled_op_name;
for (const auto& compiled_op : compiled_model.get_runtime_model()->get_ordered_ops()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

using namespace ov::tools::subgraph_dumper;

std::list<BaseMatcher::ExtractedPattern>
std::list<ExtractedPattern>
RepeatPatternMatcher::extract(const std::shared_ptr<ov::Model> &model) {
return {};
}
Expand Down

0 comments on commit b541181

Please sign in to comment.