Skip to content

Commit

Permalink
Switch to example.env.yaml if no env.yaml (#5)
Browse files Browse the repository at this point in the history
Config test tweak
  • Loading branch information
mirmozavr authored Nov 12, 2023
1 parent 0fee61b commit d0a879d
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 23 deletions.
2 changes: 2 additions & 0 deletions src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class Config:
def get_config(path: Union[str, Path] = None) -> Config:
if path is None:
path = Path(__file__).parent.parent.joinpath("env.yaml")
if not Path(path).exists():
path = Path(__file__).parent.parent.joinpath("example.env.yaml")
with open(path, encoding="utf-8") as f:
config = Config(**safe_load(f))
return config
Expand Down
24 changes: 1 addition & 23 deletions src/tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from pathlib import Path
from unittest.mock import patch

from src.config import Config, get_config
Expand All @@ -14,33 +13,14 @@


class TestConfigFunctions:
@patch("builtins.open", create=True)
@patch(
"src.config.safe_load",
return_value=conf_dict,
)
def test_get_config_with_custom_path(self, mock_safe_load, mock_open):
def test_get_config_with_custom_path(self, mock_safe_load):
custom_path = "custom_path.yaml"
config = get_config(custom_path)
assert config == Config(
OPENAI_API_KEY="your_key",
OPENAI_PROMPTS_PATH="path",
SOURCE_DIR="source",
PROCESS_DIR="process",
OUTPUT_DIR="output",
YT_PROBA=80,
)
mock_open.assert_called_once_with(custom_path, encoding="utf-8")
mock_safe_load.assert_called_once()

@patch("builtins.open", create=True)
@patch(
"src.config.safe_load",
return_value=conf_dict,
)
def test_get_config_with_default_path(self, mock_safe_load, mock_open):
default_path = Path(__file__).parent.parent.parent.joinpath("env.yaml")
config = get_config()
assert config == Config(
OPENAI_API_KEY="your_key",
OPENAI_PROMPTS_PATH="path",
Expand All @@ -49,5 +29,3 @@ def test_get_config_with_default_path(self, mock_safe_load, mock_open):
OUTPUT_DIR="output",
YT_PROBA=80,
)
mock_open.assert_called_once_with(default_path, encoding="utf-8")
mock_safe_load.assert_called_once()

0 comments on commit d0a879d

Please sign in to comment.