diff --git a/arbiter/master_chief/sa_recon.py b/arbiter/master_chief/sa_recon.py index e07491e..6a7b96e 100644 --- a/arbiter/master_chief/sa_recon.py +++ b/arbiter/master_chief/sa_recon.py @@ -81,29 +81,32 @@ def _is_ret(self, arglist): def sinks(self): return self.map.keys() - def _check_callees(self, func, target, callee): + def _check_callees(self, func, target, sinks): + """ + Check the `func` in the `target` for any call to any function name in set `sinks` + """ for site in sorted(func.get_call_sites()): name = self._callee_name(func, site) - if callee not in name: - continue - arglist = self.map[callee] - - if self._is_ret(arglist): - logger.debug("Finding ret block for %s @ 0x%x" % (callee, site)) - site = self._find_ret_block(func) - if site is None: - # No ret instruction + for callee in sinks: + if callee not in name: continue + arglist = self.map[callee] + + if self._is_ret(arglist): + logger.debug("Finding ret block for %s @ 0x%x" % (callee, site)) + site = self._find_ret_block(func) + if site is None: + # No ret instruction + continue - target.add_node(site, None, self._cfg, arglist) + target.add_node(site, None, self._cfg, arglist) def _check_sinks(self, func): target = SA1_Target(func) - for x in self.sinks: - self._check_callees(func, target, x) + self._check_callees(func, target, self.sinks) if target.node_count > 0: self._targets.append(target)