Skip to content
This repository has been archived by the owner on Jul 8, 2024. It is now read-only.

Commit

Permalink
Refactor Rust FFI
Browse files Browse the repository at this point in the history
Reorganize FFI namespace and rename SwervePathBuilderImpl to
SwervePathBuilder.
  • Loading branch information
calcmogul committed Jul 3, 2024
1 parent bfc4e8f commit 982db55
Show file tree
Hide file tree
Showing 6 changed files with 339 additions and 335 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ option(BUILD_EXAMPLES "Build examples" OFF)
include(CompilerFlags)

file(GLOB_RECURSE TrajoptLib_src src/*.cpp)
list(FILTER TrajoptLib_src EXCLUDE REGEX trajoptlibrust.cpp)
list(FILTER TrajoptLib_src EXCLUDE REGEX RustFFI.cpp)

add_library(TrajoptLib ${TrajoptLib_src})
compiler_flags(TrajoptLib)
Expand Down
6 changes: 3 additions & 3 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ fn main() {
let mut bridge_build = cxx_build::bridge("src/lib.rs");

bridge_build
.file("src/trajoptlibrust.cpp")
.file("src/RustFFI.cpp")
.include("src")
.include(format!("{}/include", cmake_dest.display()))
.include(format!("{}/include/eigen3", cmake_dest.display()))
Expand All @@ -45,7 +45,7 @@ fn main() {
println!("cargo:rustc-link-lib=Sleipnir");
println!("cargo:rustc-link-lib=fmt");

println!("cargo:rerun-if-changed=src/trajoptlibrust.hpp");
println!("cargo:rerun-if-changed=src/trajoptlibrust.cpp");
println!("cargo:rerun-if-changed=src/RustFFI.hpp");
println!("cargo:rerun-if-changed=src/RustFFI.cpp");
println!("cargo:rerun-if-changed=src/lib.rs");
}
251 changes: 251 additions & 0 deletions src/RustFFI.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
// Copyright (c) TrajoptLib contributors

#include "RustFFI.hpp"

#include <stdint.h>

#include <cstddef>
#include <stdexcept>
#include <vector>

#include "trajopt/SwerveTrajectoryGenerator.hpp"
#include "trajopt/constraint/AngularVelocityMaxMagnitudeConstraint.hpp"
#include "trajopt/constraint/LinearAccelerationMaxMagnitudeConstraint.hpp"
#include "trajopt/constraint/LinearVelocityDirectionConstraint.hpp"
#include "trajopt/constraint/LinearVelocityMaxMagnitudeConstraint.hpp"
#include "trajopt/constraint/PointAtConstraint.hpp"
#include "trajopt/drivetrain/SwerveModule.hpp"
#include "trajopt/trajectory/HolonomicTrajectory.hpp"
#include "trajopt/trajectory/HolonomicTrajectorySample.hpp"
#include "trajoptlib/src/lib.rs.h"

namespace trajopt::rsffi {

void SwervePathBuilder::set_drivetrain(const SwerveDrivetrain& drivetrain) {
std::vector<trajopt::SwerveModule> cppModules;
for (const auto& module : drivetrain.modules) {
cppModules.push_back(
trajopt::SwerveModule{{module.x, module.y},
module.wheel_radius,
module.wheel_max_angular_velocity,
module.wheel_max_torque});
}

path_builder.SetDrivetrain(trajopt::SwerveDrivetrain{
drivetrain.mass, drivetrain.moi, std::move(cppModules)});
}

void SwervePathBuilder::set_control_interval_counts(
const rust::Vec<size_t> counts) {
std::vector<size_t> cppCounts;
for (const auto& count : counts) {
cppCounts.emplace_back(count);
}

path_builder.ControlIntervalCounts(std::move(cppCounts));
}

void SwervePathBuilder::set_bumpers(double length, double width) {
path_builder.AddBumpers(
trajopt::Bumpers{.safetyDistance = 0.01,
.points = {{+length / 2, +width / 2},
{-length / 2, +width / 2},
{-length / 2, -width / 2},
{+length / 2, -width / 2}}});
}

void SwervePathBuilder::pose_wpt(size_t index, double x, double y,
double heading) {
path_builder.PoseWpt(index, x, y, heading);
}

void SwervePathBuilder::translation_wpt(size_t index, double x, double y,
double heading_guess) {
path_builder.TranslationWpt(index, x, y, heading_guess);
}

void SwervePathBuilder::empty_wpt(size_t index, double x_guess, double y_guess,
double heading_guess) {
path_builder.WptInitialGuessPoint(index, {x_guess, y_guess, heading_guess});
}

void SwervePathBuilder::sgmt_initial_guess_points(
size_t from_index, const rust::Vec<Pose2d>& guess_points) {
std::vector<trajopt::Pose2d> cppGuessPoints;
for (const auto& guess_point : guess_points) {
cppGuessPoints.emplace_back(guess_point.x, guess_point.y,
guess_point.heading);
}

path_builder.SgmtInitialGuessPoints(from_index, std::move(cppGuessPoints));
}

void SwervePathBuilder::wpt_linear_velocity_direction(size_t index,
double angle) {
path_builder.WptConstraint(index,
trajopt::LinearVelocityDirectionConstraint{angle});
}

void SwervePathBuilder::wpt_linear_velocity_max_magnitude(size_t index,
double magnitude) {
path_builder.WptConstraint(
index, trajopt::LinearVelocityMaxMagnitudeConstraint{magnitude});
}

void SwervePathBuilder::wpt_angular_velocity_max_magnitude(
size_t index, double angular_velocity) {
path_builder.WptConstraint(
index, trajopt::AngularVelocityMaxMagnitudeConstraint{angular_velocity});
}

void SwervePathBuilder::wpt_linear_acceleration_max_magnitude(
size_t index, double magnitude) {
path_builder.WptConstraint(
index, trajopt::LinearAccelerationMaxMagnitudeConstraint{magnitude});
}

void SwervePathBuilder::wpt_point_at(size_t index, double field_point_x,
double field_point_y,
double heading_tolerance) {
path_builder.WptConstraint(
index, trajopt::PointAtConstraint{
trajopt::Translation2d{field_point_x, field_point_y},
heading_tolerance});
}

void SwervePathBuilder::sgmt_linear_velocity_direction(size_t from_index,
size_t to_index,
double angle) {
path_builder.SgmtConstraint(
from_index, to_index, trajopt::LinearVelocityDirectionConstraint{angle});
}

void SwervePathBuilder::sgmt_linear_velocity_max_magnitude(size_t from_index,
size_t to_index,
double magnitude) {
path_builder.SgmtConstraint(
from_index, to_index,
trajopt::LinearVelocityMaxMagnitudeConstraint{magnitude});
}

void SwervePathBuilder::sgmt_angular_velocity_max_magnitude(
size_t from_index, size_t to_index, double angular_velocity) {
path_builder.SgmtConstraint(
from_index, to_index,
trajopt::AngularVelocityMaxMagnitudeConstraint{angular_velocity});
}

void SwervePathBuilder::sgmt_linear_acceleration_max_magnitude(
size_t from_index, size_t to_index, double magnitude) {
path_builder.SgmtConstraint(
from_index, to_index,
trajopt::LinearAccelerationMaxMagnitudeConstraint{magnitude});
}

void SwervePathBuilder::sgmt_point_at(size_t from_index, size_t to_index,
double field_point_x,
double field_point_y,
double heading_tolerance) {
path_builder.SgmtConstraint(
from_index, to_index,
trajopt::PointAtConstraint{{field_point_x, field_point_y},
heading_tolerance});
}

void SwervePathBuilder::sgmt_circle_obstacle(size_t from_index, size_t to_index,
double x, double y,
double radius) {
path_builder.SgmtObstacle(from_index, to_index, {radius, {{x, y}}});
}

void SwervePathBuilder::sgmt_polygon_obstacle(size_t from_index,
size_t to_index,
const rust::Vec<double> x,
const rust::Vec<double> y,
double radius) {
if (x.size() != y.size()) [[unlikely]] {
return;
}

std::vector<trajopt::Translation2d> cppPoints;
for (size_t i = 0; i < x.size(); ++i) {
cppPoints.emplace_back(x.at(i), y.at(i));
}

path_builder.SgmtObstacle(from_index, to_index,
trajopt::Obstacle{.safetyDistance = radius,
.points = std::move(cppPoints)});
}

HolonomicTrajectory SwervePathBuilder::generate(bool diagnostics,
int64_t handle) const {
trajopt::SwerveTrajectoryGenerator generator{path_builder, handle};
if (auto sol = generator.Generate(diagnostics); sol.has_value()) {
trajopt::HolonomicTrajectory cppTrajectory{sol.value()};

rust::Vec<HolonomicTrajectorySample> rustSamples;
for (const auto& cppSample : cppTrajectory.samples) {
rust::Vec<double> fx;
std::copy(cppSample.moduleForcesX.begin(), cppSample.moduleForcesX.end(),
std::back_inserter(fx));

rust::Vec<double> fy;
std::copy(cppSample.moduleForcesY.begin(), cppSample.moduleForcesY.end(),
std::back_inserter(fy));

rustSamples.push_back(HolonomicTrajectorySample{
cppSample.timestamp, cppSample.x, cppSample.y, cppSample.heading,
cppSample.velocityX, cppSample.velocityY, cppSample.angularVelocity,
std::move(fx), std::move(fy)});
}

return HolonomicTrajectory{std::move(rustSamples)};
} else {
throw std::runtime_error{sol.error()};
}
}

/**
* Add a callback that will be called on each iteration of the solver.
*
* @param callback: a `fn` (not a closure) to be executed. The callback's
* first parameter will be a `trajopt::HolonomicTrajectory`, and the second
* parameter will be an `i64` equal to the handle passed in `generate()`
*
* This function can be called multiple times to add multiple callbacks.
*/
void SwervePathBuilder::add_progress_callback(
rust::Fn<void(HolonomicTrajectory, int64_t)> callback) {
path_builder.AddIntermediateCallback(
[=](trajopt::SwerveSolution& solution, int64_t handle) {
trajopt::HolonomicTrajectory cppTrajectory{solution};

rust::Vec<HolonomicTrajectorySample> rustSamples;
for (const auto& cppSample : cppTrajectory.samples) {
rust::Vec<double> fx;
std::copy(cppSample.moduleForcesX.begin(),
cppSample.moduleForcesX.end(), std::back_inserter(fx));

rust::Vec<double> fy;
std::copy(cppSample.moduleForcesY.begin(),
cppSample.moduleForcesY.end(), std::back_inserter(fy));

rustSamples.push_back(HolonomicTrajectorySample{
cppSample.timestamp, cppSample.x, cppSample.y, cppSample.heading,
cppSample.velocityX, cppSample.velocityY,
cppSample.angularVelocity, std::move(fx), std::move(fy)});
}

callback(HolonomicTrajectory{rustSamples}, handle);
});
}

void SwervePathBuilder::cancel_all() {
path_builder.CancelAll();
}

std::unique_ptr<SwervePathBuilder> swerve_path_builder_new() {
return std::make_unique<SwervePathBuilder>();
}

} // namespace trajopt::rsffi
12 changes: 6 additions & 6 deletions src/trajoptlibrust.hpp → src/RustFFI.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@

#include "trajopt/path/SwervePathBuilder.hpp"

namespace trajoptlibrust {
namespace trajopt::rsffi {

struct HolonomicTrajectory;
struct Pose2d;
struct SwerveDrivetrain;

class SwervePathBuilderImpl {
class SwervePathBuilder {
public:
SwervePathBuilderImpl() = default;
SwervePathBuilder() = default;

void set_drivetrain(const SwerveDrivetrain& drivetrain);
void set_bumpers(double length, double width);
Expand Down Expand Up @@ -68,9 +68,9 @@ class SwervePathBuilderImpl {
void cancel_all();

private:
trajopt::SwervePathBuilder path;
trajopt::SwervePathBuilder path_builder;
};

std::unique_ptr<SwervePathBuilderImpl> new_swerve_path_builder_impl();
std::unique_ptr<SwervePathBuilder> swerve_path_builder_new();

} // namespace trajoptlibrust
} // namespace trajopt::rsffi
Loading

0 comments on commit 982db55

Please sign in to comment.