diff --git a/hydra/_internal/grammar/grammar_functions.py b/hydra/_internal/grammar/grammar_functions.py index 7d08cc9550a..0b458ba478b 100644 --- a/hydra/_internal/grammar/grammar_functions.py +++ b/hydra/_internal/grammar/grammar_functions.py @@ -1,5 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import builtins +import json import random from copy import copy from typing import Any, Callable, Dict, List, Optional, Union @@ -138,6 +139,43 @@ def cast_str(*args: CastType, value: Optional[CastType] = None) -> Any: return str(value) +def extract_text(*args: Any, value: Optional[Any] = None) -> Any: + value = _normalize_cast_value(*args, value=value) + if isinstance(value, QuotedString): + return value.text + if isinstance(value, dict): + return apply_to_dict_values(value, extract_text) + elif isinstance(value, list): + return list(map(extract_text, value)) + elif isinstance(value, ChoiceSweep): + return cast_choice(value, extract_text) + elif isinstance(value, RangeSweep): + return cast_range(value, extract_text) + else: + return value + + +def cast_json_str(*args: Any, value: Optional[Any] = None) -> Any: + value = _normalize_cast_value(*args, value=value) + json_val = value + if isinstance(value, QuotedString): + json_val = value.text + if isinstance(value, dict): + json_val = apply_to_dict_values(value, extract_text) + elif isinstance(value, list): + json_val = list(map(extract_text, value)) + elif isinstance(value, ChoiceSweep): + json_choices = cast_choice(value, extract_text) + return cast_choice(json_choices, json.dumps) + elif isinstance(value, RangeSweep): + json_range = cast_range(value, extract_text) + return cast_range(json_range, json.dumps) + elif isinstance(value, IntervalSweep): + raise ValueError("Intervals cannot be cast to json_str") + + return json.dumps(json_val) + + def cast_bool(*args: CastType, value: Optional[CastType] = None) -> Any: value = _normalize_cast_value(*args, value=value) if isinstance(value, QuotedString): diff --git a/hydra/core/override_parser/overrides_parser.py b/hydra/core/override_parser/overrides_parser.py index b095ccaad06..50192d72a3a 100644 --- a/hydra/core/override_parser/overrides_parser.py +++ b/hydra/core/override_parser/overrides_parser.py @@ -114,6 +114,7 @@ def create_functions() -> Functions: functions.register(name="str", func=grammar_functions.cast_str) functions.register(name="bool", func=grammar_functions.cast_bool) functions.register(name="float", func=grammar_functions.cast_float) + functions.register(name="json_str", func=grammar_functions.cast_json_str) # sweeps functions.register(name="choice", func=grammar_functions.choice) functions.register(name="range", func=grammar_functions.range) diff --git a/news/2929.feature b/news/2929.feature new file mode 100644 index 00000000000..615b987bef6 --- /dev/null +++ b/news/2929.feature @@ -0,0 +1 @@ +Add json_str function to override syntax diff --git a/tests/test_overrides_parser.py b/tests/test_overrides_parser.py index 62df296f795..e89df9e8159 100644 --- a/tests/test_overrides_parser.py +++ b/tests/test_overrides_parser.py @@ -1408,6 +1408,11 @@ def test_sweep_shuffle(value: str, expected: str) -> None: @dataclass class CastResults: + json_str: Union[ + str, + Sweep, + RaisesContext[HydraException], + ] int: Union[ int, List[Union[int, List[int]]], @@ -1446,11 +1451,25 @@ def error(msg: builtins.str) -> Any: "value,expected_value", [ # int - param(10, CastResults(int=10, float=10.0, str="10", bool=True), id="10"), - param(0, CastResults(int=0, float=0.0, str="0", bool=False), id="0"), + param( + 10, + CastResults(int=10, float=10.0, str="10", bool=True, json_str="10"), + id="10", + ), + param( + 0, CastResults(int=0, float=0.0, str="0", bool=False, json_str="0"), id="0" + ), # float - param(10.0, CastResults(int=10, float=10.0, str="10.0", bool=True), id="10.0"), - param(0.0, CastResults(int=0, float=0.0, str="0.0", bool=False), id="0.0"), + param( + 10.0, + CastResults(int=10, float=10.0, str="10.0", bool=True, json_str="10.0"), + id="10.0", + ), + param( + 0.0, + CastResults(int=0, float=0.0, str="0.0", bool=False, json_str="0.0"), + id="0.0", + ), param( "inf", CastResults( @@ -1460,6 +1479,7 @@ def error(msg: builtins.str) -> Any: float=math.inf, str="inf", bool=True, + json_str="Infinity", ), id="inf", ), @@ -1472,12 +1492,15 @@ def error(msg: builtins.str) -> Any: float=math.nan, str="nan", bool=True, + json_str="NaN", ), id="nan", ), param( "1e6", - CastResults(int=1000000, float=1e6, str="1000000.0", bool=True), + CastResults( + int=1000000, float=1e6, str="1000000.0", bool=True, json_str="1000000.0" + ), id="1e6", ), param( @@ -1493,6 +1516,7 @@ def error(msg: builtins.str) -> Any: bool=CastResults.error( "ValueError while evaluating 'bool('')': Cannot cast '' to bool" ), + json_str='""', ), id="''", ), @@ -1506,6 +1530,7 @@ def error(msg: builtins.str) -> Any: bool=CastResults.error( "ValueError while evaluating 'bool('10')': Cannot cast '10' to bool" ), + json_str='"10"', ), id="'10'", ), @@ -1520,6 +1545,7 @@ def error(msg: builtins.str) -> Any: bool=CastResults.error( "ValueError while evaluating 'bool('10.0')': Cannot cast '10.0' to bool" ), + json_str='"10.0"', ), id="'10.0'", ), @@ -1534,6 +1560,7 @@ def error(msg: builtins.str) -> Any: ), str="true", bool=True, + json_str='"true"', ), id="'true'", ), @@ -1548,6 +1575,7 @@ def error(msg: builtins.str) -> Any: ), str="false", bool=False, + json_str='"false"', ), id="'false'", ), @@ -1564,6 +1592,7 @@ def error(msg: builtins.str) -> Any: bool=CastResults.error( "ValueError while evaluating 'bool('[1,2,3]')': Cannot cast '[1,2,3]' to bool" ), + json_str='"[1,2,3]"', ), id="'[1,2,3]'", ), @@ -1580,16 +1609,25 @@ def error(msg: builtins.str) -> Any: bool=CastResults.error( "ValueError while evaluating 'bool('{a:10}')': Cannot cast '{a:10}' to bool" ), + json_str='"{a:10}"', ), id="'{a:10}'", ), # bool - param("true", CastResults(int=1, float=1.0, str="true", bool=True), id="true"), param( - "false", CastResults(int=0, float=0.0, str="false", bool=False), id="false" + "true", + CastResults(int=1, float=1.0, str="true", bool=True, json_str="true"), + id="true", + ), + param( + "false", + CastResults(int=0, float=0.0, str="false", bool=False, json_str="false"), + id="false", ), # list - param("[]", CastResults(int=[], float=[], str=[], bool=[]), id="[]"), + param( + "[]", CastResults(int=[], float=[], str=[], bool=[], json_str="[]"), id="[]" + ), param( "[0,1,2]", CastResults( @@ -1597,13 +1635,18 @@ def error(msg: builtins.str) -> Any: float=[0.0, 1.0, 2.0], str=["0", "1", "2"], bool=[False, True, True], + json_str="[0, 1, 2]", ), id="[1,2,3]", ), param( "[1,[2]]", CastResults( - int=[1, [2]], float=[1.0, [2.0]], str=["1", ["2"]], bool=[True, [True]] + int=[1, [2]], + float=[1.0, [2.0]], + str=["1", ["2"]], + bool=[True, [True]], + json_str="[1, [2]]", ), id="[1,[2]]", ), @@ -1620,15 +1663,22 @@ def error(msg: builtins.str) -> Any: bool=CastResults.error( "ValueError while evaluating 'bool([a,1])': Cannot cast 'a' to bool" ), + json_str='["a", 1]', ), id="[a,1]", ), # dicts - param("{}", CastResults(int={}, float={}, str={}, bool={}), id="{}"), + param( + "{}", CastResults(int={}, float={}, str={}, bool={}, json_str="{}"), id="{}" + ), param( "{a:10}", CastResults( - int={"a": 10}, float={"a": 10.0}, str={"a": "10"}, bool={"a": True} + int={"a": 10}, + float={"a": 10.0}, + str={"a": "10"}, + bool={"a": True}, + json_str='{"a": 10}', ), id="{a:10}", ), @@ -1639,6 +1689,7 @@ def error(msg: builtins.str) -> Any: float={"a": [0.0, 1.0, 2.0]}, str={"a": ["0", "1", "2"]}, bool={"a": [False, True, True]}, + json_str='{"a": [0, 1, 2]}', ), id="{a:[0,1,2]}", ), @@ -1655,9 +1706,34 @@ def error(msg: builtins.str) -> Any: bool=CastResults.error( "ValueError while evaluating 'bool({a:10,b:xyz})': Cannot cast 'xyz' to bool" ), + json_str='{"a": 10, "b": "xyz"}', ), id="{a:10,b:xyz}", ), + param( + "{a:10,b:xyz,c:{d:foo,f:[1,2,{g:0}]}}", + CastResults( + int=CastResults.error( + "ValueError while evaluating 'int({a:10,b:xyz,c:{d:foo,f:[1,2,{g:0}]}})': " + "invalid literal for int() with base 10: 'xyz'" + ), + float=CastResults.error( + "ValueError while evaluating 'float({a:10,b:xyz,c:{d:foo,f:[1,2,{g:0}]}})': " + "could not convert string to float: 'xyz'" + ), + str={ + "a": "10", + "b": "xyz", + "c": {"d": "foo", "f": ["1", "2", {"g": "0"}]}, + }, + bool=CastResults.error( + "ValueError while evaluating 'bool({a:10,b:xyz,c:{d:foo,f:[1,2,{g:0}]}})': " + "Cannot cast 'xyz' to bool" + ), + json_str='{"a": 10, "b": "xyz", "c": {"d": "foo", "f": [1, 2, {"g": 0}]}}', + ), + id="{a:10,b:xyz,c:{d:foo,f:[1,2,{g:0}]}}", + ), # choice param( "choice(0,1)", @@ -1666,6 +1742,7 @@ def error(msg: builtins.str) -> Any: float=ChoiceSweep(list=[0.0, 1.0]), str=ChoiceSweep(list=["0", "1"]), bool=ChoiceSweep(list=[False, True]), + json_str=ChoiceSweep(list=["0", "1"]), ), id="choice(0,1)", ), @@ -1676,6 +1753,7 @@ def error(msg: builtins.str) -> Any: float=ChoiceSweep(list=[2.0, 1.0, 0.0], simple_form=True), str=ChoiceSweep(list=["2", "1", "0"], simple_form=True), bool=ChoiceSweep(list=[True, True, False], simple_form=True), + json_str=ChoiceSweep(list=["2", "1", "0"], simple_form=True), ), id="simple_choice:ints", ), @@ -1697,6 +1775,10 @@ def error(msg: builtins.str) -> Any: bool=CastResults.error( "ValueError while evaluating 'bool(a,'b',1,1.0,true,[a,b],{a:10})': Cannot cast 'a' to bool" ), + json_str=ChoiceSweep( + list=['"a"', '"b"', "1", "1.0", "true", '["a", "b"]', '{"a": 10}'], + simple_form=True, + ), ), id="simple_choice:types", ), @@ -1713,6 +1795,7 @@ def error(msg: builtins.str) -> Any: bool=CastResults.error( "ValueError while evaluating 'bool(choice(a,b))': Cannot cast 'a' to bool" ), + json_str=ChoiceSweep(list=['"a"', '"b"']), ), id="choice(a,b)", ), @@ -1729,6 +1812,7 @@ def error(msg: builtins.str) -> Any: bool=CastResults.error( "ValueError while evaluating 'bool(choice(1,a))': Cannot cast 'a' to bool" ), + json_str=ChoiceSweep(list=["1", '"a"']), ), id="choice(1,a)", ), @@ -1744,6 +1828,9 @@ def error(msg: builtins.str) -> Any: bool=CastResults.error( "ValueError while evaluating 'bool(interval(1.0, 2.0))': Intervals cannot be cast to bool" ), + json_str=CastResults.error( + "ValueError while evaluating 'json_str(interval(1.0, 2.0))': Intervals cannot be cast to json_str" + ), ), id="interval(1.0, 2.0)", ), @@ -1759,6 +1846,9 @@ def error(msg: builtins.str) -> Any: bool=CastResults.error( "ValueError while evaluating 'bool(range(1,10))': Range can only be cast to int or float" ), + json_str=CastResults.error( + "ValueError while evaluating 'json_str(range(1,10))': Range can only be cast to int or float" + ), ), id="range(1,10)", ), @@ -1773,13 +1863,16 @@ def error(msg: builtins.str) -> Any: bool=CastResults.error( "ValueError while evaluating 'bool(range(1.0,10.0))': Range can only be cast to int or float" ), + json_str=CastResults.error( + "ValueError while evaluating 'json_str(range(1.0,10.0))': Range can only be cast to int or float" + ), ), id="range(1.0,10.0)", ), ], ) def test_cast_conversions(value: Any, expected_value: Any) -> None: - for field in ("int", "float", "bool", "str"): + for field in ("int", "float", "bool", "str", "json_str"): cast_str = f"{field}({value})" expected = getattr(expected_value, field) if isinstance(expected, RaisesContext): diff --git a/website/docs/advanced/override_grammar/extended.md b/website/docs/advanced/override_grammar/extended.md index d0d266c97da..7f740f7e37d 100644 --- a/website/docs/advanced/override_grammar/extended.md +++ b/website/docs/advanced/override_grammar/extended.md @@ -187,7 +187,8 @@ shuffle([a,b,c]), shuffle(list=[a,b,c]) # shuffled list [a,b,c] ``` ## Type casting -You can cast values and sweeps to `int`, `float`, `bool` or `str`. +You can cast values and sweeps to `int`, `float`, `bool`, `str` or `json_str`. Note that unlike the +others, `json_str` will affect the whole container rather than just the values. ```python title="Example" int(3.14) # 3 (int) int(value=3.14) # 3 (int) @@ -197,6 +198,7 @@ bool(1) # true (bool) float(range(1,10)) # range(1.0,10.0) str([1,2,3]) # ['1','2','3'] str({a:10}) # {a:'10'} +json_str({a:10}) # '{"a":10}' ``` Below are pseudo code snippets that illustrates the differences between Python's casting and Hydra's casting. @@ -346,38 +348,37 @@ Input are grouped by type. [//]: # (Conversion matrix source: https://docs.google.com/document/d/1JDZGHKk4PrZHqsTTS6ao-DQOu2eVD4ULR6uAxVUR-WI/edit#) -| | int() | float() | str() | bool() | -|-------------------- |------------- |------------------- |------------------- |----------------------- | -| 10 | 10 | 10.0 | “10” | true | -| 0 | 0 | 0.0 | “0” | false | -| 10.0 | 10 | 10.0 | “10.0” | true | -| 0.0 | 0 | 0.0 | “0.0” | false | -| inf | error | inf | ‘inf’ | true | -| nan | error | nan | ‘nan’ | true | -| 1e6 | 1,000,000 | 1e6 | ‘1000000.0’ | true | -| foo | error | error | foo | error | -| “” (empty string) | error | error | “” | error | -| “10” | 10 | 10.0 | “10” | error | -| “10.0” | error | 10.0 | “10.0” | error | -| “true” | error | error | “true” | true | -| “false” | error | error | “false” | false | -| “[1,2,3]” | error | error | “[1,2,3]” | error | -| “{a:10}” | error | error | “{a:10}” | error | -| true | 1 | 1.0 | “true” | true | -| false | 0 | 0.0 | “false” | false | -| [] | [] | [] | [] | [] | -| [0,1,2] | [0,1,2] | [0.0,1.0,2.0] | [“0”,”1”,”2”] | [false,true,true] | -| [1,[2]] | [1,[2]] | [1.0,[2.0]] | [“1”,[“2”]] | [true,[true]] | -| [a,1] | error | error | [“a”,”1”] | error | -| {} | {} | {} | {} | {} | -| {a:10} | {a:10} | {a:10.0} | {a:”10”} | {a: true} | -| {a:[0,1,2]} | {a:[0,1,2]} | {a:[0.0,1.0,2.-]} | {a:[“0”,”1”,”2”]} | {a:[false,true,true]} | -| {a:10,b:xyz} | error | error | {a:”10”,b:”xyz”} | error | -| choice(0,1) | choice(0,1) | choice(0.0,1.0) | choice(“0”,“1”) | choice(false,true) | -| choice(a,b) | error | error | choice(“a”,”b”) | error | -| choice(1,a) | error | error | choice(“1”,”a”) | error | -| interval(1.0, 2.0) | interval(1, 2)| interval(1.0, 2.0) | error | error | -| interval(1, 2) | interval(1, 2)| interval(1.0, 2.0) | error | error | -| range(1,10) | range(1,10) | range(1.0,10.0) | error | error | -| range(1.0, 10.0) | range(1,10) | range(1.0,10.0) | error | error | - +| | int() | float() | str() | bool() | json() | +|-------------------- |------------- |------------------- |------------------- |----------------------- |-------------------- | +| 10 | 10 | 10.0 | “10” | true | “10” | +| 0 | 0 | 0.0 | “0” | false | “0” | +| 10.0 | 10 | 10.0 | “10.0” | true | “10.0” | +| 0.0 | 0 | 0.0 | “0.0” | false | “0.0” | +| inf | error | inf | ‘inf’ | true | “Infinity” | +| nan | error | nan | ‘nan’ | true | “NaN” | +| 1e6 | 1,000,000 | 1e6 | ‘1000000.0’ | true | “1000000.0” | +| foo | error | error | foo | error | '“foo”' | +| “” (empty string) | error | error | “” | error | '“”' | +| “10” | 10 | 10.0 | “10” | error | '“10”' | +| “10.0” | error | 10.0 | “10.0” | error | '“10.0”' | +| “true” | error | error | “true” | true | '“true”' | +| “false” | error | error | “false” | false | '“false”' | +| “[1,2,3]” | error | error | “[1,2,3]” | error | '“[1,2,3]”' | +| “{a:10}” | error | error | “{a:10}” | error | '“{a:10}”' | +| true | 1 | 1.0 | “true” | true | “true” | +| false | 0 | 0.0 | “false” | false | “false” | +| [] | [] | [] | [] | [] | “[]” | +| [0,1,2] | [0,1,2] | [0.0,1.0,2.0] | [“0”,”1”,”2”] | [false,true,true] | “[1, 2, 3]” | +| [1,[2]] | [1,[2]] | [1.0,[2.0]] | [“1”,[“2”]] | [true,[true]] | “[1, [2]]” | +| [a,1] | error | error | [“a”,”1”] | error | '[“a”, 1]' | +| {} | {} | {} | {} | {} | “{}” | +| {a:10} | {a:10} | {a:10.0} | {a:”10”} | {a: true} | '{“a”: 10}' | +| {a:[0,1,2]} | {a:[0,1,2]} | {a:[0.0,1.0,2.-]} | {a:[“0”,”1”,”2”]} | {a:[false,true,true]} | '{“a”: [1, 2, 3]}' | +| {a:10,b:xyz} | error | error | {a:”10”,b:”xyz”} | error | '{“a”: 10, “b”: “xyz”}'| +| choice(0,1) | choice(0,1) | choice(0.0,1.0) | choice(“0”,“1”) | choice(false,true) | choice(“0”, “1”) | +| choice(a,b) | error | error | choice(“a”,”b”) | error | choice('“a”', '“b”') | +| choice(1,a) | error | error | choice(“1”,”a”) | error | choice(“1”, '“a”') | +| interval(1.0, 2.0) | interval(1, 2)| interval(1.0, 2.0) | error | error | interval(“1.0”, “2.0”) | +| interval(1, 2) | interval(1, 2)| interval(1.0, 2.0) | error | error | interval(“1”, “2”) | +| range(1,10) | range(1,10) | range(1.0,10.0) | error | error | error | +| range(1.0, 10.0) | range(1,10) | range(1.0,10.0) | error | error | error |