Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
Signed-off-by: yzamir <kobi.zamir@gmail.com>
  • Loading branch information
yaacov committed Jan 25, 2024
1 parent 35b88e6 commit 76d964f
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 34 deletions.
Binary file modified checkpoints/driver.pth
Binary file not shown.
16 changes: 7 additions & 9 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class DriverModel(nn.Module):
- Hidden Layer 3: 128 neurons, followed by batch normalization and 50% dropout.
- Hidden Layer 4: 64 neurons, followed by batch normalization and 50% dropout.
- Hidden Layer 5: 32 neurons, followed by batch normalization and 50% dropout.
- Output Layer: 7 neurons, representing the possible driving actions (e.g., obstacles.ALL).
- Output Layer: 6 neurons, representing the possible driving actions (e.g., obstacles.ALL).
Activation Function:
- ReLU activation function is used for all hidden layers.
Expand Down Expand Up @@ -97,7 +97,7 @@ def __init__(self):
self.bn5 = nn.BatchNorm1d(32)
self.dropout5 = nn.Dropout(0.5)

self.fc6 = nn.Linear(32, 7)
self.fc6 = nn.Linear(32, 6)

def forward(self, x):
x = F.relu(self.bn1(self.fc1(x)))
Expand Down Expand Up @@ -136,21 +136,19 @@ def view_to_inputs(array, car_lane):
Notes:
The function uses a predefined mapping of obstacles to indices for the one-hot encoding.
"""
OBSTACLE_TO_INDEX = {
obstacle: index for index, obstacle in enumerate(obstacles.ALL)
}

height = len(array)
width = len(array[0])
tensor = torch.zeros((height, width, 7))
num_obstacles = len(obstacles.ALL)

tensor = torch.zeros((height, width, num_obstacles))

for i in range(height):
for j in range(width):
obstacle = array[i][j]
tensor[i, j, OBSTACLE_TO_INDEX[obstacle]] = 1
tensor[i, j, obstacles.ALL.index(obstacle)] = 1

world_tensor = tensor.view(-1)
car_lane_tensor = torch.zeros(6)
car_lane_tensor = torch.zeros(width)
car_lane_tensor[car_lane] = 1

return torch.cat((world_tensor, car_lane_tensor))
Expand Down
49 changes: 24 additions & 25 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,48 +44,46 @@
model = DriverModel()


def generate_obstacle_array():
def generate_obstacle_array(width=6, height=4):
"""
Generates a 4x6 2D array with random obstacles.
Generates a 2D array with random obstacles.
Parameters:
width (int): The width of the 2D array. Default is 6.
height (int): The height of the 2D array. Default is 4.
Returns:
list[list[str]]: 4x6 2D array with random obstacles.
list[list[str]]: 2D array with random obstacles.
"""
OBSTACLE_TO_INDEX = {
"": 0,
"crack": 1,
"trash": 2,
"penguin": 3,
"bike": 4,
"water": 5,
"barrier": 6,
}
OBSTACLES = ["", "crack", "trash", "penguin", "bike", "water", "barrier"]

array = [["" for _ in range(6)] for _ in range(4)]
array = [["" for _ in range(width)] for _ in range(height)]

for i in range(4):
obstacle = random.choice(list(OBSTACLE_TO_INDEX.keys()))
position = random.randint(0, 2)
for i in range(height):
obstacle = random.choice(OBSTACLES)
position = random.randint(0, width // 2 - 1)
# lane A
array[i][position] = obstacle
# lane B
array[i][3 + position] = obstacle
array[i][width // 2 + position] = obstacle

return array


def driver_simulator(array, car_x):
def driver_simulator(array, car_x, width=6, height=4):
"""
Simulates the driver's decision based on the obstacle in front of the car.
Args:
array (list[list[str]]): 2D array representation of the world with obstacles as strings.
car_x (int): The car's x position.
width (int): The width of the 2D array. Default is 6.
height (int): The height of the 2D array. Default is 4.
Returns:
str: The determined action for the car to take. Possible actions include those defined in the `actions` class.
"""
obstacle = array[3][car_x]
obstacle = array[height - 1][car_x]

# Define a dictionary to map obstacles to actions
action_map = {
Expand All @@ -100,7 +98,7 @@ def driver_simulator(array, car_x):

# If the obstacle is not in the dictionary, determine the action based on the car's x position
if action is None:
action = actions.RIGHT if (car_x % 3) == 0 else actions.LEFT
action = actions.RIGHT if (car_x % (width // 2)) == 0 else actions.LEFT

return action

Expand All @@ -109,17 +107,16 @@ def action_to_outputs(action):
"""
Converts an action into a target tensor.
This function takes an action (LEFT, RIGHT, or other) and converts it into a target tensor with three elements.
The tensor's elements correspond to the actions LEFT, forward, and RIGHT respectively. The element corresponding
to the given action is set to 1, and the others are set to 0.
This function takes an action (LEFT, RIGHT, or other) and converts it into a target tensor with elements corresponding to the actions.
The element corresponding to the given action is set to 1, and the others are set to 0.
Args:
action (str): The action to convert. Should be one of the actions defined in the `actions` class.
Returns:
torch.Tensor: A tensor of shape (3,) where the element corresponding to the given action is 1, and the others are 0.
torch.Tensor: A tensor of shape (len(actions.ALL),) where the element corresponding to the given action is 1, and the others are 0.
"""
target = torch.zeros(7)
target = torch.zeros(len(actions.ALL))

try:
action_index = actions.ALL.index(action)
Expand All @@ -143,6 +140,7 @@ def generate_batch(batch_size):
"""
inputs = []
targets = []

for _ in range(batch_size):
car_x = random.choice([0, 1, 2, 3, 4, 5])
array = generate_obstacle_array()
Expand All @@ -153,6 +151,7 @@ def generate_batch(batch_size):

inputs.append(input_tensor)
targets.append(target_tensor)

return torch.stack(inputs), torch.stack(targets)


Expand Down

0 comments on commit 76d964f

Please sign in to comment.