Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify model #4

Merged
merged 1 commit into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified checkpoints/driver.pth
Binary file not shown.
43 changes: 14 additions & 29 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class DriverModel(nn.Module):
- Hidden Layer 3: 128 neurons, followed by batch normalization and 50% dropout.
- Hidden Layer 4: 64 neurons, followed by batch normalization and 50% dropout.
- Hidden Layer 5: 32 neurons, followed by batch normalization and 50% dropout.
- Output Layer: 3 neurons, representing the possible driving decisions (e.g., left, forward, right).
- Output Layer: 7 neurons, representing the possible driving actions (e.g., obstacles.ALL).

Activation Function:
- ReLU activation function is used for all hidden layers.
Expand Down Expand Up @@ -97,7 +97,7 @@ def __init__(self):
self.bn5 = nn.BatchNorm1d(32)
self.dropout5 = nn.Dropout(0.5)

self.fc6 = nn.Linear(32, 3)
self.fc6 = nn.Linear(32, 7)

def forward(self, x):
x = F.relu(self.bn1(self.fc1(x)))
Expand Down Expand Up @@ -156,40 +156,25 @@ def view_to_inputs(array, car_lane):
return torch.cat((world_tensor, car_lane_tensor))


def outputs_to_action(output, world):
def outputs_to_action(output):
"""
Convert the model's output tensor into a driving action based on the current state of the world.

The function first determines the car's intended position (left, forward, right) based on the model's output.
If the position is forward, the function checks the obstacle in front of the car and maps it to an appropriate action.
The function determines the car's intended action (left, forward, right) based on the model's output.

Args:
output (torch.Tensor): The model's output tensor, typically representing probabilities for each position.
world (World): An instance of the World class providing read-only access to the current game state.

Returns:
str: The determined action for the car to take. Possible actions include those defined in the `actions` class.

Notes:
The function uses a predefined mapping of obstacles to actions to determine the appropriate action when moving forward.
"""
positions = ["left", "forward", "right"]
obstacle_action_map = {
obstacles.PENGUIN: actions.PICKUP,
obstacles.CRACK: actions.JUMP,
obstacles.WATER: actions.BRAKE,
}

obstacle = world.get((world.car.x, world.car.y - 1))
position_index = torch.argmax(output).item()
position = positions[position_index]
action = actions.ALL[position_index]

if position == "left":
return actions.LEFT
elif position == "right":
return actions.RIGHT
else:
return obstacle_action_map.get(obstacle, actions.NONE)
return action


def action_to_outputs(action):
Expand All @@ -206,13 +191,13 @@ def action_to_outputs(action):
Returns:
torch.Tensor: A tensor of shape (3,) where the element corresponding to the given action is 1, and the others are 0.
"""
target = torch.zeros(3)

if action == actions.LEFT:
target[0] = 1
elif action == actions.RIGHT:
target[2] = 1
else:
target[1] = 1
target = torch.zeros(7)

try:
action_index = actions.ALL.index(action)
except ValueError:
action_index = 0

target[action_index] = 1

return target
2 changes: 1 addition & 1 deletion mydriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,6 @@ def drive(world):
output = model(input_tensor)

# Convert the output tensor into a real world response
action = outputs_to_action(output, world)
action = outputs_to_action(output)

return action
30 changes: 18 additions & 12 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,16 +84,22 @@ def driver_simulator(array, car_x):
"""
obstacle = array[3][car_x]

if obstacle == obstacles.PENGUIN:
return actions.PICKUP
elif obstacle == obstacles.WATER:
return actions.BRAKE
elif obstacle == obstacles.CRACK:
return actions.JUMP
elif obstacle == obstacles.NONE:
return actions.NONE
else:
return actions.RIGHT if (car_x % 3) == 0 else actions.LEFT
# Define a dictionary to map obstacles to actions
action_map = {
obstacles.PENGUIN: actions.PICKUP,
obstacles.WATER: actions.BRAKE,
obstacles.CRACK: actions.JUMP,
obstacles.NONE: actions.NONE,
}

# Determine the action based on the obstacle
action = action_map.get(obstacle)

# If the obstacle is not in the dictionary, determine the action based on the car's x position
if action is None:
action = actions.RIGHT if (car_x % 3) == 0 else actions.LEFT

return action


def generate_batch(batch_size):
Expand Down Expand Up @@ -166,10 +172,10 @@ def main():
"--checkpoint-out", default="", help="Path to the output checkpoint file."
)
parser.add_argument(
"--num-epochs", type=int, default=10, help="Number of epochs for training."
"--num-epochs", type=int, default=25, help="Number of epochs for training."
)
parser.add_argument(
"--batch-size", type=int, default=200, help="Batch size for training."
"--batch-size", type=int, default=250, help="Batch size for training."
)
parser.add_argument(
"--learning-rate", type=float, default=0.001, help="Learning rate for training."
Expand Down