From e4da940de41defa8f87276357637f4e8531ab93a Mon Sep 17 00:00:00 2001 From: Yaacov Zamir Date: Tue, 9 Jan 2024 17:12:15 +0200 Subject: [PATCH 1/2] rename files Signed-off-by: Yaacov Zamir --- .coveragerc | 2 + .flake8 | 8 + .github/workflows/ci.yml | 31 ++++ .gitignore | 170 ++---------------- Dockerfile | 31 +++- Makefile | 46 ++++- conftest.py | 4 + model.py | 3 +- driver.py => mydriver.py | 6 +- pytest.ini | 2 + requirements-dev.txt | 10 ++ ...driver.yaml => rose-game-ai-reference.yaml | 16 +- {client => rose}/__init__.py | 0 {client/game => rose/ai}/__init__.py | 0 {client/game => rose/ai}/car.py | 0 {client/game => rose/ai}/server.py | 2 +- rose/ai/test_car.py | 21 +++ rose/ai/test_server.py | 57 ++++++ rose/ai/test_track.py | 46 +++++ {client/game => rose/ai}/track.py | 0 {client/game => rose/ai}/world.py | 4 +- rose/common/__init__.py | 0 {client/game => rose/common}/actions.py | 0 {client/game => rose/common}/obstacles.py | 0 rose/common/test_actions.py | 14 ++ rose/common/test_obstacles.py | 31 ++++ {client => rose}/main.py | 24 ++- train.py | 60 ++++--- 28 files changed, 369 insertions(+), 219 deletions(-) create mode 100644 .coveragerc create mode 100644 .flake8 create mode 100644 .github/workflows/ci.yml create mode 100644 conftest.py rename driver.py => mydriver.py (98%) create mode 100644 pytest.ini create mode 100644 requirements-dev.txt rename rose-ml-driver.yaml => rose-game-ai-reference.yaml (56%) rename {client => rose}/__init__.py (100%) rename {client/game => rose/ai}/__init__.py (100%) rename {client/game => rose/ai}/car.py (100%) rename {client/game => rose/ai}/server.py (99%) create mode 100644 rose/ai/test_car.py create mode 100644 rose/ai/test_server.py create mode 100644 rose/ai/test_track.py rename {client/game => rose/ai}/track.py (100%) rename {client/game => rose/ai}/world.py (94%) create mode 100644 rose/common/__init__.py rename {client/game => rose/common}/actions.py (100%) rename {client/game => rose/common}/obstacles.py (100%) create mode 100644 rose/common/test_actions.py create mode 100644 rose/common/test_obstacles.py rename {client => rose}/main.py (80%) diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..75ee292 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,2 @@ +[run] +omit = test_*.py diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..46b6524 --- /dev/null +++ b/.flake8 @@ -0,0 +1,8 @@ +[flake8] +show_source = True +statistics = True + +# E501: line to long. +# E203: whitespace before ':' to accept black code style +# W503: line break before binary operator +ignore = E501,E203,W503 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..ca3c1ba --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,31 @@ +name: CI +on: +- push +- pull_request +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: | + 3.9 + 3.12 + + - name: Install dependencies + run: pip install --no-cache-dir -r requirements.txt + + - name: Install dev dependencies + run: pip install --no-cache-dir -r requirements-dev.txt + + - name: Lint + run: make lint + + - name: Run tests + run: make test + + - name: Run code quality test + run: make code-quality diff --git a/.gitignore b/.gitignore index 68bc17f..3a76c10 100644 --- a/.gitignore +++ b/.gitignore @@ -1,160 +1,20 @@ -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class +# Python +__pycache__ +*.pyc +.pytest_cache -# C extensions -*.so +# Editors +.idea +.vscode -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ +# Coverage .coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ -.pytest_cache/ -cover/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -.pybuilder/ -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# pyenv -# For a library or package, you might want to ignore these files since the code is -# intended to run in multiple environments; otherwise, check them in: -# .python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# poetry -# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. -# This is especially recommended for binary packages to ensure reproducibility, and is more -# commonly ignored for libraries. -# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control -#poetry.lock - -# pdm -# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. -#pdm.lock -# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it -# in version control. -# https://pdm.fming.dev/#use-with-ide -.pdm.toml - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm -__pypackages__/ - -# Celery stuff -celerybeat-schedule -celerybeat.pid - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ - -# pytype static type analyzer -.pytype/ +htmlcov/ -# Cython debug symbols -cython_debug/ +rose_project.egg-info +dist -# PyCharm -# JetBrains specific template is maintained in a separate JetBrains.gitignore that can -# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore -# and can be added to the global gitignore or merged into this file. For a more nuclear -# option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +# Classroom setup +classroom/credentials.json +classroom/*.csv +classroom/token.pickle diff --git a/Dockerfile b/Dockerfile index 38ffbdf..0510cee 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,8 +6,37 @@ COPY . /ml # Install the Python dependencies RUN pip install --upgrade pip -RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu +RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu ENTRYPOINT ["python", "main.py", "--listen", "0.0.0.0", "--driver", "/ml/driver.py"] CMD ["--port", "8081"] +# --- Build Image --- +FROM registry.access.redhat.com/ubi9/python-39 AS build + +WORKDIR /build + +# Copy only the requirements file and install the Python dependencies +COPY requirements.txt . +RUN pip install --upgrade pip +RUN pip install --no-cache-dir -r requirements.txt +RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + +# --- Runtime Image --- +FROM build + +WORKDIR /app +COPY . /app + +# Add the rose client package to the python path +ENV PYTHONPATH "${PYTHONPATH}:/app" + +# Default values for environment variables +ENV DRIVER ./mydriver.py +ENV PORT 8081 + +# Inform Docker that the container listens on port 3000 +EXPOSE 8081 + +# Define the command to run your app using CMD which defines your runtime +CMD ["sh", "-c", "python rose/main.py --listen 0.0.0.0 --driver ${DRIVER} --port ${PORT}"] diff --git a/Makefile b/Makefile index 71cb19c..811eabe 100644 --- a/Makefile +++ b/Makefile @@ -1,15 +1,47 @@ -# Project variables +.PHONY: lint test lint-fix code-quality run build-image run-image clean + +SRC_DIR = . + IMAGE_NAME ?= quay.io/rose/rose-game-ai-reference +DRIVER_PATH ?= mydriver.py PORT ?= 8081 +# By default, run both linting and tests +all: lint test + +lint: + @echo "Running flake8 linting..." + flake8 --show-source --statistics . + black --check --diff . + +lint-fix: + @echo "Running lint fixing..." + @black --verbose --color . + +code-quality: + @echo "Running static code quality checks..." + radon cc . + radon mi . + +test: + @echo "Running unittests..." + pytest + +run: + @echo "Running driver logic server ..." + PYTHONPATH=$(SRC_DIR):$$PYTHONPATH python rose/main.py --port $(PORT) --driver $(DRIVER_PATH) + build-image: - @echo "Building Docker image..." + @echo "Building container image ..." podman build -t $(IMAGE_NAME) . run-image: @echo "Running container image ..." - podman run --rm \ - --network host \ - -it $(IMAGE_NAME) \ - --port $(PORT) \ - --driver /ml/driver.py + podman run --rm -it --network host -e PORT=$(PORT) $(IMAGE_NAME) + +clean: + -rm -rf .coverage + -rm -rf htmlcov + -rm -rf .pytest_cache + -find . -name '*.pyc' -exec rm {} \; + -find . -name '__pycache__' -exec rmdir {} \; diff --git a/conftest.py b/conftest.py new file mode 100644 index 0000000..89b5464 --- /dev/null +++ b/conftest.py @@ -0,0 +1,4 @@ +# conftest.py +import sys + +sys.path.append(".") diff --git a/model.py b/model.py index 6fc26a5..32f831c 100644 --- a/model.py +++ b/model.py @@ -9,8 +9,6 @@ Trained models are expected to be in the checkpoints directory. """ -import os - try: import torch except ImportError: @@ -75,6 +73,7 @@ class DriverModel(nn.Module): Note: - The model expects a flattened version of the 3x4x7 input tensor, which should be reshaped to (batch_size, 84) before being passed to the model. """ + def __init__(self): super(DriverModel, self).__init__() diff --git a/driver.py b/mydriver.py similarity index 98% rename from driver.py rename to mydriver.py index a1e3059..0a9b6d4 100644 --- a/driver.py +++ b/mydriver.py @@ -10,14 +10,10 @@ """ import os -import sys # Get the directory of the current script script_directory = os.path.dirname(os.path.abspath(__file__)) -# Add the script directory to the system path -sys.path.append(script_directory) - # Try to import pytorch try: import torch @@ -26,7 +22,7 @@ print(" see: https://pytorch.org/get-started/locally/") exit() -from model import DriverModel, outputs_to_action, view_to_inputs +from model import DriverModel, outputs_to_action, view_to_inputs # noqa: E402 """ Torch Car diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..e6da193 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +addopts = -vv -rxs --timeout 10 --cov . diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..df92181 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,10 @@ +# requirements.txt + +flake8>=3.9.0 +coverage>=7.3.0 +radon>=6.0.0 +black>=23.7.0 +pytest +pytest-check-links +pytest-coverage +pytest-timeout \ No newline at end of file diff --git a/rose-ml-driver.yaml b/rose-game-ai-reference.yaml similarity index 56% rename from rose-ml-driver.yaml rename to rose-game-ai-reference.yaml index c21f05e..829adeb 100644 --- a/rose-ml-driver.yaml +++ b/rose-game-ai-reference.yaml @@ -1,22 +1,22 @@ apiVersion: apps/v1 kind: Deployment metadata: - name: rose-game-ai + name: rose-game-ai-reference labels: - app: rose-game-ai + app: rose-game-ai-reference spec: replicas: 1 selector: matchLabels: - app: rose-game-ai + app: rose-game-ai-reference template: metadata: labels: - app: rose-game-ai + app: rose-game-ai-reference spec: containers: - - name: rose-game-ai - image: quay.io/rose/rose-game-ai:latest # Modify with your Docker image name and tag. + - name: rose-game-ai-reference + image: quay.io/rose/rose-game-ai-reference:latest # Modify with your Docker image name and tag. ports: - containerPort: 8081 @@ -25,10 +25,10 @@ spec: apiVersion: v1 kind: Service metadata: - name: rose-game-ai + name: rose-game-ai-reference spec: selector: - app: rose-game-ai + app: rose-game-ai-reference ports: - protocol: TCP port: 8081 diff --git a/client/__init__.py b/rose/__init__.py similarity index 100% rename from client/__init__.py rename to rose/__init__.py diff --git a/client/game/__init__.py b/rose/ai/__init__.py similarity index 100% rename from client/game/__init__.py rename to rose/ai/__init__.py diff --git a/client/game/car.py b/rose/ai/car.py similarity index 100% rename from client/game/car.py rename to rose/ai/car.py diff --git a/client/game/server.py b/rose/ai/server.py similarity index 99% rename from client/game/server.py rename to rose/ai/server.py index d0d47a0..fa9923c 100644 --- a/client/game/server.py +++ b/rose/ai/server.py @@ -4,7 +4,7 @@ import socket import socketserver -from game import world +from rose.ai import world log = logging.getLogger("driver") diff --git a/rose/ai/test_car.py b/rose/ai/test_car.py new file mode 100644 index 0000000..c200b17 --- /dev/null +++ b/rose/ai/test_car.py @@ -0,0 +1,21 @@ +import pytest +from rose.ai.car import Car + + +def test_car_initialization(): + info = {"x": 5, "y": 10} + car = Car(info) + + assert car.x == 5 + assert car.y == 10 + + +def test_car_initialization_missing_key(): + info_missing_x = {"y": 10} + info_missing_y = {"x": 5} + + with pytest.raises(KeyError): + Car(info_missing_x) + + with pytest.raises(KeyError): + Car(info_missing_y) diff --git a/rose/ai/test_server.py b/rose/ai/test_server.py new file mode 100644 index 0000000..7016797 --- /dev/null +++ b/rose/ai/test_server.py @@ -0,0 +1,57 @@ +import pytest +import requests +import threading +from rose.ai.server import MyTCPServer, MyHTTPRequestHandler + + +def drive(world): + return "" + + +# Start the server in a separate thread for testing +@pytest.fixture(scope="module") +def start_server(): + server_address = ("", 8081) + MyHTTPRequestHandler.drive = drive + httpd = MyTCPServer(server_address, MyHTTPRequestHandler) + thread = threading.Thread(target=httpd.serve_forever) + thread.start() + yield + httpd.shutdown() + thread.join() + + +def test_get_driver_name(start_server): + response = requests.get("http://localhost:8081/") + data = response.json() + assert data["info"]["name"] == "Unknown" # Default driver name + + +def test_post_valid_data(start_server): + payload = { + "info": {"car": {"x": 3, "y": 8}}, + "track": [ + ["", "", "bike"], + ["", "", ""], + ["", "", ""], + ["", "", ""], + ["", "", ""], + ["", "", ""], + ["", "", ""], + ["", "", ""], + ], + } + response = requests.post("http://localhost:8081/", json=payload) + data = response.json() + assert "action" in data["info"] + + +def test_post_invalid_json(start_server): + response = requests.post("http://localhost:8081/", data="not a valid json") + assert response.status_code == 400 + + +def test_post_unexpected_data_structure(start_server): + payload = {"unexpected": "data"} + response = requests.post("http://localhost:8081/", json=payload) + assert response.status_code == 500 diff --git a/rose/ai/test_track.py b/rose/ai/test_track.py new file mode 100644 index 0000000..0d3a1f2 --- /dev/null +++ b/rose/ai/test_track.py @@ -0,0 +1,46 @@ +import pytest +from rose.ai.track import Track + + +def test_track_initialization(): + t = Track() + assert t.max_x == 0 + assert t.max_y == 0 + + t2 = Track([["a", "b"], ["c", "d"]]) + assert t2.max_x == 2 + assert t2.max_y == 2 + + +def test_track_get(): + t = Track([["a", "b"], ["c", "d"]]) + assert t.get(0, 0) == "a" + assert t.get(1, 0) == "b" + assert t.get(0, 1) == "c" + assert t.get(1, 1) == "d" + + +def test_track_get_out_of_bounds(): + t = Track([["a", "b"], ["c", "d"]]) + + with pytest.raises(IndexError, match="x out of range: 0-1"): + t.get(2, 0) + + with pytest.raises(IndexError, match="y out of range: 0-1"): + t.get(0, 2) + + +def test_track_validate_pos(): + t = Track([["a", "b"], ["c", "d"]]) + + # These should not raise any errors + t._validate_pos(0, 0) + t._validate_pos(1, 0) + t._validate_pos(0, 1) + t._validate_pos(1, 1) + + with pytest.raises(IndexError, match="x out of range: 0-1"): + t._validate_pos(2, 0) + + with pytest.raises(IndexError, match="y out of range: 0-1"): + t._validate_pos(0, 2) diff --git a/client/game/track.py b/rose/ai/track.py similarity index 100% rename from client/game/track.py rename to rose/ai/track.py diff --git a/client/game/world.py b/rose/ai/world.py similarity index 94% rename from client/game/world.py rename to rose/ai/world.py index e2e10a0..85568f5 100644 --- a/client/game/world.py +++ b/rose/ai/world.py @@ -1,5 +1,5 @@ -from game.car import Car -from game.track import Track +from rose.ai.car import Car +from rose.ai.track import Track def create(game_data): diff --git a/rose/common/__init__.py b/rose/common/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/client/game/actions.py b/rose/common/actions.py similarity index 100% rename from client/game/actions.py rename to rose/common/actions.py diff --git a/client/game/obstacles.py b/rose/common/obstacles.py similarity index 100% rename from client/game/obstacles.py rename to rose/common/obstacles.py diff --git a/rose/common/test_actions.py b/rose/common/test_actions.py new file mode 100644 index 0000000..dbbaeb5 --- /dev/null +++ b/rose/common/test_actions.py @@ -0,0 +1,14 @@ +from rose.common.actions import NONE, RIGHT, LEFT, PICKUP, JUMP, BRAKE, ALL + + +def test_constants(): + assert NONE == "none" + assert RIGHT == "right" + assert LEFT == "left" + assert PICKUP == "pickup" + assert JUMP == "jump" + assert BRAKE == "brake" + + +def test_all_constant(): + assert ALL == (NONE, RIGHT, LEFT, PICKUP, JUMP, BRAKE) diff --git a/rose/common/test_obstacles.py b/rose/common/test_obstacles.py new file mode 100644 index 0000000..1b5d688 --- /dev/null +++ b/rose/common/test_obstacles.py @@ -0,0 +1,31 @@ +from rose.common.obstacles import ( + NONE, + CRACK, + TRASH, + PENGUIN, + BIKE, + WATER, + BARRIER, + ALL, + get_random_obstacle, +) + + +def test_constants(): + assert NONE == "" + assert CRACK == "crack" + assert TRASH == "trash" + assert PENGUIN == "penguin" + assert BIKE == "bike" + assert WATER == "water" + assert BARRIER == "barrier" + + +def test_all_constant(): + assert ALL == (NONE, CRACK, TRASH, PENGUIN, BIKE, WATER, BARRIER) + + +def test_get_random_obstacle(): + # This test checks if the function returns a valid obstacle + obstacle = get_random_obstacle() + assert obstacle in ALL diff --git a/client/main.py b/rose/main.py similarity index 80% rename from client/main.py rename to rose/main.py index 15f20c7..5e2fe86 100644 --- a/client/main.py +++ b/rose/main.py @@ -2,26 +2,24 @@ import importlib.util import logging -from game import server +from rose.ai import server def load_driver_module(driver_path): """ Load the driver module from the specified path. - :param driver_path: Path to the driver module. - :return: The loaded module. - :raises ImportError: If there's an issue loading the module. + Arguments: + file_path (str): The path to the driver module + Returns: + Driver module (module) + Raises: + Exception if the module cannot be loaded """ - try: - spec = importlib.util.spec_from_file_location("driver_module", driver_path) - driver_module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(driver_module) - return driver_module - except Exception as e: - raise ImportError( - f"Error loading driver module from path {driver_path}: {str(e)}" - ) + spec = importlib.util.spec_from_file_location("driver_module", driver_path) + driver_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(driver_module) + return driver_module def main(): diff --git a/train.py b/train.py index 7a37ebf..61bd1c7 100644 --- a/train.py +++ b/train.py @@ -3,10 +3,10 @@ This script is used to train a deep learning model for a driving simulator. The model is trained using PyTorch. -The script generates a 4x3 2D array with random obstacles and simulates the driver's decision based on the obstacle in front of the car. -The driver's decision and the 2D array are used to generate a batch of samples for training. +The script generates a 4x3 2D array with random obstacles and simulates the driver's decision based on the obstacle in front of the car. +The driver's decision and the 2D array are used to generate a batch of samples for training. -The model is trained for a specified number of epochs. In each epoch, the model is trained over a number of batches. +The model is trained for a specified number of epochs. In each epoch, the model is trained over a number of batches. For each batch, the model's parameters are updated based on the computed loss between the model's predictions and the actual targets. The script requires PyTorch to be installed. See: https://pytorch.org/get-started/locally/ @@ -109,7 +109,7 @@ def generate_batch(batch_size): inputs = [] targets = [] for _ in range(batch_size): - car_x = random.choice([0,1,2]) + car_x = random.choice([0, 1, 2]) array = generate_obstacle_array() correct_output = driver_simulator(array, car_x) @@ -126,43 +126,54 @@ def main(): for epoch in range(num_epochs): # Initialize running loss to 0.0 at the start of each epoch running_loss = 0.0 - + # Assuming you have a dataset size, calculate the number of batches num_batches = 100 - + # Loop over each batch - for i in range(num_batches): + for _i in range(num_batches): # Get a batch of training data inputs, targets = generate_batch(batch_size) - + # Reset the gradients in the optimizer (i.e., make it forget the gradients computed in the previous iteration) optimizer.zero_grad() - + # Forward pass: compute predicted outputs by passing inputs to the model outputs = model(inputs) - + # Compute loss: calculate the batch loss based on the difference between the predicted outputs and the actual targets loss = criterion(outputs, targets) - + # Backward pass: compute gradient of the loss with respect to model parameters loss.backward() - + # Perform a single optimization step (parameter update) optimizer.step() - + # Update running loss running_loss += loss.item() - + # Print average loss for the epoch - print(f"Epoch {epoch+1}, Loss: {running_loss / num_batches}") - + print(f"Epoch {epoch + 1}, Loss: {running_loss / num_batches}") + + if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Train the model.') - parser.add_argument('--checkpoint-in', default="", help='Path to the input checkpoint file.') - parser.add_argument('--checkpoint-out', default="", help='Path to the output checkpoint file.') - parser.add_argument('--num-epochs', type=int, default=10, help='Number of epochs for training.') - parser.add_argument('--batch-size', type=int, default=200, help='Batch size for training.') - parser.add_argument('--learning-rate', type=float, default=0.001, help='Learning rate for training.') + parser = argparse.ArgumentParser(description="Train the model.") + parser.add_argument( + "--checkpoint-in", default="", help="Path to the input checkpoint file." + ) + parser.add_argument( + "--checkpoint-out", default="", help="Path to the output checkpoint file." + ) + parser.add_argument( + "--num-epochs", type=int, default=10, help="Number of epochs for training." + ) + parser.add_argument( + "--batch-size", type=int, default=200, help="Batch size for training." + ) + parser.add_argument( + "--learning-rate", type=float, default=0.001, help="Learning rate for training." + ) args = parser.parse_args() # Training parameters @@ -178,10 +189,9 @@ def main(): if args.checkpoint_in != "": model.load_state_dict(torch.load(args.checkpoint_in)) model.eval() - + # Run training main() print("Finished Training") - torch.save(model.state_dict(), args.checkpoint_out or f"driver.pth") - \ No newline at end of file + torch.save(model.state_dict(), args.checkpoint_out or "driver.pth") From 218458574278e9313b6f5862e43fb07eeb2985a0 Mon Sep 17 00:00:00 2001 From: Yaacov Zamir Date: Tue, 9 Jan 2024 17:15:52 +0200 Subject: [PATCH 2/2] try load module Signed-off-by: Yaacov Zamir --- rose/main.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/rose/main.py b/rose/main.py index 5e2fe86..72aaff1 100644 --- a/rose/main.py +++ b/rose/main.py @@ -16,10 +16,15 @@ def load_driver_module(driver_path): Raises: Exception if the module cannot be loaded """ - spec = importlib.util.spec_from_file_location("driver_module", driver_path) - driver_module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(driver_module) - return driver_module + try: + spec = importlib.util.spec_from_file_location("driver_module", driver_path) + driver_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(driver_module) + return driver_module + except Exception as e: + raise ImportError( + f"Error loading driver module from path {driver_path}: {str(e)}" + ) def main():