From a67fc839496495a1d787ef63d0dfb72c0f1c2fcb Mon Sep 17 00:00:00 2001 From: Skylot Date: Fri, 1 Jul 2022 16:41:58 +0100 Subject: [PATCH] fix: better dominators algorithms --- .../java/jadx/core/dex/nodes/MethodNode.java | 4 + .../blocks/BlockExceptionHandler.java | 2 +- .../dex/visitors/blocks/BlockProcessor.java | 148 +-------------- .../dex/visitors/blocks/DominatorTree.java | 175 ++++++++++++++++++ .../main/java/jadx/core/utils/DebugUtils.java | 25 +++ .../integration/others/TestMoveInline.java | 1 - 6 files changed, 209 insertions(+), 146 deletions(-) create mode 100644 jadx-core/src/main/java/jadx/core/dex/visitors/blocks/DominatorTree.java diff --git a/jadx-core/src/main/java/jadx/core/dex/nodes/MethodNode.java b/jadx-core/src/main/java/jadx/core/dex/nodes/MethodNode.java index 899d64521..7d8c87810 100644 --- a/jadx-core/src/main/java/jadx/core/dex/nodes/MethodNode.java +++ b/jadx-core/src/main/java/jadx/core/dex/nodes/MethodNode.java @@ -316,6 +316,10 @@ public class MethodNode extends NotificationAttrNode implements IMethodDetails, return blocks; } + public void setBasicBlocks(List blocks) { + this.blocks = blocks; + } + public BlockNode getEnterBlock() { return enterBlock; } diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/blocks/BlockExceptionHandler.java b/jadx-core/src/main/java/jadx/core/dex/visitors/blocks/BlockExceptionHandler.java index ccab8b091..bff1cff5e 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/blocks/BlockExceptionHandler.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/blocks/BlockExceptionHandler.java @@ -50,7 +50,7 @@ public class BlockExceptionHandler { return false; } BlockProcessor.updateCleanSuccessors(mth); - BlockProcessor.computeDominanceFrontier(mth); + DominatorTree.computeDominanceFrontier(mth); processCatchAttr(mth); initExcHandlers(mth); diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/blocks/BlockProcessor.java b/jadx-core/src/main/java/jadx/core/dex/visitors/blocks/BlockProcessor.java index f54b78a78..6628ea312 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/blocks/BlockProcessor.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/blocks/BlockProcessor.java @@ -1,9 +1,6 @@ package jadx.core.dex.visitors.blocks; import java.util.ArrayList; -import java.util.BitSet; -import java.util.Collections; -import java.util.Deque; import java.util.LinkedHashSet; import java.util.LinkedList; import java.util.List; @@ -31,7 +28,6 @@ import jadx.core.utils.Utils; import jadx.core.utils.exceptions.JadxRuntimeException; import static jadx.core.dex.visitors.blocks.BlockSplitter.connect; -import static jadx.core.utils.EmptyBitSet.EMPTY; public class BlockProcessor extends AbstractVisitor { private static final Logger LOG = LoggerFactory.getLogger(BlockProcessor.class); @@ -50,29 +46,23 @@ public class BlockProcessor extends AbstractVisitor { computeDominators(mth); if (independentBlockTreeMod(mth)) { checkForUnreachableBlocks(mth); - clearBlocksState(mth); computeDominators(mth); } if (FixMultiEntryLoops.process(mth)) { - clearBlocksState(mth); computeDominators(mth); } updateCleanSuccessors(mth); int i = 0; while (modifyBlocksTree(mth)) { - // revert calculations - clearBlocksState(mth); - // recalculate dominators tree computeDominators(mth); - if (i++ > 100) { throw new JadxRuntimeException("CFG modification limit reached, blocks count: " + mth.getBasicBlocks().size()); } } checkForUnreachableBlocks(mth); - computeDominanceFrontier(mth); + DominatorTree.computeDominanceFrontier(mth); registerLoops(mth); processNestedLoops(mth); @@ -209,139 +199,9 @@ public class BlockProcessor extends AbstractVisitor { } private static void computeDominators(MethodNode mth) { - List basicBlocks = mth.getBasicBlocks(); - int nBlocks = basicBlocks.size(); - for (int i = 0; i < nBlocks; i++) { - BlockNode block = basicBlocks.get(i); - block.setId(i); - block.setDoms(new BitSet(nBlocks)); - block.getDoms().set(0, nBlocks); - } - - BlockNode entryBlock = mth.getEnterBlock(); - calcDominators(basicBlocks, entryBlock); + clearBlocksState(mth); + DominatorTree.compute(mth); markLoops(mth); - - // clear self dominance - basicBlocks.forEach(block -> { - block.getDoms().clear(block.getId()); - if (block.getDoms().isEmpty()) { - block.setDoms(EMPTY); - } - }); - - calcImmediateDominators(mth, basicBlocks, entryBlock); - } - - private static void calcDominators(List basicBlocks, BlockNode entryBlock) { - entryBlock.getDoms().clear(); - entryBlock.getDoms().set(entryBlock.getId()); - - BitSet domSet = new BitSet(basicBlocks.size()); - boolean changed; - do { - changed = false; - for (BlockNode block : basicBlocks) { - if (block == entryBlock) { - continue; - } - BitSet d = block.getDoms(); - if (!changed) { - domSet.clear(); - domSet.or(d); - } - for (BlockNode pred : block.getPredecessors()) { - d.and(pred.getDoms()); - } - d.set(block.getId()); - if (!changed && !d.equals(domSet)) { - changed = true; - } - } - } while (changed); - } - - private static void calcImmediateDominators(MethodNode mth, List basicBlocks, BlockNode entryBlock) { - for (BlockNode block : basicBlocks) { - if (block == entryBlock) { - continue; - } - BlockNode idom; - List preds = block.getPredecessors(); - if (preds.size() == 1) { - idom = preds.get(0); - } else { - BitSet bs = new BitSet(block.getDoms().length()); - bs.or(block.getDoms()); - for (int i = bs.nextSetBit(0); i >= 0; i = bs.nextSetBit(i + 1)) { - BlockNode dom = basicBlocks.get(i); - bs.andNot(dom.getDoms()); - } - if (bs.cardinality() != 1) { - throw new JadxRuntimeException("Can't find immediate dominator for block " + block - + " in " + bs + " preds:" + preds); - } - idom = basicBlocks.get(bs.nextSetBit(0)); - } - block.setIDom(idom); - idom.addDominatesOn(block); - } - } - - static void computeDominanceFrontier(MethodNode mth) { - mth.getExitBlock().setDomFrontier(EMPTY); - List domSortedBlocks = new ArrayList<>(mth.getBasicBlocks().size()); - Deque stack = new LinkedList<>(); - stack.push(mth.getEnterBlock()); - while (!stack.isEmpty()) { - BlockNode node = stack.pop(); - for (BlockNode dominated : node.getDominatesOn()) { - stack.push(dominated); - } - domSortedBlocks.add(node); - } - Collections.reverse(domSortedBlocks); - for (BlockNode block : domSortedBlocks) { - try { - computeBlockDF(mth, block); - } catch (Exception e) { - throw new JadxRuntimeException("Failed compute block dominance frontier", e); - } - } - } - - private static void computeBlockDF(MethodNode mth, BlockNode block) { - if (block.getDomFrontier() != null) { - return; - } - List blocks = mth.getBasicBlocks(); - BitSet domFrontier = null; - for (BlockNode s : block.getSuccessors()) { - if (s.getIDom() != block) { - if (domFrontier == null) { - domFrontier = new BitSet(blocks.size()); - } - domFrontier.set(s.getId()); - } - } - for (BlockNode c : block.getDominatesOn()) { - BitSet frontier = c.getDomFrontier(); - if (frontier == null) { - throw new JadxRuntimeException("Dominance frontier not calculated for dominated block: " + c + ", from: " + block); - } - for (int p = frontier.nextSetBit(0); p >= 0; p = frontier.nextSetBit(p + 1)) { - if (blocks.get(p).getIDom() != block) { - if (domFrontier == null) { - domFrontier = new BitSet(blocks.size()); - } - domFrontier.set(p); - } - } - } - if (domFrontier == null || domFrontier.isEmpty()) { - domFrontier = EMPTY; - } - block.setDomFrontier(domFrontier); } private static void markLoops(MethodNode mth) { @@ -349,7 +209,7 @@ public class BlockProcessor extends AbstractVisitor { // Every successor that dominates its predecessor is a header of a loop, // block -> successor is a back edge. block.getSuccessors().forEach(successor -> { - if (block.getDoms().get(successor.getId())) { + if (block.getDoms().get(successor.getId()) || block == successor) { successor.add(AFlag.LOOP_START); block.add(AFlag.LOOP_END); diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/blocks/DominatorTree.java b/jadx-core/src/main/java/jadx/core/dex/visitors/blocks/DominatorTree.java new file mode 100644 index 000000000..dd60d5790 --- /dev/null +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/blocks/DominatorTree.java @@ -0,0 +1,175 @@ +package jadx.core.dex.visitors.blocks; + +import java.util.ArrayList; +import java.util.BitSet; +import java.util.List; + +import org.jetbrains.annotations.NotNull; + +import jadx.core.dex.nodes.BlockNode; +import jadx.core.dex.nodes.MethodNode; +import jadx.core.utils.BlockUtils; +import jadx.core.utils.EmptyBitSet; +import jadx.core.utils.exceptions.JadxRuntimeException; + +/** + * Build dominator tree based on the algorithm described in paper: + * Cooper, Keith D.; Harvey, Timothy J; Kennedy, Ken (2001). + * "A Simple, Fast Dominance Algorithm" + * http://www.hipersoft.rice.edu/grads/publications/dom14.pdf + */ +@SuppressWarnings("JavadocLinkAsPlainText") +public class DominatorTree { + + public static void compute(MethodNode mth) { + List sorted = sortBlocks(mth); + BlockNode[] doms = build(sorted); + apply(sorted, doms); + mth.setBasicBlocks(sorted); + } + + private static List sortBlocks(MethodNode mth) { + int blocksCount = mth.getBasicBlocks().size(); + BitSet reachSet = new BitSet(blocksCount); + List sorted = new ArrayList<>(blocksCount); + BlockUtils.dfsVisit(mth, b -> { + sorted.add(b); + reachSet.set(b.getId()); + }); + if (reachSet.cardinality() != blocksCount) { + throw new JadxRuntimeException("Found unreachable blocks"); + } + for (int i = 0; i < blocksCount; i++) { + sorted.get(i).setId(i); + } + return sorted; + } + + @NotNull + private static BlockNode[] build(List sorted) { + int blocksCount = sorted.size(); + BlockNode[] doms = new BlockNode[blocksCount]; + doms[0] = sorted.get(0); + boolean changed = true; + while (changed) { + changed = false; + for (int blockId = 1; blockId < blocksCount; blockId++) { + BlockNode b = sorted.get(blockId); + List preds = b.getPredecessors(); + int pickedPred = -1; + BlockNode newIDom = null; + for (BlockNode pred : preds) { + int id = pred.getId(); + if (doms[id] != null) { + newIDom = pred; + pickedPred = id; + break; + } + } + if (newIDom == null) { + throw new JadxRuntimeException("No predecessors for block: " + b); + } + for (BlockNode predBlock : preds) { + int predId = predBlock.getId(); + if (predId == pickedPred) { + continue; + } + if (doms[predId] != null) { + newIDom = intersect(sorted, doms, predBlock, newIDom); + } + } + if (doms[blockId] != newIDom) { + doms[blockId] = newIDom; + changed = true; + } + } + } + return doms; + } + + private static BlockNode intersect(List sorted, BlockNode[] doms, BlockNode b1, BlockNode b2) { + int f1 = b1.getId(); + int f2 = b2.getId(); + while (f1 != f2) { + while (f1 > f2) { + f1 = doms[f1].getId(); + } + while (f2 > f1) { + f2 = doms[f2].getId(); + } + } + return sorted.get(f1); + } + + private static void apply(List sorted, BlockNode[] doms) { + BlockNode enterBlock = sorted.get(0); + enterBlock.setDoms(EmptyBitSet.EMPTY); + enterBlock.setIDom(null); + int blocksCount = sorted.size(); + for (int i = 1; i < blocksCount; i++) { + BlockNode block = sorted.get(i); + BlockNode idom = doms[i]; + block.setIDom(idom); + idom.addDominatesOn(block); + BitSet domBS = collectDoms(doms, idom); + domBS.clear(i); + block.setDoms(domBS); + } + } + + private static BitSet collectDoms(BlockNode[] doms, BlockNode idom) { + BitSet domBS = new BitSet(doms.length); + BlockNode nextIDom = idom; + while (true) { + int id = nextIDom.getId(); + if (domBS.get(id)) { + break; + } + domBS.set(id); + BitSet curDoms = nextIDom.getDoms(); + if (curDoms != null) { + // use already collected set + domBS.or(curDoms); + break; + } + nextIDom = doms[id]; + } + return domBS; + } + + public static void computeDominanceFrontier(MethodNode mth) { + List blocks = mth.getBasicBlocks(); + for (BlockNode block : blocks) { + block.setDomFrontier(null); + } + int blocksCount = blocks.size(); + for (BlockNode block : blocks) { + List preds = block.getPredecessors(); + if (preds.size() >= 2) { + BlockNode idom = block.getIDom(); + for (BlockNode pred : preds) { + BlockNode runner = pred; + while (runner != idom) { + addToDF(runner, block, blocksCount); + runner = runner.getIDom(); + } + } + } + } + for (BlockNode block : blocks) { + BitSet df = block.getDomFrontier(); + if (df == null || df.isEmpty()) { + block.setDomFrontier(EmptyBitSet.EMPTY); + } + } + } + + private static void addToDF(BlockNode block, BlockNode dfBlock, int blocksCount) { + BitSet df = block.getDomFrontier(); + if (df == null) { + df = new BitSet(blocksCount); + block.setDomFrontier(df); + } + df.set(dfBlock.getId()); + } +} diff --git a/jadx-core/src/main/java/jadx/core/utils/DebugUtils.java b/jadx-core/src/main/java/jadx/core/utils/DebugUtils.java index de6dee4d4..aa251b8f8 100644 --- a/jadx-core/src/main/java/jadx/core/utils/DebugUtils.java +++ b/jadx-core/src/main/java/jadx/core/utils/DebugUtils.java @@ -246,4 +246,29 @@ public class DebugUtils { Set seen = ConcurrentHashMap.newKeySet(); return t -> seen.add(keyExtractor.apply(t)); } + + private static Map execTimes; + + public static void initExecTimes() { + execTimes = new ConcurrentHashMap<>(); + } + + public static void mergeExecTimeFromStart(String tag, long startTimeMillis) { + mergeExecTime(tag, System.currentTimeMillis() - startTimeMillis); + } + + public static void mergeExecTime(String tag, long execTimeMillis) { + execTimes.merge(tag, execTimeMillis, Long::sum); + } + + public static void printExecTimes() { + System.out.println("Exec times:"); + execTimes.forEach((tag, time) -> System.out.println(" " + tag + ": " + time + "ms")); + } + + public static void printExecTimesWithTotal(long totalMillis) { + System.out.println("Exec times: total " + totalMillis + "ms"); + execTimes.forEach((tag, time) -> System.out.println(" " + tag + ": " + time + "ms" + + String.format(" (%.2f%%)", time * 100. / (double) totalMillis))); + } } diff --git a/jadx-core/src/test/java/jadx/tests/integration/others/TestMoveInline.java b/jadx-core/src/test/java/jadx/tests/integration/others/TestMoveInline.java index a679d0d41..b3972d1be 100644 --- a/jadx-core/src/test/java/jadx/tests/integration/others/TestMoveInline.java +++ b/jadx-core/src/test/java/jadx/tests/integration/others/TestMoveInline.java @@ -26,7 +26,6 @@ public class TestMoveInline extends SmaliTest { @Test public void test() { - getArgs().setRawCFGOutput(true); assertThat(getClassNodeFromSmali()) .code() // check operations order