Skip to content

Commit

Permalink
refactor functions
Browse files Browse the repository at this point in the history
  • Loading branch information
allnes committed May 24, 2024
1 parent b8a7ea8 commit 7d3ba52
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ bool ACLCommonExecutor::update(const MemoryArgs &memory) {
acl_tensors_layouts_list[cpu_mem_ptr.first]);
}

auto status = this->prepare_tensors_info();
if (!status) {
DEBUG_LOG("ACL operator validation was failed: ", status.error_description());
this->prepareTensorsInfo();
if (!tensorsInfoValidateStatus) {
DEBUG_LOG("ACL operator validation was failed: ", tensorsInfoValidateStatus.error_description());
return false;
}

Expand All @@ -47,7 +47,7 @@ bool ACLCommonExecutor::update(const MemoryArgs &memory) {
aclMemoryArgs[acl_tensor_info.first]->allocator()->init(*acl_tensor_info.second);
}

configureThreadSafe([&] { this->configure_function(); });
configureThreadSafe([&] { this->configureFunction(); });
return true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ struct ACLTensorAttrs {

class ACLCommonExecutor : public Executor {
public:
virtual arm_compute::Status prepare_tensors_info() {
OPENVINO_THROW_NOT_IMPLEMENTED("This version of the 'prepare_tensors_info' method is not implemented by executor");
virtual void prepareTensorsInfo() {
OPENVINO_THROW_NOT_IMPLEMENTED("This version of the 'prepareTensorsInfo' method is not implemented by executor");
}
virtual void configure_function() {
OPENVINO_THROW_NOT_IMPLEMENTED("This version of the 'configure_function' method is not implemented by executor");
virtual void configureFunction() {
OPENVINO_THROW_NOT_IMPLEMENTED("This version of the 'configureFunction' method is not implemented by executor");
}
impl_desc_type implType() const override {
return impl_desc_type::acl;
Expand All @@ -35,6 +35,7 @@ class ACLCommonExecutor : public Executor {
bool update(const MemoryArgs& memory) override;

protected:
arm_compute::Status tensorsInfoValidateStatus;
ACLFunction iFunction = nullptr;
ACLMemoryArgs aclMemoryArgs;
ACLMemoryInfoArgs aclMemoryInfoArgs;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ bool ACLFullyConnectedExecutor::supports(const FCConfig &config) {
return true;
}

arm_compute::Status ACLFullyConnectedExecutor::prepare_tensors_info() {
void ACLFullyConnectedExecutor::prepareTensorsInfo() {
auto wei_shape = aclMemoryInfoArgs.at(ARG_WEI)->tensor_shape();
if (wei_shape.num_dimensions() == 3) {
aclMemoryInfoArgs.at(ARG_WEI)->set_tensor_shape({wei_shape[0] * wei_shape[1], wei_shape[2]});
Expand All @@ -74,15 +74,15 @@ arm_compute::Status ACLFullyConnectedExecutor::prepare_tensors_info() {
aclMemoryInfoArgs.at(ARG_WEI)->tensor_shape().total_size(),
false, expected_weight_format);

auto opt_impl_status = arm_compute::NEFullyConnectedLayer::has_opt_impl(
tensorsInfoValidateStatus = arm_compute::NEFullyConnectedLayer::has_opt_impl(
expected_weight_format,
aclMemoryInfoArgs.at(ARG_SRC).get(),
aclMemoryInfoArgs.at(ARG_WEI).get(),
withBias ? aclMemoryInfoArgs.at(ARG_BIAS).get() : nullptr,
aclMemoryInfoArgs.at(ARG_DST).get(),
fullyConnectedLayerInfo,
weightsInfo);
if (!opt_impl_status) { return opt_impl_status; }
if (!tensorsInfoValidateStatus) { return; }
fullyConnectedLayerInfo.enable_fast_math = arm_compute::is_fixed_format_fast_math(expected_weight_format);

if (!fullyConnectedLayerInfo.transpose_weights) {
Expand All @@ -91,15 +91,16 @@ arm_compute::Status ACLFullyConnectedExecutor::prepare_tensors_info() {
aclMemoryInfoArgs.at(ARG_WEI)->set_tensor_shape(temp_weights_shape);
}

return arm_compute::NEFullyConnectedLayer::validate(aclMemoryInfoArgs.at(ARG_SRC).get(),
aclMemoryInfoArgs.at(ARG_WEI).get(),
withBias ? aclMemoryInfoArgs.at(ARG_BIAS).get() : nullptr,
aclMemoryInfoArgs.at(ARG_DST).get(),
fullyConnectedLayerInfo,
weightsInfo);
tensorsInfoValidateStatus = arm_compute::NEFullyConnectedLayer::validate(
aclMemoryInfoArgs.at(ARG_SRC).get(),
aclMemoryInfoArgs.at(ARG_WEI).get(),
withBias ? aclMemoryInfoArgs.at(ARG_BIAS).get() : nullptr,
aclMemoryInfoArgs.at(ARG_DST).get(),
fullyConnectedLayerInfo,
weightsInfo);
}

void ACLFullyConnectedExecutor::configure_function() {
void ACLFullyConnectedExecutor::configureFunction() {
iFunction = std::make_unique<arm_compute::NEFullyConnectedLayer>();
reinterpret_cast<arm_compute::NEFullyConnectedLayer*>(iFunction.get())->configure(
aclMemoryArgs.at(ARG_SRC).get(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ class ACLFullyConnectedExecutor : public ACLCommonExecutor {

static bool supports(const FCConfig& config);

arm_compute::Status prepare_tensors_info() override;
void prepareTensorsInfo() override;

void configure_function() override;
void configureFunction() override;

impl_desc_type implType() const override {
return impl_desc_type::gemm_acl;
Expand Down

0 comments on commit 7d3ba52

Please sign in to comment.