Skip to content

Commit

Permalink
Merge pull request #4 from yaacov/simplify-model
Browse files Browse the repository at this point in the history
Simplify model
  • Loading branch information
yaacov authored Jan 24, 2024
2 parents e25a02a + 59ed824 commit 210730b
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 42 deletions.
Binary file modified checkpoints/driver.pth
Binary file not shown.
43 changes: 14 additions & 29 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class DriverModel(nn.Module):
- Hidden Layer 3: 128 neurons, followed by batch normalization and 50% dropout.
- Hidden Layer 4: 64 neurons, followed by batch normalization and 50% dropout.
- Hidden Layer 5: 32 neurons, followed by batch normalization and 50% dropout.
- Output Layer: 3 neurons, representing the possible driving decisions (e.g., left, forward, right).
- Output Layer: 7 neurons, representing the possible driving actions (e.g., obstacles.ALL).
Activation Function:
- ReLU activation function is used for all hidden layers.
Expand Down Expand Up @@ -97,7 +97,7 @@ def __init__(self):
self.bn5 = nn.BatchNorm1d(32)
self.dropout5 = nn.Dropout(0.5)

self.fc6 = nn.Linear(32, 3)
self.fc6 = nn.Linear(32, 7)

def forward(self, x):
x = F.relu(self.bn1(self.fc1(x)))
Expand Down Expand Up @@ -156,40 +156,25 @@ def view_to_inputs(array, car_lane):
return torch.cat((world_tensor, car_lane_tensor))


def outputs_to_action(output, world):
def outputs_to_action(output):
"""
Convert the model's output tensor into a driving action based on the current state of the world.
The function first determines the car's intended position (left, forward, right) based on the model's output.
If the position is forward, the function checks the obstacle in front of the car and maps it to an appropriate action.
The function determines the car's intended action (left, forward, right) based on the model's output.
Args:
output (torch.Tensor): The model's output tensor, typically representing probabilities for each position.
world (World): An instance of the World class providing read-only access to the current game state.
Returns:
str: The determined action for the car to take. Possible actions include those defined in the `actions` class.
Notes:
The function uses a predefined mapping of obstacles to actions to determine the appropriate action when moving forward.
"""
positions = ["left", "forward", "right"]
obstacle_action_map = {
obstacles.PENGUIN: actions.PICKUP,
obstacles.CRACK: actions.JUMP,
obstacles.WATER: actions.BRAKE,
}

obstacle = world.get((world.car.x, world.car.y - 1))
position_index = torch.argmax(output).item()
position = positions[position_index]
action = actions.ALL[position_index]

if position == "left":
return actions.LEFT
elif position == "right":
return actions.RIGHT
else:
return obstacle_action_map.get(obstacle, actions.NONE)
return action


def action_to_outputs(action):
Expand All @@ -206,13 +191,13 @@ def action_to_outputs(action):
Returns:
torch.Tensor: A tensor of shape (3,) where the element corresponding to the given action is 1, and the others are 0.
"""
target = torch.zeros(3)

if action == actions.LEFT:
target[0] = 1
elif action == actions.RIGHT:
target[2] = 1
else:
target[1] = 1
target = torch.zeros(7)

try:
action_index = actions.ALL.index(action)
except ValueError:
action_index = 0

target[action_index] = 1

return target
2 changes: 1 addition & 1 deletion mydriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,6 @@ def drive(world):
output = model(input_tensor)

# Convert the output tensor into a real world response
action = outputs_to_action(output, world)
action = outputs_to_action(output)

return action
30 changes: 18 additions & 12 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,16 +84,22 @@ def driver_simulator(array, car_x):
"""
obstacle = array[3][car_x]

if obstacle == obstacles.PENGUIN:
return actions.PICKUP
elif obstacle == obstacles.WATER:
return actions.BRAKE
elif obstacle == obstacles.CRACK:
return actions.JUMP
elif obstacle == obstacles.NONE:
return actions.NONE
else:
return actions.RIGHT if (car_x % 3) == 0 else actions.LEFT
# Define a dictionary to map obstacles to actions
action_map = {
obstacles.PENGUIN: actions.PICKUP,
obstacles.WATER: actions.BRAKE,
obstacles.CRACK: actions.JUMP,
obstacles.NONE: actions.NONE,
}

# Determine the action based on the obstacle
action = action_map.get(obstacle)

# If the obstacle is not in the dictionary, determine the action based on the car's x position
if action is None:
action = actions.RIGHT if (car_x % 3) == 0 else actions.LEFT

return action


def generate_batch(batch_size):
Expand Down Expand Up @@ -166,10 +172,10 @@ def main():
"--checkpoint-out", default="", help="Path to the output checkpoint file."
)
parser.add_argument(
"--num-epochs", type=int, default=10, help="Number of epochs for training."
"--num-epochs", type=int, default=25, help="Number of epochs for training."
)
parser.add_argument(
"--batch-size", type=int, default=200, help="Batch size for training."
"--batch-size", type=int, default=250, help="Batch size for training."
)
parser.add_argument(
"--learning-rate", type=float, default=0.001, help="Learning rate for training."
Expand Down

0 comments on commit 210730b

Please sign in to comment.