Skip to content

Commit

Permalink
Add interruptor parameter for onCommandInterrupt callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
KangarooKoala committed Sep 6, 2023
1 parent 1a6df6f commit a660039
Show file tree
Hide file tree
Showing 5 changed files with 245 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.WeakHashMap;
import java.util.function.BiConsumer;
import java.util.function.Consumer;

/**
Expand All @@ -42,6 +44,16 @@
* <p>This class is provided by the NewCommands VendorDep
*/
public final class CommandScheduler implements Sendable, AutoCloseable {
private static class CancelData {
public final Command m_command;
public final Optional<Command> m_interruptor;

CancelData(Command command, Optional<Command> interruptor) {
m_command = command;
m_interruptor = interruptor;
}
}

/** The Singleton Instance. */
private static CommandScheduler instance;

Expand Down Expand Up @@ -79,14 +91,14 @@ public static synchronized CommandScheduler getInstance() {
// Lists of user-supplied actions to be executed on scheduling events for every command.
private final List<Consumer<Command>> m_initActions = new ArrayList<>();
private final List<Consumer<Command>> m_executeActions = new ArrayList<>();
private final List<Consumer<Command>> m_interruptActions = new ArrayList<>();
private final List<BiConsumer<Command, Optional<Command>>> m_interruptActions = new ArrayList<>();
private final List<Consumer<Command>> m_finishActions = new ArrayList<>();

// Flag and queues for avoiding ConcurrentModificationException if commands are
// scheduled/canceled during run
private boolean m_inRunLoop;
private final Set<Command> m_toSchedule = new LinkedHashSet<>();
private final List<Command> m_toCancel = new ArrayList<>();
private final List<CancelData> m_toCancel = new ArrayList<>();

private final Watchdog m_watchdog = new Watchdog(TimedRobot.kDefaultPeriod, () -> {});

Expand Down Expand Up @@ -211,7 +223,7 @@ private void schedule(Command command) {
for (Subsystem requirement : requirements) {
Command requiring = requiring(requirement);
if (requiring != null) {
cancel(requiring);
cancel(requiring, Optional.of(command));
}
}
initCommand(command, requirements);
Expand Down Expand Up @@ -272,8 +284,8 @@ public void run() {

if (!command.runsWhenDisabled() && RobotState.isDisabled()) {
command.end(true);
for (Consumer<Command> action : m_interruptActions) {
action.accept(command);
for (BiConsumer<Command, Optional<Command>> action : m_interruptActions) {
action.accept(command, Optional.empty());
}
m_requirements.keySet().removeAll(command.getRequirements());
iterator.remove();
Expand Down Expand Up @@ -304,8 +316,8 @@ public void run() {
schedule(command);
}

for (Command command : m_toCancel) {
cancel(command);
for (CancelData cancelData : m_toCancel) {
cancel(cancelData.m_command, cancelData.m_interruptor);
}

m_toSchedule.clear();
Expand Down Expand Up @@ -440,28 +452,41 @@ public Command getDefaultCommand(Subsystem subsystem) {
* @param commands the commands to cancel
*/
public void cancel(Command... commands) {
for (Command command : commands) {
cancel(command, Optional.empty());
}
}

/**
* Cancels a command. The scheduler will only call {@link Command#end(boolean)} method of the
* canceled command with {@code true}, indicating they were canceled (as opposed to finishing
* normally).
*
* <p>Commands will be canceled regardless of {@link InterruptionBehavior interruption behavior}.
*
* @param command the command to cancel
* @param interruptor the interrupting command, if any
*/
private void cancel(Command command, Optional<Command> interruptor) {
if (command == null) {
DriverStation.reportWarning("Tried to cancel a null command", true);
return;
}
if (m_inRunLoop) {
m_toCancel.addAll(List.of(commands));
m_toCancel.add(new CancelData(command, interruptor));
return;
}
if (!isScheduled(command)) {
return;
}

for (Command command : commands) {
if (command == null) {
DriverStation.reportWarning("Tried to cancel a null command", true);
continue;
}
if (!isScheduled(command)) {
continue;
}

m_scheduledCommands.remove(command);
m_requirements.keySet().removeAll(command.getRequirements());
command.end(true);
for (Consumer<Command> action : m_interruptActions) {
action.accept(command);
}
m_watchdog.addEpoch(command.getName() + ".end(true)");
m_scheduledCommands.remove(command);
m_requirements.keySet().removeAll(command.getRequirements());
command.end(true);
for (BiConsumer<Command, Optional<Command>> action : m_interruptActions) {
action.accept(command, interruptor);
}
m_watchdog.addEpoch(command.getName() + ".end(true)");
}

/** Cancels all commands that are currently scheduled. */
Expand Down Expand Up @@ -528,6 +553,19 @@ public void onCommandExecute(Consumer<Command> action) {
* @param action the action to perform
*/
public void onCommandInterrupt(Consumer<Command> action) {
requireNonNullParam(action, "action", "onCommandInterrupt");
m_interruptActions.add((command, interruptor) -> action.accept(command));
}

/**
* Adds an action to perform on the interruption of any command by the scheduler. The action
* receives the interrupted command and an Optional containing the interrupting command, or
* Optional.empty() if it was not canceled by a command (e.g., by {@link
* CommandScheduler#cancel}).
*
* @param action the action to perform
*/
public void onCommandInterrupt(BiConsumer<Command, Optional<Command>> action) {
m_interruptActions.add(requireNonNullParam(action, "action", "onCommandInterrupt"));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ class CommandScheduler::Impl {
// every command.
wpi::SmallVector<Action, 4> initActions;
wpi::SmallVector<Action, 4> executeActions;
wpi::SmallVector<Action, 4> interruptActions;
wpi::SmallVector<InterruptAction, 4> interruptActions;
wpi::SmallVector<Action, 4> finishActions;

// Flag and queues for avoiding concurrent modification if commands are
// scheduled/canceled during run

bool inRunLoop = false;
wpi::SmallVector<Command*, 4> toSchedule;
wpi::SmallVector<Command*, 4> toCancel;
wpi::SmallVector<std::pair<Command*, std::optional<Command*>>, 4> toCancel;
};

template <typename TMap, typename TKey>
Expand Down Expand Up @@ -138,7 +138,7 @@ void CommandScheduler::Schedule(Command* command) {
if (isDisjoint || allInterruptible) {
if (allInterruptible) {
for (auto&& cmdToCancel : intersection) {
Cancel(cmdToCancel);
Cancel(cmdToCancel, std::make_optional(command));
}
}
m_impl->scheduledCommands.insert(command);
Expand Down Expand Up @@ -196,7 +196,7 @@ void CommandScheduler::Run() {
// Run scheduled commands, remove finished commands.
for (Command* command : m_impl->scheduledCommands) {
if (!command->RunsWhenDisabled() && frc::RobotState::IsDisabled()) {
Cancel(command);
Cancel(command, std::nullopt);
continue;
}

Expand Down Expand Up @@ -226,8 +226,8 @@ void CommandScheduler::Run() {
Schedule(command);
}

for (auto&& command : m_impl->toCancel) {
Cancel(command);
for (auto&& cancelData : m_impl->toCancel) {
Cancel(cancelData.first, cancelData.second);
}

m_impl->toSchedule.clear();
Expand Down Expand Up @@ -319,13 +319,14 @@ Command* CommandScheduler::GetDefaultCommand(const Subsystem* subsystem) const {
}
}

void CommandScheduler::Cancel(Command* command) {
void CommandScheduler::Cancel(Command* command,
std::optional<Command*> interruptor) {
if (!m_impl) {
return;
}

if (m_impl->inRunLoop) {
m_impl->toCancel.emplace_back(command);
m_impl->toCancel.emplace_back(command, interruptor);
return;
}

Expand All @@ -341,11 +342,15 @@ void CommandScheduler::Cancel(Command* command) {
}
command->End(true);
for (auto&& action : m_impl->interruptActions) {
action(*command);
action(*command, interruptor);
}
m_watchdog.AddEpoch(command->GetName() + ".End(true)");
}

void CommandScheduler::Cancel(Command* command) {
Cancel(command, std::nullopt);
}

void CommandScheduler::Cancel(const CommandPtr& command) {
Cancel(command.get());
}
Expand Down Expand Up @@ -424,6 +429,14 @@ void CommandScheduler::OnCommandExecute(Action action) {
}

void CommandScheduler::OnCommandInterrupt(Action action) {
m_impl->interruptActions.emplace_back(
[action = std::move(action)](const Command& command,
const std::optional<Command*>& interruptor) {
action(command);
});
}

void CommandScheduler::OnCommandInterrupt(InterruptAction action) {
m_impl->interruptActions.emplace_back(std::move(action));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <functional>
#include <initializer_list>
#include <memory>
#include <optional>
#include <span>
#include <utility>

Expand Down Expand Up @@ -48,6 +49,8 @@ class CommandScheduler final : public wpi::Sendable,
CommandScheduler& operator=(const CommandScheduler&) = delete;

using Action = std::function<void(const Command&)>;
using InterruptAction =
std::function<void(const Command&, const std::optional<Command*>&)>;

/**
* Changes the period of the loop overrun watchdog. This should be kept in
Expand Down Expand Up @@ -353,6 +356,16 @@ class CommandScheduler final : public wpi::Sendable,
*/
void OnCommandInterrupt(Action action);

/**
* Adds an action to perform on the interruption of any command by the
* scheduler. The action receives the interrupted command and an optional
* containing the interrupting command, or nullopt if it was not canceled by a
* command (e.g., by Cancel()).
*
* @param action the action to perform
*/
void OnCommandInterrupt(InterruptAction action);

/**
* Adds an action to perform on the finishing of any command by the scheduler.
*
Expand Down Expand Up @@ -397,6 +410,8 @@ class CommandScheduler final : public wpi::Sendable,
void SetDefaultCommandImpl(Subsystem* subsystem,
std::unique_ptr<Command> command);

void Cancel(Command* command, std::optional<Command*> interruptor);

class Impl;
std::unique_ptr<Impl> m_impl;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.util.concurrent.atomic.AtomicInteger;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -43,6 +46,76 @@ void schedulerInterruptLambdaTest() {
}
}

@Test
void schedulerInterruptNoCauseLambdaTest() {
try (CommandScheduler scheduler = new CommandScheduler()) {
AtomicInteger counter = new AtomicInteger();

scheduler.onCommandInterrupt(
(interrupted, cause) -> {
assertFalse(cause.isPresent());
counter.incrementAndGet();
});

Command command = Commands.run(() -> {});

scheduler.schedule(command);
scheduler.cancel(command);

assertEquals(1, counter.get());
}
}

@Test
void schedulerInterruptCauseLambdaTest() {
try (CommandScheduler scheduler = new CommandScheduler()) {
AtomicInteger counter = new AtomicInteger();

Subsystem subsystem = new Subsystem() {};
Command command = subsystem.run(() -> {});
Command interruptor = subsystem.runOnce(() -> {});

scheduler.onCommandInterrupt(
(interrupted, cause) -> {
assertTrue(cause.isPresent());
assertSame(interruptor, cause.get());
counter.incrementAndGet();
});

scheduler.schedule(command);
scheduler.schedule(interruptor);

assertEquals(1, counter.get());
}
}

@Test
void schedulerInterruptCauseLambdaInRunLoopTest() {
try (CommandScheduler scheduler = new CommandScheduler()) {
AtomicInteger counter = new AtomicInteger();

Subsystem subsystem = new Subsystem() {};
Command command = subsystem.run(() -> {});
Command interruptor = subsystem.runOnce(() -> {});
// This command will schedule interruptor in execute() inside the run loop
Command interruptorScheduler = Commands.runOnce(() -> scheduler.schedule(interruptor));

scheduler.onCommandInterrupt(
(interrupted, cause) -> {
assertTrue(cause.isPresent());
assertSame(interruptor, cause.get());
counter.incrementAndGet();
});

scheduler.schedule(command);
scheduler.schedule(interruptorScheduler);

scheduler.run();

assertEquals(1, counter.get());
}
}

@Test
void registerSubsystemTest() {
try (CommandScheduler scheduler = new CommandScheduler()) {
Expand Down Expand Up @@ -87,6 +160,7 @@ void schedulerCancelAllTest() {
AtomicInteger counter = new AtomicInteger();

scheduler.onCommandInterrupt(command -> counter.incrementAndGet());
scheduler.onCommandInterrupt((command, interruptor) -> assertFalse(interruptor.isPresent()));

Command command = new WaitCommand(10);
Command command2 = new WaitCommand(10);
Expand Down
Loading

0 comments on commit a660039

Please sign in to comment.