Skip to content

Commit

Permalink
Merge pull request #7 from yaacov/docstrings-tweaks
Browse files Browse the repository at this point in the history
use 4 by 6 instead of 3 by N
  • Loading branch information
yaacov authored Jan 25, 2024
2 parents bbb9c38 + 0b8d34a commit 27087a4
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 41 deletions.
Binary file modified checkpoints/driver.pth
Binary file not shown.
26 changes: 0 additions & 26 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,29 +175,3 @@ def outputs_to_action(output):
action = actions.ALL[position_index]

return action


def action_to_outputs(action):
"""
Converts an action into a target tensor.
This function takes an action (LEFT, RIGHT, or other) and converts it into a target tensor with three elements.
The tensor's elements correspond to the actions LEFT, forward, and RIGHT respectively. The element corresponding
to the given action is set to 1, and the others are set to 0.
Args:
action (str): The action to convert. Should be one of the actions defined in the `actions` class.
Returns:
torch.Tensor: A tensor of shape (3,) where the element corresponding to the given action is 1, and the others are 0.
"""
target = torch.zeros(7)

try:
action_index = actions.ALL.index(action)
except ValueError:
action_index = 0

target[action_index] = 1

return target
20 changes: 6 additions & 14 deletions mydriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,36 +50,28 @@

def build_lane_view(world):
"""
Build a 3xN 2D array representation of the world based on the car's lane and x position.
Build a 6x4 2D array representation of the world based on the car's lane and x position.
Args:
world (World): An instance of the World class providing read-only
access to the current game state.
Returns:
list[list[str]]: 3xN array representation of the world view from the car, where N is the specified height.
list[list[str]]: 6x4 array representation of the world view from the car, where 4 is the specified height.
The bottom line is one line above the car's y position, and the top line is the line height lines above that.
The array provides a view of the world from the car's perspective, with the car's y position excluded.
Notes:
The function uses the car's y position to determine the vertical range of the 2D array.
The starting x-coordinate is determined by the car's lane. If the lane is 0, the starting x is 0. If the lane is 1, the starting x is 3.
The function also provides a wrapper around world.get to handle negative y values, returning an empty string for such cases.
"""
height = 4
width = 6
car_y = world.car.y

# Calculate the starting y-coordinate based on the car's y position and the desired height
start_y = car_y - height

# Wrapper around world.get to handle negative y values
def get_value(j, i):
if i < 0:
return ""
return world.get((j, i))

# Generate the 2D array from start_y up to world.car.y
array = [[get_value(j, i) for j in range(6)] for i in range(start_y, car_y)]
array = [
[world.get((j, i)) for j in range(width)] for i in range(car_y - height, car_y)
]

return array

Expand Down
28 changes: 27 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import torch.nn as nn
import torch.optim as optim

from model import DriverModel, action_to_outputs, actions, obstacles, view_to_inputs
from model import DriverModel, actions, obstacles, view_to_inputs

# Training parameters
num_epochs = 0
Expand Down Expand Up @@ -105,6 +105,32 @@ def driver_simulator(array, car_x):
return action


def action_to_outputs(action):
"""
Converts an action into a target tensor.
This function takes an action (LEFT, RIGHT, or other) and converts it into a target tensor with three elements.
The tensor's elements correspond to the actions LEFT, forward, and RIGHT respectively. The element corresponding
to the given action is set to 1, and the others are set to 0.
Args:
action (str): The action to convert. Should be one of the actions defined in the `actions` class.
Returns:
torch.Tensor: A tensor of shape (3,) where the element corresponding to the given action is 1, and the others are 0.
"""
target = torch.zeros(7)

try:
action_index = actions.ALL.index(action)
except ValueError:
action_index = 0

target[action_index] = 1

return target


def generate_batch(batch_size):
"""
Generates a batch of samples for training.
Expand Down

0 comments on commit 27087a4

Please sign in to comment.