diff --git a/ansible/roles/vm_set/library/vm_topology.py b/ansible/roles/vm_set/library/vm_topology.py index e794029cde9..ea21acafe53 100644 --- a/ansible/roles/vm_set/library/vm_topology.py +++ b/ansible/roles/vm_set/library/vm_topology.py @@ -1,12 +1,15 @@ #!/usr/bin/python +import functools import hashlib import json +import multiprocessing import os.path import re import subprocess import shlex import sys +import threading import time import traceback import logging @@ -15,6 +18,7 @@ import six from ansible.module_utils.basic import AnsibleModule +from logging.handlers import MemoryHandler try: from ansible.module_utils.dualtor_utils import generate_mux_cable_facts @@ -25,6 +29,12 @@ from ansible.module_utils.debug_utils import config_module_logging +if sys.version_info.major == 2: + from multiprocessing.pool import ThreadPool +else: + from concurrent.futures import ThreadPoolExecutor as ThreadPool + + DOCUMENTATION = ''' --- module: vm_topology @@ -161,6 +171,9 @@ RT_TABLE_FILEPATH = "/etc/iproute2/rt_tables" +MIN_THREAD_WORKER_COUNT = 8 +LOG_SEPARATOR = "=" * 120 + def construct_log_filename(cmd, vm_set_name): log_filename = 'vm_topology' @@ -213,7 +226,7 @@ def adaptive_temporary_interface(vm_set_name, interface_name, reserved_space=0): class VMTopology(object): - def __init__(self, vm_names, vm_properties, fp_mtu, max_fp_num, topo): + def __init__(self, vm_names, vm_properties, fp_mtu, max_fp_num, topo, worker): self.vm_names = vm_names self.vm_properties = vm_properties self.fp_mtu = fp_mtu @@ -222,6 +235,7 @@ def __init__(self, vm_names, vm_properties, fp_mtu, max_fp_num, topo): self._host_interfaces = None self._disabled_host_interfaces = None self._host_interfaces_active_active = None + self.worker = worker return def init(self, vm_set_name, vm_base, duts_fp_ports, duts_name, ptf_exists=True, check_bridge=True): @@ -846,6 +860,7 @@ def bind_fp_ports(self, disconnect_vm=False): +----------------------+ """ + bind_ovs_ports_args = [] for attr in self.VMs.values(): for idx, vlan in enumerate(attr['vlans']): br_name = adaptive_name( @@ -857,8 +872,11 @@ def bind_fp_ports(self, disconnect_vm=False): INJECTED_INTERFACES_TEMPLATE, self.vm_set_name, ptf_index) if len(self.duts_fp_ports[self.duts_name[dut_index]]) == 0: continue - self.bind_ovs_ports(br_name, self.duts_fp_ports[self.duts_name[dut_index]][str( - vlan_index)], injected_iface, vm_iface, disconnect_vm) + bind_ovs_ports_args.append( + (br_name, self.duts_fp_ports[self.duts_name[dut_index]][str(vlan_index)], + injected_iface, vm_iface, disconnect_vm) + ) + self.worker.map(lambda args: self.bind_ovs_ports(*args), bind_ovs_ports_args) if self.topo and 'DUT' in self.topo and 'vs_chassis' in self.topo['DUT']: # We have a KVM based virtaul chassis, bind the midplane and inband ports @@ -869,13 +887,16 @@ def bind_fp_ports(self, disconnect_vm=False): def unbind_fp_ports(self): logging.info("=== unbind front panel ports ===") + unbind_ovs_ports_args = [] for attr in self.VMs.values(): for vlan_num, vlan in enumerate(attr['vlans']): br_name = adaptive_name( OVS_FP_BRIDGE_TEMPLATE, self.vm_names[self.vm_base_index + attr['vm_offset']], vlan_num) vm_iface = OVS_FP_TAP_TEMPLATE % ( self.vm_names[self.vm_base_index + attr['vm_offset']], vlan_num) - self.unbind_ovs_ports(br_name, vm_iface) + unbind_ovs_ports_args.append((br_name, vm_iface)) + + self.worker.map(lambda args: self.unbind_ovs_ports(*args), unbind_ovs_ports_args) if self.topo and 'DUT' in self.topo and 'vs_chassis' in self.topo['DUT']: # We have a KVM based virtaul chassis, unbind the midplane and inband ports @@ -1152,7 +1173,7 @@ def add_host_ports(self): for non-dual topo, inject the dut port into ptf docker. for dual-tor topo, create ovs port and add to ptf docker. """ - for i, intf in enumerate(self.host_interfaces): + def _add_host_port(i, intf): if self._is_multi_duts and not self._is_cable: if isinstance(intf, list): # For dualtor interface: create veth link and inject one end into the ptf docker @@ -1227,6 +1248,8 @@ def add_host_ports(self): self.add_dut_vlan_subif_to_docker( ptf_if, vlan_separator, vlan_id) + self.worker.map(lambda args: _add_host_port(*args), enumerate(self.host_interfaces)) + def enable_netns_loopback(self): """Enable loopback device in the netns.""" VMTopology.cmd("ip netns exec %s ifconfig lo up" % self.netns) @@ -1294,7 +1317,8 @@ def remove_host_ports(self): remove dut port from the ptf docker """ logging.info("=== Remove host ports ===") - for i, intf in enumerate(self.host_interfaces): + + def _remove_host_port(i, intf): if self._is_multi_duts: if isinstance(intf, list): host_ifindex = intf[0][2] if len(intf[0]) == 3 else i @@ -1318,6 +1342,8 @@ def remove_host_ports(self): self.remove_dut_vlan_subif_from_docker( ptf_if, vlan_separator, vlan_id) + self.worker.map(lambda args: _remove_host_port(*args), enumerate(self.host_interfaces)) + def remove_veth_if_from_docker(self, ext_if, int_if, tmp_name): """ Remove veth interface from docker @@ -1753,6 +1779,149 @@ def check_params(module, params, mode): (param, mode)) +class ThreadBufferHandler(logging.Handler): + """ + ThreadBufferHandler stores log records from each thread separately and can flush + logs from each thread separately. + + Each thread will have its own memory log handler, and each log will be only buffered in + the memory log handler of the thread that emits the log. The flushing is performed by + each memory log handler whenever the memory buffer is full or explicitly triggered by + user. The logs from one thread will be coalesced together and batch-sent to the target + handler. + """ + + THREAD_LOG_HANDLER_CAPACITY = 4096 + + def __init__(self, target, loglevel=logging.NOTSET): + """ + Initialize the ThreadBufferHandler object. + + Args: + target: the target handler, all log records stored temporarily in this handler will be + flushed to the target handler. + loglevel: log level. + """ + super(ThreadBufferHandler, self).__init__(level=loglevel) + self.memory_handlers = {} + self.target = target + + def get_current_thread_log_memory_handler(self): + """Get the current thread log memory handler.""" + thread_id = threading.current_thread().ident + if thread_id in self.memory_handlers: + return self.memory_handlers[thread_id] + else: + memory_handler = MemoryHandler(ThreadBufferHandler.THREAD_LOG_HANDLER_CAPACITY, + target=self.target) + self.memory_handlers[thread_id] = memory_handler + return memory_handler + + def flush_current_thread_logs(self): + """Flush the log records stored in the current thread log memory handler.""" + self.get_current_thread_log_memory_handler().flush() + + def emit(self, record): + """ + Emit a record. + + Dispatch the log record to the current thread log memory handler. + """ + self.get_current_thread_log_memory_handler().emit(record) + + def flush(self): + """Flush all log records to the target handler.""" + for handler in self.memory_handlers.values(): + handler.flush() + self.target.flush() + + def close(self): + """Close all log memory handlers.""" + for handler in self.memory_handlers.values(): + handler.close() + self.memory_handlers.clear() + self.target.close() + super(ThreadBufferHandler, self).close() + + +class VMTopologyWorker(object): + """VM Topology worker class.""" + + def __init__(self, use_thread_worker, thread_worker_count): + """ + Initialize the VMTopologyWorker object. + + Args: + use_thread_worker: use thread pool or not. + thread_worker_count: the thread worker count if use thread pool is enabled. + """ + logging.info("Init VM topology worker: use thread worker %s, thread worker count %s", + use_thread_worker, thread_worker_count) + self.thread_pool = None + self._map_helper = map + self._shutdown_helper = None + self.use_thread_worker = use_thread_worker + self.thread_worker_count = thread_worker_count + self.thread_buffer_handler = None + if use_thread_worker: + self.thread_pool = ThreadPool(thread_worker_count) + self._map_helper = self.thread_pool.map + if hasattr(self.thread_pool, "shutdown"): + self._shutdown_helper = \ + lambda: self.thread_pool.shutdown(wait=True, cancel_futures=True) + else: + self._shutdown_helper = \ + lambda: self.thread_pool.terminate() + + self._setup_thread_buffered_handler() + + def _setup_thread_buffered_handler(self): + """Setup the per-thread log batch handler with ThreadBufferHandler.""" + handlers = logging.getLogger().handlers + if not handlers: + raise ValueError("No logging handler is available in the default logging.") + handler = handlers[-1] + self.thread_buffer_handler = ThreadBufferHandler(target=handler) + + def map(self, func, iterable): + """Apply the function to every item of the iterable.""" + def _buffer_logs_helper(func, *args, **kwargs): + if self.use_thread_worker: + logging.debug(LOG_SEPARATOR) + logging.debug("Start task %s, arguments (%s, %s), worker %s", + func, args, kwargs, threading.current_thread().ident) + try: + func(*args, **kwargs) + finally: + if self.use_thread_worker: + logging.debug("Finish task %s, arguments (%s, %s), worker %s", + func, args, kwargs, threading.current_thread().ident) + logging.debug(LOG_SEPARATOR) + self.thread_buffer_handler.flush_current_thread_logs() + + # NOTE: replace the original handler with the thread buffer handler, so logs from + # one task will be buffered and flushed together. + if self.use_thread_worker: + handlers = logging.getLogger().handlers + handlers.remove(self.thread_buffer_handler.target) + handlers.append(self.thread_buffer_handler) + try: + return list(self._map_helper(functools.partial(_buffer_logs_helper, func), iterable)) + finally: + if self.use_thread_worker: + handlers.remove(self.thread_buffer_handler) + handlers.append(self.thread_buffer_handler.target) + + def shutdown(self): + """Stop the worker threads immediately without completing outstanding work.""" + if self.use_thread_worker: + self._shutdown_helper() + self.thread_buffer_handler.flush() + + def __del__(self): + self.shutdown() + + def main(): module = AnsibleModule( argument_spec=dict( @@ -1778,7 +1947,11 @@ def main(): fp_mtu=dict(required=False, type='int', default=DEFAULT_MTU), max_fp_num=dict(required=False, type='int', default=NUM_FP_VLANS_PER_FP), - netns_mgmt_ip_addr=dict(required=False, type='str', default=None) + netns_mgmt_ip_addr=dict(required=False, type='str', default=None), + use_thread_worker=dict(required=False, type='bool', default=True), + thread_worker_count=dict(required=False, type='int', + default=max(MIN_THREAD_WORKER_COUNT, + multiprocessing.cpu_count() // 8)) ), supports_check_mode=False) @@ -1788,6 +1961,8 @@ def main(): fp_mtu = module.params['fp_mtu'] max_fp_num = module.params['max_fp_num'] vm_properties = module.params['vm_properties'] + use_thread_worker = module.params['use_thread_worker'] + thread_worker_count = module.params['thread_worker_count'] config_module_logging(construct_log_filename(cmd, vm_set_name)) @@ -1797,7 +1972,8 @@ def main(): try: topo = module.params['topo'] - net = VMTopology(vm_names, vm_properties, fp_mtu, max_fp_num, topo) + worker = VMTopologyWorker(use_thread_worker, thread_worker_count) + net = VMTopology(vm_names, vm_properties, fp_mtu, max_fp_num, topo, worker) if cmd == 'create': net.create_bridges()