From 9d9bcc25277f9d7de54295f85a91294df4fc982c Mon Sep 17 00:00:00 2001 From: Nicolas Dessart Date: Fri, 16 Jun 2023 14:10:41 +0200 Subject: [PATCH] [INTEGRATION] prepare 7.7.0-rc3 release --- src/olympe/arsdkng/cmd_itf.py | 24 ++- src/olympe/arsdkng/controller.py | 40 +++-- src/olympe/arsdkng/expectations.py | 8 +- src/olympe/arsdkng/json.py | 2 +- src/olympe/arsdkng/messages.py | 254 ++++++++++++++++++++--------- src/olympe/arsdkng/proto.py | 85 ++++++---- 6 files changed, 280 insertions(+), 133 deletions(-) diff --git a/src/olympe/arsdkng/cmd_itf.py b/src/olympe/arsdkng/cmd_itf.py index 563026d..c27ae56 100644 --- a/src/olympe/arsdkng/cmd_itf.py +++ b/src/olympe/arsdkng/cmd_itf.py @@ -46,12 +46,13 @@ from olympe.messages import connectivity from olympe.messages import common from olympe.messages import developer +from olympe.messages import devicemanager from olympe.messages import drone_manager +from olympe.messages import microhard from olympe.messages import mission from olympe.messages import network from olympe.messages import pointnfly from olympe.messages import privacy -from olympe.messages import sleepmode from olympe.scheduler import AbstractScheduler, Scheduler from collections import OrderedDict from olympe.utils import py_object_cast, callback_decorator, DEFAULT_FLOAT_TOL @@ -461,20 +462,29 @@ def _on_connection_state_changed(self, message_event, _): all_states_settings_commands = [ common.Common.AllStates, common.Settings.AllSettings, + ] + for all_states_settings_command in all_states_settings_commands: + self._send_command_raw(all_states_settings_command, dict()) + + # Enable airsdk mission support from the drone + self._send_command_raw(mission.custom_msg_enable, dict()) + self._send_command_raw(developer.Command.GetState, dict()) + + get_state_commands = [ antiflicker.Command.GetState, camera2.Command.GetState, connectivity.Command.GetState, - developer.Command.GetState, + devicemanager.Command.GetState, + microhard.Command.GetState, network.Command.GetState, pointnfly.Command.GetState, privacy.Command.GetState, - sleepmode.Command.GetState, ] - for all_states_settings_command in all_states_settings_commands: - self._send_command_raw(all_states_settings_command, dict()) - # Enable airsdk mission support from the drone - self._send_command_raw(mission.custom_msg_enable, dict()) + for get_state_command in get_state_commands: + self._send_command_raw( + get_state_command, dict(include_default_capabilities=True) + ) @callback_decorator() def _dispose_cmd_cb(self, _interface, _user_data): diff --git a/src/olympe/arsdkng/controller.py b/src/olympe/arsdkng/controller.py index a88f053..d58b332 100644 --- a/src/olympe/arsdkng/controller.py +++ b/src/olympe/arsdkng/controller.py @@ -55,7 +55,6 @@ from olympe.messages import pointnfly from olympe.messages import privacy from olympe.messages import skyctrl -from olympe.messages import sleepmode from olympe.video.pdraw import (PDRAW_LOCAL_STREAM_PORT, PDRAW_LOCAL_CONTROL_PORT) from tzlocal import get_localzone from warnings import warn @@ -601,23 +600,22 @@ async def _on_connected(self): not self._ip_addr_str.startswith("127.0")): self._synchronize_clock() # We're connected to the device, get all device states and settings if necessary + get_state_commands = [] if not self._is_skyctrl: all_states_settings_commands = [ common.Common.AllStates(), common.Settings.AllSettings() ] if self._device_type != od.ARSDK_DEVICE_TYPE_ANAFI4K: - all_states_settings_commands.extend( - [ - antiflicker.Command.GetState(), - camera2.Command.GetState(), - connectivity.Command.GetState(), - developer.Command.GetState(), - network.Command.GetState(), - pointnfly.Command.GetState(), - privacy.Command.GetState(), - sleepmode.Command.GetState(), - mission.custom_msg_enable()] - ) + get_state_commands = [ + antiflicker.Command.GetState(include_default_capabilities=True), + camera2.Command.GetState(include_default_capabilities=True), + connectivity.Command.GetState(include_default_capabilities=True), + developer.Command.GetState(), + network.Command.GetState(include_default_capabilities=True), + pointnfly.Command.GetState(include_default_capabilities=True), + privacy.Command.GetState(include_default_capabilities=True), + mission.custom_msg_enable() + ] else: all_states_settings_commands = [ skyctrl.Common.AllStates(), @@ -631,7 +629,7 @@ async def _on_connected(self): od.ARSDK_DEVICE_TYPE_SKYCTRL_3, od.ARSDK_DEVICE_TYPE_SKYCTRL_UA, ]: - all_states_settings_commands.append(controllerNetwork.Command.GetState()) + get_state_commands = [controllerNetwork.Command.GetState()] # Get device specific states and settings for states_settings_command in all_states_settings_commands: timeout = self._connection_deadline - time.time() @@ -645,6 +643,20 @@ async def _on_connected(self): if not res: return False + # Get specific optional states + for state_command in get_state_commands: + timeout = self._connection_deadline - time.time() + try: + res = await self._thread_loop.await_for( + timeout, + self._send_states_settings_cmd, state_command + ) + except FutureTimeoutError: + # Protobuf Command.GetState are optional + return True + if not res: + return False + # Process the ConnectedEvent event = ConnectedEvent() self.logger.info(str(event)) diff --git a/src/olympe/arsdkng/expectations.py b/src/olympe/arsdkng/expectations.py index cd22c41..bcb648a 100644 --- a/src/olympe/arsdkng/expectations.py +++ b/src/olympe/arsdkng/expectations.py @@ -504,8 +504,10 @@ def _schedule(self, scheduler): self._command_future = controller._send_command_raw( self.command_message, self.command_args ) + super()._schedule(scheduler) self._command_future.add_done_callback(lambda _: self.check(None)) - super()._schedule(scheduler) + else: + super()._schedule(scheduler) def no_expect(self, value): self._no_expect = value @@ -615,10 +617,12 @@ def _schedule(self, scheduler): self._command_future = controller._send_protobuf_command( self.command_message, self.command_args ) + super()._schedule(scheduler) self._command_future.add_done_callback(lambda _: self.check(None)) if self.expectation is not None: self.expectation._schedule(scheduler) - super()._schedule(scheduler) + else: + super()._schedule(scheduler) def no_expect(self, value): self._no_expect = value diff --git a/src/olympe/arsdkng/json.py b/src/olympe/arsdkng/json.py index 632cbe9..cba20e6 100644 --- a/src/olympe/arsdkng/json.py +++ b/src/olympe/arsdkng/json.py @@ -47,7 +47,7 @@ def default(self, o): f"{o.__class__.__name__}.{o.to_upper_str()}" ) elif issubclass(o.__class__, ArsdkMessageBase): - return f"olympe.messages.{o.feature_name}.{o.name}" + return f"olympe.messages.{o.fullName}" return super().default(o) diff --git a/src/olympe/arsdkng/messages.py b/src/olympe/arsdkng/messages.py index a5d94c6..5fb4d5e 100644 --- a/src/olympe/arsdkng/messages.py +++ b/src/olympe/arsdkng/messages.py @@ -32,6 +32,7 @@ import arsdkparser import ctypes +import math import textwrap from aenum import OrderedEnum @@ -135,7 +136,6 @@ class ArsdkMessageMeta(type): - _base = None def __new__(mcls, *args, **kwds): @@ -495,9 +495,7 @@ def _py_ar_supported(cls, supported_devices, deprecated): ) ) else: - ret.append( - f" :{device_str}: with an up to date firmware" - ) + ret.append(f" :{device_str}: with an up to date firmware") if not ret: return unsupported_notice @@ -608,9 +606,7 @@ def _py_ar_cmd_expectation_argval_docstring(cls, argname, argval): elif isinstance(argval, ArsdkBitfield): return argval.pretty() elif callable(argval): - command_args = OrderedDict( - (arg, f"this.{arg}") for arg in cls.args_name - ) + command_args = OrderedDict((arg, f"this.{arg}") for arg in cls.args_name) try: return argval(cls, command_args) except KeyError: @@ -913,9 +909,7 @@ def _set_last_event(self, event): if self._last_event is None: self._last_event = OrderedDict() key = event.args[self.key_name] - if not event_list_flags or event_list_flags == [ - list_flags.Last - ]: + if not event_list_flags or event_list_flags == [list_flags.Last]: self._state[key] = event.args if list_flags.First in event_list_flags: self._state = OrderedDict() @@ -929,9 +923,7 @@ def _set_last_event(self, event): else: self._last_event[key] = event elif self.callback_type == ArsdkMessageCallbackType.LIST: - if not event_list_flags or event.args["list_flags"] == [ - list_flags.Last - ]: + if not event_list_flags or event.args["list_flags"] == [list_flags.Last]: # append to the current list insert_pos = next(reversed(self._state), -1) + 1 self._state[insert_pos] = event.args @@ -1030,7 +1022,9 @@ def _expect(cls, *args, **kwds): ) for arg_name in args: if arg_name not in cls.args_name: - raise ValueError(f"'{cls.fullName}' message has no such '{arg_name}' parameter") + raise ValueError( + f"'{cls.fullName}' message has no such '{arg_name}' parameter" + ) if policy != ExpectPolicy.check: if not send_command and cls.message_type == ArsdkMessageType.CMD: expectations = ArsdkWhenAllExpectations( @@ -1116,7 +1110,10 @@ def default_args(cls): args[name] = type_() else: args[name] = None - if cls.callback_type in (ArsdkMessageCallbackType.MAP, ArsdkMessageCallbackType.LIST): + if cls.callback_type in ( + ArsdkMessageCallbackType.MAP, + ArsdkMessageCallbackType.LIST, + ): if "list_flags" not in args: args[name] = list_flags._bitfield_type_() return args @@ -1173,7 +1170,7 @@ def _encode_args(cls, args): # python -> ctypes -> struct_arsdk_value argv conversion encode_args_len = len(cls.arsdk_type_args) argv = (od.struct_arsdk_value * encode_args_len)() - for (i, arg, sdktype, value_attr, ctype) in zip( + for i, arg, sdktype, value_attr, ctype in zip( range(encode_args_len), encoded_args, cls.arsdk_type_args, @@ -1190,7 +1187,9 @@ def _decode_args(cls, message_buffer): Decode a ctypes message buffer into a list of python typed arguments. This also perform the necessary enum, bitfield and unicode conversions. """ - od.arsdk_cmd_dec.argtypes = od.arsdk_cmd_dec.argtypes[:2] + cls.decoded_args_type + od.arsdk_cmd_dec.argtypes = ( + od.arsdk_cmd_dec.argtypes[:2] + cls.decoded_args_type + ) res = od.arsdk_cmd_dec(message_buffer, cls.arsdk_desc, *cls.decoded_args) @@ -1201,7 +1200,8 @@ def _decode_args(cls, message_buffer): decoded_args[i] = arg = arg.contents.value else: decoded_args[i] = arg = (ctypes.c_char * arg.contents.len).from_address( - arg.contents.cdata) + arg.contents.cdata + ) # bytes utf-8 -> str conversion if isinstance(arg, bytes): decoded_args[i] = arg = str(arg, "utf-8") @@ -1331,7 +1331,6 @@ def _populate_messages(self): self._add_arsdk_proto_message(feature, message) def _add_arsdk_message(self, msgObj, name_path, id_path): - message = ArsdkMessageMeta.__new__( ArsdkMessageMeta, msgObj, name_path, id_path, self.enums ) @@ -1397,7 +1396,7 @@ def _add_arsdk_proto_message(self, feature, message_desc): path, service, svc_message_desc, - message_desc.doc, + message_desc, ) self._do_add_arsdk_proto_message( target_name_path, message, message_desc @@ -1414,7 +1413,7 @@ def _add_arsdk_proto_message(self, feature, message_desc): name_path, service, svc_message_desc, - message_desc.doc, + message_desc, ) self._do_add_arsdk_proto_message( name_path, message, message_desc @@ -1430,8 +1429,8 @@ def _add_arsdk_proto_message(self, feature, message_desc): self._root, path, None, + None, message_desc, - message_desc.doc, ) self._do_add_arsdk_proto_message(path, message, message_desc) @@ -1556,7 +1555,6 @@ def __len__(cls): class ArsdkProtoMessageMeta(type, ProtoNestedMixin): - _base = None def __new__(mcls, *args, **kwds): @@ -1573,7 +1571,7 @@ def __new__(mcls, *args, **kwds): mcls._base = cls return cls - root, name_path, service, message_desc, doc_desc = args + root, name_path, service, svc_message_desc, message_desc = args olympe_proto = ArsdkProto.get("olympe") olympe_messages = ArsdkMessages.get("olympe") olympe_enums = ArsdkEnums.get("olympe") @@ -1593,21 +1591,32 @@ def __new__(mcls, *args, **kwds): cls.args_map = OrderedDict() cls.dict_type = dict_type cls.service = service + cls.svc_message_desc = svc_message_desc cls.message_desc = message_desc cls.feature_name = name_path[0] cls.name = name_path[-1] - cls.doc = doc_desc - cls.field_name = getattr(message_desc, "field_name", None) - cls.service_proto = getattr(message_desc, "service", None) - cls.message_proto = message_desc.message + cls.doc = message_desc.doc + if svc_message_desc is not None: + cls.field_name = svc_message_desc.field_name + cls.service_proto = svc_message_desc.service + cls.message_proto = svc_message_desc.message + cls.number = svc_message_desc.number + cls.fields = svc_message_desc.fields + service_type = svc_message_desc.service_type + else: + cls.field_name = None + cls.service_proto = None + cls.message_proto = message_desc.message + cls.number = None + cls.fields = message_desc.fields + service_type = None cls.fullName = fullName cls.prefix = name_path[:-1] - cls.number = getattr(message_desc, "number", None) cls._recipient_id = None cls.loglevel = logging.INFO cls.buffer_type = ArsdkMessageBufferType.ACK cls.callback_type = ArsdkMessageCallbackType.STANDARD - if not cls.number or cls.number < 16: + if cls.number is None or cls.number < 16: cls.loglevel = logging.DEBUG cls.buffer_type = ArsdkMessageBufferType.NON_ACK @@ -1623,7 +1632,6 @@ def __new__(mcls, *args, **kwds): cls.args_name = list( name for name in cls.real_args_name if name != "selected_fields" ) - service_type = getattr(message_desc, "service_type", None) if service_type is None: cls.message_type = None cls._expectation = None @@ -1719,8 +1727,10 @@ def recipient_id(cls, id_): cls._recipient_id = id_ def _resolve_expectations(cls, messages, module): + if cls.svc_message_desc is None: + return expectation = cls._parse_expectation( - getattr(cls.message_desc, "on_success", None) or "None", module + cls.svc_message_desc.on_success or "None", module ) if cls.message_type == ArsdkMessageType.CMD: cls._expectation = ArsdkProtoCommandExpectation( @@ -1823,6 +1833,7 @@ def __call__(self, *args, **kwds): def _validate_args(self, args): assert isinstance(args, Mapping) args = self._map_enum_to_int(args) + args = self._map_float_specials(args) args = remove_from_collection(args, callable) for name, field in self.message_proto.DESCRIPTOR.fields_by_name.items(): if name not in args: @@ -1832,7 +1843,9 @@ def _validate_args(self, args): args[name] = self.args_message[name]._validate_args(args[name]) elif field.label == ProtoFieldLabel.Repeated._value_: assert isinstance(args[name], Iterable) - args[name] = list(map(self.args_message[name]._validate_args, args[name])) + args[name] = list( + map(self.args_message[name]._validate_args, args[name]) + ) elif field.message_type is not None and ( isinstance(args[name], Mapping) and field.message_type.file.package == "google.protobuf" @@ -1854,10 +1867,14 @@ def _map_google_protobuf(self, args, unwrap=False): ) elif field.label == ProtoFieldLabel.Repeated._value_: assert isinstance(args[name], Iterable) - args[name] = list(map( - lambda a: self.args_message[name]._map_google_protobuf( - a, unwrap=unwrap), args[name] - )) + args[name] = list( + map( + lambda a: self.args_message[name]._map_google_protobuf( + a, unwrap=unwrap + ), + args[name], + ) + ) elif field.message_type is not None and ( field.message_type.file.package == "google.protobuf" and field.message_type.fields @@ -1894,7 +1911,8 @@ def _expect_args(self, *args, **kwds): unknown_args = tuple(filter(lambda a: a not in self.args_name, args.keys())) if unknown_args: raise ValueError( - f"Unknown {unknown_args} parameter(s) passed to {self.fullName}") + f"Unknown {unknown_args} parameter(s) passed to {self.fullName}" + ) # filter out None value args = remove_from_collection(args, lambda a: a is None) self._validate_args(args) @@ -1928,9 +1946,10 @@ def _expect(self, *args, **kwds): args_to_validate = [] if self.message_type == ArsdkMessageType.CMD: for expectation in expectations: - if hasattr(expectation, "expected_args") and ( - hasattr(expectation, "expected_message")) and ( - not no_expect + if ( + hasattr(expectation, "expected_args") + and (hasattr(expectation, "expected_message")) + and (not no_expect) ): args_to_validate.append( ( @@ -1939,17 +1958,20 @@ def _expect(self, *args, **kwds): ) ) else: - unknown_args = tuple(filter(lambda a: a not in self.args_name, args.keys())) + unknown_args = tuple( + filter(lambda a: a not in self.args_name, args.keys()) + ) if unknown_args: raise ValueError( - f"Unknown {unknown_args} parameter(s) passed to {self.fullName}") + f"Unknown {unknown_args} parameter(s) passed to {self.fullName}" + ) args = expectations.expected_args args_to_validate.append((args, self)) # Use protobuf_json_format to validate protobuf message format # filter out lambda ArdkProtoThis lambdas before validation for expected_args, message in args_to_validate: - expected_args = self._validate_args(expected_args) + expected_args = message._validate_args(expected_args) if ( policy == ExpectPolicy.check_wait @@ -1969,9 +1991,7 @@ def _expect(self, *args, **kwds): return expectations def _reverse_expect(self, *args, **kwds): - args, policy, float_tol, no_expect, timeout = self._expect_args( - *args, **kwds - ) + args, policy, float_tol, no_expect, timeout = self._expect_args(*args, **kwds) if self.service is None: # Non-service messages are just equivalent to mapping object @@ -1987,15 +2007,13 @@ def _reverse_expect(self, *args, **kwds): args_to_validate = [] if self.message_type == ArsdkMessageType.EVT: for expectation in expectations: - if hasattr(expectation, "expected_args") and ( - hasattr(expectation, "expected_message")) and ( - not no_expect + if ( + hasattr(expectation, "expected_args") + and (hasattr(expectation, "expected_message")) + and (not no_expect) ): args_to_validate.append( - ( - expectation.expected_args, - expectation.expected_message - ) + (expectation.expected_args, expectation.expected_message) ) else: args = expectations.expected_args @@ -2055,8 +2073,11 @@ def _encode_args(self, args): args = self._map_set_selected_fields(self.message_proto, args) args = self._map_enum_to_str(args) args = self._map_google_protobuf(args) - if self.service_proto.DESCRIPTOR.fields_by_name[self.field_name].message_type is None: - proto = self.service_proto(**{self.field_name: args['value']}) + if ( + self.service_proto.DESCRIPTOR.fields_by_name[self.field_name].message_type + is None + ): + proto = self.service_proto(**{self.field_name: args["value"]}) else: proto = self.service_proto(**{self.field_name: args}) return bytearray(proto.SerializeToString(deterministic=True)) @@ -2082,6 +2103,55 @@ def _decode_payload(self, payload): args = OrderedDict(value=args) return args + def _map_float_specials(self, args): + if args is None or callable(args): + return args + args = args.copy() + + def _conv(value): + if math.isnan(value): + return "NaN" + elif math.isinf(value): + if value > 0.: + return "Infinity" + else: + return "-Infinity" + else: + return value + + for argname, argvalue in args.copy().items(): + if isinstance(argvalue, ArsdkProtoEnum): + args[argname] = int(argvalue._value_) + elif isinstance(argvalue, MutableMapping): + if argname == "selected_fields": + continue + if argname in self.args_message: + args[argname] = self.args_message[argname]._map_float_specials( + argvalue + ) + elif argname in self.args_map: + args[argname] = self.args_map[argname]._map_float_specials(argvalue) + for arg in args.keys(): + if isinstance(args[arg], float): + args[arg] = _conv(args[arg]) + elif isinstance(args[arg], Iterable) and ( + all(map(lambda a: isinstance(a, float), args[arg])) + ): + args[arg] = tuple(_conv(a) for a in args[arg]) + for nested_message_name, nested_message in self.args_message.items(): + if nested_message_name not in args: + continue + if isinstance(args[nested_message_name], (tuple, list)): + args[nested_message_name] = type(args[nested_message_name])( + nested_message._map_float_specials(a) + for a in args[nested_message_name] + ) + else: + args[nested_message_name] = nested_message._map_float_specials( + args[nested_message_name] + ) + return args + def _map_enum_type(self, args): if args is None or callable(args): return args @@ -2137,7 +2207,9 @@ def _map_enum_to_int(self, args): if argname == "selected_fields": continue if argname in self.args_message: - args[argname] = self.args_message[argname]._map_enum_to_int(argvalue) + args[argname] = self.args_message[argname]._map_enum_to_int( + argvalue + ) elif argname in self.args_map: args[argname] = self.args_map[argname]._map_enum_to_int(argvalue) for arg, enum in self.args_enum.items(): @@ -2157,7 +2229,8 @@ def _map_enum_to_int(self, args): if isinstance(v, int): continue args[nested_map_name][k] = int( - nested_map.args_enum["value"][v]._value_) + nested_map.args_enum["value"][v]._value_ + ) for nested_message_name, nested_message in self.args_message.items(): if nested_message_name not in args: continue @@ -2183,7 +2256,9 @@ def _map_enum_to_str(self, args): if argname == "selected_fields": continue if argname in self.args_message: - args[argname] = self.args_message[argname]._map_enum_to_str(argvalue) + args[argname] = self.args_message[argname]._map_enum_to_str( + argvalue + ) elif argname in self.args_map: args[argname] = self.args_map[argname]._map_enum_to_str(argvalue) for arg, enum in self.args_enum.items(): @@ -2332,6 +2407,30 @@ def last_events(self, key=None): if self._last_event[key] is not None: yield self._last_event[key] + def _update_oneofs_state(self, state, args): + for field in self.fields: + if field.name not in args: + continue + # If a oneof field is present in the received event: + # drop its associated oneofs fields state + for exclusive_with in field.exclusive_with: + state.pop(exclusive_with, None) + # Recursively update oneofs + for nested_message_name, nested_message in self.args_message.items(): + if nested_message_name not in args: + continue + if nested_message_name not in state: + continue + if isinstance(state[nested_message_name], (tuple, list)): + for s, a in zip( + state[nested_message_name], args[nested_message_name] + ): + nested_message._update_oneofs_state(s, a) + else: + nested_message._update_oneofs_state( + state[nested_message_name], args[nested_message_name] + ) + def _set_last_event(self, event): if event.id != self.id: raise RuntimeError( @@ -2340,6 +2439,9 @@ def _set_last_event(self, event): ) ) self._last_event = event + # Honor oneofs fields when before updating our cached states + self._update_oneofs_state(self._state, event.args) + # Recursively update our cached state update_mapping(self._state, event.args) @classmethod @@ -2372,13 +2474,17 @@ def _argsmap_from_args(cls, *args, **kwds): continue if isinstance(args[nested_message_name], ArsdkProtoThis): continue - elif isinstance(args[nested_message_name] , (tuple, list)): + elif isinstance(args[nested_message_name], (tuple, list)): args[nested_message_name] = type(args[nested_message_name])( - map(lambda a: nested_message._argsmap_from_args(a), args[nested_message_name]) + map( + lambda a: nested_message._argsmap_from_args(a), + args[nested_message_name], + ) ) else: args[nested_message_name] = nested_message._argsmap_from_args( - **args[nested_message_name]) + **args[nested_message_name] + ) return args @@ -2438,9 +2544,7 @@ def _supported_doc(cls): ) ) else: - ret.append( - f" :{device_str}: with an up to date firmware" - ) + ret.append(f" :{device_str}: with an up to date firmware") if not ret: return "\n**Unsupported message**\n" docstring = "\nSupported by:\n" @@ -2452,20 +2556,20 @@ def _supported_doc(cls): def _resolve_doc(cls, messages, module): if cls.doc is not None: cls.docstring = cls.doc.doc + "\n" - for field_doc in cls.doc.fields_doc: - if field_doc.name == "selected_fields": + for field in cls.message_desc.fields: + if field.name == "selected_fields": continue exclusive_with = "" - if field_doc.exclusive_with: - exclusive_with = ', '.join(field_doc.exclusive_with) + if field.exclusive_with: + exclusive_with = ", ".join(field.exclusive_with) exclusive_with = f" (mutually exclusive with: {exclusive_with})" - cls.docstring += f"\n:param {field_doc.name}: {field_doc.doc} {exclusive_with}\n" - if field_doc.label is ProtoFieldLabel.Repeated: - cls.docstring += ( - f"\n:type {field_doc.name}: list({field_doc.type})\n" - ) + cls.docstring += ( + f"\n:param {field.name}: {field.doc} {exclusive_with}\n" + ) + if field.label is ProtoFieldLabel.Repeated: + cls.docstring += f"\n:type {field.name}: list({field.type})\n" else: - cls.docstring += f"\n:type {field_doc.name}: {field_doc.type}\n" + cls.docstring += f"\n:type {field.name}: {field.type}\n" cls.docstring += "\n" cls.docstring += cls._supported_doc() diff --git a/src/olympe/arsdkng/proto.py b/src/olympe/arsdkng/proto.py index b9f865a..3456fdb 100644 --- a/src/olympe/arsdkng/proto.py +++ b/src/olympe/arsdkng/proto.py @@ -110,9 +110,9 @@ class ArsdkProtoEnum( pass -class ArsdkProtoFieldDoc( +class ArsdkProtoField( namedtuple( - "ArsdkProtoFieldDoc", + "ArsdkProtoField", ["name", "type", "label", "exclusive_with", "doc"], ) ): @@ -122,7 +122,7 @@ class ArsdkProtoFieldDoc( class ArsdkProtoMessageDoc( namedtuple( "ArsdkProtoMessageDoc", - ["doc", "fields_doc", "support"], + ["doc", "support"], ) ): pass @@ -131,7 +131,7 @@ class ArsdkProtoMessageDoc( class ArsdkProtoMessage( namedtuple( "ArsdkProtoMessage", - ["name", "path", "message", "feature_name", "doc"], + ["name", "path", "message", "feature_name", "doc", "fields"], ) ): pass @@ -149,6 +149,7 @@ class ArsdkProtoServiceMessage( "service_type", "service", "message", + "fields", "on_success", "on_failure", ], @@ -176,7 +177,6 @@ class ArsdkProtoFeature( class ArsdkProto: - _store = {} @classmethod @@ -269,7 +269,7 @@ def __init__(self, root, parent=None): self.field_doc_ext = self.parent.field_doc_ext self.support_ext = self.parent.support_ext - def create_service(self, feature_name, service_descriptor): + def create_service(self, feature_name, module_descriptor, service_descriptor): service_type = service_descriptor.name service_name = service_descriptor.full_name service_id = self.service_id(service_name) @@ -284,10 +284,11 @@ def create_service(self, feature_name, service_descriptor): field_type, field_enums, message, + fields, success, failure, ) in self.list_oneof_messages( - feature_name, service_descriptor.oneofs_by_name["id"] + feature_name, module_descriptor, service_descriptor.oneofs_by_name["id"] ): messages.append( ArsdkProtoServiceMessage( @@ -299,6 +300,7 @@ def create_service(self, feature_name, service_descriptor): service_type, service, message, + fields, success, failure, ) @@ -331,7 +333,7 @@ def message_type_from_field(self, feature_name, field_descriptor): ) return message_type, enum_types - def list_oneof_messages(self, feature_name, oneof_descriptor): + def list_oneof_messages(self, feature_name, module_descriptor, oneof_descriptor): for field in oneof_descriptor.fields: if field.message_type is None: message_type, enum_types = self.message_type_from_field( @@ -352,6 +354,7 @@ def list_oneof_messages(self, feature_name, oneof_descriptor): if self.on_success_ext is not None: success_exp = message_type.GetOptions().Extensions[self.on_success_ext] failure_exp = message_type.GetOptions().Extensions[self.on_failure_ext] + fields = self.message_fields(feature_name, module_descriptor, prototype) yield ( field.number, message_name, @@ -360,6 +363,7 @@ def list_oneof_messages(self, feature_name, oneof_descriptor): message_type, enum_types, prototype, + fields, success_exp, failure_exp, ) @@ -397,6 +401,33 @@ def _get_field_type(self, module_descriptor, feature_name, field_descriptor): FieldDescriptor.TYPE_SINT64: "i64", }[field_descriptor.type] + def message_fields(self, feature_name, module_descriptor, message): + fields = [] + for field in message.DESCRIPTOR.fields: + type_ = self._get_field_type(module_descriptor, feature_name, field) + label = None + if field.label: + label = ProtoFieldLabel(field.label) + exclusive_with = [] + if field.containing_oneof is not None: + exclusive_with = list( + filter( + lambda n: n != field.name, + map(lambda f: f.name, field.containing_oneof.fields), + ) + ) + + fields.append( + ArsdkProtoField( + field.name, + type_, + label, + exclusive_with, + field.GetOptions().Extensions[self.field_doc_ext], + ) + ) + return fields + def feature_messages( self, root, filename, feature_name, module_descriptor, services ): @@ -418,32 +449,11 @@ def feature_messages( message_doc = message.DESCRIPTOR.GetOptions().Extensions[ self.message_doc_ext ] - field_docs = [] - for field in message.DESCRIPTOR.fields: - type_ = self._get_field_type(module_descriptor, feature_name, field) - label = None - if field.label: - label = ProtoFieldLabel(field.label) - exclusive_with = [] - if field.containing_oneof is not None: - exclusive_with = list(filter( - lambda n: n != field.name, - map(lambda f: f.name, field.containing_oneof.fields) - )) - - field_docs.append( - ArsdkProtoFieldDoc( - field.name, - type_, - label, - exclusive_with, - field.GetOptions().Extensions[self.field_doc_ext], - ) - ) - doc = ArsdkProtoMessageDoc(message_doc, field_docs, support) + doc = ArsdkProtoMessageDoc(message_doc, support) + fields = self.message_fields(feature_name, module_descriptor, message) ret.append( ArsdkProtoMessage( - message.DESCRIPTOR.name, path, message, feature_name, doc + message.DESCRIPTOR.name, path, message, feature_name, doc, fields ) ) for service in services: @@ -462,6 +472,7 @@ def feature_messages( svc_message_desc.field_type, feature_name, None, + [], ) ) return ret @@ -671,10 +682,16 @@ def parse_proto( services = [] if hasattr(module, "Command"): services.append( - self.create_service(feature_name, module.Command.DESCRIPTOR) + self.create_service( + feature_name, module.DESCRIPTOR, module.Command.DESCRIPTOR + ) ) if hasattr(module, "Event"): - services.append(self.create_service(feature_name, module.Event.DESCRIPTOR)) + services.append( + self.create_service( + feature_name, module.DESCRIPTOR, module.Event.DESCRIPTOR + ) + ) feature = ArsdkProtoFeature( feature_name, module,