Skip to content

Commit

Permalink
Add json_str parsing function to simplify specifying json strings in …
Browse files Browse the repository at this point in the history
…overrides (facebookresearch#2930)
  • Loading branch information
jesszzzz committed Aug 5, 2024
1 parent de47f41 commit 9deaa6c
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 48 deletions.
38 changes: 38 additions & 0 deletions hydra/_internal/grammar/grammar_functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import builtins
import json
import random
from copy import copy
from typing import Any, Callable, Dict, List, Optional, Union
Expand Down Expand Up @@ -138,6 +139,43 @@ def cast_str(*args: CastType, value: Optional[CastType] = None) -> Any:
return str(value)


def extract_text(*args: Any, value: Optional[Any] = None) -> Any:
value = _normalize_cast_value(*args, value=value)
if isinstance(value, QuotedString):
return value.text
if isinstance(value, dict):
return apply_to_dict_values(value, extract_text)
elif isinstance(value, list):
return list(map(extract_text, value))
elif isinstance(value, ChoiceSweep):
return cast_choice(value, extract_text)
elif isinstance(value, RangeSweep):
return cast_range(value, extract_text)
else:
return value


def cast_json_str(*args: Any, value: Optional[Any] = None) -> Any:
value = _normalize_cast_value(*args, value=value)
json_val = value
if isinstance(value, QuotedString):
json_val = value.text
if isinstance(value, dict):
json_val = apply_to_dict_values(value, extract_text)
elif isinstance(value, list):
json_val = list(map(extract_text, value))
elif isinstance(value, ChoiceSweep):
json_choices = cast_choice(value, extract_text)
return cast_choice(json_choices, json.dumps)
elif isinstance(value, RangeSweep):
json_range = cast_range(value, extract_text)
return cast_range(json_range, json.dumps)
elif isinstance(value, IntervalSweep):
raise ValueError("Intervals cannot be cast to json_str")

return json.dumps(json_val)


def cast_bool(*args: CastType, value: Optional[CastType] = None) -> Any:
value = _normalize_cast_value(*args, value=value)
if isinstance(value, QuotedString):
Expand Down
1 change: 1 addition & 0 deletions hydra/core/override_parser/overrides_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def create_functions() -> Functions:
functions.register(name="str", func=grammar_functions.cast_str)
functions.register(name="bool", func=grammar_functions.cast_bool)
functions.register(name="float", func=grammar_functions.cast_float)
functions.register(name="json_str", func=grammar_functions.cast_json_str)
# sweeps
functions.register(name="choice", func=grammar_functions.choice)
functions.register(name="range", func=grammar_functions.range)
Expand Down
1 change: 1 addition & 0 deletions news/2929.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add json_str function to override syntax
117 changes: 105 additions & 12 deletions tests/test_overrides_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1408,6 +1408,11 @@ def test_sweep_shuffle(value: str, expected: str) -> None:

@dataclass
class CastResults:
json_str: Union[
str,
Sweep,
RaisesContext[HydraException],
]
int: Union[
int,
List[Union[int, List[int]]],
Expand Down Expand Up @@ -1446,11 +1451,25 @@ def error(msg: builtins.str) -> Any:
"value,expected_value",
[
# int
param(10, CastResults(int=10, float=10.0, str="10", bool=True), id="10"),
param(0, CastResults(int=0, float=0.0, str="0", bool=False), id="0"),
param(
10,
CastResults(int=10, float=10.0, str="10", bool=True, json_str="10"),
id="10",
),
param(
0, CastResults(int=0, float=0.0, str="0", bool=False, json_str="0"), id="0"
),
# float
param(10.0, CastResults(int=10, float=10.0, str="10.0", bool=True), id="10.0"),
param(0.0, CastResults(int=0, float=0.0, str="0.0", bool=False), id="0.0"),
param(
10.0,
CastResults(int=10, float=10.0, str="10.0", bool=True, json_str="10.0"),
id="10.0",
),
param(
0.0,
CastResults(int=0, float=0.0, str="0.0", bool=False, json_str="0.0"),
id="0.0",
),
param(
"inf",
CastResults(
Expand All @@ -1460,6 +1479,7 @@ def error(msg: builtins.str) -> Any:
float=math.inf,
str="inf",
bool=True,
json_str="Infinity",
),
id="inf",
),
Expand All @@ -1472,12 +1492,15 @@ def error(msg: builtins.str) -> Any:
float=math.nan,
str="nan",
bool=True,
json_str="NaN",
),
id="nan",
),
param(
"1e6",
CastResults(int=1000000, float=1e6, str="1000000.0", bool=True),
CastResults(
int=1000000, float=1e6, str="1000000.0", bool=True, json_str="1000000.0"
),
id="1e6",
),
param(
Expand All @@ -1493,6 +1516,7 @@ def error(msg: builtins.str) -> Any:
bool=CastResults.error(
"ValueError while evaluating 'bool('')': Cannot cast '' to bool"
),
json_str='""',
),
id="''",
),
Expand All @@ -1506,6 +1530,7 @@ def error(msg: builtins.str) -> Any:
bool=CastResults.error(
"ValueError while evaluating 'bool('10')': Cannot cast '10' to bool"
),
json_str='"10"',
),
id="'10'",
),
Expand All @@ -1520,6 +1545,7 @@ def error(msg: builtins.str) -> Any:
bool=CastResults.error(
"ValueError while evaluating 'bool('10.0')': Cannot cast '10.0' to bool"
),
json_str='"10.0"',
),
id="'10.0'",
),
Expand All @@ -1534,6 +1560,7 @@ def error(msg: builtins.str) -> Any:
),
str="true",
bool=True,
json_str='"true"',
),
id="'true'",
),
Expand All @@ -1548,6 +1575,7 @@ def error(msg: builtins.str) -> Any:
),
str="false",
bool=False,
json_str='"false"',
),
id="'false'",
),
Expand All @@ -1564,6 +1592,7 @@ def error(msg: builtins.str) -> Any:
bool=CastResults.error(
"ValueError while evaluating 'bool('[1,2,3]')': Cannot cast '[1,2,3]' to bool"
),
json_str='"[1,2,3]"',
),
id="'[1,2,3]'",
),
Expand All @@ -1580,30 +1609,44 @@ def error(msg: builtins.str) -> Any:
bool=CastResults.error(
"ValueError while evaluating 'bool('{a:10}')': Cannot cast '{a:10}' to bool"
),
json_str='"{a:10}"',
),
id="'{a:10}'",
),
# bool
param("true", CastResults(int=1, float=1.0, str="true", bool=True), id="true"),
param(
"false", CastResults(int=0, float=0.0, str="false", bool=False), id="false"
"true",
CastResults(int=1, float=1.0, str="true", bool=True, json_str="true"),
id="true",
),
param(
"false",
CastResults(int=0, float=0.0, str="false", bool=False, json_str="false"),
id="false",
),
# list
param("[]", CastResults(int=[], float=[], str=[], bool=[]), id="[]"),
param(
"[]", CastResults(int=[], float=[], str=[], bool=[], json_str="[]"), id="[]"
),
param(
"[0,1,2]",
CastResults(
int=[0, 1, 2],
float=[0.0, 1.0, 2.0],
str=["0", "1", "2"],
bool=[False, True, True],
json_str="[0, 1, 2]",
),
id="[1,2,3]",
),
param(
"[1,[2]]",
CastResults(
int=[1, [2]], float=[1.0, [2.0]], str=["1", ["2"]], bool=[True, [True]]
int=[1, [2]],
float=[1.0, [2.0]],
str=["1", ["2"]],
bool=[True, [True]],
json_str="[1, [2]]",
),
id="[1,[2]]",
),
Expand All @@ -1620,15 +1663,22 @@ def error(msg: builtins.str) -> Any:
bool=CastResults.error(
"ValueError while evaluating 'bool([a,1])': Cannot cast 'a' to bool"
),
json_str='["a", 1]',
),
id="[a,1]",
),
# dicts
param("{}", CastResults(int={}, float={}, str={}, bool={}), id="{}"),
param(
"{}", CastResults(int={}, float={}, str={}, bool={}, json_str="{}"), id="{}"
),
param(
"{a:10}",
CastResults(
int={"a": 10}, float={"a": 10.0}, str={"a": "10"}, bool={"a": True}
int={"a": 10},
float={"a": 10.0},
str={"a": "10"},
bool={"a": True},
json_str='{"a": 10}',
),
id="{a:10}",
),
Expand All @@ -1639,6 +1689,7 @@ def error(msg: builtins.str) -> Any:
float={"a": [0.0, 1.0, 2.0]},
str={"a": ["0", "1", "2"]},
bool={"a": [False, True, True]},
json_str='{"a": [0, 1, 2]}',
),
id="{a:[0,1,2]}",
),
Expand All @@ -1655,9 +1706,34 @@ def error(msg: builtins.str) -> Any:
bool=CastResults.error(
"ValueError while evaluating 'bool({a:10,b:xyz})': Cannot cast 'xyz' to bool"
),
json_str='{"a": 10, "b": "xyz"}',
),
id="{a:10,b:xyz}",
),
param(
"{a:10,b:xyz,c:{d:foo,f:[1,2,{g:0}]}}",
CastResults(
int=CastResults.error(
"ValueError while evaluating 'int({a:10,b:xyz,c:{d:foo,f:[1,2,{g:0}]}})': "
"invalid literal for int() with base 10: 'xyz'"
),
float=CastResults.error(
"ValueError while evaluating 'float({a:10,b:xyz,c:{d:foo,f:[1,2,{g:0}]}})': "
"could not convert string to float: 'xyz'"
),
str={
"a": "10",
"b": "xyz",
"c": {"d": "foo", "f": ["1", "2", {"g": "0"}]},
},
bool=CastResults.error(
"ValueError while evaluating 'bool({a:10,b:xyz,c:{d:foo,f:[1,2,{g:0}]}})': "
"Cannot cast 'xyz' to bool"
),
json_str='{"a": 10, "b": "xyz", "c": {"d": "foo", "f": [1, 2, {"g": 0}]}}',
),
id="{a:10,b:xyz,c:{d:foo,f:[1,2,{g:0}]}}",
),
# choice
param(
"choice(0,1)",
Expand All @@ -1666,6 +1742,7 @@ def error(msg: builtins.str) -> Any:
float=ChoiceSweep(list=[0.0, 1.0]),
str=ChoiceSweep(list=["0", "1"]),
bool=ChoiceSweep(list=[False, True]),
json_str=ChoiceSweep(list=["0", "1"]),
),
id="choice(0,1)",
),
Expand All @@ -1676,6 +1753,7 @@ def error(msg: builtins.str) -> Any:
float=ChoiceSweep(list=[2.0, 1.0, 0.0], simple_form=True),
str=ChoiceSweep(list=["2", "1", "0"], simple_form=True),
bool=ChoiceSweep(list=[True, True, False], simple_form=True),
json_str=ChoiceSweep(list=["2", "1", "0"], simple_form=True),
),
id="simple_choice:ints",
),
Expand All @@ -1697,6 +1775,10 @@ def error(msg: builtins.str) -> Any:
bool=CastResults.error(
"ValueError while evaluating 'bool(a,'b',1,1.0,true,[a,b],{a:10})': Cannot cast 'a' to bool"
),
json_str=ChoiceSweep(
list=['"a"', '"b"', "1", "1.0", "true", '["a", "b"]', '{"a": 10}'],
simple_form=True,
),
),
id="simple_choice:types",
),
Expand All @@ -1713,6 +1795,7 @@ def error(msg: builtins.str) -> Any:
bool=CastResults.error(
"ValueError while evaluating 'bool(choice(a,b))': Cannot cast 'a' to bool"
),
json_str=ChoiceSweep(list=['"a"', '"b"']),
),
id="choice(a,b)",
),
Expand All @@ -1729,6 +1812,7 @@ def error(msg: builtins.str) -> Any:
bool=CastResults.error(
"ValueError while evaluating 'bool(choice(1,a))': Cannot cast 'a' to bool"
),
json_str=ChoiceSweep(list=["1", '"a"']),
),
id="choice(1,a)",
),
Expand All @@ -1744,6 +1828,9 @@ def error(msg: builtins.str) -> Any:
bool=CastResults.error(
"ValueError while evaluating 'bool(interval(1.0, 2.0))': Intervals cannot be cast to bool"
),
json_str=CastResults.error(
"ValueError while evaluating 'json_str(interval(1.0, 2.0))': Intervals cannot be cast to json_str"
),
),
id="interval(1.0, 2.0)",
),
Expand All @@ -1759,6 +1846,9 @@ def error(msg: builtins.str) -> Any:
bool=CastResults.error(
"ValueError while evaluating 'bool(range(1,10))': Range can only be cast to int or float"
),
json_str=CastResults.error(
"ValueError while evaluating 'json_str(range(1,10))': Range can only be cast to int or float"
),
),
id="range(1,10)",
),
Expand All @@ -1773,13 +1863,16 @@ def error(msg: builtins.str) -> Any:
bool=CastResults.error(
"ValueError while evaluating 'bool(range(1.0,10.0))': Range can only be cast to int or float"
),
json_str=CastResults.error(
"ValueError while evaluating 'json_str(range(1.0,10.0))': Range can only be cast to int or float"
),
),
id="range(1.0,10.0)",
),
],
)
def test_cast_conversions(value: Any, expected_value: Any) -> None:
for field in ("int", "float", "bool", "str"):
for field in ("int", "float", "bool", "str", "json_str"):
cast_str = f"{field}({value})"
expected = getattr(expected_value, field)
if isinstance(expected, RaisesContext):
Expand Down
Loading

0 comments on commit 9deaa6c

Please sign in to comment.