From d88bd42394e658e2538c1eb611cbe119ca1314dc Mon Sep 17 00:00:00 2001 From: Tyler Veness Date: Mon, 14 Oct 2024 10:43:31 -0700 Subject: [PATCH] [wpimath] Merge .inc files into headers Splitting the files didn't help readability or save compilation time and it confused contributors. Merging them is also in line with how C++ modules will be written. --- .../native/cpp/geometry/Translation3d.cpp | 38 --- .../frc/controller/LinearQuadraticRegulator.h | 77 +++++- .../controller/LinearQuadraticRegulator.inc | 110 -------- .../proto/SimpleMotorFeedforwardProto.h | 36 ++- .../proto/SimpleMotorFeedforwardProto.inc | 45 ---- .../struct/SimpleMotorFeedforwardStruct.h | 35 ++- .../struct/SimpleMotorFeedforwardStruct.inc | 45 ---- .../frc/estimator/ExtendedKalmanFilter.h | 132 ++++++++- .../frc/estimator/ExtendedKalmanFilter.inc | 171 ------------ .../include/frc/estimator/KalmanFilter.h | 91 ++++++- .../include/frc/estimator/KalmanFilter.inc | 110 -------- .../include/frc/estimator/PoseEstimator.h | 206 ++++++++++++-- .../include/frc/estimator/PoseEstimator.inc | 246 ----------------- .../frc/estimator/SteadyStateKalmanFilter.h | 71 ++++- .../frc/estimator/SteadyStateKalmanFilter.inc | 89 ------- .../frc/estimator/UnscentedKalmanFilter.h | 145 +++++++++- .../frc/estimator/UnscentedKalmanFilter.inc | 196 -------------- .../main/native/include/frc/geometry/Pose2d.h | 31 ++- .../native/include/frc/geometry/Pose2d.inc | 43 --- .../native/include/frc/geometry/Rotation2d.h | 55 +++- .../include/frc/geometry/Rotation2d.inc | 70 ----- .../native/include/frc/geometry/Transform2d.h | 19 +- .../include/frc/geometry/Transform2d.inc | 30 --- .../include/frc/geometry/Translation2d.h | 50 +++- .../include/frc/geometry/Translation2d.inc | 68 ----- .../include/frc/geometry/Translation3d.h | 66 +++-- .../include/frc/geometry/Translation3d.inc | 48 ---- .../native/include/frc/kinematics/Odometry.h | 28 +- .../include/frc/kinematics/Odometry.inc | 38 --- .../frc/kinematics/SwerveDriveKinematics.h | 148 ++++++++++- .../frc/kinematics/SwerveDriveKinematics.inc | 176 ------------ .../frc/kinematics/SwerveDriveOdometry.h | 14 +- .../frc/kinematics/SwerveDriveOdometry.inc | 23 -- .../proto/SwerveDriveKinematicsProto.h | 32 ++- .../proto/SwerveDriveKinematicsProto.inc | 44 --- .../struct/SwerveDriveKinematicsStruct.h | 16 +- .../struct/SwerveDriveKinematicsStruct.inc | 25 -- .../native/include/frc/proto/MatrixProto.h | 44 ++- .../native/include/frc/proto/MatrixProto.inc | 60 ----- .../native/include/frc/proto/VectorProto.h | 36 ++- .../native/include/frc/proto/VectorProto.inc | 49 ---- .../native/include/frc/struct/MatrixStruct.h | 23 +- .../include/frc/struct/MatrixStruct.inc | 35 --- .../native/include/frc/struct/VectorStruct.h | 23 +- .../include/frc/struct/VectorStruct.inc | 33 --- .../frc/system/proto/LinearSystemProto.h | 42 ++- .../frc/system/proto/LinearSystemProto.inc | 55 ---- .../frc/system/struct/LinearSystemStruct.h | 33 ++- .../frc/system/struct/LinearSystemStruct.inc | 42 --- .../frc/trajectory/ExponentialProfile.h | 208 +++++++++++++-- .../frc/trajectory/ExponentialProfile.inc | 251 ------------------ .../include/frc/trajectory/TrapezoidProfile.h | 143 +++++++++- .../frc/trajectory/TrapezoidProfile.inc | 156 ----------- .../SwerveDriveKinematicsConstraint.h | 23 +- .../SwerveDriveKinematicsConstraint.inc | 42 --- 55 files changed, 1640 insertions(+), 2525 deletions(-) delete mode 100644 wpimath/src/main/native/include/frc/controller/LinearQuadraticRegulator.inc delete mode 100644 wpimath/src/main/native/include/frc/controller/proto/SimpleMotorFeedforwardProto.inc delete mode 100644 wpimath/src/main/native/include/frc/controller/struct/SimpleMotorFeedforwardStruct.inc delete mode 100644 wpimath/src/main/native/include/frc/estimator/ExtendedKalmanFilter.inc delete mode 100644 wpimath/src/main/native/include/frc/estimator/KalmanFilter.inc delete mode 100644 wpimath/src/main/native/include/frc/estimator/PoseEstimator.inc delete mode 100644 wpimath/src/main/native/include/frc/estimator/SteadyStateKalmanFilter.inc delete mode 100644 wpimath/src/main/native/include/frc/estimator/UnscentedKalmanFilter.inc delete mode 100644 wpimath/src/main/native/include/frc/geometry/Pose2d.inc delete mode 100644 wpimath/src/main/native/include/frc/geometry/Rotation2d.inc delete mode 100644 wpimath/src/main/native/include/frc/geometry/Transform2d.inc delete mode 100644 wpimath/src/main/native/include/frc/geometry/Translation2d.inc delete mode 100644 wpimath/src/main/native/include/frc/geometry/Translation3d.inc delete mode 100644 wpimath/src/main/native/include/frc/kinematics/Odometry.inc delete mode 100644 wpimath/src/main/native/include/frc/kinematics/SwerveDriveKinematics.inc delete mode 100644 wpimath/src/main/native/include/frc/kinematics/SwerveDriveOdometry.inc delete mode 100644 wpimath/src/main/native/include/frc/kinematics/proto/SwerveDriveKinematicsProto.inc delete mode 100644 wpimath/src/main/native/include/frc/kinematics/struct/SwerveDriveKinematicsStruct.inc delete mode 100644 wpimath/src/main/native/include/frc/proto/MatrixProto.inc delete mode 100644 wpimath/src/main/native/include/frc/proto/VectorProto.inc delete mode 100644 wpimath/src/main/native/include/frc/struct/MatrixStruct.inc delete mode 100644 wpimath/src/main/native/include/frc/struct/VectorStruct.inc delete mode 100644 wpimath/src/main/native/include/frc/system/proto/LinearSystemProto.inc delete mode 100644 wpimath/src/main/native/include/frc/system/struct/LinearSystemStruct.inc delete mode 100644 wpimath/src/main/native/include/frc/trajectory/ExponentialProfile.inc delete mode 100644 wpimath/src/main/native/include/frc/trajectory/TrapezoidProfile.inc delete mode 100644 wpimath/src/main/native/include/frc/trajectory/constraint/SwerveDriveKinematicsConstraint.inc diff --git a/wpimath/src/main/native/cpp/geometry/Translation3d.cpp b/wpimath/src/main/native/cpp/geometry/Translation3d.cpp index 69757f30801..1609f434e8b 100644 --- a/wpimath/src/main/native/cpp/geometry/Translation3d.cpp +++ b/wpimath/src/main/native/cpp/geometry/Translation3d.cpp @@ -6,46 +6,8 @@ #include -#include "units/length.h" -#include "units/math.h" - using namespace frc; -Translation3d::Translation3d(units::meter_t distance, const Rotation3d& angle) { - auto rectangular = Translation3d{distance, 0_m, 0_m}.RotateBy(angle); - m_x = rectangular.X(); - m_y = rectangular.Y(); - m_z = rectangular.Z(); -} - -Translation3d::Translation3d(const Eigen::Vector3d& vector) - : m_x{units::meter_t{vector.x()}}, - m_y{units::meter_t{vector.y()}}, - m_z{units::meter_t{vector.z()}} {} - -units::meter_t Translation3d::Distance(const Translation3d& other) const { - return units::math::sqrt(units::math::pow<2>(other.m_x - m_x) + - units::math::pow<2>(other.m_y - m_y) + - units::math::pow<2>(other.m_z - m_z)); -} - -units::meter_t Translation3d::Norm() const { - return units::math::sqrt(m_x * m_x + m_y * m_y + m_z * m_z); -} - -Translation3d Translation3d::RotateBy(const Rotation3d& other) const { - Quaternion p{0.0, m_x.value(), m_y.value(), m_z.value()}; - auto qprime = other.GetQuaternion() * p * other.GetQuaternion().Inverse(); - return Translation3d{units::meter_t{qprime.X()}, units::meter_t{qprime.Y()}, - units::meter_t{qprime.Z()}}; -} - -bool Translation3d::operator==(const Translation3d& other) const { - return units::math::abs(m_x - other.m_x) < 1E-9_m && - units::math::abs(m_y - other.m_y) < 1E-9_m && - units::math::abs(m_z - other.m_z) < 1E-9_m; -} - void frc::to_json(wpi::json& json, const Translation3d& translation) { json = wpi::json{{"x", translation.X().value()}, {"y", translation.Y().value()}, diff --git a/wpimath/src/main/native/include/frc/controller/LinearQuadraticRegulator.h b/wpimath/src/main/native/include/frc/controller/LinearQuadraticRegulator.h index 2671b3dccb8..263a2a769ed 100644 --- a/wpimath/src/main/native/include/frc/controller/LinearQuadraticRegulator.h +++ b/wpimath/src/main/native/include/frc/controller/LinearQuadraticRegulator.h @@ -4,12 +4,22 @@ #pragma once +#include +#include + +#include +#include #include #include +#include "frc/DARE.h" #include "frc/EigenCore.h" +#include "frc/StateSpaceUtil.h" +#include "frc/fmt/Eigen.h" +#include "frc/system/Discretization.h" #include "frc/system/LinearSystem.h" #include "units/time.h" +#include "wpimath/MathShared.h" namespace frc { @@ -50,7 +60,8 @@ class LinearQuadraticRegulator { template LinearQuadraticRegulator(const LinearSystem& plant, const StateArray& Qelems, const InputArray& Relems, - units::second_t dt); + units::second_t dt) + : LinearQuadraticRegulator(plant.A(), plant.B(), Qelems, Relems, dt) {} /** * Constructs a controller with the given coefficients and plant. @@ -69,7 +80,9 @@ class LinearQuadraticRegulator { LinearQuadraticRegulator(const Matrixd& A, const Matrixd& B, const StateArray& Qelems, const InputArray& Relems, - units::second_t dt); + units::second_t dt) + : LinearQuadraticRegulator(A, B, MakeCostMatrix(Qelems), + MakeCostMatrix(Relems), dt) {} /** * Constructs a controller with the given coefficients and plant. @@ -85,7 +98,30 @@ class LinearQuadraticRegulator { const Matrixd& B, const Matrixd& Q, const Matrixd& R, - units::second_t dt); + units::second_t dt) { + Matrixd discA; + Matrixd discB; + DiscretizeAB(A, B, dt, &discA, &discB); + + if (!IsStabilizable(discA, discB)) { + std::string msg = fmt::format( + "The system passed to the LQR is unstabilizable!\n\nA =\n{}\nB " + "=\n{}\n", + discA, discB); + + wpi::math::MathSharedStore::ReportError(msg); + throw std::invalid_argument(msg); + } + + Matrixd S = DARE(discA, discB, Q, R); + + // K = (BᵀSB + R)⁻¹BᵀSA + m_K = (discB.transpose() * S * discB + R) + .llt() + .solve(discB.transpose() * S * discA); + + Reset(); + } /** * Constructs a controller with the given coefficients and plant. @@ -103,7 +139,20 @@ class LinearQuadraticRegulator { const Matrixd& Q, const Matrixd& R, const Matrixd& N, - units::second_t dt); + units::second_t dt) { + Matrixd discA; + Matrixd discB; + DiscretizeAB(A, B, dt, &discA, &discB); + + Matrixd S = DARE(discA, discB, Q, R, N); + + // K = (BᵀSB + R)⁻¹(BᵀSA + Nᵀ) + m_K = (discB.transpose() * S * discB + R) + .llt() + .solve(discB.transpose() * S * discA + N.transpose()); + + Reset(); + } LinearQuadraticRegulator(LinearQuadraticRegulator&&) = default; LinearQuadraticRegulator& operator=(LinearQuadraticRegulator&&) = default; @@ -166,7 +215,10 @@ class LinearQuadraticRegulator { * * @param x The current state x. */ - InputVector Calculate(const StateVector& x); + InputVector Calculate(const StateVector& x) { + m_u = m_K * (m_r - x); + return m_u; + } /** * Returns the next output of the controller. @@ -174,7 +226,10 @@ class LinearQuadraticRegulator { * @param x The current state x. * @param nextR The next reference vector r. */ - InputVector Calculate(const StateVector& x, const StateVector& nextR); + InputVector Calculate(const StateVector& x, const StateVector& nextR) { + m_r = nextR; + return Calculate(x); + } /** * Adjusts LQR controller gain to compensate for a pure time delay in the @@ -194,7 +249,13 @@ class LinearQuadraticRegulator { */ template void LatencyCompensate(const LinearSystem& plant, - units::second_t dt, units::second_t inputDelay); + units::second_t dt, units::second_t inputDelay) { + Matrixd discA; + Matrixd discB; + DiscretizeAB(plant.A(), plant.B(), dt, &discA, &discB); + + m_K = m_K * (discA - discB * m_K).pow(inputDelay / dt); + } private: // Current reference @@ -215,5 +276,3 @@ extern template class EXPORT_TEMPLATE_DECLARE(WPILIB_DLLEXPORT) LinearQuadraticRegulator<2, 2>; } // namespace frc - -#include "LinearQuadraticRegulator.inc" diff --git a/wpimath/src/main/native/include/frc/controller/LinearQuadraticRegulator.inc b/wpimath/src/main/native/include/frc/controller/LinearQuadraticRegulator.inc deleted file mode 100644 index 333181ce497..00000000000 --- a/wpimath/src/main/native/include/frc/controller/LinearQuadraticRegulator.inc +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright (c) FIRST and other WPILib contributors. -// Open Source Software; you can modify and/or share it under the terms of -// the WPILib BSD license file in the root directory of this project. - -#pragma once - -#include -#include - -#include -#include - -#include "frc/DARE.h" -#include "frc/StateSpaceUtil.h" -#include "frc/controller/LinearQuadraticRegulator.h" -#include "frc/fmt/Eigen.h" -#include "frc/system/Discretization.h" -#include "wpimath/MathShared.h" - -namespace frc { - -template -template -LinearQuadraticRegulator::LinearQuadraticRegulator( - const LinearSystem& plant, - const StateArray& Qelems, const InputArray& Relems, units::second_t dt) - : LinearQuadraticRegulator(plant.A(), plant.B(), Qelems, Relems, dt) {} - -template -LinearQuadraticRegulator::LinearQuadraticRegulator( - const Matrixd& A, const Matrixd& B, - const StateArray& Qelems, const InputArray& Relems, units::second_t dt) - : LinearQuadraticRegulator(A, B, MakeCostMatrix(Qelems), - MakeCostMatrix(Relems), dt) {} - -template -LinearQuadraticRegulator::LinearQuadraticRegulator( - const Matrixd& A, const Matrixd& B, - const Matrixd& Q, const Matrixd& R, - units::second_t dt) { - Matrixd discA; - Matrixd discB; - DiscretizeAB(A, B, dt, &discA, &discB); - - if (!IsStabilizable(discA, discB)) { - std::string msg = fmt::format( - "The system passed to the LQR is unstabilizable!\n\nA =\n{}\nB =\n{}\n", - discA, discB); - - wpi::math::MathSharedStore::ReportError(msg); - throw std::invalid_argument(msg); - } - - Matrixd S = DARE(discA, discB, Q, R); - - // K = (BᵀSB + R)⁻¹BᵀSA - m_K = (discB.transpose() * S * discB + R) - .llt() - .solve(discB.transpose() * S * discA); - - Reset(); -} - -template -LinearQuadraticRegulator::LinearQuadraticRegulator( - const Matrixd& A, const Matrixd& B, - const Matrixd& Q, const Matrixd& R, - const Matrixd& N, units::second_t dt) { - Matrixd discA; - Matrixd discB; - DiscretizeAB(A, B, dt, &discA, &discB); - - Matrixd S = DARE(discA, discB, Q, R, N); - - // K = (BᵀSB + R)⁻¹(BᵀSA + Nᵀ) - m_K = (discB.transpose() * S * discB + R) - .llt() - .solve(discB.transpose() * S * discA + N.transpose()); - - Reset(); -} - -template -typename LinearQuadraticRegulator::InputVector -LinearQuadraticRegulator::Calculate(const StateVector& x) { - m_u = m_K * (m_r - x); - return m_u; -} - -template -typename LinearQuadraticRegulator::InputVector -LinearQuadraticRegulator::Calculate(const StateVector& x, - const StateVector& nextR) { - m_r = nextR; - return Calculate(x); -} - -template -template -void LinearQuadraticRegulator::LatencyCompensate( - const LinearSystem& plant, units::second_t dt, - units::second_t inputDelay) { - Matrixd discA; - Matrixd discB; - DiscretizeAB(plant.A(), plant.B(), dt, &discA, &discB); - - m_K = m_K * (discA - discB * m_K).pow(inputDelay / dt); -} - -} // namespace frc diff --git a/wpimath/src/main/native/include/frc/controller/proto/SimpleMotorFeedforwardProto.h b/wpimath/src/main/native/include/frc/controller/proto/SimpleMotorFeedforwardProto.h index ad57d132d7b..4d763f0b195 100644 --- a/wpimath/src/main/native/include/frc/controller/proto/SimpleMotorFeedforwardProto.h +++ b/wpimath/src/main/native/include/frc/controller/proto/SimpleMotorFeedforwardProto.h @@ -4,8 +4,10 @@ #pragma once +#include #include +#include "controller.pb.h" #include "frc/controller/SimpleMotorFeedforward.h" #include "units/length.h" @@ -14,11 +16,35 @@ template struct wpi::Protobuf> { - static google::protobuf::Message* New(google::protobuf::Arena* arena); + static google::protobuf::Message* New(google::protobuf::Arena* arena) { + return wpi::CreateMessage( + arena); + } + static frc::SimpleMotorFeedforward Unpack( - const google::protobuf::Message& msg); + const google::protobuf::Message& msg) { + auto m = + static_cast(&msg); + return {units::volt_t{m->ks()}, + units::unit_t::kv_unit>{ + m->kv()}, + units::unit_t::ka_unit>{ + m->ka()}, + units::second_t{m->dt()}}; + } + static void Pack(google::protobuf::Message* msg, - const frc::SimpleMotorFeedforward& value); + const frc::SimpleMotorFeedforward& value) { + auto m = static_cast(msg); + m->set_ks(value.GetKs().value()); + m->set_kv( + units::unit_t::kv_unit>{ + value.GetKv()} + .value()); + m->set_ka( + units::unit_t::ka_unit>{ + value.GetKa()} + .value()); + m->set_dt(units::second_t{value.GetDt()}.value()); + } }; - -#include "frc/controller/proto/SimpleMotorFeedforwardProto.inc" diff --git a/wpimath/src/main/native/include/frc/controller/proto/SimpleMotorFeedforwardProto.inc b/wpimath/src/main/native/include/frc/controller/proto/SimpleMotorFeedforwardProto.inc deleted file mode 100644 index 8cda505a3db..00000000000 --- a/wpimath/src/main/native/include/frc/controller/proto/SimpleMotorFeedforwardProto.inc +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright (c) FIRST and other WPILib contributors. -// Open Source Software; you can modify and/or share it under the terms of -// the WPILib BSD license file in the root directory of this project. - -#pragma once - -#include - -#include "controller.pb.h" -#include "frc/controller/proto/SimpleMotorFeedforwardProto.h" - -template -google::protobuf::Message* -wpi::Protobuf>::New( - google::protobuf::Arena* arena) { - return wpi::CreateMessage(arena); -} - -template -frc::SimpleMotorFeedforward -wpi::Protobuf>::Unpack( - const google::protobuf::Message& msg) { - auto m = static_cast(&msg); - return {units::volt_t{m->ks()}, - units::unit_t::kv_unit>{ - m->kv()}, - units::unit_t::ka_unit>{ - m->ka()}, - units::second_t{m->dt()}}; -} - -template -void wpi::Protobuf>::Pack( - google::protobuf::Message* msg, - const frc::SimpleMotorFeedforward& value) { - auto m = static_cast(msg); - m->set_ks(value.GetKs().value()); - m->set_kv(units::unit_t::kv_unit>{ - value.GetKv()} - .value()); - m->set_ka(units::unit_t::ka_unit>{ - value.GetKa()} - .value()); - m->set_dt(units::second_t{value.GetDt()}.value()); -} diff --git a/wpimath/src/main/native/include/frc/controller/struct/SimpleMotorFeedforwardStruct.h b/wpimath/src/main/native/include/frc/controller/struct/SimpleMotorFeedforwardStruct.h index 1ebbefa3633..156eef5d4b3 100644 --- a/wpimath/src/main/native/include/frc/controller/struct/SimpleMotorFeedforwardStruct.h +++ b/wpimath/src/main/native/include/frc/controller/struct/SimpleMotorFeedforwardStruct.h @@ -24,14 +24,41 @@ struct wpi::Struct> { } static frc::SimpleMotorFeedforward Unpack( - std::span data); + std::span data) { + constexpr size_t kKsOff = 0; + constexpr size_t kKvOff = kKsOff + 8; + constexpr size_t kKaOff = kKvOff + 8; + constexpr size_t kDtOff = kKaOff + 8; + return {units::volt_t{wpi::UnpackStruct(data)}, + units::unit_t::kv_unit>{ + wpi::UnpackStruct(data)}, + units::unit_t::ka_unit>{ + wpi::UnpackStruct(data)}, + units::second_t{wpi::UnpackStruct(data)}}; + } + static void Pack(std::span data, - const frc::SimpleMotorFeedforward& value); + const frc::SimpleMotorFeedforward& value) { + constexpr size_t kKsOff = 0; + constexpr size_t kKvOff = kKsOff + 8; + constexpr size_t kKaOff = kKvOff + 8; + constexpr size_t kDtOff = kKaOff + 8; + wpi::PackStruct(data, value.GetKs().value()); + wpi::PackStruct( + data, + units::unit_t::kv_unit>{ + value.GetKv()} + .value()); + wpi::PackStruct( + data, + units::unit_t::ka_unit>{ + value.GetKa()} + .value()); + wpi::PackStruct(data, units::second_t{value.GetDt()}.value()); + } }; static_assert( wpi::StructSerializable>); static_assert( wpi::StructSerializable>); - -#include "frc/controller/struct/SimpleMotorFeedforwardStruct.inc" diff --git a/wpimath/src/main/native/include/frc/controller/struct/SimpleMotorFeedforwardStruct.inc b/wpimath/src/main/native/include/frc/controller/struct/SimpleMotorFeedforwardStruct.inc deleted file mode 100644 index 50c6760f958..00000000000 --- a/wpimath/src/main/native/include/frc/controller/struct/SimpleMotorFeedforwardStruct.inc +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright (c) FIRST and other WPILib contributors. -// Open Source Software; you can modify and/or share it under the terms of -// the WPILib BSD license file in the root directory of this project. - -#pragma once - -#include "frc/controller/struct/SimpleMotorFeedforwardStruct.h" - -template -frc::SimpleMotorFeedforward -wpi::Struct>::Unpack( - std::span data) { - constexpr size_t kKsOff = 0; - constexpr size_t kKvOff = kKsOff + 8; - constexpr size_t kKaOff = kKvOff + 8; - constexpr size_t kDtOff = kKaOff + 8; - return {units::volt_t{wpi::UnpackStruct(data)}, - units::unit_t::kv_unit>{ - wpi::UnpackStruct(data)}, - units::unit_t::ka_unit>{ - wpi::UnpackStruct(data)}, - units::second_t{wpi::UnpackStruct(data)}}; -} - -template -void wpi::Struct>::Pack( - std::span data, - const frc::SimpleMotorFeedforward& value) { - constexpr size_t kKsOff = 0; - constexpr size_t kKvOff = kKsOff + 8; - constexpr size_t kKaOff = kKvOff + 8; - constexpr size_t kDtOff = kKaOff + 8; - wpi::PackStruct(data, value.GetKs().value()); - wpi::PackStruct( - data, - units::unit_t::kv_unit>{ - value.GetKv()} - .value()); - wpi::PackStruct( - data, - units::unit_t::ka_unit>{ - value.GetKa()} - .value()); - wpi::PackStruct(data, units::second_t{value.GetDt()}.value()); -} diff --git a/wpimath/src/main/native/include/frc/estimator/ExtendedKalmanFilter.h b/wpimath/src/main/native/include/frc/estimator/ExtendedKalmanFilter.h index ec8a3ec7d65..89a28dc248f 100644 --- a/wpimath/src/main/native/include/frc/estimator/ExtendedKalmanFilter.h +++ b/wpimath/src/main/native/include/frc/estimator/ExtendedKalmanFilter.h @@ -5,10 +5,17 @@ #pragma once #include +#include +#include #include +#include "frc/DARE.h" #include "frc/EigenCore.h" +#include "frc/StateSpaceUtil.h" +#include "frc/system/Discretization.h" +#include "frc/system/NumericalIntegration.h" +#include "frc/system/NumericalJacobian.h" #include "units/time.h" namespace frc { @@ -68,7 +75,38 @@ class ExtendedKalmanFilter { std::function f, std::function h, const StateArray& stateStdDevs, const OutputArray& measurementStdDevs, - units::second_t dt); + units::second_t dt) + : m_f(std::move(f)), m_h(std::move(h)) { + m_contQ = MakeCovMatrix(stateStdDevs); + m_contR = MakeCovMatrix(measurementStdDevs); + m_residualFuncY = [](const OutputVector& a, + const OutputVector& b) -> OutputVector { + return a - b; + }; + m_addFuncX = [](const StateVector& a, const StateVector& b) -> StateVector { + return a + b; + }; + m_dt = dt; + + StateMatrix contA = NumericalJacobianX( + m_f, m_xHat, InputVector::Zero()); + Matrixd C = NumericalJacobianX( + m_h, m_xHat, InputVector::Zero()); + + StateMatrix discA; + StateMatrix discQ; + DiscretizeAQ(contA, m_contQ, dt, &discA, &discQ); + + Matrixd discR = DiscretizeR(m_contR, dt); + + if (IsDetectable(discA, C) && Outputs <= States) { + m_initP = + DARE(discA.transpose(), C.transpose(), discQ, discR); + } else { + m_initP = StateMatrix::Zero(); + } + m_P = m_initP; + } /** * Constructs an extended Kalman filter. @@ -96,7 +134,34 @@ class ExtendedKalmanFilter { residualFuncY, std::function addFuncX, - units::second_t dt); + units::second_t dt) + : m_f(std::move(f)), + m_h(std::move(h)), + m_residualFuncY(std::move(residualFuncY)), + m_addFuncX(std::move(addFuncX)) { + m_contQ = MakeCovMatrix(stateStdDevs); + m_contR = MakeCovMatrix(measurementStdDevs); + m_dt = dt; + + StateMatrix contA = NumericalJacobianX( + m_f, m_xHat, InputVector::Zero()); + Matrixd C = NumericalJacobianX( + m_h, m_xHat, InputVector::Zero()); + + StateMatrix discA; + StateMatrix discQ; + DiscretizeAQ(contA, m_contQ, dt, &discA, &discQ); + + Matrixd discR = DiscretizeR(m_contR, dt); + + if (IsDetectable(discA, C) && Outputs <= States) { + m_initP = + DARE(discA.transpose(), C.transpose(), discQ, discR); + } else { + m_initP = StateMatrix::Zero(); + } + m_P = m_initP; + } /** * Returns the error covariance matrix P. @@ -159,7 +224,23 @@ class ExtendedKalmanFilter { * @param u New control input from controller. * @param dt Timestep for prediction. */ - void Predict(const InputVector& u, units::second_t dt); + void Predict(const InputVector& u, units::second_t dt) { + // Find continuous A + StateMatrix contA = + NumericalJacobianX(m_f, m_xHat, u); + + // Find discrete A and Q + StateMatrix discA; + StateMatrix discQ; + DiscretizeAQ(contA, m_contQ, dt, &discA, &discQ); + + m_xHat = RK4(m_f, m_xHat, u, dt); + + // Pₖ₊₁⁻ = APₖ⁻Aᵀ + Q + m_P = discA * m_P * discA.transpose() + discQ; + + m_dt = dt; + } /** * Correct the state estimate x-hat using the measurements in y. @@ -202,7 +283,16 @@ class ExtendedKalmanFilter { void Correct( const InputVector& u, const Vectord& y, std::function(const StateVector&, const InputVector&)> h, - const Matrixd& R); + const Matrixd& R) { + auto residualFuncY = [](const Vectord& a, + const Vectord& b) -> Vectord { + return a - b; + }; + auto addFuncX = [](const StateVector& a, + const StateVector& b) -> StateVector { return a + b; }; + Correct(u, y, std::move(h), R, std::move(residualFuncY), + std::move(addFuncX)); + } /** * Correct the state estimate x-hat using the measurements in y. @@ -228,7 +318,37 @@ class ExtendedKalmanFilter { std::function(const Vectord&, const Vectord&)> residualFuncY, std::function - addFuncX); + addFuncX) { + const Matrixd C = + NumericalJacobianX(h, m_xHat, u); + const Matrixd discR = DiscretizeR(R, m_dt); + + Matrixd S = C * m_P * C.transpose() + discR; + + // We want to put K = PCᵀS⁻¹ into Ax = b form so we can solve it more + // efficiently. + // + // K = PCᵀS⁻¹ + // KS = PCᵀ + // (KS)ᵀ = (PCᵀ)ᵀ + // SᵀKᵀ = CPᵀ + // + // The solution of Ax = b can be found via x = A.solve(b). + // + // Kᵀ = Sᵀ.solve(CPᵀ) + // K = (Sᵀ.solve(CPᵀ))ᵀ + Matrixd K = + S.transpose().ldlt().solve(C * m_P.transpose()).transpose(); + + // x̂ₖ₊₁⁺ = x̂ₖ₊₁⁻ + Kₖ₊₁(y − h(x̂ₖ₊₁⁻, uₖ₊₁)) + m_xHat = addFuncX(m_xHat, K * residualFuncY(y, h(m_xHat, u))); + + // Pₖ₊₁⁺ = (I−Kₖ₊₁C)Pₖ₊₁⁻(I−Kₖ₊₁C)ᵀ + Kₖ₊₁RKₖ₊₁ᵀ + // Use Joseph form for numerical stability + m_P = (StateMatrix::Identity() - K * C) * m_P * + (StateMatrix::Identity() - K * C).transpose() + + K * discR * K.transpose(); + } private: std::function m_f; @@ -246,5 +366,3 @@ class ExtendedKalmanFilter { }; } // namespace frc - -#include "ExtendedKalmanFilter.inc" diff --git a/wpimath/src/main/native/include/frc/estimator/ExtendedKalmanFilter.inc b/wpimath/src/main/native/include/frc/estimator/ExtendedKalmanFilter.inc deleted file mode 100644 index b31d35b742f..00000000000 --- a/wpimath/src/main/native/include/frc/estimator/ExtendedKalmanFilter.inc +++ /dev/null @@ -1,171 +0,0 @@ -// Copyright (c) FIRST and other WPILib contributors. -// Open Source Software; you can modify and/or share it under the terms of -// the WPILib BSD license file in the root directory of this project. - -#pragma once - -#include -#include - -#include - -#include "frc/DARE.h" -#include "frc/StateSpaceUtil.h" -#include "frc/estimator/ExtendedKalmanFilter.h" -#include "frc/system/Discretization.h" -#include "frc/system/NumericalIntegration.h" -#include "frc/system/NumericalJacobian.h" - -namespace frc { - -template -ExtendedKalmanFilter::ExtendedKalmanFilter( - std::function f, - std::function h, - const StateArray& stateStdDevs, const OutputArray& measurementStdDevs, - units::second_t dt) - : m_f(std::move(f)), m_h(std::move(h)) { - m_contQ = MakeCovMatrix(stateStdDevs); - m_contR = MakeCovMatrix(measurementStdDevs); - m_residualFuncY = [](const OutputVector& a, - const OutputVector& b) -> OutputVector { return a - b; }; - m_addFuncX = [](const StateVector& a, const StateVector& b) -> StateVector { - return a + b; - }; - m_dt = dt; - - StateMatrix contA = NumericalJacobianX( - m_f, m_xHat, InputVector::Zero()); - Matrixd C = NumericalJacobianX( - m_h, m_xHat, InputVector::Zero()); - - StateMatrix discA; - StateMatrix discQ; - DiscretizeAQ(contA, m_contQ, dt, &discA, &discQ); - - Matrixd discR = DiscretizeR(m_contR, dt); - - if (IsDetectable(discA, C) && Outputs <= States) { - m_initP = - DARE(discA.transpose(), C.transpose(), discQ, discR); - } else { - m_initP = StateMatrix::Zero(); - } - m_P = m_initP; -} - -template -ExtendedKalmanFilter::ExtendedKalmanFilter( - std::function f, - std::function h, - const StateArray& stateStdDevs, const OutputArray& measurementStdDevs, - std::function - residualFuncY, - std::function addFuncX, - units::second_t dt) - : m_f(std::move(f)), - m_h(std::move(h)), - m_residualFuncY(std::move(residualFuncY)), - m_addFuncX(std::move(addFuncX)) { - m_contQ = MakeCovMatrix(stateStdDevs); - m_contR = MakeCovMatrix(measurementStdDevs); - m_dt = dt; - - StateMatrix contA = NumericalJacobianX( - m_f, m_xHat, InputVector::Zero()); - Matrixd C = NumericalJacobianX( - m_h, m_xHat, InputVector::Zero()); - - StateMatrix discA; - StateMatrix discQ; - DiscretizeAQ(contA, m_contQ, dt, &discA, &discQ); - - Matrixd discR = DiscretizeR(m_contR, dt); - - if (IsDetectable(discA, C) && Outputs <= States) { - m_initP = - DARE(discA.transpose(), C.transpose(), discQ, discR); - } else { - m_initP = StateMatrix::Zero(); - } - m_P = m_initP; -} - -template -void ExtendedKalmanFilter::Predict( - const InputVector& u, units::second_t dt) { - // Find continuous A - StateMatrix contA = - NumericalJacobianX(m_f, m_xHat, u); - - // Find discrete A and Q - StateMatrix discA; - StateMatrix discQ; - DiscretizeAQ(contA, m_contQ, dt, &discA, &discQ); - - m_xHat = RK4(m_f, m_xHat, u, dt); - - // Pₖ₊₁⁻ = APₖ⁻Aᵀ + Q - m_P = discA * m_P * discA.transpose() + discQ; - - m_dt = dt; -} - -template -template -void ExtendedKalmanFilter::Correct( - const InputVector& u, const Vectord& y, - std::function(const StateVector&, const InputVector&)> h, - const Matrixd& R) { - auto residualFuncY = [](const Vectord& a, - const Vectord& b) -> Vectord { - return a - b; - }; - auto addFuncX = [](const StateVector& a, - const StateVector& b) -> StateVector { return a + b; }; - Correct(u, y, std::move(h), R, std::move(residualFuncY), - std::move(addFuncX)); -} - -template -template -void ExtendedKalmanFilter::Correct( - const InputVector& u, const Vectord& y, - std::function(const StateVector&, const InputVector&)> h, - const Matrixd& R, - std::function(const Vectord&, const Vectord&)> - residualFuncY, - std::function - addFuncX) { - const Matrixd C = - NumericalJacobianX(h, m_xHat, u); - const Matrixd discR = DiscretizeR(R, m_dt); - - Matrixd S = C * m_P * C.transpose() + discR; - - // We want to put K = PCᵀS⁻¹ into Ax = b form so we can solve it more - // efficiently. - // - // K = PCᵀS⁻¹ - // KS = PCᵀ - // (KS)ᵀ = (PCᵀ)ᵀ - // SᵀKᵀ = CPᵀ - // - // The solution of Ax = b can be found via x = A.solve(b). - // - // Kᵀ = Sᵀ.solve(CPᵀ) - // K = (Sᵀ.solve(CPᵀ))ᵀ - Matrixd K = - S.transpose().ldlt().solve(C * m_P.transpose()).transpose(); - - // x̂ₖ₊₁⁺ = x̂ₖ₊₁⁻ + Kₖ₊₁(y − h(x̂ₖ₊₁⁻, uₖ₊₁)) - m_xHat = addFuncX(m_xHat, K * residualFuncY(y, h(m_xHat, u))); - - // Pₖ₊₁⁺ = (I−Kₖ₊₁C)Pₖ₊₁⁻(I−Kₖ₊₁C)ᵀ + Kₖ₊₁RKₖ₊₁ᵀ - // Use Joseph form for numerical stability - m_P = (StateMatrix::Identity() - K * C) * m_P * - (StateMatrix::Identity() - K * C).transpose() + - K * discR * K.transpose(); -} - -} // namespace frc diff --git a/wpimath/src/main/native/include/frc/estimator/KalmanFilter.h b/wpimath/src/main/native/include/frc/estimator/KalmanFilter.h index 54ee51bafc9..c0220b436a5 100644 --- a/wpimath/src/main/native/include/frc/estimator/KalmanFilter.h +++ b/wpimath/src/main/native/include/frc/estimator/KalmanFilter.h @@ -4,11 +4,21 @@ #pragma once +#include +#include +#include + +#include #include +#include "frc/DARE.h" #include "frc/EigenCore.h" +#include "frc/StateSpaceUtil.h" +#include "frc/fmt/Eigen.h" +#include "frc/system/Discretization.h" #include "frc/system/LinearSystem.h" #include "units/time.h" +#include "wpimath/MathShared.h" namespace frc { @@ -59,7 +69,37 @@ class KalmanFilter { */ KalmanFilter(LinearSystem& plant, const StateArray& stateStdDevs, - const OutputArray& measurementStdDevs, units::second_t dt); + const OutputArray& measurementStdDevs, units::second_t dt) { + m_plant = &plant; + + m_contQ = MakeCovMatrix(stateStdDevs); + m_contR = MakeCovMatrix(measurementStdDevs); + m_dt = dt; + + // Find discrete A and Q + Matrixd discA; + Matrixd discQ; + DiscretizeAQ(plant.A(), m_contQ, dt, &discA, &discQ); + + Matrixd discR = DiscretizeR(m_contR, dt); + + const auto& C = plant.C(); + + if (!IsDetectable(discA, C)) { + std::string msg = fmt::format( + "The system passed to the Kalman filter is undetectable!\n\n" + "A =\n{}\nC =\n{}\n", + discA, C); + + wpi::math::MathSharedStore::ReportError(msg); + throw std::invalid_argument(msg); + } + + m_initP = + DARE(discA.transpose(), C.transpose(), discQ, discR); + + Reset(); + } /** * Returns the error covariance matrix P. @@ -122,7 +162,19 @@ class KalmanFilter { * @param u New control input from controller. * @param dt Timestep for prediction. */ - void Predict(const InputVector& u, units::second_t dt); + void Predict(const InputVector& u, units::second_t dt) { + // Find discrete A and Q + StateMatrix discA; + StateMatrix discQ; + DiscretizeAQ(m_plant->A(), m_contQ, dt, &discA, &discQ); + + m_xHat = m_plant->CalculateX(m_xHat, u, dt); + + // Pₖ₊₁⁻ = APₖ⁻Aᵀ + Q + m_P = discA * m_P * discA.transpose() + discQ; + + m_dt = dt; + } /** * Correct the state estimate x-hat using the measurements in y. @@ -144,7 +196,38 @@ class KalmanFilter { * @param R Continuous measurement noise covariance matrix. */ void Correct(const InputVector& u, const OutputVector& y, - const Matrixd& R); + const Matrixd& R) { + const auto& C = m_plant->C(); + const auto& D = m_plant->D(); + + const Matrixd discR = DiscretizeR(R, m_dt); + + Matrixd S = C * m_P * C.transpose() + discR; + + // We want to put K = PCᵀS⁻¹ into Ax = b form so we can solve it more + // efficiently. + // + // K = PCᵀS⁻¹ + // KS = PCᵀ + // (KS)ᵀ = (PCᵀ)ᵀ + // SᵀKᵀ = CPᵀ + // + // The solution of Ax = b can be found via x = A.solve(b). + // + // Kᵀ = Sᵀ.solve(CPᵀ) + // K = (Sᵀ.solve(CPᵀ))ᵀ + Matrixd K = + S.transpose().ldlt().solve(C * m_P.transpose()).transpose(); + + // x̂ₖ₊₁⁺ = x̂ₖ₊₁⁻ + K(y − (Cx̂ₖ₊₁⁻ + Duₖ₊₁)) + m_xHat += K * (y - (C * m_xHat + D * u)); + + // Pₖ₊₁⁺ = (I−Kₖ₊₁C)Pₖ₊₁⁻(I−Kₖ₊₁C)ᵀ + Kₖ₊₁RKₖ₊₁ᵀ + // Use Joseph form for numerical stability + m_P = (StateMatrix::Identity() - K * C) * m_P * + (StateMatrix::Identity() - K * C).transpose() + + K * discR * K.transpose(); + } private: LinearSystem* m_plant; @@ -163,5 +246,3 @@ extern template class EXPORT_TEMPLATE_DECLARE(WPILIB_DLLEXPORT) KalmanFilter<2, 1, 1>; } // namespace frc - -#include "KalmanFilter.inc" diff --git a/wpimath/src/main/native/include/frc/estimator/KalmanFilter.inc b/wpimath/src/main/native/include/frc/estimator/KalmanFilter.inc deleted file mode 100644 index a00b455b316..00000000000 --- a/wpimath/src/main/native/include/frc/estimator/KalmanFilter.inc +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright (c) FIRST and other WPILib contributors. -// Open Source Software; you can modify and/or share it under the terms of -// the WPILib BSD license file in the root directory of this project. - -#pragma once - -#include -#include -#include - -#include - -#include "frc/DARE.h" -#include "frc/StateSpaceUtil.h" -#include "frc/estimator/KalmanFilter.h" -#include "frc/fmt/Eigen.h" -#include "frc/system/Discretization.h" -#include "wpimath/MathShared.h" - -namespace frc { - -template -KalmanFilter::KalmanFilter( - LinearSystem& plant, - const StateArray& stateStdDevs, const OutputArray& measurementStdDevs, - units::second_t dt) { - m_plant = &plant; - - m_contQ = MakeCovMatrix(stateStdDevs); - m_contR = MakeCovMatrix(measurementStdDevs); - m_dt = dt; - - // Find discrete A and Q - Matrixd discA; - Matrixd discQ; - DiscretizeAQ(plant.A(), m_contQ, dt, &discA, &discQ); - - Matrixd discR = DiscretizeR(m_contR, dt); - - const auto& C = plant.C(); - - if (!IsDetectable(discA, C)) { - std::string msg = fmt::format( - "The system passed to the Kalman filter is undetectable!\n\n" - "A =\n{}\nC =\n{}\n", - discA, C); - - wpi::math::MathSharedStore::ReportError(msg); - throw std::invalid_argument(msg); - } - - m_initP = - DARE(discA.transpose(), C.transpose(), discQ, discR); - - Reset(); -} - -template -void KalmanFilter::Predict(const InputVector& u, - units::second_t dt) { - // Find discrete A and Q - StateMatrix discA; - StateMatrix discQ; - DiscretizeAQ(m_plant->A(), m_contQ, dt, &discA, &discQ); - - m_xHat = m_plant->CalculateX(m_xHat, u, dt); - - // Pₖ₊₁⁻ = APₖ⁻Aᵀ + Q - m_P = discA * m_P * discA.transpose() + discQ; - - m_dt = dt; -} - -template -void KalmanFilter::Correct( - const InputVector& u, const OutputVector& y, - const Matrixd& R) { - const auto& C = m_plant->C(); - const auto& D = m_plant->D(); - - const Matrixd discR = DiscretizeR(R, m_dt); - - Matrixd S = C * m_P * C.transpose() + discR; - - // We want to put K = PCᵀS⁻¹ into Ax = b form so we can solve it more - // efficiently. - // - // K = PCᵀS⁻¹ - // KS = PCᵀ - // (KS)ᵀ = (PCᵀ)ᵀ - // SᵀKᵀ = CPᵀ - // - // The solution of Ax = b can be found via x = A.solve(b). - // - // Kᵀ = Sᵀ.solve(CPᵀ) - // K = (Sᵀ.solve(CPᵀ))ᵀ - Matrixd K = - S.transpose().ldlt().solve(C * m_P.transpose()).transpose(); - - // x̂ₖ₊₁⁺ = x̂ₖ₊₁⁻ + K(y − (Cx̂ₖ₊₁⁻ + Duₖ₊₁)) - m_xHat += K * (y - (C * m_xHat + D * u)); - - // Pₖ₊₁⁺ = (I−Kₖ₊₁C)Pₖ₊₁⁻(I−Kₖ₊₁C)ᵀ + Kₖ₊₁RKₖ₊₁ᵀ - // Use Joseph form for numerical stability - m_P = (StateMatrix::Identity() - K * C) * m_P * - (StateMatrix::Identity() - K * C).transpose() + - K * discR * K.transpose(); -} - -} // namespace frc diff --git a/wpimath/src/main/native/include/frc/estimator/PoseEstimator.h b/wpimath/src/main/native/include/frc/estimator/PoseEstimator.h index f75a5afec73..1f4a88be2b8 100644 --- a/wpimath/src/main/native/include/frc/estimator/PoseEstimator.h +++ b/wpimath/src/main/native/include/frc/estimator/PoseEstimator.h @@ -6,7 +6,6 @@ #include #include -#include #include #include @@ -23,6 +22,7 @@ #include "wpimath/MathShared.h" namespace frc { + /** * This class wraps odometry to fuse latency-compensated * vision measurements with encoder measurements. Robot code should not use this @@ -59,7 +59,14 @@ class WPILIB_DLLEXPORT PoseEstimator { PoseEstimator(Kinematics& kinematics, Odometry& odometry, const wpi::array& stateStdDevs, - const wpi::array& visionMeasurementStdDevs); + const wpi::array& visionMeasurementStdDevs) + : m_odometry(odometry), m_poseEstimate(m_odometry.GetPose()) { + for (size_t i = 0; i < 3; ++i) { + m_q[i] = stateStdDevs[i] * stateStdDevs[i]; + } + + SetVisionMeasurementStdDevs(visionMeasurementStdDevs); + } /** * Sets the pose estimator's trust in vision measurements. This might be used @@ -72,7 +79,23 @@ class WPILIB_DLLEXPORT PoseEstimator { * less. */ void SetVisionMeasurementStdDevs( - const wpi::array& visionMeasurementStdDevs); + const wpi::array& visionMeasurementStdDevs) { + wpi::array r{wpi::empty_array}; + for (size_t i = 0; i < 3; ++i) { + r[i] = visionMeasurementStdDevs[i] * visionMeasurementStdDevs[i]; + } + + // Solve for closed form Kalman gain for continuous Kalman filter with A = 0 + // and C = I. See wpimath/algorithms.md. + for (size_t row = 0; row < 3; ++row) { + if (m_q[row] == 0.0) { + m_visionK(row, row) = 0.0; + } else { + m_visionK(row, row) = + m_q[row] / (m_q[row] + std::sqrt(m_q[row] * r[row])); + } + } + } /** * Resets the robot's position on the field. @@ -85,35 +108,50 @@ class WPILIB_DLLEXPORT PoseEstimator { * @param pose The estimated pose of the robot on the field. */ void ResetPosition(const Rotation2d& gyroAngle, - const WheelPositions& wheelPositions, const Pose2d& pose); + const WheelPositions& wheelPositions, const Pose2d& pose) { + // Reset state estimate and error covariance + m_odometry.ResetPosition(gyroAngle, wheelPositions, pose); + m_odometryPoseBuffer.Clear(); + m_visionUpdates.clear(); + m_poseEstimate = m_odometry.GetPose(); + } /** * Resets the robot's pose. * * @param pose The pose to reset to. */ - void ResetPose(const Pose2d& pose); + void ResetPose(const Pose2d& pose) { + m_odometry.ResetPose(pose); + m_odometryPoseBuffer.Clear(); + } /** * Resets the robot's translation. * * @param translation The pose to translation to. */ - void ResetTranslation(const Translation2d& translation); + void ResetTranslation(const Translation2d& translation) { + m_odometry.ResetTranslation(translation); + m_odometryPoseBuffer.Clear(); + } /** * Resets the robot's rotation. * * @param rotation The rotation to reset to. */ - void ResetRotation(const Rotation2d& rotation); + void ResetRotation(const Rotation2d& rotation) { + m_odometry.ResetRotation(rotation); + m_odometryPoseBuffer.Clear(); + } /** * Gets the estimated robot pose. * * @return The estimated robot pose in meters. */ - Pose2d GetEstimatedPosition() const; + Pose2d GetEstimatedPosition() const { return m_poseEstimate; } /** * Return the pose at a given timestamp, if the buffer is not empty. @@ -122,7 +160,47 @@ class WPILIB_DLLEXPORT PoseEstimator { * @return The pose at the given timestamp (or std::nullopt if the buffer is * empty). */ - std::optional SampleAt(units::second_t timestamp) const; + std::optional SampleAt(units::second_t timestamp) const { + // Step 0: If there are no odometry updates to sample, skip. + if (m_odometryPoseBuffer.GetInternalBuffer().empty()) { + return std::nullopt; + } + + // Step 1: Make sure timestamp matches the sample from the odometry pose + // buffer. (When sampling, the buffer will always use a timestamp + // between the first and last timestamps) + units::second_t oldestOdometryTimestamp = + m_odometryPoseBuffer.GetInternalBuffer().front().first; + units::second_t newestOdometryTimestamp = + m_odometryPoseBuffer.GetInternalBuffer().back().first; + timestamp = + std::clamp(timestamp, oldestOdometryTimestamp, newestOdometryTimestamp); + + // Step 2: If there are no applicable vision updates, use the odometry-only + // information. + if (m_visionUpdates.empty() || timestamp < m_visionUpdates.begin()->first) { + return m_odometryPoseBuffer.Sample(timestamp); + } + + // Step 3: Get the latest vision update from before or at the timestamp to + // sample at. + // First, find the iterator past the sample timestamp, then go back one. + // Note that upper_bound() won't return begin() because we check begin() + // earlier. + auto floorIter = m_visionUpdates.upper_bound(timestamp); + --floorIter; + auto visionUpdate = floorIter->second; + + // Step 4: Get the pose measured by odometry at the time of the sample. + auto odometryEstimate = m_odometryPoseBuffer.Sample(timestamp); + + // Step 5: Apply the vision compensation to the odometry pose. + // TODO Replace with std::optional::transform() in C++23 + if (odometryEstimate) { + return visionUpdate.Compensate(*odometryEstimate); + } + return std::nullopt; + } /** * Adds a vision measurement to the Kalman Filter. This will correct @@ -145,7 +223,63 @@ class WPILIB_DLLEXPORT PoseEstimator { * frc::Timer::GetFPGATimestamp() as your time source in this case. */ void AddVisionMeasurement(const Pose2d& visionRobotPose, - units::second_t timestamp); + units::second_t timestamp) { + // Step 0: If this measurement is old enough to be outside the pose buffer's + // timespan, skip. + if (m_odometryPoseBuffer.GetInternalBuffer().empty() || + m_odometryPoseBuffer.GetInternalBuffer().front().first - + kBufferDuration > + timestamp) { + return; + } + + // Step 1: Clean up any old entries + CleanUpVisionUpdates(); + + // Step 2: Get the pose measured by odometry at the moment the vision + // measurement was made. + auto odometrySample = m_odometryPoseBuffer.Sample(timestamp); + + if (!odometrySample) { + return; + } + + // Step 3: Get the vision-compensated pose estimate at the moment the vision + // measurement was made. + auto visionSample = SampleAt(timestamp); + + if (!visionSample) { + return; + } + + // Step 4: Measure the twist between the old pose estimate and the vision + // pose. + auto twist = visionSample.value().Log(visionRobotPose); + + // Step 5: We should not trust the twist entirely, so instead we scale this + // twist by a Kalman gain matrix representing how much we trust vision + // measurements compared to our current pose. + Eigen::Vector3d k_times_twist = + m_visionK * Eigen::Vector3d{twist.dx.value(), twist.dy.value(), + twist.dtheta.value()}; + + // Step 6: Convert back to Twist2d. + Twist2d scaledTwist{units::meter_t{k_times_twist(0)}, + units::meter_t{k_times_twist(1)}, + units::radian_t{k_times_twist(2)}}; + + // Step 7: Calculate and record the vision update. + VisionUpdate visionUpdate{visionSample->Exp(scaledTwist), *odometrySample}; + m_visionUpdates[timestamp] = visionUpdate; + + // Step 8: Remove later vision measurements. (Matches previous behavior) + auto firstAfter = m_visionUpdates.upper_bound(timestamp); + m_visionUpdates.erase(firstAfter, m_visionUpdates.end()); + + // Step 9: Update latest pose estimate. Since we cleared all updates after + // this vision update, it's guaranteed to be the latest vision update. + m_poseEstimate = visionUpdate.Compensate(m_odometry.GetPose()); + } /** * Adds a vision measurement to the Kalman Filter. This will correct @@ -192,7 +326,10 @@ class WPILIB_DLLEXPORT PoseEstimator { * @return The estimated pose of the robot in meters. */ Pose2d Update(const Rotation2d& gyroAngle, - const WheelPositions& wheelPositions); + const WheelPositions& wheelPositions) { + return UpdateWithTime(wpi::math::MathSharedStore::GetTimestamp(), gyroAngle, + wheelPositions); + } /** * Updates the pose estimator with wheel encoder and gyro information. This @@ -206,13 +343,53 @@ class WPILIB_DLLEXPORT PoseEstimator { */ Pose2d UpdateWithTime(units::second_t currentTime, const Rotation2d& gyroAngle, - const WheelPositions& wheelPositions); + const WheelPositions& wheelPositions) { + auto odometryEstimate = m_odometry.Update(gyroAngle, wheelPositions); + + m_odometryPoseBuffer.AddSample(currentTime, odometryEstimate); + + if (m_visionUpdates.empty()) { + m_poseEstimate = odometryEstimate; + } else { + auto visionUpdate = m_visionUpdates.rbegin()->second; + m_poseEstimate = visionUpdate.Compensate(odometryEstimate); + } + + return GetEstimatedPosition(); + } private: /** * Removes stale vision updates that won't affect sampling. */ - void CleanUpVisionUpdates(); + void CleanUpVisionUpdates() { + // Step 0: If there are no odometry samples, skip. + if (m_odometryPoseBuffer.GetInternalBuffer().empty()) { + return; + } + + // Step 1: Find the oldest timestamp that needs a vision update. + units::second_t oldestOdometryTimestamp = + m_odometryPoseBuffer.GetInternalBuffer().front().first; + + // Step 2: If there are no vision updates before that timestamp, skip. + if (m_visionUpdates.empty() || + oldestOdometryTimestamp < m_visionUpdates.begin()->first) { + return; + } + + // Step 3: Find the newest vision update timestamp before or at the oldest + // timestamp. + // First, find the iterator past the oldest odometry timestamp, then go + // back one. Note that upper_bound() won't return begin() because we check + // begin() earlier. + auto newestNeededVisionUpdate = + m_visionUpdates.upper_bound(oldestOdometryTimestamp); + --newestNeededVisionUpdate; + + // Step 4: Remove all entries strictly before the newest timestamp we need. + m_visionUpdates.erase(m_visionUpdates.begin(), newestNeededVisionUpdate); + } struct VisionUpdate { // The vision-compensated pose estimate @@ -250,6 +427,5 @@ class WPILIB_DLLEXPORT PoseEstimator { Pose2d m_poseEstimate; }; -} // namespace frc -#include "frc/estimator/PoseEstimator.inc" +} // namespace frc diff --git a/wpimath/src/main/native/include/frc/estimator/PoseEstimator.inc b/wpimath/src/main/native/include/frc/estimator/PoseEstimator.inc deleted file mode 100644 index 51bb8343062..00000000000 --- a/wpimath/src/main/native/include/frc/estimator/PoseEstimator.inc +++ /dev/null @@ -1,246 +0,0 @@ -// Copyright (c) FIRST and other WPILib contributors. -// Open Source Software; you can modify and/or share it under the terms of -// the WPILib BSD license file in the root directory of this project. - -#pragma once - -#include "frc/estimator/PoseEstimator.h" -#include "frc/geometry/Pose2d.h" -#include "frc/geometry/Translation2d.h" - -namespace frc { - -template -PoseEstimator::PoseEstimator( - Kinematics& kinematics, - Odometry& odometry, - const wpi::array& stateStdDevs, - const wpi::array& visionMeasurementStdDevs) - : m_odometry(odometry), m_poseEstimate(m_odometry.GetPose()) { - for (size_t i = 0; i < 3; ++i) { - m_q[i] = stateStdDevs[i] * stateStdDevs[i]; - } - - SetVisionMeasurementStdDevs(visionMeasurementStdDevs); -} - -template -void PoseEstimator::SetVisionMeasurementStdDevs( - const wpi::array& visionMeasurementStdDevs) { - wpi::array r{wpi::empty_array}; - for (size_t i = 0; i < 3; ++i) { - r[i] = visionMeasurementStdDevs[i] * visionMeasurementStdDevs[i]; - } - - // Solve for closed form Kalman gain for continuous Kalman filter with A = 0 - // and C = I. See wpimath/algorithms.md. - for (size_t row = 0; row < 3; ++row) { - if (m_q[row] == 0.0) { - m_visionK(row, row) = 0.0; - } else { - m_visionK(row, row) = - m_q[row] / (m_q[row] + std::sqrt(m_q[row] * r[row])); - } - } -} - -template -void PoseEstimator::ResetPosition( - const Rotation2d& gyroAngle, const WheelPositions& wheelPositions, - const Pose2d& pose) { - // Reset state estimate and error covariance - m_odometry.ResetPosition(gyroAngle, wheelPositions, pose); - m_odometryPoseBuffer.Clear(); - m_visionUpdates.clear(); - m_poseEstimate = m_odometry.GetPose(); -} - -template -void PoseEstimator::ResetPose(const Pose2d& pose) { - m_odometry.ResetPose(pose); - m_odometryPoseBuffer.Clear(); -} - -template -void PoseEstimator::ResetTranslation( - const Translation2d& translation) { - m_odometry.ResetTranslation(translation); - m_odometryPoseBuffer.Clear(); -} - -template -void PoseEstimator::ResetRotation( - const Rotation2d& rotation) { - m_odometry.ResetRotation(rotation); - m_odometryPoseBuffer.Clear(); -} - -template -Pose2d PoseEstimator::GetEstimatedPosition() - const { - return m_poseEstimate; - if (m_visionUpdates.empty()) { - return m_odometry.GetPose(); - } - auto visionUpdate = m_visionUpdates.rbegin()->second; - return visionUpdate.Compensate(m_odometry.GetPose()); -} - -template -std::optional PoseEstimator::SampleAt( - units::second_t timestamp) const { - // Step 0: If there are no odometry updates to sample, skip. - if (m_odometryPoseBuffer.GetInternalBuffer().empty()) { - return std::nullopt; - } - - // Step 1: Make sure timestamp matches the sample from the odometry pose - // buffer. (When sampling, the buffer will always use a timestamp - // between the first and last timestamps) - units::second_t oldestOdometryTimestamp = - m_odometryPoseBuffer.GetInternalBuffer().front().first; - units::second_t newestOdometryTimestamp = - m_odometryPoseBuffer.GetInternalBuffer().back().first; - timestamp = - std::clamp(timestamp, oldestOdometryTimestamp, newestOdometryTimestamp); - - // Step 2: If there are no applicable vision updates, use the odometry-only - // information. - if (m_visionUpdates.empty() || timestamp < m_visionUpdates.begin()->first) { - return m_odometryPoseBuffer.Sample(timestamp); - } - - // Step 3: Get the latest vision update from before or at the timestamp to - // sample at. - // First, find the iterator past the sample timestamp, then go back one. Note - // that upper_bound() won't return begin() because we check begin() earlier. - auto floorIter = m_visionUpdates.upper_bound(timestamp); - --floorIter; - auto visionUpdate = floorIter->second; - - // Step 4: Get the pose measured by odometry at the time of the sample. - auto odometryEstimate = m_odometryPoseBuffer.Sample(timestamp); - - // Step 5: Apply the vision compensation to the odometry pose. - // TODO Replace with std::optional::transform() in C++23 - if (odometryEstimate) { - return visionUpdate.Compensate(*odometryEstimate); - } - return std::nullopt; -} - -template -void PoseEstimator::CleanUpVisionUpdates() { - // Step 0: If there are no odometry samples, skip. - if (m_odometryPoseBuffer.GetInternalBuffer().empty()) { - return; - } - - // Step 1: Find the oldest timestamp that needs a vision update. - units::second_t oldestOdometryTimestamp = - m_odometryPoseBuffer.GetInternalBuffer().front().first; - - // Step 2: If there are no vision updates before that timestamp, skip. - if (m_visionUpdates.empty() || - oldestOdometryTimestamp < m_visionUpdates.begin()->first) { - return; - } - - // Step 3: Find the newest vision update timestamp before or at the oldest - // timestamp. - // First, find the iterator past the oldest odometry timestamp, then go - // back one. Note that upper_bound() won't return begin() because we check - // begin() earlier. - auto newestNeededVisionUpdate = - m_visionUpdates.upper_bound(oldestOdometryTimestamp); - --newestNeededVisionUpdate; - - // Step 4: Remove all entries strictly before the newest timestamp we need. - m_visionUpdates.erase(m_visionUpdates.begin(), newestNeededVisionUpdate); -} - -template -void PoseEstimator::AddVisionMeasurement( - const Pose2d& visionRobotPose, units::second_t timestamp) { - // Step 0: If this measurement is old enough to be outside the pose buffer's - // timespan, skip. - if (m_odometryPoseBuffer.GetInternalBuffer().empty() || - m_odometryPoseBuffer.GetInternalBuffer().front().first - kBufferDuration > - timestamp) { - return; - } - - // Step 1: Clean up any old entries - CleanUpVisionUpdates(); - - // Step 2: Get the pose measured by odometry at the moment the vision - // measurement was made. - auto odometrySample = m_odometryPoseBuffer.Sample(timestamp); - - if (!odometrySample) { - return; - } - - // Step 3: Get the vision-compensated pose estimate at the moment the vision - // measurement was made. - auto visionSample = SampleAt(timestamp); - - if (!visionSample) { - return; - } - - // Step 4: Measure the twist between the old pose estimate and the vision - // pose. - auto twist = visionSample.value().Log(visionRobotPose); - - // Step 5: We should not trust the twist entirely, so instead we scale this - // twist by a Kalman gain matrix representing how much we trust vision - // measurements compared to our current pose. - Eigen::Vector3d k_times_twist = - m_visionK * - Eigen::Vector3d{twist.dx.value(), twist.dy.value(), twist.dtheta.value()}; - - // Step 6: Convert back to Twist2d. - Twist2d scaledTwist{units::meter_t{k_times_twist(0)}, - units::meter_t{k_times_twist(1)}, - units::radian_t{k_times_twist(2)}}; - - // Step 7: Calculate and record the vision update. - VisionUpdate visionUpdate{visionSample->Exp(scaledTwist), *odometrySample}; - m_visionUpdates[timestamp] = visionUpdate; - - // Step 8: Remove later vision measurements. (Matches previous behavior) - auto firstAfter = m_visionUpdates.upper_bound(timestamp); - m_visionUpdates.erase(firstAfter, m_visionUpdates.end()); - - // Step 9: Update latest pose estimate. Since we cleared all updates after - // this vision update, it's guaranteed to be the latest vision update. - m_poseEstimate = visionUpdate.Compensate(m_odometry.GetPose()); -} - -template -Pose2d PoseEstimator::Update( - const Rotation2d& gyroAngle, const WheelPositions& wheelPositions) { - return UpdateWithTime(wpi::math::MathSharedStore::GetTimestamp(), gyroAngle, - wheelPositions); -} - -template -Pose2d PoseEstimator::UpdateWithTime( - units::second_t currentTime, const Rotation2d& gyroAngle, - const WheelPositions& wheelPositions) { - auto odometryEstimate = m_odometry.Update(gyroAngle, wheelPositions); - - m_odometryPoseBuffer.AddSample(currentTime, odometryEstimate); - - if (m_visionUpdates.empty()) { - m_poseEstimate = odometryEstimate; - } else { - auto visionUpdate = m_visionUpdates.rbegin()->second; - m_poseEstimate = visionUpdate.Compensate(odometryEstimate); - } - - return GetEstimatedPosition(); -} - -} // namespace frc diff --git a/wpimath/src/main/native/include/frc/estimator/SteadyStateKalmanFilter.h b/wpimath/src/main/native/include/frc/estimator/SteadyStateKalmanFilter.h index c63e5802c60..31088e93cef 100644 --- a/wpimath/src/main/native/include/frc/estimator/SteadyStateKalmanFilter.h +++ b/wpimath/src/main/native/include/frc/estimator/SteadyStateKalmanFilter.h @@ -4,12 +4,22 @@ #pragma once +#include +#include +#include + +#include #include #include +#include "frc/DARE.h" #include "frc/EigenCore.h" +#include "frc/StateSpaceUtil.h" +#include "frc/fmt/Eigen.h" +#include "frc/system/Discretization.h" #include "frc/system/LinearSystem.h" #include "units/time.h" +#include "wpimath/MathShared.h" namespace frc { @@ -63,7 +73,52 @@ class SteadyStateKalmanFilter { SteadyStateKalmanFilter(LinearSystem& plant, const StateArray& stateStdDevs, const OutputArray& measurementStdDevs, - units::second_t dt); + units::second_t dt) { + m_plant = &plant; + + auto contQ = MakeCovMatrix(stateStdDevs); + auto contR = MakeCovMatrix(measurementStdDevs); + + Matrixd discA; + Matrixd discQ; + DiscretizeAQ(plant.A(), contQ, dt, &discA, &discQ); + + auto discR = DiscretizeR(contR, dt); + + const auto& C = plant.C(); + + if (!IsDetectable(discA, C)) { + std::string msg = fmt::format( + "The system passed to the Kalman filter is undetectable!\n\n" + "A =\n{}\nC =\n{}\n", + discA, C); + + wpi::math::MathSharedStore::ReportError(msg); + throw std::invalid_argument(msg); + } + + Matrixd P = + DARE(discA.transpose(), C.transpose(), discQ, discR); + + // S = CPCᵀ + R + Matrixd S = C * P * C.transpose() + discR; + + // We want to put K = PCᵀS⁻¹ into Ax = b form so we can solve it more + // efficiently. + // + // K = PCᵀS⁻¹ + // KS = PCᵀ + // (KS)ᵀ = (PCᵀ)ᵀ + // SᵀKᵀ = CPᵀ + // + // The solution of Ax = b can be found via x = A.solve(b). + // + // Kᵀ = Sᵀ.solve(CPᵀ) + // K = (Sᵀ.solve(CPᵀ))ᵀ + m_K = S.transpose().ldlt().solve(C * P.transpose()).transpose(); + + Reset(); + } SteadyStateKalmanFilter(SteadyStateKalmanFilter&&) = default; SteadyStateKalmanFilter& operator=(SteadyStateKalmanFilter&&) = default; @@ -119,7 +174,9 @@ class SteadyStateKalmanFilter { * @param u New control input from controller. * @param dt Timestep for prediction. */ - void Predict(const InputVector& u, units::second_t dt); + void Predict(const InputVector& u, units::second_t dt) { + m_xHat = m_plant->CalculateX(m_xHat, u, dt); + } /** * Correct the state estimate x-hat using the measurements in y. @@ -127,7 +184,13 @@ class SteadyStateKalmanFilter { * @param u Same control input used in the last predict step. * @param y Measurement vector. */ - void Correct(const InputVector& u, const OutputVector& y); + void Correct(const InputVector& u, const OutputVector& y) { + const auto& C = m_plant->C(); + const auto& D = m_plant->D(); + + // x̂ₖ₊₁⁺ = x̂ₖ₊₁⁻ + K(y − (Cx̂ₖ₊₁⁻ + Duₖ₊₁)) + m_xHat += m_K * (y - (C * m_xHat + D * u)); + } private: LinearSystem* m_plant; @@ -149,5 +212,3 @@ extern template class EXPORT_TEMPLATE_DECLARE(WPILIB_DLLEXPORT) SteadyStateKalmanFilter<2, 1, 1>; } // namespace frc - -#include "SteadyStateKalmanFilter.inc" diff --git a/wpimath/src/main/native/include/frc/estimator/SteadyStateKalmanFilter.inc b/wpimath/src/main/native/include/frc/estimator/SteadyStateKalmanFilter.inc deleted file mode 100644 index 2f75b241cca..00000000000 --- a/wpimath/src/main/native/include/frc/estimator/SteadyStateKalmanFilter.inc +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright (c) FIRST and other WPILib contributors. -// Open Source Software; you can modify and/or share it under the terms of -// the WPILib BSD license file in the root directory of this project. - -#pragma once - -#include -#include -#include - -#include - -#include "frc/DARE.h" -#include "frc/StateSpaceUtil.h" -#include "frc/estimator/SteadyStateKalmanFilter.h" -#include "frc/fmt/Eigen.h" -#include "frc/system/Discretization.h" -#include "wpimath/MathShared.h" - -namespace frc { - -template -SteadyStateKalmanFilter::SteadyStateKalmanFilter( - LinearSystem& plant, - const StateArray& stateStdDevs, const OutputArray& measurementStdDevs, - units::second_t dt) { - m_plant = &plant; - - auto contQ = MakeCovMatrix(stateStdDevs); - auto contR = MakeCovMatrix(measurementStdDevs); - - Matrixd discA; - Matrixd discQ; - DiscretizeAQ(plant.A(), contQ, dt, &discA, &discQ); - - auto discR = DiscretizeR(contR, dt); - - const auto& C = plant.C(); - - if (!IsDetectable(discA, C)) { - std::string msg = fmt::format( - "The system passed to the Kalman filter is undetectable!\n\n" - "A =\n{}\nC =\n{}\n", - discA, C); - - wpi::math::MathSharedStore::ReportError(msg); - throw std::invalid_argument(msg); - } - - Matrixd P = - DARE(discA.transpose(), C.transpose(), discQ, discR); - - // S = CPCᵀ + R - Matrixd S = C * P * C.transpose() + discR; - - // We want to put K = PCᵀS⁻¹ into Ax = b form so we can solve it more - // efficiently. - // - // K = PCᵀS⁻¹ - // KS = PCᵀ - // (KS)ᵀ = (PCᵀ)ᵀ - // SᵀKᵀ = CPᵀ - // - // The solution of Ax = b can be found via x = A.solve(b). - // - // Kᵀ = Sᵀ.solve(CPᵀ) - // K = (Sᵀ.solve(CPᵀ))ᵀ - m_K = S.transpose().ldlt().solve(C * P.transpose()).transpose(); - - Reset(); -} - -template -void SteadyStateKalmanFilter::Predict( - const InputVector& u, units::second_t dt) { - m_xHat = m_plant->CalculateX(m_xHat, u, dt); -} - -template -void SteadyStateKalmanFilter::Correct( - const InputVector& u, const OutputVector& y) { - const auto& C = m_plant->C(); - const auto& D = m_plant->D(); - - // x̂ₖ₊₁⁺ = x̂ₖ₊₁⁻ + K(y − (Cx̂ₖ₊₁⁻ + Duₖ₊₁)) - m_xHat += m_K * (y - (C * m_xHat + D * u)); -} - -} // namespace frc diff --git a/wpimath/src/main/native/include/frc/estimator/UnscentedKalmanFilter.h b/wpimath/src/main/native/include/frc/estimator/UnscentedKalmanFilter.h index df68e151606..87ee5ffd29b 100644 --- a/wpimath/src/main/native/include/frc/estimator/UnscentedKalmanFilter.h +++ b/wpimath/src/main/native/include/frc/estimator/UnscentedKalmanFilter.h @@ -5,12 +5,19 @@ #pragma once #include +#include +#include #include #include #include "frc/EigenCore.h" +#include "frc/StateSpaceUtil.h" #include "frc/estimator/MerweScaledSigmaPoints.h" +#include "frc/estimator/UnscentedTransform.h" +#include "frc/system/Discretization.h" +#include "frc/system/NumericalIntegration.h" +#include "frc/system/NumericalJacobian.h" #include "units/time.h" namespace frc { @@ -74,7 +81,31 @@ class UnscentedKalmanFilter { std::function f, std::function h, const StateArray& stateStdDevs, const OutputArray& measurementStdDevs, - units::second_t dt); + units::second_t dt) + : m_f(std::move(f)), m_h(std::move(h)) { + m_contQ = MakeCovMatrix(stateStdDevs); + m_contR = MakeCovMatrix(measurementStdDevs); + m_meanFuncX = [](const Matrixd& sigmas, + const Vectord<2 * States + 1>& Wm) -> StateVector { + return sigmas * Wm; + }; + m_meanFuncY = [](const Matrixd& sigmas, + const Vectord<2 * States + 1>& Wc) -> OutputVector { + return sigmas * Wc; + }; + m_residualFuncX = [](const StateVector& a, + const StateVector& b) -> StateVector { return a - b; }; + m_residualFuncY = [](const OutputVector& a, + const OutputVector& b) -> OutputVector { + return a - b; + }; + m_addFuncX = [](const StateVector& a, const StateVector& b) -> StateVector { + return a + b; + }; + m_dt = dt; + + Reset(); + } /** * Constructs an unscented Kalman filter with custom mean, residual, and @@ -120,7 +151,20 @@ class UnscentedKalmanFilter { residualFuncY, std::function addFuncX, - units::second_t dt); + units::second_t dt) + : m_f(std::move(f)), + m_h(std::move(h)), + m_meanFuncX(std::move(meanFuncX)), + m_meanFuncY(std::move(meanFuncY)), + m_residualFuncX(std::move(residualFuncX)), + m_residualFuncY(std::move(residualFuncY)), + m_addFuncX(std::move(addFuncX)) { + m_contQ = MakeCovMatrix(stateStdDevs); + m_contR = MakeCovMatrix(measurementStdDevs); + m_dt = dt; + + Reset(); + } /** * Returns the square-root error covariance matrix S. @@ -197,7 +241,31 @@ class UnscentedKalmanFilter { * @param u New control input from controller. * @param dt Timestep for prediction. */ - void Predict(const InputVector& u, units::second_t dt); + void Predict(const InputVector& u, units::second_t dt) { + m_dt = dt; + + // Discretize Q before projecting mean and covariance forward + StateMatrix contA = + NumericalJacobianX(m_f, m_xHat, u); + StateMatrix discA; + StateMatrix discQ; + DiscretizeAQ(contA, m_contQ, m_dt, &discA, &discQ); + Eigen::internal::llt_inplace::blocked(discQ); + + Matrixd sigmas = + m_pts.SquareRootSigmaPoints(m_xHat, m_S); + + for (int i = 0; i < m_pts.NumSigmas(); ++i) { + StateVector x = sigmas.template block(0, i); + m_sigmasF.template block(0, i) = RK4(m_f, x, u, dt); + } + + auto [xHat, S] = SquareRootUnscentedTransform( + m_sigmasF, m_pts.Wm(), m_pts.Wc(), m_meanFuncX, m_residualFuncX, + discQ.template triangularView()); + m_xHat = xHat; + m_S = S; + } /** * Correct the state estimate x-hat using the measurements in y. @@ -242,7 +310,25 @@ class UnscentedKalmanFilter { void Correct( const InputVector& u, const Vectord& y, std::function(const StateVector&, const InputVector&)> h, - const Matrixd& R); + const Matrixd& R) { + auto meanFuncY = [](const Matrixd& sigmas, + const Vectord<2 * States + 1>& Wc) -> Vectord { + return sigmas * Wc; + }; + auto residualFuncX = [](const StateVector& a, + const StateVector& b) -> StateVector { + return a - b; + }; + auto residualFuncY = [](const Vectord& a, + const Vectord& b) -> Vectord { + return a - b; + }; + auto addFuncX = [](const StateVector& a, + const StateVector& b) -> StateVector { return a + b; }; + Correct(u, y, std::move(h), R, std::move(meanFuncY), + std::move(residualFuncY), std::move(residualFuncX), + std::move(addFuncX)); + } /** * Correct the state estimate x-hat using the measurements in y. @@ -277,7 +363,54 @@ class UnscentedKalmanFilter { std::function residualFuncX, std::function - addFuncX); + addFuncX) { + Matrixd discR = DiscretizeR(R, m_dt); + Eigen::internal::llt_inplace::blocked(discR); + + // Transform sigma points into measurement space + Matrixd sigmasH; + Matrixd sigmas = + m_pts.SquareRootSigmaPoints(m_xHat, m_S); + for (int i = 0; i < m_pts.NumSigmas(); ++i) { + sigmasH.template block(0, i) = + h(sigmas.template block(0, i), u); + } + + // Mean and covariance of prediction passed through UT + auto [yHat, Sy] = SquareRootUnscentedTransform( + sigmasH, m_pts.Wm(), m_pts.Wc(), meanFuncY, residualFuncY, + discR.template triangularView()); + + // Compute cross covariance of the state and the measurements + Matrixd Pxy; + Pxy.setZero(); + for (int i = 0; i < m_pts.NumSigmas(); ++i) { + // Pxy += (sigmas_f[:, i] - x̂)(sigmas_h[:, i] - ŷ)ᵀ W_c[i] + Pxy += + m_pts.Wc(i) * + (residualFuncX(m_sigmasF.template block(0, i), m_xHat)) * + (residualFuncY(sigmasH.template block(0, i), yHat)) + .transpose(); + } + + // K = (P_{xy} / S_yᵀ) / S_y + // K = (S_y \ P_{xy}ᵀ)ᵀ / S_y + // K = (S_yᵀ \ (S_y \ P_{xy}ᵀ))ᵀ + Matrixd K = + Sy.transpose() + .fullPivHouseholderQr() + .solve(Sy.fullPivHouseholderQr().solve(Pxy.transpose())) + .transpose(); + + // x̂ₖ₊₁⁺ = x̂ₖ₊₁⁻ + K(y − ŷ) + m_xHat = addFuncX(m_xHat, K * residualFuncY(y, yHat)); + + Matrixd U = K * Sy; + for (int i = 0; i < Rows; i++) { + Eigen::internal::llt_inplace::rankUpdate( + m_S, U.template block(0, i), -1); + } + } private: std::function m_f; @@ -309,5 +442,3 @@ extern template class EXPORT_TEMPLATE_DECLARE(WPILIB_DLLEXPORT) UnscentedKalmanFilter<5, 3, 3>; } // namespace frc - -#include "UnscentedKalmanFilter.inc" diff --git a/wpimath/src/main/native/include/frc/estimator/UnscentedKalmanFilter.inc b/wpimath/src/main/native/include/frc/estimator/UnscentedKalmanFilter.inc deleted file mode 100644 index 03cfd192b4c..00000000000 --- a/wpimath/src/main/native/include/frc/estimator/UnscentedKalmanFilter.inc +++ /dev/null @@ -1,196 +0,0 @@ -// Copyright (c) FIRST and other WPILib contributors. -// Open Source Software; you can modify and/or share it under the terms of -// the WPILib BSD license file in the root directory of this project. - -#pragma once - -#include -#include - -#include - -#include "frc/StateSpaceUtil.h" -#include "frc/estimator/UnscentedKalmanFilter.h" -#include "frc/estimator/UnscentedTransform.h" -#include "frc/system/Discretization.h" -#include "frc/system/NumericalIntegration.h" -#include "frc/system/NumericalJacobian.h" - -namespace frc { - -template -UnscentedKalmanFilter::UnscentedKalmanFilter( - std::function f, - std::function h, - const StateArray& stateStdDevs, const OutputArray& measurementStdDevs, - units::second_t dt) - : m_f(std::move(f)), m_h(std::move(h)) { - m_contQ = MakeCovMatrix(stateStdDevs); - m_contR = MakeCovMatrix(measurementStdDevs); - m_meanFuncX = [](const Matrixd& sigmas, - const Vectord<2 * States + 1>& Wm) -> StateVector { - return sigmas * Wm; - }; - m_meanFuncY = [](const Matrixd& sigmas, - const Vectord<2 * States + 1>& Wc) -> OutputVector { - return sigmas * Wc; - }; - m_residualFuncX = [](const StateVector& a, - const StateVector& b) -> StateVector { return a - b; }; - m_residualFuncY = [](const OutputVector& a, - const OutputVector& b) -> OutputVector { return a - b; }; - m_addFuncX = [](const StateVector& a, const StateVector& b) -> StateVector { - return a + b; - }; - m_dt = dt; - - Reset(); -} - -template -UnscentedKalmanFilter::UnscentedKalmanFilter( - std::function f, - std::function h, - const StateArray& stateStdDevs, const OutputArray& measurementStdDevs, - std::function&, - const Vectord<2 * States + 1>&)> - meanFuncX, - std::function&, - const Vectord<2 * States + 1>&)> - meanFuncY, - std::function - residualFuncX, - std::function - residualFuncY, - std::function addFuncX, - units::second_t dt) - : m_f(std::move(f)), - m_h(std::move(h)), - m_meanFuncX(std::move(meanFuncX)), - m_meanFuncY(std::move(meanFuncY)), - m_residualFuncX(std::move(residualFuncX)), - m_residualFuncY(std::move(residualFuncY)), - m_addFuncX(std::move(addFuncX)) { - m_contQ = MakeCovMatrix(stateStdDevs); - m_contR = MakeCovMatrix(measurementStdDevs); - m_dt = dt; - - Reset(); -} - -template -void UnscentedKalmanFilter::Predict( - const InputVector& u, units::second_t dt) { - m_dt = dt; - - // Discretize Q before projecting mean and covariance forward - StateMatrix contA = - NumericalJacobianX(m_f, m_xHat, u); - StateMatrix discA; - StateMatrix discQ; - DiscretizeAQ(contA, m_contQ, m_dt, &discA, &discQ); - Eigen::internal::llt_inplace::blocked(discQ); - - Matrixd sigmas = - m_pts.SquareRootSigmaPoints(m_xHat, m_S); - - for (int i = 0; i < m_pts.NumSigmas(); ++i) { - StateVector x = sigmas.template block(0, i); - m_sigmasF.template block(0, i) = RK4(m_f, x, u, dt); - } - - auto [xHat, S] = SquareRootUnscentedTransform( - m_sigmasF, m_pts.Wm(), m_pts.Wc(), m_meanFuncX, m_residualFuncX, - discQ.template triangularView()); - m_xHat = xHat; - m_S = S; -} - -template -template -void UnscentedKalmanFilter::Correct( - const InputVector& u, const Vectord& y, - std::function(const StateVector&, const InputVector&)> h, - const Matrixd& R) { - auto meanFuncY = [](const Matrixd& sigmas, - const Vectord<2 * States + 1>& Wc) -> Vectord { - return sigmas * Wc; - }; - auto residualFuncX = [](const StateVector& a, - const StateVector& b) -> StateVector { - return a - b; - }; - auto residualFuncY = [](const Vectord& a, - const Vectord& b) -> Vectord { - return a - b; - }; - auto addFuncX = [](const StateVector& a, - const StateVector& b) -> StateVector { return a + b; }; - Correct(u, y, std::move(h), R, std::move(meanFuncY), - std::move(residualFuncY), std::move(residualFuncX), - std::move(addFuncX)); -} - -template -template -void UnscentedKalmanFilter::Correct( - const InputVector& u, const Vectord& y, - std::function(const StateVector&, const InputVector&)> h, - const Matrixd& R, - std::function(const Matrixd&, - const Vectord<2 * States + 1>&)> - meanFuncY, - std::function(const Vectord&, const Vectord&)> - residualFuncY, - std::function - residualFuncX, - std::function - addFuncX) { - Matrixd discR = DiscretizeR(R, m_dt); - Eigen::internal::llt_inplace::blocked(discR); - - // Transform sigma points into measurement space - Matrixd sigmasH; - Matrixd sigmas = - m_pts.SquareRootSigmaPoints(m_xHat, m_S); - for (int i = 0; i < m_pts.NumSigmas(); ++i) { - sigmasH.template block(0, i) = - h(sigmas.template block(0, i), u); - } - - // Mean and covariance of prediction passed through UT - auto [yHat, Sy] = SquareRootUnscentedTransform( - sigmasH, m_pts.Wm(), m_pts.Wc(), meanFuncY, residualFuncY, - discR.template triangularView()); - - // Compute cross covariance of the state and the measurements - Matrixd Pxy; - Pxy.setZero(); - for (int i = 0; i < m_pts.NumSigmas(); ++i) { - // Pxy += (sigmas_f[:, i] - x̂)(sigmas_h[:, i] - ŷ)ᵀ W_c[i] - Pxy += m_pts.Wc(i) * - (residualFuncX(m_sigmasF.template block(0, i), m_xHat)) * - (residualFuncY(sigmasH.template block(0, i), yHat)) - .transpose(); - } - - // K = (P_{xy} / S_yᵀ) / S_y - // K = (S_y \ P_{xy}ᵀ)ᵀ / S_y - // K = (S_yᵀ \ (S_y \ P_{xy}ᵀ))ᵀ - Matrixd K = - Sy.transpose() - .fullPivHouseholderQr() - .solve(Sy.fullPivHouseholderQr().solve(Pxy.transpose())) - .transpose(); - - // x̂ₖ₊₁⁺ = x̂ₖ₊₁⁻ + K(y − ŷ) - m_xHat = addFuncX(m_xHat, K * residualFuncY(y, yHat)); - - Matrixd U = K * Sy; - for (int i = 0; i < Rows; i++) { - Eigen::internal::llt_inplace::rankUpdate( - m_S, U.template block(0, i), -1); - } -} - -} // namespace frc diff --git a/wpimath/src/main/native/include/frc/geometry/Pose2d.h b/wpimath/src/main/native/include/frc/geometry/Pose2d.h index 7f8f47cafe3..3a267e8cf08 100644 --- a/wpimath/src/main/native/include/frc/geometry/Pose2d.h +++ b/wpimath/src/main/native/include/frc/geometry/Pose2d.h @@ -6,6 +6,7 @@ #include #include +#include #include #include @@ -14,6 +15,7 @@ #include "frc/geometry/Transform2d.h" #include "frc/geometry/Translation2d.h" #include "frc/geometry/Twist2d.h" +#include "units/length.h" namespace frc { @@ -33,7 +35,9 @@ class WPILIB_DLLEXPORT Pose2d { * @param translation The translational component of the pose. * @param rotation The rotational component of the pose. */ - constexpr Pose2d(Translation2d translation, Rotation2d rotation); + constexpr Pose2d(Translation2d translation, Rotation2d rotation) + : m_translation{std::move(translation)}, + m_rotation{std::move(rotation)} {} /** * Constructs a pose with x and y translations instead of a separate @@ -43,7 +47,8 @@ class WPILIB_DLLEXPORT Pose2d { * @param y The y component of the translational component of the pose. * @param rotation The rotational component of the pose. */ - constexpr Pose2d(units::meter_t x, units::meter_t y, Rotation2d rotation); + constexpr Pose2d(units::meter_t x, units::meter_t y, Rotation2d rotation) + : m_translation{x, y}, m_rotation{std::move(rotation)} {} /** * Transforms the pose by the given transformation and returns the new @@ -59,7 +64,9 @@ class WPILIB_DLLEXPORT Pose2d { * * @return The transformed pose. */ - constexpr Pose2d operator+(const Transform2d& other) const; + constexpr Pose2d operator+(const Transform2d& other) const { + return TransformBy(other); + } /** * Returns the Transform2d that maps the one pose to another. @@ -109,7 +116,9 @@ class WPILIB_DLLEXPORT Pose2d { * * @return The new scaled Pose2d. */ - constexpr Pose2d operator*(double scalar) const; + constexpr Pose2d operator*(double scalar) const { + return Pose2d{m_translation * scalar, m_rotation * scalar}; + } /** * Divides the current pose by a scalar. @@ -118,7 +127,9 @@ class WPILIB_DLLEXPORT Pose2d { * * @return The new scaled Pose2d. */ - constexpr Pose2d operator/(double scalar) const; + constexpr Pose2d operator/(double scalar) const { + return *this * (1.0 / scalar); + } /** * Rotates the pose around the origin and returns the new pose. @@ -127,7 +138,9 @@ class WPILIB_DLLEXPORT Pose2d { * * @return The rotated pose. */ - constexpr Pose2d RotateBy(const Rotation2d& other) const; + constexpr Pose2d RotateBy(const Rotation2d& other) const { + return {m_translation.RotateBy(other), m_rotation.RotateBy(other)}; + } /** * Transforms the pose by the given transformation and returns the new pose. @@ -137,7 +150,10 @@ class WPILIB_DLLEXPORT Pose2d { * * @return The transformed pose. */ - constexpr Pose2d TransformBy(const Transform2d& other) const; + constexpr Pose2d TransformBy(const Transform2d& other) const { + return {m_translation + (other.Translation().RotateBy(m_rotation)), + other.Rotation() + m_rotation}; + } /** * Returns the current pose relative to the given pose. @@ -217,4 +233,3 @@ void from_json(const wpi::json& json, Pose2d& pose); #include "frc/geometry/proto/Pose2dProto.h" #endif #include "frc/geometry/struct/Pose2dStruct.h" -#include "frc/geometry/Pose2d.inc" diff --git a/wpimath/src/main/native/include/frc/geometry/Pose2d.inc b/wpimath/src/main/native/include/frc/geometry/Pose2d.inc deleted file mode 100644 index 559a0034e80..00000000000 --- a/wpimath/src/main/native/include/frc/geometry/Pose2d.inc +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) FIRST and other WPILib contributors. -// Open Source Software; you can modify and/or share it under the terms of -// the WPILib BSD license file in the root directory of this project. - -#pragma once - -#include - -#include "frc/geometry/Pose2d.h" -#include "frc/geometry/Rotation2d.h" -#include "units/length.h" - -namespace frc { - -constexpr Pose2d::Pose2d(Translation2d translation, Rotation2d rotation) - : m_translation{std::move(translation)}, m_rotation{std::move(rotation)} {} - -constexpr Pose2d::Pose2d(units::meter_t x, units::meter_t y, - Rotation2d rotation) - : m_translation{x, y}, m_rotation{std::move(rotation)} {} - -constexpr Pose2d Pose2d::operator+(const Transform2d& other) const { - return TransformBy(other); -} - -constexpr Pose2d Pose2d::operator*(double scalar) const { - return Pose2d{m_translation * scalar, m_rotation * scalar}; -} - -constexpr Pose2d Pose2d::operator/(double scalar) const { - return *this * (1.0 / scalar); -} - -constexpr Pose2d Pose2d::RotateBy(const Rotation2d& other) const { - return {m_translation.RotateBy(other), m_rotation.RotateBy(other)}; -} - -constexpr Pose2d Pose2d::TransformBy(const Transform2d& other) const { - return {m_translation + (other.Translation().RotateBy(m_rotation)), - other.Rotation() + m_rotation}; -} - -} // namespace frc diff --git a/wpimath/src/main/native/include/frc/geometry/Rotation2d.h b/wpimath/src/main/native/include/frc/geometry/Rotation2d.h index 4e5357e9bb8..d35dc8330c0 100644 --- a/wpimath/src/main/native/include/frc/geometry/Rotation2d.h +++ b/wpimath/src/main/native/include/frc/geometry/Rotation2d.h @@ -4,10 +4,15 @@ #pragma once +#include + +#include +#include #include #include #include "units/angle.h" +#include "wpimath/MathShared.h" namespace frc { @@ -32,7 +37,10 @@ class WPILIB_DLLEXPORT Rotation2d { * * @param value The value of the angle. */ - constexpr Rotation2d(units::angle_unit auto value); // NOLINT + constexpr Rotation2d(units::angle_unit auto value) // NOLINT + : m_value{value}, + m_cos{gcem::cos(value.template convert().value())}, + m_sin{gcem::sin(value.template convert().value())} {} /** * Constructs a Rotation2d with the given x and y (cosine and sine) @@ -41,7 +49,22 @@ class WPILIB_DLLEXPORT Rotation2d { * @param x The x component or cosine of the rotation. * @param y The y component or sine of the rotation. */ - constexpr Rotation2d(double x, double y); + constexpr Rotation2d(double x, double y) { + double magnitude = gcem::hypot(x, y); + if (magnitude > 1e-6) { + m_sin = y / magnitude; + m_cos = x / magnitude; + } else { + m_sin = 0.0; + m_cos = 1.0; + if (!std::is_constant_evaluated()) { + wpi::math::MathSharedStore::ReportError( + "x and y components of Rotation2d are zero\n{}", + wpi::GetStackTrace(1)); + } + } + m_value = units::radian_t{gcem::atan2(m_sin, m_cos)}; + } /** * Adds two rotations together, with the result being bounded between -pi and @@ -54,7 +77,9 @@ class WPILIB_DLLEXPORT Rotation2d { * * @return The sum of the two rotations. */ - constexpr Rotation2d operator+(const Rotation2d& other) const; + constexpr Rotation2d operator+(const Rotation2d& other) const { + return RotateBy(other); + } /** * Subtracts the new rotation from the current rotation and returns the new @@ -67,7 +92,9 @@ class WPILIB_DLLEXPORT Rotation2d { * * @return The difference between the two rotations. */ - constexpr Rotation2d operator-(const Rotation2d& other) const; + constexpr Rotation2d operator-(const Rotation2d& other) const { + return *this + -other; + } /** * Takes the inverse of the current rotation. This is simply the negative of @@ -75,7 +102,7 @@ class WPILIB_DLLEXPORT Rotation2d { * * @return The inverse of the current rotation. */ - constexpr Rotation2d operator-() const; + constexpr Rotation2d operator-() const { return Rotation2d{-m_value}; } /** * Multiplies the current rotation by a scalar. @@ -84,7 +111,9 @@ class WPILIB_DLLEXPORT Rotation2d { * * @return The new scaled Rotation2d. */ - constexpr Rotation2d operator*(double scalar) const; + constexpr Rotation2d operator*(double scalar) const { + return Rotation2d{m_value * scalar}; + } /** * Divides the current rotation by a scalar. @@ -93,7 +122,9 @@ class WPILIB_DLLEXPORT Rotation2d { * * @return The new scaled Rotation2d. */ - constexpr Rotation2d operator/(double scalar) const; + constexpr Rotation2d operator/(double scalar) const { + return *this * (1.0 / scalar); + } /** * Checks equality between this Rotation2d and another object. @@ -101,7 +132,9 @@ class WPILIB_DLLEXPORT Rotation2d { * @param other The other object. * @return Whether the two objects are equal. */ - constexpr bool operator==(const Rotation2d& other) const; + constexpr bool operator==(const Rotation2d& other) const { + return gcem::hypot(Cos() - other.Cos(), Sin() - other.Sin()) < 1E-9; + } /** * Adds the new rotation to the current rotation using a rotation matrix. @@ -116,7 +149,10 @@ class WPILIB_DLLEXPORT Rotation2d { * * @return The new rotated Rotation2d. */ - constexpr Rotation2d RotateBy(const Rotation2d& other) const; + constexpr Rotation2d RotateBy(const Rotation2d& other) const { + return {Cos() * other.Cos() - Sin() * other.Sin(), + Cos() * other.Sin() + Sin() * other.Cos()}; + } /** * Returns the radian value of the rotation. @@ -173,4 +209,3 @@ void from_json(const wpi::json& json, Rotation2d& rotation); #include "frc/geometry/proto/Rotation2dProto.h" #endif #include "frc/geometry/struct/Rotation2dStruct.h" -#include "frc/geometry/Rotation2d.inc" diff --git a/wpimath/src/main/native/include/frc/geometry/Rotation2d.inc b/wpimath/src/main/native/include/frc/geometry/Rotation2d.inc deleted file mode 100644 index 740aaf0090d..00000000000 --- a/wpimath/src/main/native/include/frc/geometry/Rotation2d.inc +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright (c) FIRST and other WPILib contributors. -// Open Source Software; you can modify and/or share it under the terms of -// the WPILib BSD license file in the root directory of this project. - -#pragma once - -#include -#include - -#include -#include - -#include "frc/geometry/Rotation2d.h" -#include "units/angle.h" -#include "wpimath/MathShared.h" - -namespace frc { - -constexpr Rotation2d::Rotation2d(units::angle_unit auto value) - : m_value{value}, - m_cos{gcem::cos(value.template convert().value())}, - m_sin{gcem::sin(value.template convert().value())} {} - -constexpr Rotation2d::Rotation2d(double x, double y) { - double magnitude = gcem::hypot(x, y); - if (magnitude > 1e-6) { - m_sin = y / magnitude; - m_cos = x / magnitude; - } else { - m_sin = 0.0; - m_cos = 1.0; - if (!std::is_constant_evaluated()) { - wpi::math::MathSharedStore::ReportError( - "x and y components of Rotation2d are zero\n{}", - wpi::GetStackTrace(1)); - } - } - m_value = units::radian_t{gcem::atan2(m_sin, m_cos)}; -} - -constexpr Rotation2d Rotation2d::operator-() const { - return Rotation2d{-m_value}; -} - -constexpr Rotation2d Rotation2d::operator*(double scalar) const { - return Rotation2d{m_value * scalar}; -} - -constexpr Rotation2d Rotation2d::operator+(const Rotation2d& other) const { - return RotateBy(other); -} - -constexpr Rotation2d Rotation2d::operator-(const Rotation2d& other) const { - return *this + -other; -} - -constexpr Rotation2d Rotation2d::operator/(double scalar) const { - return *this * (1.0 / scalar); -} - -constexpr bool Rotation2d::operator==(const Rotation2d& other) const { - return gcem::hypot(Cos() - other.Cos(), Sin() - other.Sin()) < 1E-9; -} - -constexpr Rotation2d Rotation2d::RotateBy(const Rotation2d& other) const { - return {Cos() * other.Cos() - Sin() * other.Sin(), - Cos() * other.Sin() + Sin() * other.Cos()}; -} - -} // namespace frc diff --git a/wpimath/src/main/native/include/frc/geometry/Transform2d.h b/wpimath/src/main/native/include/frc/geometry/Transform2d.h index af984eadb4a..35271056edb 100644 --- a/wpimath/src/main/native/include/frc/geometry/Transform2d.h +++ b/wpimath/src/main/native/include/frc/geometry/Transform2d.h @@ -4,8 +4,11 @@ #pragma once +#include + #include +#include "frc/geometry/Rotation2d.h" #include "frc/geometry/Translation2d.h" namespace frc { @@ -31,7 +34,9 @@ class WPILIB_DLLEXPORT Transform2d { * @param translation Translational component of the transform. * @param rotation Rotational component of the transform. */ - constexpr Transform2d(Translation2d translation, Rotation2d rotation); + constexpr Transform2d(Translation2d translation, Rotation2d rotation) + : m_translation{std::move(translation)}, + m_rotation{std::move(rotation)} {} /** * Constructs a transform with x and y translations instead of a separate @@ -41,8 +46,8 @@ class WPILIB_DLLEXPORT Transform2d { * @param y The y component of the translational component of the transform. * @param rotation The rotational component of the transform. */ - constexpr Transform2d(units::meter_t x, units::meter_t y, - Rotation2d rotation); + constexpr Transform2d(units::meter_t x, units::meter_t y, Rotation2d rotation) + : m_translation{x, y}, m_rotation{std::move(rotation)} {} /** * Constructs the identity transform -- maps an initial pose to itself. @@ -82,7 +87,12 @@ class WPILIB_DLLEXPORT Transform2d { * * @return The inverted transformation. */ - constexpr Transform2d Inverse() const; + constexpr Transform2d Inverse() const { + // We are rotating the difference between the translations + // using a clockwise rotation matrix. This transforms the global + // delta into a local delta (relative to the initial pose). + return Transform2d{(-Translation()).RotateBy(-Rotation()), -Rotation()}; + } /** * Multiplies the transform by the scalar. @@ -128,4 +138,3 @@ class WPILIB_DLLEXPORT Transform2d { #include "frc/geometry/proto/Transform2dProto.h" #endif #include "frc/geometry/struct/Transform2dStruct.h" -#include "frc/geometry/Transform2d.inc" diff --git a/wpimath/src/main/native/include/frc/geometry/Transform2d.inc b/wpimath/src/main/native/include/frc/geometry/Transform2d.inc deleted file mode 100644 index cc925148d45..00000000000 --- a/wpimath/src/main/native/include/frc/geometry/Transform2d.inc +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) FIRST and other WPILib contributors. -// Open Source Software; you can modify and/or share it under the terms of -// the WPILib BSD license file in the root directory of this project. - -#pragma once - -#include - -#include "frc/geometry/Rotation2d.h" -#include "frc/geometry/Transform2d.h" -#include "frc/geometry/Translation2d.h" - -namespace frc { - -constexpr Transform2d::Transform2d(Translation2d translation, - Rotation2d rotation) - : m_translation{std::move(translation)}, m_rotation{std::move(rotation)} {} - -constexpr Transform2d::Transform2d(units::meter_t x, units::meter_t y, - Rotation2d rotation) - : m_translation{x, y}, m_rotation{std::move(rotation)} {} - -constexpr Transform2d Transform2d::Inverse() const { - // We are rotating the difference between the translations - // using a clockwise rotation matrix. This transforms the global - // delta into a local delta (relative to the initial pose). - return Transform2d{(-Translation()).RotateBy(-Rotation()), -Rotation()}; -} - -} // namespace frc diff --git a/wpimath/src/main/native/include/frc/geometry/Translation2d.h b/wpimath/src/main/native/include/frc/geometry/Translation2d.h index 5c91e9d8fde..2a82a69c57f 100644 --- a/wpimath/src/main/native/include/frc/geometry/Translation2d.h +++ b/wpimath/src/main/native/include/frc/geometry/Translation2d.h @@ -39,7 +39,8 @@ class WPILIB_DLLEXPORT Translation2d { * @param x The x component of the translation. * @param y The y component of the translation. */ - constexpr Translation2d(units::meter_t x, units::meter_t y); + constexpr Translation2d(units::meter_t x, units::meter_t y) + : m_x{x}, m_y{y} {} /** * Constructs a Translation2d with the provided distance and angle. This is @@ -48,7 +49,8 @@ class WPILIB_DLLEXPORT Translation2d { * @param distance The distance from the origin to the end of the translation. * @param angle The angle between the x-axis and the translation vector. */ - constexpr Translation2d(units::meter_t distance, const Rotation2d& angle); + constexpr Translation2d(units::meter_t distance, const Rotation2d& angle) + : m_x{distance * angle.Cos()}, m_y{distance * angle.Sin()} {} /** * Constructs a Translation2d from the provided translation vector's X and Y @@ -90,7 +92,9 @@ class WPILIB_DLLEXPORT Translation2d { * * @return A Vector representation of this translation. */ - constexpr Eigen::Vector2d ToVector() const; + constexpr Eigen::Vector2d ToVector() const { + return Eigen::Vector2d{{m_x.value(), m_y.value()}}; + } /** * Returns the norm, or distance from the origin to the translation. @@ -104,7 +108,9 @@ class WPILIB_DLLEXPORT Translation2d { * * @return The angle of the translation */ - constexpr Rotation2d Angle() const; + constexpr Rotation2d Angle() const { + return Rotation2d{m_x.value(), m_y.value()}; + } /** * Applies a rotation to the translation in 2D space. @@ -124,7 +130,10 @@ class WPILIB_DLLEXPORT Translation2d { * * @return The new rotated translation. */ - constexpr Translation2d RotateBy(const Rotation2d& other) const; + constexpr Translation2d RotateBy(const Rotation2d& other) const { + return {m_x * other.Cos() - m_y * other.Sin(), + m_x * other.Sin() + m_y * other.Cos()}; + } /** * Rotates this translation around another translation in 2D space. @@ -139,7 +148,12 @@ class WPILIB_DLLEXPORT Translation2d { * @return The new rotated translation. */ constexpr Translation2d RotateAround(const Translation2d& other, - const Rotation2d& rot) const; + const Rotation2d& rot) const { + return {(m_x - other.X()) * rot.Cos() - (m_y - other.Y()) * rot.Sin() + + other.X(), + (m_x - other.X()) * rot.Sin() + (m_y - other.Y()) * rot.Cos() + + other.Y()}; + } /** * Returns the sum of two translations in 2D space. @@ -151,7 +165,9 @@ class WPILIB_DLLEXPORT Translation2d { * * @return The sum of the translations. */ - constexpr Translation2d operator+(const Translation2d& other) const; + constexpr Translation2d operator+(const Translation2d& other) const { + return {X() + other.X(), Y() + other.Y()}; + } /** * Returns the difference between two translations. @@ -163,7 +179,9 @@ class WPILIB_DLLEXPORT Translation2d { * * @return The difference between the two translations. */ - constexpr Translation2d operator-(const Translation2d& other) const; + constexpr Translation2d operator-(const Translation2d& other) const { + return *this + -other; + } /** * Returns the inverse of the current translation. This is equivalent to @@ -172,7 +190,7 @@ class WPILIB_DLLEXPORT Translation2d { * * @return The inverse of the current translation. */ - constexpr Translation2d operator-() const; + constexpr Translation2d operator-() const { return {-m_x, -m_y}; } /** * Returns the translation multiplied by a scalar. @@ -183,7 +201,9 @@ class WPILIB_DLLEXPORT Translation2d { * * @return The scaled translation. */ - constexpr Translation2d operator*(double scalar) const; + constexpr Translation2d operator*(double scalar) const { + return {scalar * m_x, scalar * m_y}; + } /** * Returns the translation divided by a scalar. @@ -194,7 +214,9 @@ class WPILIB_DLLEXPORT Translation2d { * * @return The scaled translation. */ - constexpr Translation2d operator/(double scalar) const; + constexpr Translation2d operator/(double scalar) const { + return operator*(1.0 / scalar); + } /** * Checks equality between this Translation2d and another object. @@ -202,7 +224,10 @@ class WPILIB_DLLEXPORT Translation2d { * @param other The other object. * @return Whether the two objects are equal. */ - constexpr bool operator==(const Translation2d& other) const; + constexpr bool operator==(const Translation2d& other) const { + return units::math::abs(m_x - other.m_x) < 1E-9_m && + units::math::abs(m_y - other.m_y) < 1E-9_m; + } /** * Returns the nearest Translation2d from a collection of translations @@ -236,4 +261,3 @@ void from_json(const wpi::json& json, Translation2d& state); #include "frc/geometry/proto/Translation2dProto.h" #endif #include "frc/geometry/struct/Translation2dStruct.h" -#include "frc/geometry/Translation2d.inc" diff --git a/wpimath/src/main/native/include/frc/geometry/Translation2d.inc b/wpimath/src/main/native/include/frc/geometry/Translation2d.inc deleted file mode 100644 index 2d1160abb4b..00000000000 --- a/wpimath/src/main/native/include/frc/geometry/Translation2d.inc +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright (c) FIRST and other WPILib contributors. -// Open Source Software; you can modify and/or share it under the terms of -// the WPILib BSD license file in the root directory of this project. - -#pragma once - -#include "frc/geometry/Translation2d.h" -#include "units/length.h" -#include "units/math.h" - -namespace frc { - -constexpr Translation2d::Translation2d(units::meter_t x, units::meter_t y) - : m_x{x}, m_y{y} {} - -constexpr Translation2d::Translation2d(units::meter_t distance, - const Rotation2d& angle) - : m_x{distance * angle.Cos()}, m_y{distance * angle.Sin()} {} - -constexpr Eigen::Vector2d Translation2d::ToVector() const { - return Eigen::Vector2d{{m_x.value(), m_y.value()}}; -} - -constexpr Rotation2d Translation2d::Angle() const { - return Rotation2d{m_x.value(), m_y.value()}; -} - -constexpr Translation2d Translation2d::RotateBy(const Rotation2d& other) const { - return {m_x * other.Cos() - m_y * other.Sin(), - m_x * other.Sin() + m_y * other.Cos()}; -} - -constexpr Translation2d Translation2d::RotateAround( - const Translation2d& other, const Rotation2d& rot) const { - return { - (m_x - other.X()) * rot.Cos() - (m_y - other.Y()) * rot.Sin() + other.X(), - (m_x - other.X()) * rot.Sin() + (m_y - other.Y()) * rot.Cos() + - other.Y()}; -} - -constexpr Translation2d Translation2d::operator+( - const Translation2d& other) const { - return {X() + other.X(), Y() + other.Y()}; -} - -constexpr Translation2d Translation2d::operator-( - const Translation2d& other) const { - return *this + -other; -} - -constexpr Translation2d Translation2d::operator-() const { - return {-m_x, -m_y}; -} - -constexpr Translation2d Translation2d::operator*(double scalar) const { - return {scalar * m_x, scalar * m_y}; -} - -constexpr Translation2d Translation2d::operator/(double scalar) const { - return operator*(1.0 / scalar); -} - -constexpr bool Translation2d::operator==(const Translation2d& other) const { - return units::math::abs(m_x - other.m_x) < 1E-9_m && - units::math::abs(m_y - other.m_y) < 1E-9_m; -} - -} // namespace frc diff --git a/wpimath/src/main/native/include/frc/geometry/Translation3d.h b/wpimath/src/main/native/include/frc/geometry/Translation3d.h index 75e5bf3dcf3..9aa3f61c31e 100644 --- a/wpimath/src/main/native/include/frc/geometry/Translation3d.h +++ b/wpimath/src/main/native/include/frc/geometry/Translation3d.h @@ -11,6 +11,7 @@ #include "frc/geometry/Rotation3d.h" #include "frc/geometry/Translation2d.h" #include "units/length.h" +#include "units/math.h" namespace frc { @@ -37,7 +38,8 @@ class WPILIB_DLLEXPORT Translation3d { * @param y The y component of the translation. * @param z The z component of the translation. */ - constexpr Translation3d(units::meter_t x, units::meter_t y, units::meter_t z); + constexpr Translation3d(units::meter_t x, units::meter_t y, units::meter_t z) + : m_x{x}, m_y{y}, m_z{z} {} /** * Constructs a Translation3d with the provided distance and angle. This is @@ -46,7 +48,12 @@ class WPILIB_DLLEXPORT Translation3d { * @param distance The distance from the origin to the end of the translation. * @param angle The angle between the x-axis and the translation vector. */ - Translation3d(units::meter_t distance, const Rotation3d& angle); + Translation3d(units::meter_t distance, const Rotation3d& angle) { + auto rectangular = Translation3d{distance, 0_m, 0_m}.RotateBy(angle); + m_x = rectangular.X(); + m_y = rectangular.Y(); + m_z = rectangular.Z(); + } /** * Constructs a Translation3d from the provided translation vector's X, Y, and @@ -54,7 +61,10 @@ class WPILIB_DLLEXPORT Translation3d { * * @param vector The translation vector to represent. */ - explicit Translation3d(const Eigen::Vector3d& vector); + explicit Translation3d(const Eigen::Vector3d& vector) + : m_x{units::meter_t{vector.x()}}, + m_y{units::meter_t{vector.y()}}, + m_z{units::meter_t{vector.z()}} {} /** * Calculates the distance between two translations in 3D space. @@ -66,7 +76,11 @@ class WPILIB_DLLEXPORT Translation3d { * * @return The distance between the two translations. */ - units::meter_t Distance(const Translation3d& other) const; + units::meter_t Distance(const Translation3d& other) const { + return units::math::sqrt(units::math::pow<2>(other.m_x - m_x) + + units::math::pow<2>(other.m_y - m_y) + + units::math::pow<2>(other.m_z - m_z)); + } /** * Returns the X component of the translation. @@ -94,14 +108,18 @@ class WPILIB_DLLEXPORT Translation3d { * * @return A Vector representation of this translation. */ - constexpr Eigen::Vector3d ToVector() const; + constexpr Eigen::Vector3d ToVector() const { + return Eigen::Vector3d{{m_x.value(), m_y.value(), m_z.value()}}; + } /** * Returns the norm, or distance from the origin to the translation. * * @return The norm of the translation. */ - units::meter_t Norm() const; + units::meter_t Norm() const { + return units::math::sqrt(m_x * m_x + m_y * m_y + m_z * m_z); + } /** * Applies a rotation to the translation in 3D space. @@ -113,13 +131,20 @@ class WPILIB_DLLEXPORT Translation3d { * * @return The new rotated translation. */ - Translation3d RotateBy(const Rotation3d& other) const; + Translation3d RotateBy(const Rotation3d& other) const { + Quaternion p{0.0, m_x.value(), m_y.value(), m_z.value()}; + auto qprime = other.GetQuaternion() * p * other.GetQuaternion().Inverse(); + return Translation3d{units::meter_t{qprime.X()}, units::meter_t{qprime.Y()}, + units::meter_t{qprime.Z()}}; + } /** * Returns a Translation2d representing this Translation3d projected into the * X-Y plane. */ - constexpr Translation2d ToTranslation2d() const; + constexpr Translation2d ToTranslation2d() const { + return Translation2d{m_x, m_y}; + } /** * Returns the sum of two translations in 3D space. @@ -131,7 +156,9 @@ class WPILIB_DLLEXPORT Translation3d { * * @return The sum of the translations. */ - constexpr Translation3d operator+(const Translation3d& other) const; + constexpr Translation3d operator+(const Translation3d& other) const { + return {X() + other.X(), Y() + other.Y(), Z() + other.Z()}; + } /** * Returns the difference between two translations. @@ -143,7 +170,9 @@ class WPILIB_DLLEXPORT Translation3d { * * @return The difference between the two translations. */ - constexpr Translation3d operator-(const Translation3d& other) const; + constexpr Translation3d operator-(const Translation3d& other) const { + return operator+(-other); + } /** * Returns the inverse of the current translation. This is equivalent to @@ -151,7 +180,7 @@ class WPILIB_DLLEXPORT Translation3d { * * @return The inverse of the current translation. */ - constexpr Translation3d operator-() const; + constexpr Translation3d operator-() const { return {-m_x, -m_y, -m_z}; } /** * Returns the translation multiplied by a scalar. @@ -163,7 +192,9 @@ class WPILIB_DLLEXPORT Translation3d { * * @return The scaled translation. */ - constexpr Translation3d operator*(double scalar) const; + constexpr Translation3d operator*(double scalar) const { + return {scalar * m_x, scalar * m_y, scalar * m_z}; + } /** * Returns the translation divided by a scalar. @@ -175,7 +206,9 @@ class WPILIB_DLLEXPORT Translation3d { * * @return The scaled translation. */ - constexpr Translation3d operator/(double scalar) const; + constexpr Translation3d operator/(double scalar) const { + return operator*(1.0 / scalar); + } /** * Checks equality between this Translation3d and another object. @@ -183,7 +216,11 @@ class WPILIB_DLLEXPORT Translation3d { * @param other The other object. * @return Whether the two objects are equal. */ - bool operator==(const Translation3d& other) const; + constexpr bool operator==(const Translation3d& other) const { + return units::math::abs(m_x - other.m_x) < 1E-9_m && + units::math::abs(m_y - other.m_y) < 1E-9_m && + units::math::abs(m_z - other.m_z) < 1E-9_m; + } private: units::meter_t m_x = 0_m; @@ -203,4 +240,3 @@ void from_json(const wpi::json& json, Translation3d& state); #include "frc/geometry/proto/Translation3dProto.h" #endif #include "frc/geometry/struct/Translation3dStruct.h" -#include "frc/geometry/Translation3d.inc" diff --git a/wpimath/src/main/native/include/frc/geometry/Translation3d.inc b/wpimath/src/main/native/include/frc/geometry/Translation3d.inc deleted file mode 100644 index 19268e6cee2..00000000000 --- a/wpimath/src/main/native/include/frc/geometry/Translation3d.inc +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright (c) FIRST and other WPILib contributors. -// Open Source Software; you can modify and/or share it under the terms of -// the WPILib BSD license file in the root directory of this project. - -#pragma once - -#include "frc/geometry/Translation2d.h" -#include "frc/geometry/Translation3d.h" -#include "units/length.h" -#include "units/math.h" - -namespace frc { - -constexpr Translation3d::Translation3d(units::meter_t x, units::meter_t y, - units::meter_t z) - : m_x{x}, m_y{y}, m_z{z} {} - -constexpr Translation2d Translation3d::ToTranslation2d() const { - return Translation2d{m_x, m_y}; -} - -constexpr Eigen::Vector3d Translation3d::ToVector() const { - return Eigen::Vector3d{{m_x.value(), m_y.value(), m_z.value()}}; -} - -constexpr Translation3d Translation3d::operator+( - const Translation3d& other) const { - return {X() + other.X(), Y() + other.Y(), Z() + other.Z()}; -} - -constexpr Translation3d Translation3d::operator-( - const Translation3d& other) const { - return operator+(-other); -} - -constexpr Translation3d Translation3d::operator-() const { - return {-m_x, -m_y, -m_z}; -} - -constexpr Translation3d Translation3d::operator*(double scalar) const { - return {scalar * m_x, scalar * m_y, scalar * m_z}; -} - -constexpr Translation3d Translation3d::operator/(double scalar) const { - return operator*(1.0 / scalar); -} - -} // namespace frc diff --git a/wpimath/src/main/native/include/frc/kinematics/Odometry.h b/wpimath/src/main/native/include/frc/kinematics/Odometry.h index e5a8e0f618e..c73eed22482 100644 --- a/wpimath/src/main/native/include/frc/kinematics/Odometry.h +++ b/wpimath/src/main/native/include/frc/kinematics/Odometry.h @@ -12,6 +12,7 @@ #include "frc/kinematics/Kinematics.h" namespace frc { + /** * Class for odometry. Robot code should not use this directly- Instead, use the * particular type for your drivetrain (e.g., DifferentialDriveOdometry). @@ -39,7 +40,13 @@ class WPILIB_DLLEXPORT Odometry { explicit Odometry(const Kinematics& kinematics, const Rotation2d& gyroAngle, const WheelPositions& wheelPositions, - const Pose2d& initialPose = Pose2d{}); + const Pose2d& initialPose = Pose2d{}) + : m_kinematics(kinematics), + m_pose(initialPose), + m_previousWheelPositions(wheelPositions) { + m_previousAngle = m_pose.Rotation(); + m_gyroOffset = m_pose.Rotation() - gyroAngle; + } /** * Resets the robot's position on the field. @@ -108,7 +115,21 @@ class WPILIB_DLLEXPORT Odometry { * @return The new pose of the robot. */ const Pose2d& Update(const Rotation2d& gyroAngle, - const WheelPositions& wheelPositions); + const WheelPositions& wheelPositions) { + auto angle = gyroAngle + m_gyroOffset; + + auto twist = + m_kinematics.ToTwist2d(m_previousWheelPositions, wheelPositions); + twist.dtheta = (angle - m_previousAngle).Radians(); + + auto newPose = m_pose.Exp(twist); + + m_previousAngle = angle; + m_previousWheelPositions = wheelPositions; + m_pose = {newPose.Translation(), angle}; + + return m_pose; + } private: const Kinematics& m_kinematics; @@ -118,6 +139,5 @@ class WPILIB_DLLEXPORT Odometry { Rotation2d m_previousAngle; Rotation2d m_gyroOffset; }; -} // namespace frc -#include "Odometry.inc" +} // namespace frc diff --git a/wpimath/src/main/native/include/frc/kinematics/Odometry.inc b/wpimath/src/main/native/include/frc/kinematics/Odometry.inc deleted file mode 100644 index 384a2ddebc6..00000000000 --- a/wpimath/src/main/native/include/frc/kinematics/Odometry.inc +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) FIRST and other WPILib contributors. -// Open Source Software; you can modify and/or share it under the terms of -// the WPILib BSD license file in the root directory of this project. - -#pragma once - -#include "frc/kinematics/Odometry.h" - -namespace frc { -template -Odometry::Odometry( - const Kinematics& kinematics, - const Rotation2d& gyroAngle, const WheelPositions& wheelPositions, - const Pose2d& initialPose) - : m_kinematics(kinematics), - m_pose(initialPose), - m_previousWheelPositions(wheelPositions) { - m_previousAngle = m_pose.Rotation(); - m_gyroOffset = m_pose.Rotation() - gyroAngle; -} - -template -const Pose2d& Odometry::Update( - const Rotation2d& gyroAngle, const WheelPositions& wheelPositions) { - auto angle = gyroAngle + m_gyroOffset; - - auto twist = m_kinematics.ToTwist2d(m_previousWheelPositions, wheelPositions); - twist.dtheta = (angle - m_previousAngle).Radians(); - - auto newPose = m_pose.Exp(twist); - - m_previousAngle = angle; - m_previousWheelPositions = wheelPositions; - m_pose = {newPose.Translation(), angle}; - - return m_pose; -} -} // namespace frc diff --git a/wpimath/src/main/native/include/frc/kinematics/SwerveDriveKinematics.h b/wpimath/src/main/native/include/frc/kinematics/SwerveDriveKinematics.h index 98eea68e306..aecc3306d55 100644 --- a/wpimath/src/main/native/include/frc/kinematics/SwerveDriveKinematics.h +++ b/wpimath/src/main/native/include/frc/kinematics/SwerveDriveKinematics.h @@ -4,6 +4,7 @@ #pragma once +#include #include #include @@ -19,6 +20,7 @@ #include "frc/kinematics/Kinematics.h" #include "frc/kinematics/SwerveModulePosition.h" #include "frc/kinematics/SwerveModuleState.h" +#include "units/math.h" #include "units/velocity.h" #include "wpimath/MathShared.h" @@ -116,7 +118,11 @@ class SwerveDriveKinematics * @param moduleHeadings The swerve module headings. The order of the module * headings should be same as passed into the constructor of this class. */ - void ResetHeadings(wpi::array moduleHeadings); + void ResetHeadings(wpi::array moduleHeadings) { + for (size_t i = 0; i < NumModules; i++) { + m_moduleHeadings[i] = moduleHeadings[i]; + } + } /** * Performs inverse kinematics to return the module states from a desired @@ -151,7 +157,52 @@ class SwerveDriveKinematics */ wpi::array ToSwerveModuleStates( const ChassisSpeeds& chassisSpeeds, - const Translation2d& centerOfRotation = Translation2d{}) const; + const Translation2d& centerOfRotation = Translation2d{}) const { + wpi::array moduleStates(wpi::empty_array); + + if (chassisSpeeds.vx == 0_mps && chassisSpeeds.vy == 0_mps && + chassisSpeeds.omega == 0_rad_per_s) { + for (size_t i = 0; i < NumModules; i++) { + moduleStates[i] = {0_mps, m_moduleHeadings[i]}; + } + + return moduleStates; + } + + // We have a new center of rotation. We need to compute the matrix again. + if (centerOfRotation != m_previousCoR) { + for (size_t i = 0; i < NumModules; i++) { + // clang-format off + m_inverseKinematics.template block<2, 3>(i * 2, 0) = + Matrixd<2, 3>{ + {1, 0, (-m_modules[i].Y() + centerOfRotation.Y()).value()}, + {0, 1, (+m_modules[i].X() - centerOfRotation.X()).value()}}; + // clang-format on + } + m_previousCoR = centerOfRotation; + } + + Eigen::Vector3d chassisSpeedsVector{chassisSpeeds.vx.value(), + chassisSpeeds.vy.value(), + chassisSpeeds.omega.value()}; + + Matrixd moduleStateMatrix = + m_inverseKinematics * chassisSpeedsVector; + + for (size_t i = 0; i < NumModules; i++) { + units::meters_per_second_t x{moduleStateMatrix(i * 2, 0)}; + units::meters_per_second_t y{moduleStateMatrix(i * 2 + 1, 0)}; + + auto speed = units::math::hypot(x, y); + auto rotation = speed > 1e-6_mps ? Rotation2d{x.value(), y.value()} + : m_moduleHeadings[i]; + + moduleStates[i] = {speed, rotation}; + m_moduleHeadings[i] = rotation; + } + + return moduleStates; + } wpi::array ToWheelSpeeds( const ChassisSpeeds& chassisSpeeds) const override { @@ -191,7 +242,23 @@ class SwerveDriveKinematics * @return The resulting chassis speed. */ ChassisSpeeds ToChassisSpeeds(const wpi::array& - moduleStates) const override; + moduleStates) const override { + Matrixd moduleStateMatrix; + + for (size_t i = 0; i < NumModules; ++i) { + SwerveModuleState module = moduleStates[i]; + moduleStateMatrix(i * 2, 0) = module.speed.value() * module.angle.Cos(); + moduleStateMatrix(i * 2 + 1, 0) = + module.speed.value() * module.angle.Sin(); + } + + Eigen::Vector3d chassisSpeedsVector = + m_forwardKinematics.solve(moduleStateMatrix); + + return {units::meters_per_second_t{chassisSpeedsVector(0)}, + units::meters_per_second_t{chassisSpeedsVector(1)}, + units::radians_per_second_t{chassisSpeedsVector(2)}}; + } /** * Performs forward kinematics to return the resulting Twist2d from the @@ -227,7 +294,24 @@ class SwerveDriveKinematics * @return The resulting Twist2d. */ Twist2d ToTwist2d( - wpi::array moduleDeltas) const; + wpi::array moduleDeltas) const { + Matrixd moduleDeltaMatrix; + + for (size_t i = 0; i < NumModules; ++i) { + SwerveModulePosition module = moduleDeltas[i]; + moduleDeltaMatrix(i * 2, 0) = + module.distance.value() * module.angle.Cos(); + moduleDeltaMatrix(i * 2 + 1, 0) = + module.distance.value() * module.angle.Sin(); + } + + Eigen::Vector3d chassisDeltaVector = + m_forwardKinematics.solve(moduleDeltaMatrix); + + return {units::meter_t{chassisDeltaVector(0)}, + units::meter_t{chassisDeltaVector(1)}, + units::radian_t{chassisDeltaVector(2)}}; + } Twist2d ToTwist2d( const wpi::array& start, @@ -257,7 +341,22 @@ class SwerveDriveKinematics */ static void DesaturateWheelSpeeds( wpi::array* moduleStates, - units::meters_per_second_t attainableMaxSpeed); + units::meters_per_second_t attainableMaxSpeed) { + auto& states = *moduleStates; + auto realMaxSpeed = + units::math::abs(std::max_element(states.begin(), states.end(), + [](const auto& a, const auto& b) { + return units::math::abs(a.speed) < + units::math::abs(b.speed); + }) + ->speed); + + if (realMaxSpeed > attainableMaxSpeed) { + for (auto& module : states) { + module.speed = module.speed / realMaxSpeed * attainableMaxSpeed; + } + } + } /** * Renormalizes the wheel speeds if any individual speed is above the @@ -285,7 +384,38 @@ class SwerveDriveKinematics ChassisSpeeds desiredChassisSpeed, units::meters_per_second_t attainableMaxModuleSpeed, units::meters_per_second_t attainableMaxRobotTranslationSpeed, - units::radians_per_second_t attainableMaxRobotRotationSpeed); + units::radians_per_second_t attainableMaxRobotRotationSpeed) { + auto& states = *moduleStates; + + auto realMaxSpeed = + units::math::abs(std::max_element(states.begin(), states.end(), + [](const auto& a, const auto& b) { + return units::math::abs(a.speed) < + units::math::abs(b.speed); + }) + ->speed); + + if (attainableMaxRobotTranslationSpeed == 0_mps || + attainableMaxRobotRotationSpeed == 0_rad_per_s || + realMaxSpeed == 0_mps) { + return; + } + + auto translationalK = + units::math::hypot(desiredChassisSpeed.vx, desiredChassisSpeed.vy) / + attainableMaxRobotTranslationSpeed; + + auto rotationalK = units::math::abs(desiredChassisSpeed.omega) / + attainableMaxRobotRotationSpeed; + + auto k = units::math::max(translationalK, rotationalK); + + auto scale = units::math::min(k * attainableMaxModuleSpeed / realMaxSpeed, + units::scalar_t{1}); + for (auto& module : states) { + module.speed = module.speed * scale; + } + } wpi::array Interpolate( const wpi::array& start, @@ -312,9 +442,11 @@ class SwerveDriveKinematics mutable Translation2d m_previousCoR; }; +template +SwerveDriveKinematics(ModuleTranslation, ModuleTranslations...) + -> SwerveDriveKinematics<1 + sizeof...(ModuleTranslations)>; + extern template class EXPORT_TEMPLATE_DECLARE(WPILIB_DLLEXPORT) SwerveDriveKinematics<4>; } // namespace frc - -#include "SwerveDriveKinematics.inc" diff --git a/wpimath/src/main/native/include/frc/kinematics/SwerveDriveKinematics.inc b/wpimath/src/main/native/include/frc/kinematics/SwerveDriveKinematics.inc deleted file mode 100644 index 8c5e18b3592..00000000000 --- a/wpimath/src/main/native/include/frc/kinematics/SwerveDriveKinematics.inc +++ /dev/null @@ -1,176 +0,0 @@ -// Copyright (c) FIRST and other WPILib contributors. -// Open Source Software; you can modify and/or share it under the terms of -// the WPILib BSD license file in the root directory of this project. - -#pragma once - -#include -#include - -#include "frc/kinematics/ChassisSpeeds.h" -#include "frc/kinematics/SwerveDriveKinematics.h" -#include "units/math.h" - -namespace frc { - -template -SwerveDriveKinematics(ModuleTranslation, ModuleTranslations...) - -> SwerveDriveKinematics<1 + sizeof...(ModuleTranslations)>; - -template -void SwerveDriveKinematics::ResetHeadings( - wpi::array moduleHeadings) { - for (size_t i = 0; i < NumModules; i++) { - m_moduleHeadings[i] = moduleHeadings[i]; - } -} - -template -wpi::array -SwerveDriveKinematics::ToSwerveModuleStates( - const ChassisSpeeds& chassisSpeeds, - const Translation2d& centerOfRotation) const { - wpi::array moduleStates(wpi::empty_array); - - if (chassisSpeeds.vx == 0_mps && chassisSpeeds.vy == 0_mps && - chassisSpeeds.omega == 0_rad_per_s) { - for (size_t i = 0; i < NumModules; i++) { - moduleStates[i] = {0_mps, m_moduleHeadings[i]}; - } - - return moduleStates; - } - - // We have a new center of rotation. We need to compute the matrix again. - if (centerOfRotation != m_previousCoR) { - for (size_t i = 0; i < NumModules; i++) { - // clang-format off - m_inverseKinematics.template block<2, 3>(i * 2, 0) = - Matrixd<2, 3>{ - {1, 0, (-m_modules[i].Y() + centerOfRotation.Y()).value()}, - {0, 1, (+m_modules[i].X() - centerOfRotation.X()).value()}}; - // clang-format on - } - m_previousCoR = centerOfRotation; - } - - Eigen::Vector3d chassisSpeedsVector{chassisSpeeds.vx.value(), - chassisSpeeds.vy.value(), - chassisSpeeds.omega.value()}; - - Matrixd moduleStateMatrix = - m_inverseKinematics * chassisSpeedsVector; - - for (size_t i = 0; i < NumModules; i++) { - units::meters_per_second_t x{moduleStateMatrix(i * 2, 0)}; - units::meters_per_second_t y{moduleStateMatrix(i * 2 + 1, 0)}; - - auto speed = units::math::hypot(x, y); - auto rotation = speed > 1e-6_mps ? Rotation2d{x.value(), y.value()} - : m_moduleHeadings[i]; - - moduleStates[i] = {speed, rotation}; - m_moduleHeadings[i] = rotation; - } - - return moduleStates; -} - -template -ChassisSpeeds SwerveDriveKinematics::ToChassisSpeeds( - const wpi::array& moduleStates) const { - Matrixd moduleStateMatrix; - - for (size_t i = 0; i < NumModules; ++i) { - SwerveModuleState module = moduleStates[i]; - moduleStateMatrix(i * 2, 0) = module.speed.value() * module.angle.Cos(); - moduleStateMatrix(i * 2 + 1, 0) = module.speed.value() * module.angle.Sin(); - } - - Eigen::Vector3d chassisSpeedsVector = - m_forwardKinematics.solve(moduleStateMatrix); - - return {units::meters_per_second_t{chassisSpeedsVector(0)}, - units::meters_per_second_t{chassisSpeedsVector(1)}, - units::radians_per_second_t{chassisSpeedsVector(2)}}; -} - -template -Twist2d SwerveDriveKinematics::ToTwist2d( - wpi::array moduleDeltas) const { - Matrixd moduleDeltaMatrix; - - for (size_t i = 0; i < NumModules; ++i) { - SwerveModulePosition module = moduleDeltas[i]; - moduleDeltaMatrix(i * 2, 0) = module.distance.value() * module.angle.Cos(); - moduleDeltaMatrix(i * 2 + 1, 0) = - module.distance.value() * module.angle.Sin(); - } - - Eigen::Vector3d chassisDeltaVector = - m_forwardKinematics.solve(moduleDeltaMatrix); - - return {units::meter_t{chassisDeltaVector(0)}, - units::meter_t{chassisDeltaVector(1)}, - units::radian_t{chassisDeltaVector(2)}}; -} - -template -void SwerveDriveKinematics::DesaturateWheelSpeeds( - wpi::array* moduleStates, - units::meters_per_second_t attainableMaxSpeed) { - auto& states = *moduleStates; - auto realMaxSpeed = - units::math::abs(std::max_element(states.begin(), states.end(), - [](const auto& a, const auto& b) { - return units::math::abs(a.speed) < - units::math::abs(b.speed); - }) - ->speed); - - if (realMaxSpeed > attainableMaxSpeed) { - for (auto& module : states) { - module.speed = module.speed / realMaxSpeed * attainableMaxSpeed; - } - } -} - -template -void SwerveDriveKinematics::DesaturateWheelSpeeds( - wpi::array* moduleStates, - ChassisSpeeds desiredChassisSpeed, - units::meters_per_second_t attainableMaxModuleSpeed, - units::meters_per_second_t attainableMaxRobotTranslationSpeed, - units::radians_per_second_t attainableMaxRobotRotationSpeed) { - auto& states = *moduleStates; - - auto realMaxSpeed = - units::math::abs(std::max_element(states.begin(), states.end(), - [](const auto& a, const auto& b) { - return units::math::abs(a.speed) < - units::math::abs(b.speed); - }) - ->speed); - - if (attainableMaxRobotTranslationSpeed == 0_mps || - attainableMaxRobotRotationSpeed == 0_rad_per_s || realMaxSpeed == 0_mps) { - return; - } - - auto translationalK = - units::math::hypot(desiredChassisSpeed.vx, desiredChassisSpeed.vy) / - attainableMaxRobotTranslationSpeed; - - auto rotationalK = units::math::abs(desiredChassisSpeed.omega) / - attainableMaxRobotRotationSpeed; - - auto k = units::math::max(translationalK, rotationalK); - - auto scale = units::math::min(k * attainableMaxModuleSpeed / realMaxSpeed, - units::scalar_t{1}); - for (auto& module : states) { - module.speed = module.speed * scale; - } -} - -} // namespace frc diff --git a/wpimath/src/main/native/include/frc/kinematics/SwerveDriveOdometry.h b/wpimath/src/main/native/include/frc/kinematics/SwerveDriveOdometry.h index 7c66977bb98..c00acd2e0da 100644 --- a/wpimath/src/main/native/include/frc/kinematics/SwerveDriveOdometry.h +++ b/wpimath/src/main/native/include/frc/kinematics/SwerveDriveOdometry.h @@ -4,7 +4,6 @@ #pragma once -#include #include #include @@ -16,7 +15,7 @@ #include "SwerveModulePosition.h" #include "SwerveModuleState.h" #include "frc/geometry/Pose2d.h" -#include "units/time.h" +#include "wpimath/MathShared.h" namespace frc { @@ -45,7 +44,14 @@ class SwerveDriveOdometry SwerveDriveOdometry( SwerveDriveKinematics kinematics, const Rotation2d& gyroAngle, const wpi::array& modulePositions, - const Pose2d& initialPose = Pose2d{}); + const Pose2d& initialPose = Pose2d{}) + : Odometry, + wpi::array>( + m_kinematicsImpl, gyroAngle, modulePositions, initialPose), + m_kinematicsImpl(kinematics) { + wpi::math::MathSharedStore::ReportUsage( + wpi::math::MathUsageId::kOdometry_SwerveDrive, 1); + } private: SwerveDriveKinematics m_kinematicsImpl; @@ -55,5 +61,3 @@ extern template class EXPORT_TEMPLATE_DECLARE(WPILIB_DLLEXPORT) SwerveDriveOdometry<4>; } // namespace frc - -#include "SwerveDriveOdometry.inc" diff --git a/wpimath/src/main/native/include/frc/kinematics/SwerveDriveOdometry.inc b/wpimath/src/main/native/include/frc/kinematics/SwerveDriveOdometry.inc deleted file mode 100644 index 2b047c21944..00000000000 --- a/wpimath/src/main/native/include/frc/kinematics/SwerveDriveOdometry.inc +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (c) FIRST and other WPILib contributors. -// Open Source Software; you can modify and/or share it under the terms of -// the WPILib BSD license file in the root directory of this project. - -#pragma once - -#include "frc/kinematics/SwerveDriveOdometry.h" -#include "wpimath/MathShared.h" - -namespace frc { -template -SwerveDriveOdometry::SwerveDriveOdometry( - SwerveDriveKinematics kinematics, const Rotation2d& gyroAngle, - const wpi::array& modulePositions, - const Pose2d& initialPose) - : Odometry, - wpi::array>( - m_kinematicsImpl, gyroAngle, modulePositions, initialPose), - m_kinematicsImpl(kinematics) { - wpi::math::MathSharedStore::ReportUsage( - wpi::math::MathUsageId::kOdometry_SwerveDrive, 1); -} -} // namespace frc diff --git a/wpimath/src/main/native/include/frc/kinematics/proto/SwerveDriveKinematicsProto.h b/wpimath/src/main/native/include/frc/kinematics/proto/SwerveDriveKinematicsProto.h index b84343459d9..d9fc19a2004 100644 --- a/wpimath/src/main/native/include/frc/kinematics/proto/SwerveDriveKinematicsProto.h +++ b/wpimath/src/main/native/include/frc/kinematics/proto/SwerveDriveKinematicsProto.h @@ -4,17 +4,39 @@ #pragma once +#include + +#include +#include #include #include "frc/kinematics/SwerveDriveKinematics.h" +#include "kinematics.pb.h" template struct wpi::Protobuf> { - static google::protobuf::Message* New(google::protobuf::Arena* arena); + static google::protobuf::Message* New(google::protobuf::Arena* arena) { + return wpi::CreateMessage(arena); + } + static frc::SwerveDriveKinematics Unpack( - const google::protobuf::Message& msg); + const google::protobuf::Message& msg) { + auto m = + static_cast(&msg); + if (m->modules_size() != NumModules) { + throw std::invalid_argument(fmt::format( + "Tried to unpack message with {} elements in modules into " + "SwerveDriveKinematics with {} modules", + m->modules_size(), NumModules)); + } + return frc::SwerveDriveKinematics{ + wpi::UnpackProtobufArray(m->modules())}; + } + static void Pack(google::protobuf::Message* msg, - const frc::SwerveDriveKinematics& value); + const frc::SwerveDriveKinematics& value) { + auto m = static_cast(msg); + wpi::PackProtobufArray(m->mutable_modules(), value.GetModules()); + } }; - -#include "frc/kinematics/proto/SwerveDriveKinematicsProto.inc" diff --git a/wpimath/src/main/native/include/frc/kinematics/proto/SwerveDriveKinematicsProto.inc b/wpimath/src/main/native/include/frc/kinematics/proto/SwerveDriveKinematicsProto.inc deleted file mode 100644 index 0a9925481d9..00000000000 --- a/wpimath/src/main/native/include/frc/kinematics/proto/SwerveDriveKinematicsProto.inc +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) FIRST and other WPILib contributors. -// Open Source Software; you can modify and/or share it under the terms of -// the WPILib BSD license file in the root directory of this project. - -#pragma once - -#include - -#include -#include - -#include "frc/kinematics/proto/SwerveDriveKinematicsProto.h" -#include "kinematics.pb.h" - -template -google::protobuf::Message* -wpi::Protobuf>::New( - google::protobuf::Arena* arena) { - return wpi::CreateMessage(arena); -} - -template -frc::SwerveDriveKinematics -wpi::Protobuf>::Unpack( - const google::protobuf::Message& msg) { - auto m = static_cast(&msg); - if (m->modules_size() != NumModules) { - throw std::invalid_argument( - fmt::format("Tried to unpack message with {} elements in modules into " - "SwerveDriveKinematics with {} modules", - m->modules_size(), NumModules)); - } - return frc::SwerveDriveKinematics{ - wpi::UnpackProtobufArray(m->modules())}; -} - -template -void wpi::Protobuf>::Pack( - google::protobuf::Message* msg, - const frc::SwerveDriveKinematics& value) { - auto m = static_cast(msg); - wpi::PackProtobufArray(m->mutable_modules(), value.GetModules()); -} diff --git a/wpimath/src/main/native/include/frc/kinematics/struct/SwerveDriveKinematicsStruct.h b/wpimath/src/main/native/include/frc/kinematics/struct/SwerveDriveKinematicsStruct.h index 139b78dbdc0..b351fab26c7 100644 --- a/wpimath/src/main/native/include/frc/kinematics/struct/SwerveDriveKinematicsStruct.h +++ b/wpimath/src/main/native/include/frc/kinematics/struct/SwerveDriveKinematicsStruct.h @@ -24,9 +24,19 @@ struct wpi::Struct> { static constexpr std::string_view GetSchema() { return kSchema; } static frc::SwerveDriveKinematics Unpack( - std::span data); + std::span data) { + constexpr size_t kModulesOff = 0; + return frc::SwerveDriveKinematics{ + wpi::UnpackStructArray( + data)}; + } + static void Pack(std::span data, - const frc::SwerveDriveKinematics& value); + const frc::SwerveDriveKinematics& value) { + constexpr size_t kModulesOff = 0; + wpi::PackStructArray(data, value.GetModules()); + } + static void ForEachNested( std::invocable auto fn) { wpi::ForEachStructSchema(fn); @@ -37,5 +47,3 @@ static_assert(wpi::StructSerializable>); static_assert(wpi::HasNestedStruct>); static_assert(wpi::StructSerializable>); static_assert(wpi::HasNestedStruct>); - -#include "frc/kinematics/struct/SwerveDriveKinematicsStruct.inc" diff --git a/wpimath/src/main/native/include/frc/kinematics/struct/SwerveDriveKinematicsStruct.inc b/wpimath/src/main/native/include/frc/kinematics/struct/SwerveDriveKinematicsStruct.inc deleted file mode 100644 index 5e4dee13eeb..00000000000 --- a/wpimath/src/main/native/include/frc/kinematics/struct/SwerveDriveKinematicsStruct.inc +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) FIRST and other WPILib contributors. -// Open Source Software; you can modify and/or share it under the terms of -// the WPILib BSD license file in the root directory of this project. - -#pragma once - -#include "frc/kinematics/struct/SwerveDriveKinematicsStruct.h" - -template -frc::SwerveDriveKinematics -wpi::Struct>::Unpack( - std::span data) { - constexpr size_t kModulesOff = 0; - return frc::SwerveDriveKinematics{ - wpi::UnpackStructArray( - data)}; -} - -template -void wpi::Struct>::Pack( - std::span data, - const frc::SwerveDriveKinematics& value) { - constexpr size_t kModulesOff = 0; - wpi::PackStructArray(data, value.GetModules()); -} diff --git a/wpimath/src/main/native/include/frc/proto/MatrixProto.h b/wpimath/src/main/native/include/frc/proto/MatrixProto.h index 1440c513c6e..8dc9285b3cb 100644 --- a/wpimath/src/main/native/include/frc/proto/MatrixProto.h +++ b/wpimath/src/main/native/include/frc/proto/MatrixProto.h @@ -4,19 +4,53 @@ #pragma once +#include + +#include +#include #include #include "frc/EigenCore.h" +#include "wpimath.pb.h" template requires(Cols != 1) struct wpi::Protobuf> { - static google::protobuf::Message* New(google::protobuf::Arena* arena); + static google::protobuf::Message* New(google::protobuf::Arena* arena) { + return wpi::CreateMessage(arena); + } + static frc::Matrixd Unpack( - const google::protobuf::Message& msg); + const google::protobuf::Message& msg) { + auto m = static_cast(&msg); + if (m->num_rows() != Rows || m->num_cols() != Cols) { + throw std::invalid_argument(fmt::format( + "Tried to unpack message with {} rows and {} columns into " + "Matrix with {} rows and {} columns", + m->num_rows(), m->num_cols(), Rows, Cols)); + } + if (m->data_size() != Rows * Cols) { + throw std::invalid_argument( + fmt::format("Tried to unpack message with {} elements in data into " + "Matrix with {} elements", + m->data_size(), Rows * Cols)); + } + frc::Matrixd mat; + for (int i = 0; i < Rows * Cols; i++) { + mat(i) = m->data(i); + } + return mat; + } + static void Pack( google::protobuf::Message* msg, - const frc::Matrixd& value); + const frc::Matrixd& value) { + auto m = static_cast(msg); + m->set_num_rows(Rows); + m->set_num_cols(Cols); + m->clear_data(); + for (int i = 0; i < Rows * Cols; i++) { + m->add_data(value(i)); + } + } }; - -#include "frc/proto/MatrixProto.inc" diff --git a/wpimath/src/main/native/include/frc/proto/MatrixProto.inc b/wpimath/src/main/native/include/frc/proto/MatrixProto.inc deleted file mode 100644 index 75370b64217..00000000000 --- a/wpimath/src/main/native/include/frc/proto/MatrixProto.inc +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (c) FIRST and other WPILib contributors. -// Open Source Software; you can modify and/or share it under the terms of -// the WPILib BSD license file in the root directory of this project. - -#pragma once - -#include - -#include -#include - -#include "frc/proto/MatrixProto.h" -#include "wpimath.pb.h" - -template - requires(Cols != 1) -google::protobuf::Message* -wpi::Protobuf>::New( - google::protobuf::Arena* arena) { - return wpi::CreateMessage(arena); -} - -template - requires(Cols != 1) -frc::Matrixd -wpi::Protobuf>::Unpack( - const google::protobuf::Message& msg) { - auto m = static_cast(&msg); - if (m->num_rows() != Rows || m->num_cols() != Cols) { - throw std::invalid_argument( - fmt::format("Tried to unpack message with {} rows and {} columns into " - "Matrix with {} rows and {} columns", - m->num_rows(), m->num_cols(), Rows, Cols)); - } - if (m->data_size() != Rows * Cols) { - throw std::invalid_argument( - fmt::format("Tried to unpack message with {} elements in data into " - "Matrix with {} elements", - m->data_size(), Rows * Cols)); - } - frc::Matrixd mat; - for (int i = 0; i < Rows * Cols; i++) { - mat(i) = m->data(i); - } - return mat; -} - -template - requires(Cols != 1) -void wpi::Protobuf>::Pack( - google::protobuf::Message* msg, - const frc::Matrixd& value) { - auto m = static_cast(msg); - m->set_num_rows(Rows); - m->set_num_cols(Cols); - m->clear_data(); - for (int i = 0; i < Rows * Cols; i++) { - m->add_data(value(i)); - } -} diff --git a/wpimath/src/main/native/include/frc/proto/VectorProto.h b/wpimath/src/main/native/include/frc/proto/VectorProto.h index 654f31c0882..a982e19a70a 100644 --- a/wpimath/src/main/native/include/frc/proto/VectorProto.h +++ b/wpimath/src/main/native/include/frc/proto/VectorProto.h @@ -4,18 +4,44 @@ #pragma once +#include + +#include +#include #include #include "frc/EigenCore.h" +#include "wpimath.pb.h" template struct wpi::Protobuf> { - static google::protobuf::Message* New(google::protobuf::Arena* arena); + static google::protobuf::Message* New(google::protobuf::Arena* arena) { + return wpi::CreateMessage(arena); + } + static frc::Matrixd Unpack( - const google::protobuf::Message& msg); + const google::protobuf::Message& msg) { + auto m = static_cast(&msg); + if (m->rows_size() != Size) { + throw std::invalid_argument( + fmt::format("Tried to unpack message with {} elements in rows into " + "Vector with {} rows", + m->rows_size(), Size)); + } + frc::Matrixd vec; + for (int i = 0; i < Size; i++) { + vec(i) = m->rows(i); + } + return vec; + } + static void Pack( google::protobuf::Message* msg, - const frc::Matrixd& value); + const frc::Matrixd& value) { + auto m = static_cast(msg); + m->clear_rows(); + for (int i = 0; i < Size; i++) { + m->add_rows(value(i)); + } + } }; - -#include "frc/proto/VectorProto.inc" diff --git a/wpimath/src/main/native/include/frc/proto/VectorProto.inc b/wpimath/src/main/native/include/frc/proto/VectorProto.inc deleted file mode 100644 index f07d7dfa77d..00000000000 --- a/wpimath/src/main/native/include/frc/proto/VectorProto.inc +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (c) FIRST and other WPILib contributors. -// Open Source Software; you can modify and/or share it under the terms of -// the WPILib BSD license file in the root directory of this project. - -#pragma once - -#include - -#include -#include - -#include "frc/proto/VectorProto.h" -#include "wpimath.pb.h" - -template -google::protobuf::Message* -wpi::Protobuf>::New( - google::protobuf::Arena* arena) { - return wpi::CreateMessage(arena); -} - -template -frc::Matrixd -wpi::Protobuf>::Unpack( - const google::protobuf::Message& msg) { - auto m = static_cast(&msg); - if (m->rows_size() != Size) { - throw std::invalid_argument( - fmt::format("Tried to unpack message with {} elements in rows into " - "Vector with {} rows", - m->rows_size(), Size)); - } - frc::Matrixd vec; - for (int i = 0; i < Size; i++) { - vec(i) = m->rows(i); - } - return vec; -} - -template -void wpi::Protobuf>::Pack( - google::protobuf::Message* msg, - const frc::Matrixd& value) { - auto m = static_cast(msg); - m->clear_rows(); - for (int i = 0; i < Size; i++) { - m->add_rows(value(i)); - } -} diff --git a/wpimath/src/main/native/include/frc/struct/MatrixStruct.h b/wpimath/src/main/native/include/frc/struct/MatrixStruct.h index 926109c852e..59bc4b57121 100644 --- a/wpimath/src/main/native/include/frc/struct/MatrixStruct.h +++ b/wpimath/src/main/native/include/frc/struct/MatrixStruct.h @@ -24,13 +24,28 @@ struct wpi::Struct> { static constexpr std::string_view GetSchema() { return kSchema; } static frc::Matrixd Unpack( - std::span data); + std::span data) { + constexpr size_t kDataOff = 0; + wpi::array mat_data = + wpi::UnpackStructArray(data); + frc::Matrixd mat; + for (int i = 0; i < Rows * Cols; i++) { + mat(i) = mat_data[i]; + } + return mat; + } + static void Pack( std::span data, - const frc::Matrixd& value); + const frc::Matrixd& value) { + constexpr size_t kDataOff = 0; + wpi::array mat_data(wpi::empty_array); + for (int i = 0; i < Rows * Cols; i++) { + mat_data[i] = value(i); + } + wpi::PackStructArray(data, mat_data); + } }; static_assert(wpi::StructSerializable>); static_assert(wpi::StructSerializable>); - -#include "frc/struct/MatrixStruct.inc" diff --git a/wpimath/src/main/native/include/frc/struct/MatrixStruct.inc b/wpimath/src/main/native/include/frc/struct/MatrixStruct.inc deleted file mode 100644 index 4efda8a41de..00000000000 --- a/wpimath/src/main/native/include/frc/struct/MatrixStruct.inc +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) FIRST and other WPILib contributors. -// Open Source Software; you can modify and/or share it under the terms of -// the WPILib BSD license file in the root directory of this project. - -#pragma once - -#include "frc/struct/MatrixStruct.h" - -template - requires(Cols != 1) -frc::Matrixd -wpi::Struct>::Unpack( - std::span data) { - constexpr size_t kDataOff = 0; - wpi::array mat_data = - wpi::UnpackStructArray(data); - frc::Matrixd mat; - for (int i = 0; i < Rows * Cols; i++) { - mat(i) = mat_data[i]; - } - return mat; -} - -template - requires(Cols != 1) -void wpi::Struct>::Pack( - std::span data, - const frc::Matrixd& value) { - constexpr size_t kDataOff = 0; - wpi::array mat_data(wpi::empty_array); - for (int i = 0; i < Rows * Cols; i++) { - mat_data[i] = value(i); - } - wpi::PackStructArray(data, mat_data); -} diff --git a/wpimath/src/main/native/include/frc/struct/VectorStruct.h b/wpimath/src/main/native/include/frc/struct/VectorStruct.h index d10d48ebe52..ec480240c34 100644 --- a/wpimath/src/main/native/include/frc/struct/VectorStruct.h +++ b/wpimath/src/main/native/include/frc/struct/VectorStruct.h @@ -21,13 +21,28 @@ struct wpi::Struct> { static constexpr std::string_view GetSchema() { return kSchema; } static frc::Matrixd Unpack( - std::span data); + std::span data) { + constexpr size_t kDataOff = 0; + wpi::array vec_data = + wpi::UnpackStructArray(data); + frc::Matrixd vec; + for (int i = 0; i < Size; i++) { + vec(i) = vec_data[i]; + } + return vec; + } + static void Pack( std::span data, - const frc::Matrixd& value); + const frc::Matrixd& value) { + constexpr size_t kDataOff = 0; + wpi::array vec_data(wpi::empty_array); + for (int i = 0; i < Size; i++) { + vec_data[i] = value(i); + } + wpi::PackStructArray(data, vec_data); + } }; static_assert(wpi::StructSerializable>); static_assert(wpi::StructSerializable>); - -#include "frc/struct/VectorStruct.inc" diff --git a/wpimath/src/main/native/include/frc/struct/VectorStruct.inc b/wpimath/src/main/native/include/frc/struct/VectorStruct.inc deleted file mode 100644 index 4fcfbb009dd..00000000000 --- a/wpimath/src/main/native/include/frc/struct/VectorStruct.inc +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) FIRST and other WPILib contributors. -// Open Source Software; you can modify and/or share it under the terms of -// the WPILib BSD license file in the root directory of this project. - -#pragma once - -#include "frc/struct/VectorStruct.h" - -template -frc::Matrixd -wpi::Struct>::Unpack( - std::span data) { - constexpr size_t kDataOff = 0; - wpi::array vec_data = - wpi::UnpackStructArray(data); - frc::Matrixd vec; - for (int i = 0; i < Size; i++) { - vec(i) = vec_data[i]; - } - return vec; -} - -template -void wpi::Struct>::Pack( - std::span data, - const frc::Matrixd& value) { - constexpr size_t kDataOff = 0; - wpi::array vec_data(wpi::empty_array); - for (int i = 0; i < Size; i++) { - vec_data[i] = value(i); - } - wpi::PackStructArray(data, vec_data); -} diff --git a/wpimath/src/main/native/include/frc/system/proto/LinearSystemProto.h b/wpimath/src/main/native/include/frc/system/proto/LinearSystemProto.h index f47f63e67c0..36eea0802e8 100644 --- a/wpimath/src/main/native/include/frc/system/proto/LinearSystemProto.h +++ b/wpimath/src/main/native/include/frc/system/proto/LinearSystemProto.h @@ -4,17 +4,49 @@ #pragma once +#include + +#include +#include #include +#include "frc/proto/MatrixProto.h" #include "frc/system/LinearSystem.h" +#include "system.pb.h" template struct wpi::Protobuf> { - static google::protobuf::Message* New(google::protobuf::Arena* arena); + static google::protobuf::Message* New(google::protobuf::Arena* arena) { + return wpi::CreateMessage(arena); + } + static frc::LinearSystem Unpack( - const google::protobuf::Message& msg); + const google::protobuf::Message& msg) { + auto m = static_cast(&msg); + if (m->num_states() != States || m->num_inputs() != Inputs || + m->num_outputs() != Outputs) { + throw std::invalid_argument(fmt::format( + "Tried to unpack message with {} states and {} inputs and {} outputs " + "into LinearSystem with {} states and {} inputs and {} outputs", + m->num_states(), m->num_inputs(), m->num_outputs(), States, Inputs, + Outputs)); + } + return frc::LinearSystem{ + wpi::UnpackProtobuf>(m->wpi_a()), + wpi::UnpackProtobuf>(m->wpi_b()), + wpi::UnpackProtobuf>(m->wpi_c()), + wpi::UnpackProtobuf>(m->wpi_d())}; + } + static void Pack(google::protobuf::Message* msg, - const frc::LinearSystem& value); + const frc::LinearSystem& value) { + auto m = static_cast(msg); + m->set_num_states(States); + m->set_num_inputs(Inputs); + m->set_num_outputs(Outputs); + wpi::PackProtobuf(m->mutable_a(), value.A()); + wpi::PackProtobuf(m->mutable_b(), value.B()); + wpi::PackProtobuf(m->mutable_c(), value.C()); + wpi::PackProtobuf(m->mutable_d(), value.D()); + } }; - -#include "frc/system/proto/LinearSystemProto.inc" diff --git a/wpimath/src/main/native/include/frc/system/proto/LinearSystemProto.inc b/wpimath/src/main/native/include/frc/system/proto/LinearSystemProto.inc deleted file mode 100644 index 5151f4e17b8..00000000000 --- a/wpimath/src/main/native/include/frc/system/proto/LinearSystemProto.inc +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright (c) FIRST and other WPILib contributors. -// Open Source Software; you can modify and/or share it under the terms of -// the WPILib BSD license file in the root directory of this project. - -#pragma once - -#include - -#include -#include - -#include "frc/proto/MatrixProto.h" -#include "frc/system/proto/LinearSystemProto.h" -#include "system.pb.h" - -template -google::protobuf::Message* -wpi::Protobuf>::New( - google::protobuf::Arena* arena) { - return wpi::CreateMessage(arena); -} - -template -frc::LinearSystem -wpi::Protobuf>::Unpack( - const google::protobuf::Message& msg) { - auto m = static_cast(&msg); - if (m->num_states() != States || m->num_inputs() != Inputs || - m->num_outputs() != Outputs) { - throw std::invalid_argument(fmt::format( - "Tried to unpack message with {} states and {} inputs and {} outputs " - "into LinearSystem with {} states and {} inputs and {} outputs", - m->num_states(), m->num_inputs(), m->num_outputs(), States, Inputs, - Outputs)); - } - return frc::LinearSystem{ - wpi::UnpackProtobuf>(m->wpi_a()), - wpi::UnpackProtobuf>(m->wpi_b()), - wpi::UnpackProtobuf>(m->wpi_c()), - wpi::UnpackProtobuf>(m->wpi_d())}; -} - -template -void wpi::Protobuf>::Pack( - google::protobuf::Message* msg, - const frc::LinearSystem& value) { - auto m = static_cast(msg); - m->set_num_states(States); - m->set_num_inputs(Inputs); - m->set_num_outputs(Outputs); - wpi::PackProtobuf(m->mutable_a(), value.A()); - wpi::PackProtobuf(m->mutable_b(), value.B()); - wpi::PackProtobuf(m->mutable_c(), value.C()); - wpi::PackProtobuf(m->mutable_d(), value.D()); -} diff --git a/wpimath/src/main/native/include/frc/system/struct/LinearSystemStruct.h b/wpimath/src/main/native/include/frc/system/struct/LinearSystemStruct.h index e8731a966bc..74a76bce0b6 100644 --- a/wpimath/src/main/native/include/frc/system/struct/LinearSystemStruct.h +++ b/wpimath/src/main/native/include/frc/system/struct/LinearSystemStruct.h @@ -32,9 +32,36 @@ struct wpi::Struct> { static constexpr std::string_view GetSchema() { return kSchema; } static frc::LinearSystem Unpack( - std::span data); + std::span data) { + constexpr size_t kAOff = 0; + constexpr size_t kBOff = + kAOff + wpi::GetStructSize>(); + constexpr size_t kCOff = + kBOff + wpi::GetStructSize>(); + constexpr size_t kDOff = + kCOff + wpi::GetStructSize>(); + return frc::LinearSystem{ + wpi::UnpackStruct, kAOff>(data), + wpi::UnpackStruct, kBOff>(data), + wpi::UnpackStruct, kCOff>(data), + wpi::UnpackStruct, kDOff>(data)}; + } + static void Pack(std::span data, - const frc::LinearSystem& value); + const frc::LinearSystem& value) { + constexpr size_t kAOff = 0; + constexpr size_t kBOff = + kAOff + wpi::GetStructSize>(); + constexpr size_t kCOff = + kBOff + wpi::GetStructSize>(); + constexpr size_t kDOff = + kCOff + wpi::GetStructSize>(); + wpi::PackStruct(data, value.A()); + wpi::PackStruct(data, value.B()); + wpi::PackStruct(data, value.C()); + wpi::PackStruct(data, value.D()); + } + static void ForEachNested( std::invocable auto fn) { wpi::ForEachStructSchema>(fn); @@ -48,5 +75,3 @@ static_assert(wpi::StructSerializable>); static_assert(wpi::HasNestedStruct>); static_assert(wpi::StructSerializable>); static_assert(wpi::HasNestedStruct>); - -#include "frc/system/struct/LinearSystemStruct.inc" diff --git a/wpimath/src/main/native/include/frc/system/struct/LinearSystemStruct.inc b/wpimath/src/main/native/include/frc/system/struct/LinearSystemStruct.inc deleted file mode 100644 index fbac4b97bcd..00000000000 --- a/wpimath/src/main/native/include/frc/system/struct/LinearSystemStruct.inc +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright (c) FIRST and other WPILib contributors. -// Open Source Software; you can modify and/or share it under the terms of -// the WPILib BSD license file in the root directory of this project. - -#pragma once - -#include "frc/system/struct/LinearSystemStruct.h" - -template -frc::LinearSystem -wpi::Struct>::Unpack( - std::span data) { - constexpr size_t kAOff = 0; - constexpr size_t kBOff = - kAOff + wpi::GetStructSize>(); - constexpr size_t kCOff = - kBOff + wpi::GetStructSize>(); - constexpr size_t kDOff = - kCOff + wpi::GetStructSize>(); - return frc::LinearSystem{ - wpi::UnpackStruct, kAOff>(data), - wpi::UnpackStruct, kBOff>(data), - wpi::UnpackStruct, kCOff>(data), - wpi::UnpackStruct, kDOff>(data)}; -} - -template -void wpi::Struct>::Pack( - std::span data, - const frc::LinearSystem& value) { - constexpr size_t kAOff = 0; - constexpr size_t kBOff = - kAOff + wpi::GetStructSize>(); - constexpr size_t kCOff = - kBOff + wpi::GetStructSize>(); - constexpr size_t kDOff = - kCOff + wpi::GetStructSize>(); - wpi::PackStruct(data, value.A()); - wpi::PackStruct(data, value.B()); - wpi::PackStruct(data, value.C()); - wpi::PackStruct(data, value.D()); -} diff --git a/wpimath/src/main/native/include/frc/trajectory/ExponentialProfile.h b/wpimath/src/main/native/include/frc/trajectory/ExponentialProfile.h index b5e8952b689..49ebc649ca9 100644 --- a/wpimath/src/main/native/include/frc/trajectory/ExponentialProfile.h +++ b/wpimath/src/main/native/include/frc/trajectory/ExponentialProfile.h @@ -4,8 +4,8 @@ #pragma once +#include "units/math.h" #include "units/time.h" -#include "wpimath/MathShared.h" namespace frc { @@ -134,7 +134,8 @@ class ExponentialProfile { * * @param constraints The constraints on the profile, like maximum input. */ - explicit ExponentialProfile(Constraints constraints); + explicit ExponentialProfile(Constraints constraints) + : m_constraints(constraints) {} ExponentialProfile(const ExponentialProfile&) = default; ExponentialProfile& operator=(const ExponentialProfile&) = default; @@ -152,7 +153,25 @@ class ExponentialProfile { * @return The position and velocity of the profile at time t. */ State Calculate(const units::second_t& t, const State& current, - const State& goal) const; + const State& goal) const { + auto direction = ShouldFlipInput(current, goal) ? -1 : 1; + auto u = direction * m_constraints.maxInput; + + auto inflectionPoint = CalculateInflectionPoint(current, goal, u); + auto timing = CalculateProfileTiming(current, inflectionPoint, goal, u); + + if (t < 0_s) { + return current; + } else if (t < timing.inflectionTime) { + return {ComputeDistanceFromTime(t, u, current), + ComputeVelocityFromTime(t, u, current)}; + } else if (t < timing.totalTime) { + return {ComputeDistanceFromTime(t - timing.totalTime, -u, goal), + ComputeVelocityFromTime(t - timing.totalTime, -u, goal)}; + } else { + return goal; + } + } /** * Calculates the point after which the fastest way to reach the goal state is @@ -162,7 +181,13 @@ class ExponentialProfile { * @param goal The desired state when the profile is complete. * @return The position and velocity of the profile at the inflection point. */ - State CalculateInflectionPoint(const State& current, const State& goal) const; + State CalculateInflectionPoint(const State& current, + const State& goal) const { + auto direction = ShouldFlipInput(current, goal) ? -1 : 1; + auto u = direction * m_constraints.maxInput; + + return CalculateInflectionPoint(current, goal, u); + } /** * Calculates the time it will take for this profile to reach the goal state. @@ -171,7 +196,11 @@ class ExponentialProfile { * @param goal The desired state when the profile is complete. * @return The total duration of this profile. */ - units::second_t TimeLeftUntil(const State& current, const State& goal) const; + units::second_t TimeLeftUntil(const State& current, const State& goal) const { + auto timing = CalculateProfileTiming(current, goal); + + return timing.totalTime; + } /** * Calculates the time it will take for this profile to reach the inflection @@ -182,7 +211,13 @@ class ExponentialProfile { * @return The timing information for this profile. */ ProfileTiming CalculateProfileTiming(const State& current, - const State& goal) const; + const State& goal) const { + auto direction = ShouldFlipInput(current, goal) ? -1 : 1; + auto u = direction * m_constraints.maxInput; + + auto inflectionPoint = CalculateInflectionPoint(current, goal, u); + return CalculateProfileTiming(current, inflectionPoint, goal, u); + } private: /** @@ -196,7 +231,19 @@ class ExponentialProfile { * @return The position and velocity of the profile at the inflection point. */ State CalculateInflectionPoint(const State& current, const State& goal, - const Input_t& input) const; + const Input_t& input) const { + auto u = input; + + if (current == goal) { + return current; + } + + auto inflectionVelocity = SolveForInflectionVelocity(u, current, goal); + auto inflectionPosition = + ComputeDistanceFromVelocity(inflectionVelocity, -u, goal); + + return {inflectionPosition, inflectionVelocity}; + } /** * Calculates the time it will take for this profile to reach the inflection @@ -212,7 +259,59 @@ class ExponentialProfile { ProfileTiming CalculateProfileTiming(const State& current, const State& inflectionPoint, const State& goal, - const Input_t& input) const; + const Input_t& input) const { + auto u = input; + auto u_dir = units::math::abs(u) / u; + + units::second_t inflectionT_forward; + + // We need to handle 5 cases here: + // + // - Approaching -maxVelocity from below + // - Approaching -maxVelocity from above + // - Approaching maxVelocity from below + // - Approaching maxVelocity from above + // - At +-maxVelocity + // + // For cases 1 and 3, we want to subtract epsilon from the inflection point + // velocity For cases 2 and 4, we want to add epsilon to the inflection + // point velocity. For case 5, we have reached inflection point velocity. + auto epsilon = Velocity_t(1e-9); + if (units::math::abs(u_dir * m_constraints.MaxVelocity() - + inflectionPoint.velocity) < epsilon) { + auto solvableV = inflectionPoint.velocity; + units::second_t t_to_solvable_v; + Distance_t x_at_solvable_v; + if (units::math::abs(current.velocity - inflectionPoint.velocity) < + epsilon) { + t_to_solvable_v = 0_s; + x_at_solvable_v = current.position; + } else { + if (units::math::abs(current.velocity) > m_constraints.MaxVelocity()) { + solvableV += u_dir * epsilon; + } else { + solvableV -= u_dir * epsilon; + } + + t_to_solvable_v = + ComputeTimeFromVelocity(solvableV, u, current.velocity); + x_at_solvable_v = ComputeDistanceFromVelocity(solvableV, u, current); + } + + inflectionT_forward = + t_to_solvable_v + u_dir * + (inflectionPoint.position - x_at_solvable_v) / + m_constraints.MaxVelocity(); + } else { + inflectionT_forward = ComputeTimeFromVelocity(inflectionPoint.velocity, u, + current.velocity); + } + + auto inflectionT_backward = + ComputeTimeFromVelocity(inflectionPoint.velocity, -u, goal.velocity); + + return {inflectionT_forward, inflectionT_forward - inflectionT_backward}; + } /** * Calculates the position reached after t seconds when applying an input from @@ -226,7 +325,16 @@ class ExponentialProfile { */ Distance_t ComputeDistanceFromTime(const units::second_t& time, const Input_t& input, - const State& initial) const; + const State& initial) const { + auto A = m_constraints.A; + auto B = m_constraints.B; + auto u = input; + + return initial.position + + (-B * u * time + + (initial.velocity + B * u / A) * (units::math::exp(A * time) - 1)) / + A; + } /** * Calculates the velocity reached after t seconds when applying an input from @@ -240,7 +348,14 @@ class ExponentialProfile { */ Velocity_t ComputeVelocityFromTime(const units::second_t& time, const Input_t& input, - const State& initial) const; + const State& initial) const { + auto A = m_constraints.A; + auto B = m_constraints.B; + auto u = input; + + return (initial.velocity + B * u / A) * units::math::exp(A * time) - + B * u / A; + } /** * Calculates the time required to reach a specified velocity given the @@ -254,7 +369,13 @@ class ExponentialProfile { */ units::second_t ComputeTimeFromVelocity(const Velocity_t& velocity, const Input_t& input, - const Velocity_t& initial) const; + const Velocity_t& initial) const { + auto A = m_constraints.A; + auto B = m_constraints.B; + auto u = input; + + return units::math::log((A * velocity + B * u) / (A * initial + B * u)) / A; + } /** * Calculates the distance reached at the same time as the given velocity when @@ -268,7 +389,16 @@ class ExponentialProfile { */ Distance_t ComputeDistanceFromVelocity(const Velocity_t& velocity, const Input_t& input, - const State& initial) const; + const State& initial) const { + auto A = m_constraints.A; + auto B = m_constraints.B; + auto u = input; + + return initial.position + (velocity - initial.velocity) / A - + B * u / (A * A) * + units::math::log((A * velocity + B * u) / + (A * initial.velocity + B * u)); + } /** * Calculates the velocity at which input should be reversed in order to reach @@ -282,7 +412,30 @@ class ExponentialProfile { */ Velocity_t SolveForInflectionVelocity(const Input_t& input, const State& current, - const State& goal) const; + const State& goal) const { + auto A = m_constraints.A; + auto B = m_constraints.B; + auto u = input; + + auto u_dir = u / units::math::abs(u); + + auto position_delta = goal.position - current.position; + auto velocity_delta = goal.velocity - current.velocity; + + auto scalar = (A * current.velocity + B * u) * (A * goal.velocity - B * u); + auto power = -A / B / u * (A * position_delta - velocity_delta); + + auto a = -A * A; + auto c = B * B * u * u + scalar * units::math::exp(power); + + if (-1e-9 < c.value() && c.value() < 0) { + // numeric instability - the heuristic gets it right but c is around + // -1e-13 + return Velocity_t(0); + } + + return u_dir * units::math::sqrt(-c / a); + } /** * Returns true if the profile should be inverted. @@ -293,10 +446,33 @@ class ExponentialProfile { * @param current The initial state (usually the current state). * @param goal The desired state when the profile is complete. */ - bool ShouldFlipInput(const State& current, const State& goal) const; + bool ShouldFlipInput(const State& current, const State& goal) const { + auto u = m_constraints.maxInput; + + auto v0 = current.velocity; + auto xf = goal.position; + auto vf = goal.velocity; + + auto x_forward = ComputeDistanceFromVelocity(vf, u, current); + auto x_reverse = ComputeDistanceFromVelocity(vf, -u, current); + + if (v0 >= m_constraints.MaxVelocity()) { + return xf < x_reverse; + } + + if (v0 <= -m_constraints.MaxVelocity()) { + return xf < x_forward; + } + + auto a = v0 >= Velocity_t(0); + auto b = vf >= Velocity_t(0); + auto c = xf >= x_forward; + auto d = xf >= x_reverse; + + return (a && !d) || (b && !c) || (!c && !d); + } Constraints m_constraints; }; -} // namespace frc -#include "ExponentialProfile.inc" +} // namespace frc diff --git a/wpimath/src/main/native/include/frc/trajectory/ExponentialProfile.inc b/wpimath/src/main/native/include/frc/trajectory/ExponentialProfile.inc deleted file mode 100644 index 82e33313f7a..00000000000 --- a/wpimath/src/main/native/include/frc/trajectory/ExponentialProfile.inc +++ /dev/null @@ -1,251 +0,0 @@ -// Copyright (c) FIRST and other WPILib contributors. -// Open Source Software; you can modify and/or share it under the terms of -// the WPILib BSD license file in the root directory of this project. - -#pragma once - -#include - -#include "frc/trajectory/ExponentialProfile.h" -#include "units/math.h" - -namespace frc { -template -ExponentialProfile::ExponentialProfile(Constraints constraints) - : m_constraints(constraints) {} - -template -typename ExponentialProfile::State -ExponentialProfile::Calculate(const units::second_t& t, - const State& current, - const State& goal) const { - auto direction = ShouldFlipInput(current, goal) ? -1 : 1; - auto u = direction * m_constraints.maxInput; - - auto inflectionPoint = CalculateInflectionPoint(current, goal, u); - auto timing = CalculateProfileTiming(current, inflectionPoint, goal, u); - - if (t < 0_s) { - return current; - } else if (t < timing.inflectionTime) { - return {ComputeDistanceFromTime(t, u, current), - ComputeVelocityFromTime(t, u, current)}; - } else if (t < timing.totalTime) { - return {ComputeDistanceFromTime(t - timing.totalTime, -u, goal), - ComputeVelocityFromTime(t - timing.totalTime, -u, goal)}; - } else { - return goal; - } -} - -template -typename ExponentialProfile::State -ExponentialProfile::CalculateInflectionPoint( - const State& current, const State& goal) const { - auto direction = ShouldFlipInput(current, goal) ? -1 : 1; - auto u = direction * m_constraints.maxInput; - - return CalculateInflectionPoint(current, goal, u); -} - -template -units::second_t ExponentialProfile::TimeLeftUntil( - const State& current, const State& goal) const { - auto timing = CalculateProfileTiming(current, goal); - - return timing.totalTime; -} - -template -typename ExponentialProfile::ProfileTiming -ExponentialProfile::CalculateProfileTiming( - const State& current, const State& goal) const { - auto direction = ShouldFlipInput(current, goal) ? -1 : 1; - auto u = direction * m_constraints.maxInput; - - auto inflectionPoint = CalculateInflectionPoint(current, goal, u); - return CalculateProfileTiming(current, inflectionPoint, goal, u); -} - -template -typename ExponentialProfile::State -ExponentialProfile::CalculateInflectionPoint( - const State& current, const State& goal, const Input_t& input) const { - auto u = input; - - if (current == goal) { - return current; - } - - auto inflectionVelocity = SolveForInflectionVelocity(u, current, goal); - auto inflectionPosition = - ComputeDistanceFromVelocity(inflectionVelocity, -u, goal); - - return {inflectionPosition, inflectionVelocity}; -} - -template -typename ExponentialProfile::ProfileTiming -ExponentialProfile::CalculateProfileTiming( - const State& current, const State& inflectionPoint, const State& goal, - const Input_t& input) const { - auto u = input; - auto u_dir = units::math::abs(u) / u; - - units::second_t inflectionT_forward; - - // We need to handle 5 cases here: - // - // - Approaching -maxVelocity from below - // - Approaching -maxVelocity from above - // - Approaching maxVelocity from below - // - Approaching maxVelocity from above - // - At +-maxVelocity - // - // For cases 1 and 3, we want to subtract epsilon from the inflection point - // velocity For cases 2 and 4, we want to add epsilon to the inflection point - // velocity. For case 5, we have reached inflection point velocity. - auto epsilon = Velocity_t(1e-9); - if (units::math::abs(u_dir * m_constraints.MaxVelocity() - - inflectionPoint.velocity) < epsilon) { - auto solvableV = inflectionPoint.velocity; - units::second_t t_to_solvable_v; - Distance_t x_at_solvable_v; - if (units::math::abs(current.velocity - inflectionPoint.velocity) < - epsilon) { - t_to_solvable_v = 0_s; - x_at_solvable_v = current.position; - } else { - if (units::math::abs(current.velocity) > m_constraints.MaxVelocity()) { - solvableV += u_dir * epsilon; - } else { - solvableV -= u_dir * epsilon; - } - - t_to_solvable_v = ComputeTimeFromVelocity(solvableV, u, current.velocity); - x_at_solvable_v = ComputeDistanceFromVelocity(solvableV, u, current); - } - - inflectionT_forward = - t_to_solvable_v + u_dir * (inflectionPoint.position - x_at_solvable_v) / - m_constraints.MaxVelocity(); - } else { - inflectionT_forward = - ComputeTimeFromVelocity(inflectionPoint.velocity, u, current.velocity); - } - - auto inflectionT_backward = - ComputeTimeFromVelocity(inflectionPoint.velocity, -u, goal.velocity); - - return {inflectionT_forward, inflectionT_forward - inflectionT_backward}; -} - -template -typename ExponentialProfile::Distance_t -ExponentialProfile::ComputeDistanceFromTime( - const units::second_t& time, const Input_t& input, - const State& initial) const { - auto A = m_constraints.A; - auto B = m_constraints.B; - auto u = input; - - return initial.position + - (-B * u * time + - (initial.velocity + B * u / A) * (units::math::exp(A * time) - 1)) / - A; -} - -template -typename ExponentialProfile::Velocity_t -ExponentialProfile::ComputeVelocityFromTime( - const units::second_t& time, const Input_t& input, - const State& initial) const { - auto A = m_constraints.A; - auto B = m_constraints.B; - auto u = input; - - return (initial.velocity + B * u / A) * units::math::exp(A * time) - - B * u / A; -} - -template -units::second_t ExponentialProfile::ComputeTimeFromVelocity( - const Velocity_t& velocity, const Input_t& input, - const Velocity_t& initial) const { - auto A = m_constraints.A; - auto B = m_constraints.B; - auto u = input; - - return units::math::log((A * velocity + B * u) / (A * initial + B * u)) / A; -} - -template -typename ExponentialProfile::Distance_t -ExponentialProfile::ComputeDistanceFromVelocity( - const Velocity_t& velocity, const Input_t& input, - const State& initial) const { - auto A = m_constraints.A; - auto B = m_constraints.B; - auto u = input; - - return initial.position + (velocity - initial.velocity) / A - - B * u / (A * A) * - units::math::log((A * velocity + B * u) / - (A * initial.velocity + B * u)); -} - -template -typename ExponentialProfile::Velocity_t -ExponentialProfile::SolveForInflectionVelocity( - const Input_t& input, const State& current, const State& goal) const { - auto A = m_constraints.A; - auto B = m_constraints.B; - auto u = input; - - auto u_dir = u / units::math::abs(u); - - auto position_delta = goal.position - current.position; - auto velocity_delta = goal.velocity - current.velocity; - - auto scalar = (A * current.velocity + B * u) * (A * goal.velocity - B * u); - auto power = -A / B / u * (A * position_delta - velocity_delta); - - auto a = -A * A; - auto c = B * B * u * u + scalar * units::math::exp(power); - - if (-1e-9 < c.value() && c.value() < 0) { - // numeric instability - the heuristic gets it right but c is around -1e-13 - return Velocity_t(0); - } - - return u_dir * units::math::sqrt(-c / a); -} - -template -bool ExponentialProfile::ShouldFlipInput( - const State& current, const State& goal) const { - auto u = m_constraints.maxInput; - - auto v0 = current.velocity; - auto xf = goal.position; - auto vf = goal.velocity; - - auto x_forward = ComputeDistanceFromVelocity(vf, u, current); - auto x_reverse = ComputeDistanceFromVelocity(vf, -u, current); - - if (v0 >= m_constraints.MaxVelocity()) { - return xf < x_reverse; - } - - if (v0 <= -m_constraints.MaxVelocity()) { - return xf < x_forward; - } - - auto a = v0 >= Velocity_t(0); - auto b = vf >= Velocity_t(0); - auto c = xf >= x_forward; - auto d = xf >= x_reverse; - - return (a && !d) || (b && !c) || (!c && !d); -} -} // namespace frc diff --git a/wpimath/src/main/native/include/frc/trajectory/TrapezoidProfile.h b/wpimath/src/main/native/include/frc/trajectory/TrapezoidProfile.h index 08df10c8d94..20a88176a87 100644 --- a/wpimath/src/main/native/include/frc/trajectory/TrapezoidProfile.h +++ b/wpimath/src/main/native/include/frc/trajectory/TrapezoidProfile.h @@ -4,6 +4,7 @@ #pragma once +#include "units/math.h" #include "units/time.h" #include "wpimath/MathShared.h" @@ -101,7 +102,9 @@ class TrapezoidProfile { * * @param constraints The constraints on the profile, like maximum velocity. */ - TrapezoidProfile(Constraints constraints); // NOLINT + TrapezoidProfile(Constraints constraints) // NOLINT + : m_constraints(constraints) {} + TrapezoidProfile(const TrapezoidProfile&) = default; TrapezoidProfile& operator=(const TrapezoidProfile&) = default; TrapezoidProfile(TrapezoidProfile&&) = default; @@ -117,7 +120,74 @@ class TrapezoidProfile { * @param goal The desired state when the profile is complete. * @return The position and velocity of the profile at time t. */ - State Calculate(units::second_t t, State current, State goal); + State Calculate(units::second_t t, State current, State goal) { + m_direction = ShouldFlipAcceleration(current, goal) ? -1 : 1; + m_current = Direct(current); + goal = Direct(goal); + if (m_current.velocity > m_constraints.maxVelocity) { + m_current.velocity = m_constraints.maxVelocity; + } + + // Deal with a possibly truncated motion profile (with nonzero initial or + // final velocity) by calculating the parameters as if the profile began and + // ended at zero velocity + units::second_t cutoffBegin = + m_current.velocity / m_constraints.maxAcceleration; + Distance_t cutoffDistBegin = + cutoffBegin * cutoffBegin * m_constraints.maxAcceleration / 2.0; + + units::second_t cutoffEnd = goal.velocity / m_constraints.maxAcceleration; + Distance_t cutoffDistEnd = + cutoffEnd * cutoffEnd * m_constraints.maxAcceleration / 2.0; + + // Now we can calculate the parameters as if it was a full trapezoid instead + // of a truncated one + + Distance_t fullTrapezoidDist = + cutoffDistBegin + (goal.position - m_current.position) + cutoffDistEnd; + units::second_t accelerationTime = + m_constraints.maxVelocity / m_constraints.maxAcceleration; + + Distance_t fullSpeedDist = + fullTrapezoidDist - + accelerationTime * accelerationTime * m_constraints.maxAcceleration; + + // Handle the case where the profile never reaches full speed + if (fullSpeedDist < Distance_t{0}) { + accelerationTime = + units::math::sqrt(fullTrapezoidDist / m_constraints.maxAcceleration); + fullSpeedDist = Distance_t{0}; + } + + m_endAccel = accelerationTime - cutoffBegin; + m_endFullSpeed = m_endAccel + fullSpeedDist / m_constraints.maxVelocity; + m_endDecel = m_endFullSpeed + accelerationTime - cutoffEnd; + State result = m_current; + + if (t < m_endAccel) { + result.velocity += t * m_constraints.maxAcceleration; + result.position += + (m_current.velocity + t * m_constraints.maxAcceleration / 2.0) * t; + } else if (t < m_endFullSpeed) { + result.velocity = m_constraints.maxVelocity; + result.position += (m_current.velocity + + m_endAccel * m_constraints.maxAcceleration / 2.0) * + m_endAccel + + m_constraints.maxVelocity * (t - m_endAccel); + } else if (t <= m_endDecel) { + result.velocity = + goal.velocity + (m_endDecel - t) * m_constraints.maxAcceleration; + units::second_t timeLeft = m_endDecel - t; + result.position = + goal.position - + (goal.velocity + timeLeft * m_constraints.maxAcceleration / 2.0) * + timeLeft; + } else { + result = goal; + } + + return Direct(result); + } /** * Returns the time left until a target distance in the profile is reached. @@ -125,7 +195,71 @@ class TrapezoidProfile { * @param target The target distance. * @return The time left until a target distance in the profile is reached. */ - units::second_t TimeLeftUntil(Distance_t target) const; + units::second_t TimeLeftUntil(Distance_t target) const { + Distance_t position = m_current.position * m_direction; + Velocity_t velocity = m_current.velocity * m_direction; + + units::second_t endAccel = m_endAccel * m_direction; + units::second_t endFullSpeed = m_endFullSpeed * m_direction - endAccel; + + if (target < position) { + endAccel *= -1.0; + endFullSpeed *= -1.0; + velocity *= -1.0; + } + + endAccel = units::math::max(endAccel, 0_s); + endFullSpeed = units::math::max(endFullSpeed, 0_s); + + const Acceleration_t acceleration = m_constraints.maxAcceleration; + const Acceleration_t deceleration = -m_constraints.maxAcceleration; + + Distance_t distToTarget = units::math::abs(target - position); + + if (distToTarget < Distance_t{1e-6}) { + return 0_s; + } + + Distance_t accelDist = + velocity * endAccel + 0.5 * acceleration * endAccel * endAccel; + + Velocity_t decelVelocity; + if (endAccel > 0_s) { + decelVelocity = units::math::sqrt( + units::math::abs(velocity * velocity + 2 * acceleration * accelDist)); + } else { + decelVelocity = velocity; + } + + Distance_t fullSpeedDist = m_constraints.maxVelocity * endFullSpeed; + Distance_t decelDist; + + if (accelDist > distToTarget) { + accelDist = distToTarget; + fullSpeedDist = Distance_t{0}; + decelDist = Distance_t{0}; + } else if (accelDist + fullSpeedDist > distToTarget) { + fullSpeedDist = distToTarget - accelDist; + decelDist = Distance_t{0}; + } else { + decelDist = distToTarget - fullSpeedDist - accelDist; + } + + units::second_t accelTime = + (-velocity + units::math::sqrt(units::math::abs( + velocity * velocity + 2 * acceleration * accelDist))) / + acceleration; + + units::second_t decelTime = + (-decelVelocity + + units::math::sqrt(units::math::abs(decelVelocity * decelVelocity + + 2 * deceleration * decelDist))) / + deceleration; + + units::second_t fullSpeedTime = fullSpeedDist / m_constraints.maxVelocity; + + return accelTime + fullSpeedTime + decelTime; + } /** * Returns the total time the profile takes to reach the goal. @@ -176,6 +310,5 @@ class TrapezoidProfile { units::second_t m_endFullSpeed; units::second_t m_endDecel; }; -} // namespace frc -#include "TrapezoidProfile.inc" +} // namespace frc diff --git a/wpimath/src/main/native/include/frc/trajectory/TrapezoidProfile.inc b/wpimath/src/main/native/include/frc/trajectory/TrapezoidProfile.inc deleted file mode 100644 index 4b1eff72b69..00000000000 --- a/wpimath/src/main/native/include/frc/trajectory/TrapezoidProfile.inc +++ /dev/null @@ -1,156 +0,0 @@ -// Copyright (c) FIRST and other WPILib contributors. -// Open Source Software; you can modify and/or share it under the terms of -// the WPILib BSD license file in the root directory of this project. - -#pragma once - -#include - -#include "frc/trajectory/TrapezoidProfile.h" -#include "units/math.h" - -namespace frc { -template -TrapezoidProfile::TrapezoidProfile(Constraints constraints) - : m_constraints(constraints) {} - -template -typename TrapezoidProfile::State -TrapezoidProfile::Calculate(units::second_t t, State current, - State goal) { - m_direction = ShouldFlipAcceleration(current, goal) ? -1 : 1; - m_current = Direct(current); - goal = Direct(goal); - if (m_current.velocity > m_constraints.maxVelocity) { - m_current.velocity = m_constraints.maxVelocity; - } - - // Deal with a possibly truncated motion profile (with nonzero initial or - // final velocity) by calculating the parameters as if the profile began and - // ended at zero velocity - units::second_t cutoffBegin = - m_current.velocity / m_constraints.maxAcceleration; - Distance_t cutoffDistBegin = - cutoffBegin * cutoffBegin * m_constraints.maxAcceleration / 2.0; - - units::second_t cutoffEnd = goal.velocity / m_constraints.maxAcceleration; - Distance_t cutoffDistEnd = - cutoffEnd * cutoffEnd * m_constraints.maxAcceleration / 2.0; - - // Now we can calculate the parameters as if it was a full trapezoid instead - // of a truncated one - - Distance_t fullTrapezoidDist = - cutoffDistBegin + (goal.position - m_current.position) + cutoffDistEnd; - units::second_t accelerationTime = - m_constraints.maxVelocity / m_constraints.maxAcceleration; - - Distance_t fullSpeedDist = - fullTrapezoidDist - - accelerationTime * accelerationTime * m_constraints.maxAcceleration; - - // Handle the case where the profile never reaches full speed - if (fullSpeedDist < Distance_t{0}) { - accelerationTime = - units::math::sqrt(fullTrapezoidDist / m_constraints.maxAcceleration); - fullSpeedDist = Distance_t{0}; - } - - m_endAccel = accelerationTime - cutoffBegin; - m_endFullSpeed = m_endAccel + fullSpeedDist / m_constraints.maxVelocity; - m_endDecel = m_endFullSpeed + accelerationTime - cutoffEnd; - State result = m_current; - - if (t < m_endAccel) { - result.velocity += t * m_constraints.maxAcceleration; - result.position += - (m_current.velocity + t * m_constraints.maxAcceleration / 2.0) * t; - } else if (t < m_endFullSpeed) { - result.velocity = m_constraints.maxVelocity; - result.position += (m_current.velocity + - m_endAccel * m_constraints.maxAcceleration / 2.0) * - m_endAccel + - m_constraints.maxVelocity * (t - m_endAccel); - } else if (t <= m_endDecel) { - result.velocity = - goal.velocity + (m_endDecel - t) * m_constraints.maxAcceleration; - units::second_t timeLeft = m_endDecel - t; - result.position = - goal.position - - (goal.velocity + timeLeft * m_constraints.maxAcceleration / 2.0) * - timeLeft; - } else { - result = goal; - } - - return Direct(result); -} - -template -units::second_t TrapezoidProfile::TimeLeftUntil( - Distance_t target) const { - Distance_t position = m_current.position * m_direction; - Velocity_t velocity = m_current.velocity * m_direction; - - units::second_t endAccel = m_endAccel * m_direction; - units::second_t endFullSpeed = m_endFullSpeed * m_direction - endAccel; - - if (target < position) { - endAccel *= -1.0; - endFullSpeed *= -1.0; - velocity *= -1.0; - } - - endAccel = units::math::max(endAccel, 0_s); - endFullSpeed = units::math::max(endFullSpeed, 0_s); - - const Acceleration_t acceleration = m_constraints.maxAcceleration; - const Acceleration_t deceleration = -m_constraints.maxAcceleration; - - Distance_t distToTarget = units::math::abs(target - position); - - if (distToTarget < Distance_t{1e-6}) { - return 0_s; - } - - Distance_t accelDist = - velocity * endAccel + 0.5 * acceleration * endAccel * endAccel; - - Velocity_t decelVelocity; - if (endAccel > 0_s) { - decelVelocity = units::math::sqrt( - units::math::abs(velocity * velocity + 2 * acceleration * accelDist)); - } else { - decelVelocity = velocity; - } - - Distance_t fullSpeedDist = m_constraints.maxVelocity * endFullSpeed; - Distance_t decelDist; - - if (accelDist > distToTarget) { - accelDist = distToTarget; - fullSpeedDist = Distance_t{0}; - decelDist = Distance_t{0}; - } else if (accelDist + fullSpeedDist > distToTarget) { - fullSpeedDist = distToTarget - accelDist; - decelDist = Distance_t{0}; - } else { - decelDist = distToTarget - fullSpeedDist - accelDist; - } - - units::second_t accelTime = - (-velocity + units::math::sqrt(units::math::abs( - velocity * velocity + 2 * acceleration * accelDist))) / - acceleration; - - units::second_t decelTime = - (-decelVelocity + - units::math::sqrt(units::math::abs(decelVelocity * decelVelocity + - 2 * deceleration * decelDist))) / - deceleration; - - units::second_t fullSpeedTime = fullSpeedDist / m_constraints.maxVelocity; - - return accelTime + fullSpeedTime + decelTime; -} -} // namespace frc diff --git a/wpimath/src/main/native/include/frc/trajectory/constraint/SwerveDriveKinematicsConstraint.h b/wpimath/src/main/native/include/frc/trajectory/constraint/SwerveDriveKinematicsConstraint.h index d26497735d9..ac05bcfeb2f 100644 --- a/wpimath/src/main/native/include/frc/trajectory/constraint/SwerveDriveKinematicsConstraint.h +++ b/wpimath/src/main/native/include/frc/trajectory/constraint/SwerveDriveKinematicsConstraint.h @@ -8,6 +8,7 @@ #include "frc/kinematics/SwerveDriveKinematics.h" #include "frc/trajectory/constraint/TrajectoryConstraint.h" +#include "units/math.h" #include "units/velocity.h" namespace frc { @@ -22,19 +23,31 @@ class SwerveDriveKinematicsConstraint : public TrajectoryConstraint { public: SwerveDriveKinematicsConstraint( const frc::SwerveDriveKinematics& kinematics, - units::meters_per_second_t maxSpeed); + units::meters_per_second_t maxSpeed) + : m_kinematics(kinematics), m_maxSpeed(maxSpeed) {} units::meters_per_second_t MaxVelocity( const Pose2d& pose, units::curvature_t curvature, - units::meters_per_second_t velocity) const override; + units::meters_per_second_t velocity) const override { + auto xVelocity = velocity * pose.Rotation().Cos(); + auto yVelocity = velocity * pose.Rotation().Sin(); + auto wheelSpeeds = m_kinematics.ToSwerveModuleStates( + {xVelocity, yVelocity, velocity * curvature}); + m_kinematics.DesaturateWheelSpeeds(&wheelSpeeds, m_maxSpeed); + + auto normSpeeds = m_kinematics.ToChassisSpeeds(wheelSpeeds); + + return units::math::hypot(normSpeeds.vx, normSpeeds.vy); + } MinMax MinMaxAcceleration(const Pose2d& pose, units::curvature_t curvature, - units::meters_per_second_t speed) const override; + units::meters_per_second_t speed) const override { + return {}; + } private: frc::SwerveDriveKinematics m_kinematics; units::meters_per_second_t m_maxSpeed; }; -} // namespace frc -#include "SwerveDriveKinematicsConstraint.inc" +} // namespace frc diff --git a/wpimath/src/main/native/include/frc/trajectory/constraint/SwerveDriveKinematicsConstraint.inc b/wpimath/src/main/native/include/frc/trajectory/constraint/SwerveDriveKinematicsConstraint.inc deleted file mode 100644 index 5c4e5ace9d5..00000000000 --- a/wpimath/src/main/native/include/frc/trajectory/constraint/SwerveDriveKinematicsConstraint.inc +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright (c) FIRST and other WPILib contributors. -// Open Source Software; you can modify and/or share it under the terms of -// the WPILib BSD license file in the root directory of this project. - -#pragma once - -#include "frc/trajectory/constraint/SwerveDriveKinematicsConstraint.h" -#include "units/math.h" - -namespace frc { - -template -SwerveDriveKinematicsConstraint::SwerveDriveKinematicsConstraint( - const frc::SwerveDriveKinematics& kinematics, - units::meters_per_second_t maxSpeed) - : m_kinematics(kinematics), m_maxSpeed(maxSpeed) {} - -template -units::meters_per_second_t -SwerveDriveKinematicsConstraint::MaxVelocity( - const Pose2d& pose, units::curvature_t curvature, - units::meters_per_second_t velocity) const { - auto xVelocity = velocity * pose.Rotation().Cos(); - auto yVelocity = velocity * pose.Rotation().Sin(); - auto wheelSpeeds = m_kinematics.ToSwerveModuleStates( - {xVelocity, yVelocity, velocity * curvature}); - m_kinematics.DesaturateWheelSpeeds(&wheelSpeeds, m_maxSpeed); - - auto normSpeeds = m_kinematics.ToChassisSpeeds(wheelSpeeds); - - return units::math::hypot(normSpeeds.vx, normSpeeds.vy); -} - -template -TrajectoryConstraint::MinMax -SwerveDriveKinematicsConstraint::MinMaxAcceleration( - const Pose2d& pose, units::curvature_t curvature, - units::meters_per_second_t speed) const { - return {}; -} - -} // namespace frc